Compare commits

..

6 Commits

Author SHA1 Message Date
Colin P. McCabe
d6ea17073c test 2025-09-30 11:58:20 -07:00
BubbleCal
c123bbf391 Merge branch 'main' of https://github.com/lancedb/lancedb into add-ivfrq 2025-09-30 16:30:58 +08:00
BubbleCal
fb856005a9 update docs
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2025-09-29 18:24:58 +08:00
BubbleCal
5c1c2e2dd6 fmt
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2025-09-29 17:47:59 +08:00
BubbleCal
1beef5f6e3 fix
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2025-09-29 17:08:12 +08:00
BubbleCal
0913632584 feat: support IVF_RQ index type
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2025-09-29 16:53:43 +08:00
80 changed files with 1937 additions and 5303 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.22.2"
current_version = "0.22.2-beta.0"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,45 +0,0 @@
name: Create Failure Issue
description: Creates a GitHub issue if any jobs in the workflow failed
inputs:
job-results:
description: 'JSON string of job results from needs context'
required: true
workflow-name:
description: 'Name of the workflow'
required: true
runs:
using: composite
steps:
- name: Check for failures and create issue
shell: bash
env:
JOB_RESULTS: ${{ inputs.job-results }}
WORKFLOW_NAME: ${{ inputs.workflow-name }}
RUN_URL: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}
GH_TOKEN: ${{ github.token }}
run: |
# Check if any job failed
if echo "$JOB_RESULTS" | jq -e 'to_entries | any(.value.result == "failure")' > /dev/null; then
echo "Detected job failures, creating issue..."
# Extract failed job names
FAILED_JOBS=$(echo "$JOB_RESULTS" | jq -r 'to_entries | map(select(.value.result == "failure")) | map(.key) | join(", ")')
# Create issue with workflow name, failed jobs, and run URL
gh issue create \
--title "$WORKFLOW_NAME Failed ($FAILED_JOBS)" \
--body "The workflow **$WORKFLOW_NAME** failed during execution.
**Failed jobs:** $FAILED_JOBS
**Run URL:** $RUN_URL
Please investigate the failed jobs and address any issues." \
--label "ci"
echo "Issue created successfully"
else
echo "No job failures detected, skipping issue creation"
fi

View File

@@ -38,17 +38,3 @@ jobs:
- name: Publish the package
run: |
cargo publish -p lancedb --all-features --token ${{ steps.auth.outputs.token }}
report-failure:
name: Report Workflow Failure
runs-on: ubuntu-latest
needs: [build]
if: always() && (github.event_name == 'release' || github.event_name == 'workflow_dispatch')
permissions:
contents: read
issues: write
steps:
- uses: actions/checkout@v4
- uses: ./.github/actions/create-failure-issue
with:
job-results: ${{ toJSON(needs) }}
workflow-name: ${{ github.workflow }}

View File

@@ -58,7 +58,7 @@ jobs:
cache: 'npm'
cache-dependency-path: docs/package-lock.json
- name: Install node dependencies
working-directory: nodejs
working-directory: node
run: |
sudo apt update
sudo apt install -y protobuf-compiler libssl-dev

View File

@@ -43,6 +43,7 @@ jobs:
- uses: Swatinem/rust-cache@v2
- uses: actions-rust-lang/setup-rust-toolchain@v1
with:
toolchain: "1.81.0"
cache-workspaces: "./java/core/lancedb-jni"
# Disable full debug symbol generation to speed up CI build and keep memory down
# "1" means line tables only, which is useful for panic tracebacks.
@@ -111,17 +112,3 @@ jobs:
env:
SONATYPE_USER: ${{ secrets.SONATYPE_USER }}
SONATYPE_TOKEN: ${{ secrets.SONATYPE_TOKEN }}
report-failure:
name: Report Workflow Failure
runs-on: ubuntu-latest
needs: [linux-arm64, linux-x86, macos-arm64]
if: always() && (github.event_name == 'release' || github.event_name == 'workflow_dispatch')
permissions:
contents: read
issues: write
steps:
- uses: actions/checkout@v4
- uses: ./.github/actions/create-failure-issue
with:
job-results: ${{ toJSON(needs) }}
workflow-name: ${{ github.workflow }}

View File

@@ -6,7 +6,6 @@ on:
- main
pull_request:
paths:
- Cargo.toml
- nodejs/**
- .github/workflows/nodejs.yml
- docker-compose.yml

View File

@@ -365,17 +365,3 @@ jobs:
ARGS="$ARGS --tag preview"
fi
npm publish $ARGS
report-failure:
name: Report Workflow Failure
runs-on: ubuntu-latest
needs: [build-lancedb, test-lancedb, publish]
if: always() && (github.event_name == 'release' || github.event_name == 'workflow_dispatch')
permissions:
contents: read
issues: write
steps:
- uses: actions/checkout@v4
- uses: ./.github/actions/create-failure-issue
with:
job-results: ${{ toJSON(needs) }}
workflow-name: ${{ github.workflow }}

View File

@@ -173,17 +173,3 @@ jobs:
generate_release_notes: false
name: Python LanceDB v${{ steps.extract_version.outputs.version }}
body: ${{ steps.python_release_notes.outputs.changelog }}
report-failure:
name: Report Workflow Failure
runs-on: ubuntu-latest
needs: [linux, mac, windows]
permissions:
contents: read
issues: write
if: always() && (github.event_name == 'release' || github.event_name == 'workflow_dispatch')
steps:
- uses: actions/checkout@v4
- uses: ./.github/actions/create-failure-issue
with:
job-results: ${{ toJSON(needs) }}
workflow-name: ${{ github.workflow }}

View File

@@ -6,7 +6,6 @@ on:
- main
pull_request:
paths:
- Cargo.toml
- python/**
- .github/workflows/python.yml

View File

@@ -125,9 +125,6 @@ jobs:
- name: Run examples
run: cargo run --example simple --locked
- name: Run remote tests
# Running this requires access to secrets, so skip if this is
# a PR from a fork.
if: github.event_name != 'pull_request' || !github.event.pull_request.head.repo.fork
run: make -C ./lancedb remote-tests
macos:

2618
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -15,34 +15,31 @@ categories = ["database-implementations"]
rust-version = "1.78.0"
[workspace.dependencies]
lance = { "version" = "=0.38.2", default-features = false, "features" = ["dynamodb"], "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-core = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-datagen = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-file = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-io = { "version" = "=0.38.2", default-features = false, "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-index = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-linalg = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-table = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-testing = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-datafusion = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-encoding = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-namespace = "0.0.18"
ahash = "0.8"
lance = { "version" = "=0.37.0", default-features = false, "features" = ["dynamodb"], "tag" = "v0.37.1-beta.1", "git" = "https://github.com/lancedb/lance.git" }
lance-io = { "version" = "=0.37.0", default-features = false, "tag" = "v0.37.1-beta.1", "git" = "https://github.com/lancedb/lance.git" }
lance-index = { "version" = "=0.37.0", "tag" = "v0.37.1-beta.1", "git" = "https://github.com/lancedb/lance.git" }
lance-linalg = { "version" = "=0.37.0", "tag" = "v0.37.1-beta.1", "git" = "https://github.com/lancedb/lance.git" }
lance-table = { "version" = "=0.37.0", "tag" = "v0.37.1-beta.1", "git" = "https://github.com/lancedb/lance.git" }
lance-testing = { "version" = "=0.37.0", "tag" = "v0.37.1-beta.1", "git" = "https://github.com/lancedb/lance.git" }
lance-datafusion = { "version" = "=0.37.0", "tag" = "v0.37.1-beta.1", "git" = "https://github.com/lancedb/lance.git" }
lance-encoding = { "version" = "=0.37.0", "tag" = "v0.37.1-beta.1", "git" = "https://github.com/lancedb/lance.git" }
lance-namespace = "0.0.15"
# Note that this one does not include pyarrow
arrow = { version = "56.2", optional = false }
arrow-array = "56.2"
arrow-data = "56.2"
arrow-ipc = "56.2"
arrow-ord = "56.2"
arrow-schema = "56.2"
arrow-cast = "56.2"
arrow = { version = "55.1", optional = false }
arrow-array = "55.1"
arrow-data = "55.1"
arrow-ipc = "55.1"
arrow-ord = "55.1"
arrow-schema = "55.1"
arrow-arith = "55.1"
arrow-cast = "55.1"
async-trait = "0"
datafusion = { version = "50.1", default-features = false }
datafusion-catalog = "50.1"
datafusion-common = { version = "50.1", default-features = false }
datafusion-execution = "50.1"
datafusion-expr = "50.1"
datafusion-physical-plan = "50.1"
datafusion = { version = "49.0", default-features = false }
datafusion-catalog = "49.0"
datafusion-common = { version = "49.0", default-features = false }
datafusion-execution = "49.0"
datafusion-expr = "49.0"
datafusion-physical-plan = "49.0"
env_logger = "0.11"
half = { "version" = "2.6.0", default-features = false, features = [
"num-traits",
@@ -52,26 +49,18 @@ log = "0.4"
moka = { version = "0.12", features = ["future"] }
object_store = "0.12.0"
pin-project = "1.0.7"
rand = "0.9"
snafu = "0.8"
url = "2"
num-traits = "0.2"
rand = "0.9"
regex = "1.10"
lazy_static = "1"
semver = "1.0.25"
crunchy = "0.2.4"
chrono = "0.4"
# Temporary pins to work around downstream issues
# https://github.com/apache/arrow-rs/commit/2fddf85afcd20110ce783ed5b4cdeb82293da30b
chrono = "=0.4.41"
# https://github.com/RustCrypto/formats/issues/1684
base64ct = "=1.6.0"
# Workaround for: https://github.com/Lokathor/bytemuck/issues/306
bytemuck_derive = ">=1.8.1, <1.9.0"
# This is only needed when we reference preview releases of lance
# Force to use the same lance version as the rest of the project to avoid duplicate dependencies
[patch.crates-io]
lance = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-io = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-index = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-linalg = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-table = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-testing = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-datafusion = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-encoding = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }

View File

@@ -16,47 +16,30 @@ check_command_exists() {
}
if [[ ! -e ./lancedb ]]; then
if [[ -v SOPHON_READ_TOKEN ]]; then
INPUT="lancedb-linux-x64"
gh release \
--repo lancedb/lancedb \
download ci-support-binaries \
--pattern "${INPUT}" \
|| die "failed to fetch cli."
check_command_exists openssl
openssl enc -aes-256-cbc \
-d -pbkdf2 \
-pass "env:SOPHON_READ_TOKEN" \
-in "${INPUT}" \
-out ./lancedb-linux-x64.tar.gz \
|| die "openssl failed"
TARGET="${INPUT}.tar.gz"
else
ARCH="x64"
if [[ $OSTYPE == 'darwin'* ]]; then
UNAME=$(uname -m)
if [[ $UNAME == 'arm64' ]]; then
ARCH='arm64'
fi
OSTYPE="macos"
elif [[ $OSTYPE == 'linux'* ]]; then
if [[ $UNAME == 'aarch64' ]]; then
ARCH='arm64'
fi
OSTYPE="linux"
else
die "unknown OSTYPE: $OSTYPE"
ARCH="x64"
if [[ $OSTYPE == 'darwin'* ]]; then
UNAME=$(uname -m)
if [[ $UNAME == 'arm64' ]]; then
ARCH='arm64'
fi
check_command_exists gh
TARGET="lancedb-${OSTYPE}-${ARCH}.tar.gz"
gh release \
--repo lancedb/sophon \
download lancedb-cli-v0.0.3 \
--pattern "${TARGET}" \
|| die "failed to fetch cli."
OSTYPE="macos"
elif [[ $OSTYPE == 'linux'* ]]; then
if [[ $UNAME == 'aarch64' ]]; then
ARCH='arm64'
fi
OSTYPE="linux"
else
die "unknown OSTYPE: $OSTYPE"
fi
check_command_exists gh
TARGET="lancedb-${OSTYPE}-${ARCH}.tar.gz"
gh release \
--repo lancedb/sophon \
download lancedb-cli-v0.0.3 \
--pattern "${TARGET}" \
|| die "failed to fetch cli."
check_command_exists tar
tar xvf "${TARGET}" || die "tar failed."
[[ -e ./lancedb ]] || die "failed to extract lancedb."

View File

@@ -117,7 +117,7 @@ def update_cargo_toml(line_updater):
lance_line = ""
is_parsing_lance_line = False
for line in lines:
if line.startswith("lance") and not line.startswith("lance-namespace"):
if line.startswith("lance"):
# Check if this is a single-line or multi-line entry
# Single-line entries either:
# 1. End with } (complete inline table)

View File

@@ -84,7 +84,6 @@ plugins:
'examples.md': 'https://lancedb.com/docs/tutorials/'
'concepts/vector_search.md': 'https://lancedb.com/docs/search/vector-search/'
'troubleshooting.md': 'https://lancedb.com/docs/troubleshooting/'
'guides/storage.md': 'https://lancedb.com/docs/storage/integrations'
@@ -403,4 +402,4 @@ extra:
- icon: fontawesome/brands/x-twitter
link: https://twitter.com/lancedb
- icon: fontawesome/brands/linkedin
link: https://www.linkedin.com/company/lancedb
link: https://www.linkedin.com/company/lancedb

View File

@@ -1,220 +0,0 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / PermutationBuilder
# Class: PermutationBuilder
A PermutationBuilder for creating data permutations with splits, shuffling, and filtering.
This class provides a TypeScript wrapper around the native Rust PermutationBuilder,
offering methods to configure data splits, shuffling, and filtering before executing
the permutation to create a new table.
## Methods
### execute()
```ts
execute(): Promise<Table>
```
Execute the permutation and create the destination table.
#### Returns
`Promise`&lt;[`Table`](Table.md)&gt;
A Promise that resolves to the new Table instance
#### Example
```ts
const permutationTable = await builder.execute();
console.log(`Created table: ${permutationTable.name}`);
```
***
### filter()
```ts
filter(filter): PermutationBuilder
```
Configure filtering for the permutation.
#### Parameters
* **filter**: `string`
SQL filter expression
#### Returns
[`PermutationBuilder`](PermutationBuilder.md)
A new PermutationBuilder instance
#### Example
```ts
builder.filter("age > 18 AND status = 'active'");
```
***
### shuffle()
```ts
shuffle(options): PermutationBuilder
```
Configure shuffling for the permutation.
#### Parameters
* **options**: [`ShuffleOptions`](../interfaces/ShuffleOptions.md)
Configuration for shuffling
#### Returns
[`PermutationBuilder`](PermutationBuilder.md)
A new PermutationBuilder instance
#### Example
```ts
// Basic shuffle
builder.shuffle({ seed: 42 });
// Shuffle with clump size
builder.shuffle({ seed: 42, clumpSize: 10 });
```
***
### splitCalculated()
```ts
splitCalculated(calculation): PermutationBuilder
```
Configure calculated splits for the permutation.
#### Parameters
* **calculation**: `string`
SQL expression for calculating splits
#### Returns
[`PermutationBuilder`](PermutationBuilder.md)
A new PermutationBuilder instance
#### Example
```ts
builder.splitCalculated("user_id % 3");
```
***
### splitHash()
```ts
splitHash(options): PermutationBuilder
```
Configure hash-based splits for the permutation.
#### Parameters
* **options**: [`SplitHashOptions`](../interfaces/SplitHashOptions.md)
Configuration for hash-based splitting
#### Returns
[`PermutationBuilder`](PermutationBuilder.md)
A new PermutationBuilder instance
#### Example
```ts
builder.splitHash({
columns: ["user_id"],
splitWeights: [70, 30],
discardWeight: 0
});
```
***
### splitRandom()
```ts
splitRandom(options): PermutationBuilder
```
Configure random splits for the permutation.
#### Parameters
* **options**: [`SplitRandomOptions`](../interfaces/SplitRandomOptions.md)
Configuration for random splitting
#### Returns
[`PermutationBuilder`](PermutationBuilder.md)
A new PermutationBuilder instance
#### Example
```ts
// Split by ratios
builder.splitRandom({ ratios: [0.7, 0.3], seed: 42 });
// Split by counts
builder.splitRandom({ counts: [1000, 500], seed: 42 });
// Split with fixed size
builder.splitRandom({ fixed: 100, seed: 42 });
```
***
### splitSequential()
```ts
splitSequential(options): PermutationBuilder
```
Configure sequential splits for the permutation.
#### Parameters
* **options**: [`SplitSequentialOptions`](../interfaces/SplitSequentialOptions.md)
Configuration for sequential splitting
#### Returns
[`PermutationBuilder`](PermutationBuilder.md)
A new PermutationBuilder instance
#### Example
```ts
// Split by ratios
builder.splitSequential({ ratios: [0.8, 0.2] });
// Split by counts
builder.splitSequential({ counts: [800, 200] });
// Split with fixed size
builder.splitSequential({ fixed: 1000 });
```

View File

@@ -1,37 +0,0 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / permutationBuilder
# Function: permutationBuilder()
```ts
function permutationBuilder(table, destTableName): PermutationBuilder
```
Create a permutation builder for the given table.
## Parameters
* **table**: [`Table`](../classes/Table.md)
The source table to create a permutation from
* **destTableName**: `string`
The name for the destination permutation table
## Returns
[`PermutationBuilder`](../classes/PermutationBuilder.md)
A PermutationBuilder instance
## Example
```ts
const builder = permutationBuilder(sourceTable, "training_data")
.splitRandom({ ratios: [0.8, 0.2], seed: 42 })
.shuffle({ seed: 123 });
const trainingTable = await builder.execute();
```

View File

@@ -28,7 +28,6 @@
- [MultiMatchQuery](classes/MultiMatchQuery.md)
- [NativeJsHeaderProvider](classes/NativeJsHeaderProvider.md)
- [OAuthHeaderProvider](classes/OAuthHeaderProvider.md)
- [PermutationBuilder](classes/PermutationBuilder.md)
- [PhraseQuery](classes/PhraseQuery.md)
- [Query](classes/Query.md)
- [QueryBase](classes/QueryBase.md)
@@ -77,10 +76,6 @@
- [QueryExecutionOptions](interfaces/QueryExecutionOptions.md)
- [RemovalStats](interfaces/RemovalStats.md)
- [RetryConfig](interfaces/RetryConfig.md)
- [ShuffleOptions](interfaces/ShuffleOptions.md)
- [SplitHashOptions](interfaces/SplitHashOptions.md)
- [SplitRandomOptions](interfaces/SplitRandomOptions.md)
- [SplitSequentialOptions](interfaces/SplitSequentialOptions.md)
- [TableNamesOptions](interfaces/TableNamesOptions.md)
- [TableStatistics](interfaces/TableStatistics.md)
- [TimeoutConfig](interfaces/TimeoutConfig.md)
@@ -108,4 +103,3 @@
- [connect](functions/connect.md)
- [makeArrowTable](functions/makeArrowTable.md)
- [packBits](functions/packBits.md)
- [permutationBuilder](functions/permutationBuilder.md)

View File

@@ -1,23 +0,0 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / ShuffleOptions
# Interface: ShuffleOptions
## Properties
### clumpSize?
```ts
optional clumpSize: number;
```
***
### seed?
```ts
optional seed: number;
```

View File

@@ -1,31 +0,0 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / SplitHashOptions
# Interface: SplitHashOptions
## Properties
### columns
```ts
columns: string[];
```
***
### discardWeight?
```ts
optional discardWeight: number;
```
***
### splitWeights
```ts
splitWeights: number[];
```

View File

@@ -1,39 +0,0 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / SplitRandomOptions
# Interface: SplitRandomOptions
## Properties
### counts?
```ts
optional counts: number[];
```
***
### fixed?
```ts
optional fixed: number;
```
***
### ratios?
```ts
optional ratios: number[];
```
***
### seed?
```ts
optional seed: number;
```

View File

@@ -1,31 +0,0 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / SplitSequentialOptions
# Interface: SplitSequentialOptions
## Properties
### counts?
```ts
optional counts: number[];
```
***
### fixed?
```ts
optional fixed: number;
```
***
### ratios?
```ts
optional ratios: number[];
```

View File

@@ -8,7 +8,7 @@
<parent>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.22.2-final.0</version>
<version>0.22.2-beta.0</version>
<relativePath>../pom.xml</relativePath>
</parent>

View File

@@ -8,7 +8,7 @@
<parent>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.22.2-final.0</version>
<version>0.22.2-beta.0</version>
<relativePath>../pom.xml</relativePath>
</parent>

View File

@@ -6,7 +6,7 @@
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.22.2-final.0</version>
<version>0.22.2-beta.0</version>
<packaging>pom</packaging>
<name>${project.artifactId}</name>
<description>LanceDB Java SDK Parent POM</description>

View File

@@ -1,7 +1,7 @@
[package]
name = "lancedb-nodejs"
edition.workspace = true
version = "0.22.2"
version = "0.22.2-beta.0"
license.workspace = true
description.workspace = true
repository.workspace = true

View File

@@ -1,234 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
import * as tmp from "tmp";
import { Table, connect, permutationBuilder } from "../lancedb";
import { makeArrowTable } from "../lancedb/arrow";
describe("PermutationBuilder", () => {
let tmpDir: tmp.DirResult;
let table: Table;
beforeEach(async () => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
const db = await connect(tmpDir.name);
// Create test data
const data = makeArrowTable(
[
{ id: 1, value: 10 },
{ id: 2, value: 20 },
{ id: 3, value: 30 },
{ id: 4, value: 40 },
{ id: 5, value: 50 },
{ id: 6, value: 60 },
{ id: 7, value: 70 },
{ id: 8, value: 80 },
{ id: 9, value: 90 },
{ id: 10, value: 100 },
],
{ vectorColumns: {} },
);
table = await db.createTable("test_table", data);
});
afterEach(() => {
tmpDir.removeCallback();
});
test("should create permutation builder", () => {
const builder = permutationBuilder(table, "permutation_table");
expect(builder).toBeDefined();
});
test("should execute basic permutation", async () => {
const builder = permutationBuilder(table, "permutation_table");
const permutationTable = await builder.execute();
expect(permutationTable).toBeDefined();
expect(permutationTable.name).toBe("permutation_table");
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
});
test("should create permutation with random splits", async () => {
const builder = permutationBuilder(table, "permutation_table").splitRandom({
ratios: [1.0],
seed: 42,
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
});
test("should create permutation with percentage splits", async () => {
const builder = permutationBuilder(table, "permutation_table").splitRandom({
ratios: [0.3, 0.7],
seed: 42,
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Check split distribution
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(10);
});
test("should create permutation with count splits", async () => {
const builder = permutationBuilder(table, "permutation_table").splitRandom({
counts: [3, 7],
seed: 42,
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Check split distribution
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBe(3);
expect(split1Count).toBe(7);
});
test("should create permutation with hash splits", async () => {
const builder = permutationBuilder(table, "permutation_table").splitHash({
columns: ["id"],
splitWeights: [50, 50],
discardWeight: 0,
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Check that splits exist
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(10);
});
test("should create permutation with sequential splits", async () => {
const builder = permutationBuilder(
table,
"permutation_table",
).splitSequential({ ratios: [0.5, 0.5] });
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Check split distribution - sequential should give exactly 5 and 5
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBe(5);
expect(split1Count).toBe(5);
});
test("should create permutation with calculated splits", async () => {
const builder = permutationBuilder(
table,
"permutation_table",
).splitCalculated("id % 2");
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Check split distribution
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(10);
});
test("should create permutation with shuffle", async () => {
const builder = permutationBuilder(table, "permutation_table").shuffle({
seed: 42,
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
});
test("should create permutation with shuffle and clump size", async () => {
const builder = permutationBuilder(table, "permutation_table").shuffle({
seed: 42,
clumpSize: 2,
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
});
test("should create permutation with filter", async () => {
const builder = permutationBuilder(table, "permutation_table").filter(
"value > 50",
);
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(5); // Values 60, 70, 80, 90, 100
});
test("should chain multiple operations", async () => {
const builder = permutationBuilder(table, "permutation_table")
.filter("value <= 80")
.splitRandom({ ratios: [0.5, 0.5], seed: 42 })
.shuffle({ seed: 123 });
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(8); // Values 10, 20, 30, 40, 50, 60, 70, 80
// Check split distribution
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(8);
});
test("should throw error for invalid split arguments", () => {
const builder = permutationBuilder(table, "permutation_table");
// Test no arguments provided
expect(() => builder.splitRandom({})).toThrow(
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
);
// Test multiple arguments provided
expect(() =>
builder.splitRandom({ ratios: [0.5, 0.5], counts: [3, 7], seed: 42 }),
).toThrow("Exactly one of 'ratios', 'counts', or 'fixed' must be provided");
});
test("should throw error when builder is consumed", async () => {
const builder = permutationBuilder(table, "permutation_table");
// Execute once
await builder.execute();
// Should throw error on second execution
await expect(builder.execute()).rejects.toThrow("Builder already consumed");
});
});

View File

@@ -1,184 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
import * as arrow from "../lancedb/arrow";
import { sanitizeField, sanitizeType } from "../lancedb/sanitize";
describe("sanitize", function () {
describe("sanitizeType function", function () {
it("should handle type objects", function () {
const type = new arrow.Int32();
const result = sanitizeType(type);
expect(result.typeId).toBe(arrow.Type.Int);
expect((result as arrow.Int).bitWidth).toBe(32);
expect((result as arrow.Int).isSigned).toBe(true);
const floatType = {
typeId: 3, // Type.Float = 3
precision: 2,
toString: () => "Float",
isFloat: true,
isFixedWidth: true,
};
const floatResult = sanitizeType(floatType);
expect(floatResult).toBeInstanceOf(arrow.DataType);
expect(floatResult.typeId).toBe(arrow.Type.Float);
const floatResult2 = sanitizeType({ ...floatType, typeId: () => 3 });
expect(floatResult2).toBeInstanceOf(arrow.DataType);
expect(floatResult2.typeId).toBe(arrow.Type.Float);
});
const allTypeNameTestCases = [
["null", new arrow.Null()],
["binary", new arrow.Binary()],
["utf8", new arrow.Utf8()],
["bool", new arrow.Bool()],
["int8", new arrow.Int8()],
["int16", new arrow.Int16()],
["int32", new arrow.Int32()],
["int64", new arrow.Int64()],
["uint8", new arrow.Uint8()],
["uint16", new arrow.Uint16()],
["uint32", new arrow.Uint32()],
["uint64", new arrow.Uint64()],
["float16", new arrow.Float16()],
["float32", new arrow.Float32()],
["float64", new arrow.Float64()],
["datemillisecond", new arrow.DateMillisecond()],
["dateday", new arrow.DateDay()],
["timenanosecond", new arrow.TimeNanosecond()],
["timemicrosecond", new arrow.TimeMicrosecond()],
["timemillisecond", new arrow.TimeMillisecond()],
["timesecond", new arrow.TimeSecond()],
["intervaldaytime", new arrow.IntervalDayTime()],
["intervalyearmonth", new arrow.IntervalYearMonth()],
["durationnanosecond", new arrow.DurationNanosecond()],
["durationmicrosecond", new arrow.DurationMicrosecond()],
["durationmillisecond", new arrow.DurationMillisecond()],
["durationsecond", new arrow.DurationSecond()],
] as const;
it.each(allTypeNameTestCases)(
'should map type name "%s" to %s',
function (name, expected) {
const result = sanitizeType(name);
expect(result).toBeInstanceOf(expected.constructor);
},
);
const caseVariationTestCases = [
["NULL", new arrow.Null()],
["Utf8", new arrow.Utf8()],
["FLOAT32", new arrow.Float32()],
["DaTedAy", new arrow.DateDay()],
] as const;
it.each(caseVariationTestCases)(
'should be case insensitive for type name "%s" mapped to %s',
function (name, expected) {
const result = sanitizeType(name);
expect(result).toBeInstanceOf(expected.constructor);
},
);
it("should throw error for unrecognized type name", function () {
expect(() => sanitizeType("invalid_type")).toThrow(
"Unrecognized type name in schema: invalid_type",
);
});
});
describe("sanitizeField function", function () {
it("should handle field with string type name", function () {
const field = sanitizeField({
name: "string_field",
type: "utf8",
nullable: true,
metadata: new Map([["key", "value"]]),
});
expect(field).toBeInstanceOf(arrow.Field);
expect(field.name).toBe("string_field");
expect(field.type).toBeInstanceOf(arrow.Utf8);
expect(field.nullable).toBe(true);
expect(field.metadata?.get("key")).toBe("value");
});
it("should handle field with type object", function () {
const floatType = {
typeId: 3, // Float
precision: 32,
};
const field = sanitizeField({
name: "float_field",
type: floatType,
nullable: false,
});
expect(field).toBeInstanceOf(arrow.Field);
expect(field.name).toBe("float_field");
expect(field.type).toBeInstanceOf(arrow.DataType);
expect(field.type.typeId).toBe(arrow.Type.Float);
expect((field.type as arrow.Float64).precision).toBe(32);
expect(field.nullable).toBe(false);
});
it("should handle field with direct Type instance", function () {
const field = sanitizeField({
name: "bool_field",
type: new arrow.Bool(),
nullable: true,
});
expect(field).toBeInstanceOf(arrow.Field);
expect(field.name).toBe("bool_field");
expect(field.type).toBeInstanceOf(arrow.Bool);
expect(field.nullable).toBe(true);
});
it("should throw error for invalid field object", function () {
expect(() =>
sanitizeField({
type: "int32",
nullable: true,
}),
).toThrow(
"The field passed in is missing a `type`/`name`/`nullable` property",
);
// Invalid type
expect(() =>
sanitizeField({
name: "invalid",
type: { invalid: true },
nullable: true,
}),
).toThrow("Expected a Type to have a typeId property");
// Invalid nullable
expect(() =>
sanitizeField({
name: "invalid_nullable",
type: "int32",
nullable: "not a boolean",
}),
).toThrow("The field passed in had a non-boolean `nullable` property");
});
it("should report error for invalid type name", function () {
expect(() =>
sanitizeField({
name: "invalid_field",
type: "invalid_type",
nullable: true,
}),
).toThrow(
"Unable to sanitize type for field: invalid_field due to error: Error: Unrecognized type name in schema: invalid_type",
);
});
});
});

View File

@@ -10,13 +10,7 @@ import * as arrow16 from "apache-arrow-16";
import * as arrow17 from "apache-arrow-17";
import * as arrow18 from "apache-arrow-18";
import {
Connection,
MatchQuery,
PhraseQuery,
Table,
connect,
} from "../lancedb";
import { MatchQuery, PhraseQuery, Table, connect } from "../lancedb";
import {
Table as ArrowTable,
Field,
@@ -27,8 +21,6 @@ import {
Int64,
List,
Schema,
SchemaLike,
Type,
Uint8,
Utf8,
makeArrowTable,
@@ -219,7 +211,8 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
},
);
it("should be able to omit nullable fields", async () => {
// TODO: https://github.com/lancedb/lancedb/issues/1832
it.skip("should be able to omit nullable fields", async () => {
const db = await connect(tmpDir.name);
const schema = new arrow.Schema([
new arrow.Field(
@@ -243,36 +236,23 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
await table.add([data3]);
let res = await table.query().limit(10).toArray();
const resVector = res.map((r) =>
r.vector ? Array.from(r.vector) : null,
);
const resVector = res.map((r) => r.get("vector").toArray());
expect(resVector).toEqual([null, data2.vector, data3.vector]);
const resItem = res.map((r) => r.item);
const resItem = res.map((r) => r.get("item").toArray());
expect(resItem).toEqual(["foo", null, "bar"]);
const resPrice = res.map((r) => r.price);
const resPrice = res.map((r) => r.get("price").toArray());
expect(resPrice).toEqual([10.0, 2.0, 3.0]);
const data4 = { item: "foo" };
// We can't omit a column if it's not nullable
await expect(table.add([data4])).rejects.toThrow(
"Append with different schema",
);
await expect(table.add([data4])).rejects.toThrow("Invalid user input");
// But we can alter columns to make them nullable
await table.alterColumns([{ path: "price", nullable: true }]);
await table.add([data4]);
res = (await table.query().limit(10).toArray()).map((r) => ({
...r.toJSON(),
vector: r.vector ? Array.from(r.vector) : null,
}));
// Rust fills missing nullable fields with null
expect(res).toEqual([
{ ...data1, vector: null },
{ ...data2, item: null },
data3,
{ ...data4, price: null, vector: null },
]);
res = (await table.query().limit(10).toArray()).map((r) => r.toJSON());
expect(res).toEqual([data1, data2, data3, data4]);
});
it("should be able to insert nullable data for non-nullable fields", async () => {
@@ -350,43 +330,6 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
const table = await db.createTable("my_table", data);
expect(await table.countRows()).toEqual(2);
});
it("should allow undefined and omitted nullable vector fields", async () => {
// Test for the bug: can't pass undefined or omit vector column
const db = await connect("memory://");
const schema = new arrow.Schema([
new arrow.Field("id", new arrow.Int32(), true),
new arrow.Field(
"vector",
new arrow.FixedSizeList(
32,
new arrow.Field("item", new arrow.Float32(), true),
),
true, // nullable = true
),
]);
const table = await db.createEmptyTable("test_table", schema);
// Should not throw error for undefined value
await table.add([{ id: 0, vector: undefined }]);
// Should not throw error for omitted field
await table.add([{ id: 1 }]);
// Should still work for null
await table.add([{ id: 2, vector: null }]);
// Should still work for actual vector
const testVector = new Array(32).fill(0.5);
await table.add([{ id: 3, vector: testVector }]);
expect(await table.countRows()).toEqual(4);
const res = await table.query().limit(10).toArray();
const resVector = res.map((r) =>
r.vector ? Array.from(r.vector) : null,
);
expect(resVector).toEqual([null, null, null, testVector]);
});
},
);
@@ -1520,9 +1463,7 @@ describe("when optimizing a dataset", () => {
it("delete unverified", async () => {
const version = await table.version();
const versionFile = `${tmpDir.name}/${table.name}.lance/_versions/${
version - 1
}.manifest`;
const versionFile = `${tmpDir.name}/${table.name}.lance/_versions/${version - 1}.manifest`;
fs.rmSync(versionFile);
let stats = await table.optimize({ deleteUnverified: false });
@@ -2036,52 +1977,3 @@ describe("column name options", () => {
expect(results2.length).toBe(10);
});
});
describe("when creating an empty table", () => {
let con: Connection;
beforeEach(async () => {
const tmpDir = tmp.dirSync({ unsafeCleanup: true });
con = await connect(tmpDir.name);
});
afterEach(() => {
con.close();
});
it("can create an empty table from an arrow Schema", async () => {
const schema = new Schema([
new Field("id", new Int64()),
new Field("vector", new Float64()),
]);
const table = await con.createEmptyTable("test", schema);
const actualSchema = await table.schema();
expect(actualSchema.fields[0].type.typeId).toBe(Type.Int);
expect((actualSchema.fields[0].type as Int64).bitWidth).toBe(64);
expect(actualSchema.fields[1].type.typeId).toBe(Type.Float);
expect((actualSchema.fields[1].type as Float64).precision).toBe(2);
});
it("can create an empty table from schema that specifies field types by name", async () => {
const schemaLike = {
fields: [
{
name: "id",
type: "int64",
nullable: true,
},
{
name: "vector",
type: "float64",
nullable: true,
},
],
metadata: new Map(),
names: ["id", "vector"],
} satisfies SchemaLike;
const table = await con.createEmptyTable("test", schemaLike);
const actualSchema = await table.schema();
expect(actualSchema.fields[0].type.typeId).toBe(Type.Int);
expect((actualSchema.fields[0].type as Int64).bitWidth).toBe(64);
expect(actualSchema.fields[1].type.typeId).toBe(Type.Float);
expect((actualSchema.fields[1].type as Float64).precision).toBe(2);
});
});

View File

@@ -73,7 +73,7 @@ export type FieldLike =
| {
type: string;
name: string;
nullable: boolean;
nullable?: boolean;
metadata?: Map<string, string>;
};
@@ -1285,36 +1285,19 @@ function validateSchemaEmbeddings(
if (isFixedSizeList(field.type)) {
field = sanitizeField(field);
if (data.length !== 0 && data?.[0]?.[field.name] === undefined) {
// Check if there's an embedding function registered for this field
let hasEmbeddingFunction = false;
// Check schema metadata for embedding functions
if (schema.metadata.has("embedding_functions")) {
const embeddings = JSON.parse(
schema.metadata.get("embedding_functions")!,
);
// biome-ignore lint/suspicious/noExplicitAny: we don't know the type of `f`
if (embeddings.find((f: any) => f["vectorColumn"] === field.name)) {
hasEmbeddingFunction = true;
}
}
// Check passed embedding function parameter
if (embeddings && embeddings.vectorColumn === field.name) {
hasEmbeddingFunction = true;
}
// If the field is nullable AND there's no embedding function, allow undefined/omitted values
if (field.nullable && !hasEmbeddingFunction) {
fields.push(field);
} else {
// Either not nullable OR has embedding function - require explicit values
if (hasEmbeddingFunction) {
// Don't add to missingEmbeddingFields since this is expected to be filled by embedding function
fields.push(field);
} else {
if (
// biome-ignore lint/suspicious/noExplicitAny: we don't know the type of `f`
embeddings.find((f: any) => f["vectorColumn"] === field.name) ===
undefined
) {
missingEmbeddingFields.push(field);
}
} else {
missingEmbeddingFields.push(field);
}
} else {
fields.push(field);

View File

@@ -43,10 +43,6 @@ export {
DeleteResult,
DropColumnsResult,
UpdateResult,
SplitRandomOptions,
SplitHashOptions,
SplitSequentialOptions,
ShuffleOptions,
} from "./native.js";
export {
@@ -115,7 +111,6 @@ export {
export { MergeInsertBuilder, WriteExecutionOptions } from "./merge";
export * as embedding from "./embedding";
export { permutationBuilder, PermutationBuilder } from "./permutation";
export * as rerankers from "./rerankers";
export {
SchemaLike,

View File

@@ -1,188 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
import {
PermutationBuilder as NativePermutationBuilder,
Table as NativeTable,
ShuffleOptions,
SplitHashOptions,
SplitRandomOptions,
SplitSequentialOptions,
permutationBuilder as nativePermutationBuilder,
} from "./native.js";
import { LocalTable, Table } from "./table";
/**
* A PermutationBuilder for creating data permutations with splits, shuffling, and filtering.
*
* This class provides a TypeScript wrapper around the native Rust PermutationBuilder,
* offering methods to configure data splits, shuffling, and filtering before executing
* the permutation to create a new table.
*/
export class PermutationBuilder {
private inner: NativePermutationBuilder;
/**
* @hidden
*/
constructor(inner: NativePermutationBuilder) {
this.inner = inner;
}
/**
* Configure random splits for the permutation.
*
* @param options - Configuration for random splitting
* @returns A new PermutationBuilder instance
* @example
* ```ts
* // Split by ratios
* builder.splitRandom({ ratios: [0.7, 0.3], seed: 42 });
*
* // Split by counts
* builder.splitRandom({ counts: [1000, 500], seed: 42 });
*
* // Split with fixed size
* builder.splitRandom({ fixed: 100, seed: 42 });
* ```
*/
splitRandom(options: SplitRandomOptions): PermutationBuilder {
const newInner = this.inner.splitRandom(options);
return new PermutationBuilder(newInner);
}
/**
* Configure hash-based splits for the permutation.
*
* @param options - Configuration for hash-based splitting
* @returns A new PermutationBuilder instance
* @example
* ```ts
* builder.splitHash({
* columns: ["user_id"],
* splitWeights: [70, 30],
* discardWeight: 0
* });
* ```
*/
splitHash(options: SplitHashOptions): PermutationBuilder {
const newInner = this.inner.splitHash(options);
return new PermutationBuilder(newInner);
}
/**
* Configure sequential splits for the permutation.
*
* @param options - Configuration for sequential splitting
* @returns A new PermutationBuilder instance
* @example
* ```ts
* // Split by ratios
* builder.splitSequential({ ratios: [0.8, 0.2] });
*
* // Split by counts
* builder.splitSequential({ counts: [800, 200] });
*
* // Split with fixed size
* builder.splitSequential({ fixed: 1000 });
* ```
*/
splitSequential(options: SplitSequentialOptions): PermutationBuilder {
const newInner = this.inner.splitSequential(options);
return new PermutationBuilder(newInner);
}
/**
* Configure calculated splits for the permutation.
*
* @param calculation - SQL expression for calculating splits
* @returns A new PermutationBuilder instance
* @example
* ```ts
* builder.splitCalculated("user_id % 3");
* ```
*/
splitCalculated(calculation: string): PermutationBuilder {
const newInner = this.inner.splitCalculated(calculation);
return new PermutationBuilder(newInner);
}
/**
* Configure shuffling for the permutation.
*
* @param options - Configuration for shuffling
* @returns A new PermutationBuilder instance
* @example
* ```ts
* // Basic shuffle
* builder.shuffle({ seed: 42 });
*
* // Shuffle with clump size
* builder.shuffle({ seed: 42, clumpSize: 10 });
* ```
*/
shuffle(options: ShuffleOptions): PermutationBuilder {
const newInner = this.inner.shuffle(options);
return new PermutationBuilder(newInner);
}
/**
* Configure filtering for the permutation.
*
* @param filter - SQL filter expression
* @returns A new PermutationBuilder instance
* @example
* ```ts
* builder.filter("age > 18 AND status = 'active'");
* ```
*/
filter(filter: string): PermutationBuilder {
const newInner = this.inner.filter(filter);
return new PermutationBuilder(newInner);
}
/**
* Execute the permutation and create the destination table.
*
* @returns A Promise that resolves to the new Table instance
* @example
* ```ts
* const permutationTable = await builder.execute();
* console.log(`Created table: ${permutationTable.name}`);
* ```
*/
async execute(): Promise<Table> {
const nativeTable: NativeTable = await this.inner.execute();
return new LocalTable(nativeTable);
}
}
/**
* Create a permutation builder for the given table.
*
* @param table - The source table to create a permutation from
* @param destTableName - The name for the destination permutation table
* @returns A PermutationBuilder instance
* @example
* ```ts
* const builder = permutationBuilder(sourceTable, "training_data")
* .splitRandom({ ratios: [0.8, 0.2], seed: 42 })
* .shuffle({ seed: 123 });
*
* const trainingTable = await builder.execute();
* ```
*/
export function permutationBuilder(
table: Table,
destTableName: string,
): PermutationBuilder {
// Extract the inner native table from the TypeScript wrapper
const localTable = table as LocalTable;
// Access inner through type assertion since it's private
const nativeBuilder = nativePermutationBuilder(
// biome-ignore lint/suspicious/noExplicitAny: need access to private variable
(localTable as any).inner,
destTableName,
);
return new PermutationBuilder(nativeBuilder);
}

View File

@@ -326,9 +326,6 @@ export function sanitizeDictionary(typeLike: object) {
// biome-ignore lint/suspicious/noExplicitAny: skip
export function sanitizeType(typeLike: unknown): DataType<any> {
if (typeof typeLike === "string") {
return dataTypeFromName(typeLike);
}
if (typeof typeLike !== "object" || typeLike === null) {
throw Error("Expected a Type but object was null/undefined");
}
@@ -450,7 +447,7 @@ export function sanitizeType(typeLike: unknown): DataType<any> {
case Type.DurationSecond:
return new DurationSecond();
default:
throw new Error("Unrecognized type id in schema: " + typeId);
throw new Error("Unrecoginized type id in schema: " + typeId);
}
}
@@ -470,15 +467,7 @@ export function sanitizeField(fieldLike: unknown): Field {
"The field passed in is missing a `type`/`name`/`nullable` property",
);
}
let type: DataType;
try {
type = sanitizeType(fieldLike.type);
} catch (error: unknown) {
throw Error(
`Unable to sanitize type for field: ${fieldLike.name} due to error: ${error}`,
{ cause: error },
);
}
const type = sanitizeType(fieldLike.type);
const name = fieldLike.name;
if (!(typeof name === "string")) {
throw Error("The field passed in had a non-string `name` property");
@@ -592,46 +581,3 @@ function sanitizeData(
},
);
}
const constructorsByTypeName = {
null: () => new Null(),
binary: () => new Binary(),
utf8: () => new Utf8(),
bool: () => new Bool(),
int8: () => new Int8(),
int16: () => new Int16(),
int32: () => new Int32(),
int64: () => new Int64(),
uint8: () => new Uint8(),
uint16: () => new Uint16(),
uint32: () => new Uint32(),
uint64: () => new Uint64(),
float16: () => new Float16(),
float32: () => new Float32(),
float64: () => new Float64(),
datemillisecond: () => new DateMillisecond(),
dateday: () => new DateDay(),
timenanosecond: () => new TimeNanosecond(),
timemicrosecond: () => new TimeMicrosecond(),
timemillisecond: () => new TimeMillisecond(),
timesecond: () => new TimeSecond(),
intervaldaytime: () => new IntervalDayTime(),
intervalyearmonth: () => new IntervalYearMonth(),
durationnanosecond: () => new DurationNanosecond(),
durationmicrosecond: () => new DurationMicrosecond(),
durationmillisecond: () => new DurationMillisecond(),
durationsecond: () => new DurationSecond(),
} as const;
type MappableTypeName = keyof typeof constructorsByTypeName;
export function dataTypeFromName(typeName: string): DataType {
const normalizedTypeName = typeName.toLowerCase() as MappableTypeName;
const _constructor = constructorsByTypeName[normalizedTypeName];
if (!_constructor) {
throw new Error("Unrecognized type name in schema: " + typeName);
}
return _constructor();
}

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-arm64",
"version": "0.22.2",
"version": "0.22.2-beta.0",
"os": ["darwin"],
"cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-x64",
"version": "0.22.2",
"version": "0.22.2-beta.0",
"os": ["darwin"],
"cpu": ["x64"],
"main": "lancedb.darwin-x64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.22.2",
"version": "0.22.2-beta.0",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-musl",
"version": "0.22.2",
"version": "0.22.2-beta.0",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.22.2",
"version": "0.22.2-beta.0",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-musl",
"version": "0.22.2",
"version": "0.22.2-beta.0",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-arm64-msvc",
"version": "0.22.2",
"version": "0.22.2-beta.0",
"os": [
"win32"
],

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.22.2",
"version": "0.22.2-beta.0",
"os": ["win32"],
"cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node",

View File

@@ -1,12 +1,12 @@
{
"name": "@lancedb/lancedb",
"version": "0.22.2",
"version": "0.22.2-beta.0",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "@lancedb/lancedb",
"version": "0.22.2",
"version": "0.22.2-beta.0",
"cpu": [
"x64",
"arm64"

View File

@@ -11,7 +11,7 @@
"ann"
],
"private": false,
"version": "0.22.2",
"version": "0.22.2-beta.0",
"main": "dist/index.js",
"exports": {
".": "./dist/index.js",

View File

@@ -12,7 +12,6 @@ mod header;
mod index;
mod iterator;
pub mod merge;
pub mod permutation;
mod query;
pub mod remote;
mod rerankers;

View File

@@ -1,222 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::{Arc, Mutex};
use crate::{error::NapiErrorExt, table::Table};
use lancedb::dataloader::{
permutation::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
split::{SplitSizes, SplitStrategy},
};
use napi_derive::napi;
#[napi(object)]
pub struct SplitRandomOptions {
pub ratios: Option<Vec<f64>>,
pub counts: Option<Vec<i64>>,
pub fixed: Option<i64>,
pub seed: Option<i64>,
}
#[napi(object)]
pub struct SplitHashOptions {
pub columns: Vec<String>,
pub split_weights: Vec<i64>,
pub discard_weight: Option<i64>,
}
#[napi(object)]
pub struct SplitSequentialOptions {
pub ratios: Option<Vec<f64>>,
pub counts: Option<Vec<i64>>,
pub fixed: Option<i64>,
}
#[napi(object)]
pub struct ShuffleOptions {
pub seed: Option<i64>,
pub clump_size: Option<i64>,
}
pub struct PermutationBuilderState {
pub builder: Option<LancePermutationBuilder>,
pub dest_table_name: String,
}
#[napi]
pub struct PermutationBuilder {
state: Arc<Mutex<PermutationBuilderState>>,
}
impl PermutationBuilder {
pub fn new(builder: LancePermutationBuilder, dest_table_name: String) -> Self {
Self {
state: Arc::new(Mutex::new(PermutationBuilderState {
builder: Some(builder),
dest_table_name,
})),
}
}
}
impl PermutationBuilder {
fn modify(
&self,
func: impl FnOnce(LancePermutationBuilder) -> LancePermutationBuilder,
) -> napi::Result<Self> {
let mut state = self.state.lock().unwrap();
let builder = state
.builder
.take()
.ok_or_else(|| napi::Error::from_reason("Builder already consumed"))?;
state.builder = Some(func(builder));
Ok(Self {
state: self.state.clone(),
})
}
}
#[napi]
impl PermutationBuilder {
/// Configure random splits
#[napi]
pub fn split_random(&self, options: SplitRandomOptions) -> napi::Result<Self> {
// Check that exactly one split type is provided
let split_args_count = [
options.ratios.is_some(),
options.counts.is_some(),
options.fixed.is_some(),
]
.iter()
.filter(|&&x| x)
.count();
if split_args_count != 1 {
return Err(napi::Error::from_reason(
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
));
}
let sizes = if let Some(ratios) = options.ratios {
SplitSizes::Percentages(ratios)
} else if let Some(counts) = options.counts {
SplitSizes::Counts(counts.into_iter().map(|c| c as u64).collect())
} else if let Some(fixed) = options.fixed {
SplitSizes::Fixed(fixed as u64)
} else {
unreachable!("One of the split arguments must be provided");
};
let seed = options.seed.map(|s| s as u64);
self.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes }))
}
/// Configure hash-based splits
#[napi]
pub fn split_hash(&self, options: SplitHashOptions) -> napi::Result<Self> {
let split_weights = options
.split_weights
.into_iter()
.map(|w| w as u64)
.collect();
let discard_weight = options.discard_weight.unwrap_or(0) as u64;
self.modify(|builder| {
builder.with_split_strategy(SplitStrategy::Hash {
columns: options.columns,
split_weights,
discard_weight,
})
})
}
/// Configure sequential splits
#[napi]
pub fn split_sequential(&self, options: SplitSequentialOptions) -> napi::Result<Self> {
// Check that exactly one split type is provided
let split_args_count = [
options.ratios.is_some(),
options.counts.is_some(),
options.fixed.is_some(),
]
.iter()
.filter(|&&x| x)
.count();
if split_args_count != 1 {
return Err(napi::Error::from_reason(
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
));
}
let sizes = if let Some(ratios) = options.ratios {
SplitSizes::Percentages(ratios)
} else if let Some(counts) = options.counts {
SplitSizes::Counts(counts.into_iter().map(|c| c as u64).collect())
} else if let Some(fixed) = options.fixed {
SplitSizes::Fixed(fixed as u64)
} else {
unreachable!("One of the split arguments must be provided");
};
self.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes }))
}
/// Configure calculated splits
#[napi]
pub fn split_calculated(&self, calculation: String) -> napi::Result<Self> {
self.modify(|builder| {
builder.with_split_strategy(SplitStrategy::Calculated { calculation })
})
}
/// Configure shuffling
#[napi]
pub fn shuffle(&self, options: ShuffleOptions) -> napi::Result<Self> {
let seed = options.seed.map(|s| s as u64);
let clump_size = options.clump_size.map(|c| c as u64);
self.modify(|builder| {
builder.with_shuffle_strategy(ShuffleStrategy::Random { seed, clump_size })
})
}
/// Configure filtering
#[napi]
pub fn filter(&self, filter: String) -> napi::Result<Self> {
self.modify(|builder| builder.with_filter(filter))
}
/// Execute the permutation builder and create the table
#[napi]
pub async fn execute(&self) -> napi::Result<Table> {
let (builder, dest_table_name) = {
let mut state = self.state.lock().unwrap();
let builder = state
.builder
.take()
.ok_or_else(|| napi::Error::from_reason("Builder already consumed"))?;
let dest_table_name = std::mem::take(&mut state.dest_table_name);
(builder, dest_table_name)
};
let table = builder.build(&dest_table_name).await.default_error()?;
Ok(Table::new(table))
}
}
/// Create a permutation builder for the given table
#[napi]
pub fn permutation_builder(
table: &crate::table::Table,
dest_table_name: String,
) -> napi::Result<PermutationBuilder> {
use lancedb::dataloader::permutation::PermutationBuilder as LancePermutationBuilder;
let inner_table = table.inner_ref()?.clone();
let inner_builder = LancePermutationBuilder::new(inner_table);
Ok(PermutationBuilder::new(inner_builder, dest_table_name))
}

View File

@@ -26,7 +26,7 @@ pub struct Table {
}
impl Table {
pub(crate) fn inner_ref(&self) -> napi::Result<&LanceDbTable> {
fn inner_ref(&self) -> napi::Result<&LanceDbTable> {
self.inner
.as_ref()
.ok_or_else(|| napi::Error::from_reason(format!("Table {} is closed", self.name)))

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.25.3-beta.0"
current_version = "0.25.2-beta.0"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.
@@ -24,19 +24,6 @@ commit = true
message = "Bump version: {current_version} → {new_version}"
commit_args = ""
# Update Cargo.lock after version bump
pre_commit_hooks = [
"""
cd python && cargo update -p lancedb-python
if git diff --quiet ../Cargo.lock; then
echo "Cargo.lock unchanged"
else
git add ../Cargo.lock
echo "Updated and staged Cargo.lock"
fi
""",
]
[tool.bumpversion.parts.pre_l]
values = ["beta", "final"]
optional_value = "final"

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-python"
version = "0.25.3-beta.0"
version = "0.25.2-beta.0"
edition.workspace = true
description = "Python bindings for LanceDB"
license.workspace = true
@@ -14,12 +14,12 @@ name = "_lancedb"
crate-type = ["cdylib"]
[dependencies]
arrow = { version = "56.2", features = ["pyarrow"] }
arrow = { version = "55.1", features = ["pyarrow"] }
async-trait = "0.1"
lancedb = { path = "../rust/lancedb", default-features = false }
env_logger.workspace = true
pyo3 = { version = "0.25", features = ["extension-module", "abi3-py39"] }
pyo3-async-runtimes = { version = "0.25", features = [
pyo3 = { version = "0.24", features = ["extension-module", "abi3-py39"] }
pyo3-async-runtimes = { version = "0.24", features = [
"attributes",
"tokio-runtime",
] }
@@ -28,7 +28,7 @@ futures.workspace = true
tokio = { version = "1.40", features = ["sync"] }
[build-dependencies]
pyo3-build-config = { version = "0.25", features = [
pyo3-build-config = { version = "0.24", features = [
"extension-module",
"abi3-py39",
] }

View File

@@ -5,12 +5,12 @@ dynamic = ["version"]
dependencies = [
"deprecation",
"numpy",
"overrides>=0.7; python_version<'3.12'",
"overrides>=0.7",
"packaging",
"pyarrow>=16",
"pydantic>=1.10",
"tqdm>=4.27.0",
"lance-namespace>=0.0.16"
"lance-namespace==0.0.6"
]
description = "lancedb"
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]

View File

@@ -133,7 +133,6 @@ class Tags:
async def update(self, tag: str, version: int): ...
class IndexConfig:
name: str
index_type: str
columns: List[str]
@@ -296,34 +295,3 @@ class AlterColumnsResult:
class DropColumnsResult:
version: int
class AsyncPermutationBuilder:
def select(self, projections: Dict[str, str]) -> "AsyncPermutationBuilder": ...
def split_random(
self,
*,
ratios: Optional[List[float]] = None,
counts: Optional[List[int]] = None,
fixed: Optional[int] = None,
seed: Optional[int] = None,
) -> "AsyncPermutationBuilder": ...
def split_hash(
self, columns: List[str], split_weights: List[int], *, discard_weight: int = 0
) -> "AsyncPermutationBuilder": ...
def split_sequential(
self,
*,
ratios: Optional[List[float]] = None,
counts: Optional[List[int]] = None,
fixed: Optional[int] = None,
) -> "AsyncPermutationBuilder": ...
def split_calculated(self, calculation: str) -> "AsyncPermutationBuilder": ...
def shuffle(
self, seed: Optional[int], clump_size: Optional[int]
) -> "AsyncPermutationBuilder": ...
def filter(self, filter: str) -> "AsyncPermutationBuilder": ...
async def execute(self) -> Table: ...
def async_permutation_builder(
table: Table, dest_table_name: str
) -> AsyncPermutationBuilder: ...

View File

@@ -5,20 +5,11 @@
from __future__ import annotations
from abc import abstractmethod
from datetime import timedelta
from pathlib import Path
import sys
from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Union
if sys.version_info >= (3, 12):
from typing import override
class EnforceOverrides:
pass
else:
from overrides import EnforceOverrides, override # type: ignore
from lancedb.embeddings.registry import EmbeddingFunctionRegistry
from overrides import EnforceOverrides, override # type: ignore
from lancedb.common import data_to_reader, sanitize_uri, validate_schema
from lancedb.background_loop import LOOP
@@ -41,6 +32,7 @@ import deprecation
if TYPE_CHECKING:
import pyarrow as pa
from .pydantic import LanceModel
from datetime import timedelta
from ._lancedb import Connection as LanceDbConnection
from .common import DATA, URI
@@ -452,12 +444,7 @@ class LanceDBConnection(DBConnection):
read_consistency_interval: Optional[timedelta] = None,
storage_options: Optional[Dict[str, str]] = None,
session: Optional[Session] = None,
_inner: Optional[LanceDbConnection] = None,
):
if _inner is not None:
self._conn = _inner
return
if not isinstance(uri, Path):
scheme = get_uri_scheme(uri)
is_local = isinstance(uri, Path) or scheme == "file"
@@ -466,6 +453,11 @@ class LanceDBConnection(DBConnection):
uri = Path(uri)
uri = uri.expanduser().absolute()
Path(uri).mkdir(parents=True, exist_ok=True)
self._uri = str(uri)
self._entered = False
self.read_consistency_interval = read_consistency_interval
self.storage_options = storage_options
self.session = session
if read_consistency_interval is not None:
read_consistency_interval_secs = read_consistency_interval.total_seconds()
@@ -484,32 +476,10 @@ class LanceDBConnection(DBConnection):
session,
)
# TODO: It would be nice if we didn't store self.storage_options but it is
# currently used by the LanceTable.to_lance method. This doesn't _really_
# work because some paths like LanceDBConnection.from_inner will lose the
# storage_options. Also, this class really shouldn't be holding any state
# beyond _conn.
self.storage_options = storage_options
self._conn = AsyncConnection(LOOP.run(do_connect()))
@property
def read_consistency_interval(self) -> Optional[timedelta]:
return LOOP.run(self._conn.get_read_consistency_interval())
@property
def session(self) -> Optional[Session]:
return self._conn.session
@property
def uri(self) -> str:
return self._conn.uri
@classmethod
def from_inner(cls, inner: LanceDbConnection):
return cls(None, _inner=inner)
def __repr__(self) -> str:
val = f"{self.__class__.__name__}(uri={self._conn.uri!r}"
val = f"{self.__class__.__name__}(uri={self._uri!r}"
if self.read_consistency_interval is not None:
val += f", read_consistency_interval={repr(self.read_consistency_interval)}"
val += ")"
@@ -519,10 +489,6 @@ class LanceDBConnection(DBConnection):
conn = AsyncConnection(await lancedb_connect(self.uri))
return await conn.table_names(start_after=start_after, limit=limit)
@property
def _inner(self) -> LanceDbConnection:
return self._conn._inner
@override
def list_namespaces(
self,
@@ -882,13 +848,6 @@ class AsyncConnection(object):
def uri(self) -> str:
return self._inner.uri
async def get_read_consistency_interval(self) -> Optional[timedelta]:
interval_secs = await self._inner.get_read_consistency_interval()
if interval_secs is not None:
return timedelta(seconds=interval_secs)
else:
return None
async def list_namespaces(
self,
namespace: List[str] = [],

View File

@@ -12,18 +12,13 @@ from __future__ import annotations
from typing import Dict, Iterable, List, Optional, Union
import os
import sys
if sys.version_info >= (3, 12):
from typing import override
else:
from overrides import override
from lancedb.db import DBConnection
from lancedb.table import LanceTable, Table
from lancedb.util import validate_table_name
from lancedb.common import validate_schema
from lancedb.table import sanitize_create_table
from overrides import override
from lance_namespace import LanceNamespace, connect as namespace_connect
from lance_namespace_urllib3_client.models import (

View File

@@ -1,72 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from ._lancedb import async_permutation_builder
from .table import LanceTable
from .background_loop import LOOP
from typing import Optional
class PermutationBuilder:
def __init__(self, table: LanceTable, dest_table_name: str):
self._async = async_permutation_builder(table, dest_table_name)
def select(self, projections: dict[str, str]) -> "PermutationBuilder":
self._async.select(projections)
return self
def split_random(
self,
*,
ratios: Optional[list[float]] = None,
counts: Optional[list[int]] = None,
fixed: Optional[int] = None,
seed: Optional[int] = None,
) -> "PermutationBuilder":
self._async.split_random(ratios=ratios, counts=counts, fixed=fixed, seed=seed)
return self
def split_hash(
self,
columns: list[str],
split_weights: list[int],
*,
discard_weight: Optional[int] = None,
) -> "PermutationBuilder":
self._async.split_hash(columns, split_weights, discard_weight=discard_weight)
return self
def split_sequential(
self,
*,
ratios: Optional[list[float]] = None,
counts: Optional[list[int]] = None,
fixed: Optional[int] = None,
) -> "PermutationBuilder":
self._async.split_sequential(ratios=ratios, counts=counts, fixed=fixed)
return self
def split_calculated(self, calculation: str) -> "PermutationBuilder":
self._async.split_calculated(calculation)
return self
def shuffle(
self, *, seed: Optional[int] = None, clump_size: Optional[int] = None
) -> "PermutationBuilder":
self._async.shuffle(seed=seed, clump_size=clump_size)
return self
def filter(self, filter: str) -> "PermutationBuilder":
self._async.filter(filter)
return self
def execute(self) -> LanceTable:
async def do_execute():
inner_tbl = await self._async.execute()
return LanceTable.from_inner(inner_tbl)
return LOOP.run(do_execute())
def permutation_builder(table: LanceTable, dest_table_name: str) -> PermutationBuilder:
return PermutationBuilder(table, dest_table_name)

View File

@@ -5,20 +5,15 @@
from datetime import timedelta
import logging
from concurrent.futures import ThreadPoolExecutor
import sys
from typing import Any, Dict, Iterable, List, Optional, Union
from urllib.parse import urlparse
import warnings
if sys.version_info >= (3, 12):
from typing import override
else:
from overrides import override
# Remove this import to fix circular dependency
# from lancedb import connect_async
from lancedb.remote import ClientConfig
import pyarrow as pa
from overrides import override
from ..common import DATA
from ..db import DBConnection, LOOP

View File

@@ -114,7 +114,7 @@ class RemoteTable(Table):
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST", "scalar"] = "scalar",
*,
replace: bool = False,
wait_timeout: Optional[timedelta] = None,
wait_timeout: timedelta = None,
name: Optional[str] = None,
):
"""Creates a scalar index
@@ -153,7 +153,7 @@ class RemoteTable(Table):
column: str,
*,
replace: bool = False,
wait_timeout: Optional[timedelta] = None,
wait_timeout: timedelta = None,
with_position: bool = False,
# tokenizer configs:
base_tokenizer: str = "simple",

View File

@@ -74,7 +74,6 @@ from .index import lang_mapping
if TYPE_CHECKING:
from .db import LanceDBConnection
from ._lancedb import (
Table as LanceDBTable,
OptimizeStats,
@@ -89,6 +88,7 @@ if TYPE_CHECKING:
MergeResult,
UpdateResult,
)
from .db import LanceDBConnection
from .index import IndexConfig
import pandas
import PIL
@@ -1707,38 +1707,22 @@ class LanceTable(Table):
namespace: List[str] = [],
storage_options: Optional[Dict[str, str]] = None,
index_cache_size: Optional[int] = None,
_async: AsyncTable = None,
):
self._conn = connection
self._namespace = namespace
if _async is not None:
self._table = _async
else:
self._table = LOOP.run(
connection._conn.open_table(
name,
namespace=namespace,
storage_options=storage_options,
index_cache_size=index_cache_size,
)
self._table = LOOP.run(
connection._conn.open_table(
name,
namespace=namespace,
storage_options=storage_options,
index_cache_size=index_cache_size,
)
)
@property
def name(self) -> str:
return self._table.name
@classmethod
def from_inner(cls, tbl: LanceDBTable):
from .db import LanceDBConnection
async_tbl = AsyncTable(tbl)
conn = LanceDBConnection.from_inner(tbl.database())
return cls(
conn,
async_tbl.name,
_async=async_tbl,
)
@classmethod
def open(cls, db, name, *, namespace: List[str] = [], **kwargs):
tbl = cls(db, name, namespace=namespace, **kwargs)
@@ -2772,10 +2756,6 @@ class LanceTable(Table):
self._table._do_merge(merge, new_data, on_bad_vectors, fill_value)
)
@property
def _inner(self) -> LanceDBTable:
return self._table._inner
@deprecation.deprecated(
deprecated_in="0.21.0",
current_version=__version__,

View File

@@ -1,496 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import pyarrow as pa
import pytest
from lancedb.permutation import permutation_builder
def test_split_random_ratios(mem_db):
"""Test random splitting with ratios."""
tbl = mem_db.create_table(
"test_table", pa.table({"x": range(100), "y": range(100)})
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.split_random(ratios=[0.3, 0.7])
.execute()
)
# Check that the table was created and has data
assert permutation_tbl.count_rows() == 100
# Check that split_id column exists and has correct values
data = permutation_tbl.search(None).to_arrow().to_pydict()
split_ids = data["split_id"]
assert set(split_ids) == {0, 1}
# Check approximate split sizes (allowing for rounding)
split_0_count = split_ids.count(0)
split_1_count = split_ids.count(1)
assert 25 <= split_0_count <= 35 # ~30% ± tolerance
assert 65 <= split_1_count <= 75 # ~70% ± tolerance
def test_split_random_counts(mem_db):
"""Test random splitting with absolute counts."""
tbl = mem_db.create_table(
"test_table", pa.table({"x": range(100), "y": range(100)})
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.split_random(counts=[20, 30])
.execute()
)
# Check that we have exactly the requested counts
assert permutation_tbl.count_rows() == 50
data = permutation_tbl.search(None).to_arrow().to_pydict()
split_ids = data["split_id"]
assert split_ids.count(0) == 20
assert split_ids.count(1) == 30
def test_split_random_fixed(mem_db):
"""Test random splitting with fixed number of splits."""
tbl = mem_db.create_table(
"test_table", pa.table({"x": range(100), "y": range(100)})
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation").split_random(fixed=4).execute()
)
# Check that we have 4 splits with 25 rows each
assert permutation_tbl.count_rows() == 100
data = permutation_tbl.search(None).to_arrow().to_pydict()
split_ids = data["split_id"]
assert set(split_ids) == {0, 1, 2, 3}
for split_id in range(4):
assert split_ids.count(split_id) == 25
def test_split_random_with_seed(mem_db):
"""Test that seeded random splits are reproducible."""
tbl = mem_db.create_table("test_table", pa.table({"x": range(50), "y": range(50)}))
# Create two identical permutations with same seed
perm1 = (
permutation_builder(tbl, "perm1")
.split_random(ratios=[0.6, 0.4], seed=42)
.execute()
)
perm2 = (
permutation_builder(tbl, "perm2")
.split_random(ratios=[0.6, 0.4], seed=42)
.execute()
)
# Results should be identical
data1 = perm1.search(None).to_arrow().to_pydict()
data2 = perm2.search(None).to_arrow().to_pydict()
assert data1["row_id"] == data2["row_id"]
assert data1["split_id"] == data2["split_id"]
def test_split_hash(mem_db):
"""Test hash-based splitting."""
tbl = mem_db.create_table(
"test_table",
pa.table(
{
"id": range(100),
"category": (["A", "B", "C"] * 34)[:100], # Repeating pattern
"value": range(100),
}
),
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.split_hash(["category"], [1, 1], discard_weight=0)
.execute()
)
# Should have all 100 rows (no discard)
assert permutation_tbl.count_rows() == 100
data = permutation_tbl.search(None).to_arrow().to_pydict()
split_ids = data["split_id"]
assert set(split_ids) == {0, 1}
# Verify that each split has roughly 50 rows (allowing for hash variance)
split_0_count = split_ids.count(0)
split_1_count = split_ids.count(1)
assert 30 <= split_0_count <= 70 # ~50 ± 20 tolerance for hash distribution
assert 30 <= split_1_count <= 70 # ~50 ± 20 tolerance for hash distribution
# Hash splits should be deterministic - same category should go to same split
# Let's verify by creating another permutation and checking consistency
perm2 = (
permutation_builder(tbl, "test_permutation2")
.split_hash(["category"], [1, 1], discard_weight=0)
.execute()
)
data2 = perm2.search(None).to_arrow().to_pydict()
assert data["split_id"] == data2["split_id"] # Should be identical
def test_split_hash_with_discard(mem_db):
"""Test hash-based splitting with discard weight."""
tbl = mem_db.create_table(
"test_table",
pa.table({"id": range(100), "category": ["A", "B"] * 50, "value": range(100)}),
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.split_hash(["category"], [1, 1], discard_weight=2) # Should discard ~50%
.execute()
)
# Should have fewer than 100 rows due to discard
row_count = permutation_tbl.count_rows()
assert row_count < 100
assert row_count > 0 # But not empty
def test_split_sequential(mem_db):
"""Test sequential splitting."""
tbl = mem_db.create_table(
"test_table", pa.table({"x": range(100), "y": range(100)})
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.split_sequential(counts=[30, 40])
.execute()
)
assert permutation_tbl.count_rows() == 70
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
split_ids = data["split_id"]
# Sequential should maintain order
assert row_ids == sorted(row_ids)
# First 30 should be split 0, next 40 should be split 1
assert split_ids[:30] == [0] * 30
assert split_ids[30:] == [1] * 40
def test_split_calculated(mem_db):
"""Test calculated splitting."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(100), "value": range(100)})
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.split_calculated("id % 3") # Split based on id modulo 3
.execute()
)
assert permutation_tbl.count_rows() == 100
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
split_ids = data["split_id"]
# Verify the calculation: each row's split_id should equal row_id % 3
for i, (row_id, split_id) in enumerate(zip(row_ids, split_ids)):
assert split_id == row_id % 3
def test_split_error_cases(mem_db):
"""Test error handling for invalid split parameters."""
tbl = mem_db.create_table("test_table", pa.table({"x": range(10), "y": range(10)}))
# Test split_random with no parameters
with pytest.raises(Exception):
permutation_builder(tbl, "error1").split_random().execute()
# Test split_random with multiple parameters
with pytest.raises(Exception):
permutation_builder(tbl, "error2").split_random(
ratios=[0.5, 0.5], counts=[5, 5]
).execute()
# Test split_sequential with no parameters
with pytest.raises(Exception):
permutation_builder(tbl, "error3").split_sequential().execute()
# Test split_sequential with multiple parameters
with pytest.raises(Exception):
permutation_builder(tbl, "error4").split_sequential(
ratios=[0.5, 0.5], fixed=2
).execute()
def test_shuffle_no_seed(mem_db):
"""Test shuffling without a seed."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(100), "value": range(100)})
)
# Create a permutation with shuffling (no seed)
permutation_tbl = permutation_builder(tbl, "test_permutation").shuffle().execute()
assert permutation_tbl.count_rows() == 100
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
# Row IDs should not be in sequential order due to shuffling
# This is probabilistic but with 100 rows, it's extremely unlikely they'd stay
# in order
assert row_ids != list(range(100))
def test_shuffle_with_seed(mem_db):
"""Test that shuffling with a seed is reproducible."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(50), "value": range(50)})
)
# Create two identical permutations with same shuffle seed
perm1 = permutation_builder(tbl, "perm1").shuffle(seed=42).execute()
perm2 = permutation_builder(tbl, "perm2").shuffle(seed=42).execute()
# Results should be identical due to same seed
data1 = perm1.search(None).to_arrow().to_pydict()
data2 = perm2.search(None).to_arrow().to_pydict()
assert data1["row_id"] == data2["row_id"]
assert data1["split_id"] == data2["split_id"]
def test_shuffle_with_clump_size(mem_db):
"""Test shuffling with clump size."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(100), "value": range(100)})
)
# Create a permutation with shuffling using clumps
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.shuffle(clump_size=10) # 10-row clumps
.execute()
)
assert permutation_tbl.count_rows() == 100
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
for i in range(10):
start = row_ids[i * 10]
assert row_ids[i * 10 : (i + 1) * 10] == list(range(start, start + 10))
def test_shuffle_different_seeds(mem_db):
"""Test that different seeds produce different shuffle orders."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(50), "value": range(50)})
)
# Create two permutations with different shuffle seeds
perm1 = (
permutation_builder(tbl, "perm1")
.split_random(fixed=2)
.shuffle(seed=42)
.execute()
)
perm2 = (
permutation_builder(tbl, "perm2")
.split_random(fixed=2)
.shuffle(seed=123)
.execute()
)
# Results should be different due to different seeds
data1 = perm1.search(None).to_arrow().to_pydict()
data2 = perm2.search(None).to_arrow().to_pydict()
# Row order should be different
assert data1["row_id"] != data2["row_id"]
def test_shuffle_combined_with_splits(mem_db):
"""Test shuffling combined with different split strategies."""
tbl = mem_db.create_table(
"test_table",
pa.table(
{
"id": range(100),
"category": (["A", "B", "C"] * 34)[:100],
"value": range(100),
}
),
)
# Test shuffle with random splits
perm_random = (
permutation_builder(tbl, "perm_random")
.split_random(ratios=[0.6, 0.4], seed=42)
.shuffle(seed=123, clump_size=None)
.execute()
)
# Test shuffle with hash splits
perm_hash = (
permutation_builder(tbl, "perm_hash")
.split_hash(["category"], [1, 1], discard_weight=0)
.shuffle(seed=456, clump_size=5)
.execute()
)
# Test shuffle with sequential splits
perm_sequential = (
permutation_builder(tbl, "perm_sequential")
.split_sequential(counts=[40, 35])
.shuffle(seed=789, clump_size=None)
.execute()
)
# Verify all permutations work and have expected properties
assert perm_random.count_rows() == 100
assert perm_hash.count_rows() == 100
assert perm_sequential.count_rows() == 75
# Verify shuffle affected the order
data_random = perm_random.search(None).to_arrow().to_pydict()
data_sequential = perm_sequential.search(None).to_arrow().to_pydict()
assert data_random["row_id"] != list(range(100))
assert data_sequential["row_id"] != list(range(75))
def test_no_shuffle_maintains_order(mem_db):
"""Test that not calling shuffle maintains the original order."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(50), "value": range(50)})
)
# Create permutation without shuffle (should maintain some order)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.split_sequential(counts=[25, 25]) # Sequential maintains order
.execute()
)
assert permutation_tbl.count_rows() == 50
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
# With sequential splits and no shuffle, should maintain order
assert row_ids == list(range(50))
def test_filter_basic(mem_db):
"""Test basic filtering functionality."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(100), "value": range(100, 200)})
)
# Filter to only include rows where id < 50
permutation_tbl = (
permutation_builder(tbl, "test_permutation").filter("id < 50").execute()
)
assert permutation_tbl.count_rows() == 50
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
# All row_ids should be less than 50
assert all(row_id < 50 for row_id in row_ids)
def test_filter_with_splits(mem_db):
"""Test filtering combined with split strategies."""
tbl = mem_db.create_table(
"test_table",
pa.table(
{
"id": range(100),
"category": (["A", "B", "C"] * 34)[:100],
"value": range(100),
}
),
)
# Filter to only category A and B, then split
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.filter("category IN ('A', 'B')")
.split_random(ratios=[0.5, 0.5])
.execute()
)
# Should have fewer than 100 rows due to filtering
row_count = permutation_tbl.count_rows()
assert row_count == 67
data = permutation_tbl.search(None).to_arrow().to_pydict()
categories = data["category"]
# All categories should be A or B
assert all(cat in ["A", "B"] for cat in categories)
def test_filter_with_shuffle(mem_db):
"""Test filtering combined with shuffling."""
tbl = mem_db.create_table(
"test_table",
pa.table(
{
"id": range(100),
"category": (["A", "B", "C", "D"] * 25)[:100],
"value": range(100),
}
),
)
# Filter and shuffle
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.filter("category IN ('A', 'C')")
.shuffle(seed=42)
.execute()
)
row_count = permutation_tbl.count_rows()
assert row_count == 50 # Should have 50 rows (A and C categories)
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
assert row_ids != sorted(row_ids)
def test_filter_empty_result(mem_db):
"""Test filtering that results in empty set."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(10), "value": range(10)})
)
# Filter that matches nothing
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.filter("value > 100") # No values > 100 in our data
.execute()
)
assert permutation_tbl.count_rows() == 0

View File

@@ -4,10 +4,7 @@
use std::{collections::HashMap, sync::Arc, time::Duration};
use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow};
use lancedb::{
connection::Connection as LanceConnection,
database::{CreateTableMode, ReadConsistency},
};
use lancedb::{connection::Connection as LanceConnection, database::CreateTableMode};
use pyo3::{
exceptions::{PyRuntimeError, PyValueError},
pyclass, pyfunction, pymethods, Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
@@ -26,7 +23,7 @@ impl Connection {
Self { inner: Some(inner) }
}
pub(crate) fn get_inner(&self) -> PyResult<&LanceConnection> {
fn get_inner(&self) -> PyResult<&LanceConnection> {
self.inner
.as_ref()
.ok_or_else(|| PyRuntimeError::new_err("Connection is closed"))
@@ -66,18 +63,6 @@ impl Connection {
self.get_inner().map(|inner| inner.uri().to_string())
}
#[pyo3(signature = ())]
pub fn get_read_consistency_interval(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.get_inner()?.clone();
future_into_py(self_.py(), async move {
Ok(match inner.read_consistency().await.infer_error()? {
ReadConsistency::Manual => None,
ReadConsistency::Eventual(duration) => Some(duration.as_secs_f64()),
ReadConsistency::Strong => Some(0.0_f64),
})
})
}
#[pyo3(signature = (namespace=vec![], start_after=None, limit=None))]
pub fn table_names(
self_: PyRef<'_, Self>,

View File

@@ -5,7 +5,6 @@ use arrow::RecordBatchStream;
use connection::{connect, Connection};
use env_logger::Env;
use index::IndexConfig;
use permutation::PyAsyncPermutationBuilder;
use pyo3::{
pymodule,
types::{PyModule, PyModuleMethods},
@@ -23,7 +22,6 @@ pub mod connection;
pub mod error;
pub mod header;
pub mod index;
pub mod permutation;
pub mod query;
pub mod session;
pub mod table;
@@ -51,9 +49,7 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<DeleteResult>()?;
m.add_class::<DropColumnsResult>()?;
m.add_class::<UpdateResult>()?;
m.add_class::<PyAsyncPermutationBuilder>()?;
m.add_function(wrap_pyfunction!(connect, m)?)?;
m.add_function(wrap_pyfunction!(permutation::async_permutation_builder, m)?)?;
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
Ok(())

View File

@@ -1,177 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::{Arc, Mutex};
use crate::{error::PythonErrorExt, table::Table};
use lancedb::dataloader::{
permutation::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
split::{SplitSizes, SplitStrategy},
};
use pyo3::{
exceptions::PyRuntimeError, pyclass, pymethods, types::PyAnyMethods, Bound, PyAny, PyRefMut,
PyResult,
};
use pyo3_async_runtimes::tokio::future_into_py;
/// Create a permutation builder for the given table
#[pyo3::pyfunction]
pub fn async_permutation_builder(
table: Bound<'_, PyAny>,
dest_table_name: String,
) -> PyResult<PyAsyncPermutationBuilder> {
let table = table.getattr("_inner")?.downcast_into::<Table>()?;
let inner_table = table.borrow().inner_ref()?.clone();
let inner_builder = LancePermutationBuilder::new(inner_table);
Ok(PyAsyncPermutationBuilder {
state: Arc::new(Mutex::new(PyAsyncPermutationBuilderState {
builder: Some(inner_builder),
dest_table_name,
})),
})
}
struct PyAsyncPermutationBuilderState {
builder: Option<LancePermutationBuilder>,
dest_table_name: String,
}
#[pyclass(name = "AsyncPermutationBuilder")]
pub struct PyAsyncPermutationBuilder {
state: Arc<Mutex<PyAsyncPermutationBuilderState>>,
}
impl PyAsyncPermutationBuilder {
fn modify(
&self,
func: impl FnOnce(LancePermutationBuilder) -> LancePermutationBuilder,
) -> PyResult<Self> {
let mut state = self.state.lock().unwrap();
let builder = state
.builder
.take()
.ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?;
state.builder = Some(func(builder));
Ok(Self {
state: self.state.clone(),
})
}
}
#[pymethods]
impl PyAsyncPermutationBuilder {
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None, seed=None))]
pub fn split_random(
slf: PyRefMut<'_, Self>,
ratios: Option<Vec<f64>>,
counts: Option<Vec<u64>>,
fixed: Option<u64>,
seed: Option<u64>,
) -> PyResult<Self> {
// Check that exactly one split type is provided
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
.iter()
.filter(|&&x| x)
.count();
if split_args_count != 1 {
return Err(pyo3::exceptions::PyValueError::new_err(
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
));
}
let sizes = if let Some(ratios) = ratios {
SplitSizes::Percentages(ratios)
} else if let Some(counts) = counts {
SplitSizes::Counts(counts)
} else if let Some(fixed) = fixed {
SplitSizes::Fixed(fixed)
} else {
unreachable!("One of the split arguments must be provided");
};
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes }))
}
#[pyo3(signature = (columns, split_weights, *, discard_weight=0))]
pub fn split_hash(
slf: PyRefMut<'_, Self>,
columns: Vec<String>,
split_weights: Vec<u64>,
discard_weight: u64,
) -> PyResult<Self> {
slf.modify(|builder| {
builder.with_split_strategy(SplitStrategy::Hash {
columns,
split_weights,
discard_weight,
})
})
}
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None))]
pub fn split_sequential(
slf: PyRefMut<'_, Self>,
ratios: Option<Vec<f64>>,
counts: Option<Vec<u64>>,
fixed: Option<u64>,
) -> PyResult<Self> {
// Check that exactly one split type is provided
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
.iter()
.filter(|&&x| x)
.count();
if split_args_count != 1 {
return Err(pyo3::exceptions::PyValueError::new_err(
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
));
}
let sizes = if let Some(ratios) = ratios {
SplitSizes::Percentages(ratios)
} else if let Some(counts) = counts {
SplitSizes::Counts(counts)
} else if let Some(fixed) = fixed {
SplitSizes::Fixed(fixed)
} else {
unreachable!("One of the split arguments must be provided");
};
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes }))
}
pub fn split_calculated(slf: PyRefMut<'_, Self>, calculation: String) -> PyResult<Self> {
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Calculated { calculation }))
}
pub fn shuffle(
slf: PyRefMut<'_, Self>,
seed: Option<u64>,
clump_size: Option<u64>,
) -> PyResult<Self> {
slf.modify(|builder| {
builder.with_shuffle_strategy(ShuffleStrategy::Random { seed, clump_size })
})
}
pub fn filter(slf: PyRefMut<'_, Self>, filter: String) -> PyResult<Self> {
slf.modify(|builder| builder.with_filter(filter))
}
pub fn execute(slf: PyRefMut<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let mut state = slf.state.lock().unwrap();
let builder = state
.builder
.take()
.ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?;
let dest_table_name = std::mem::take(&mut state.dest_table_name);
future_into_py(slf.py(), async move {
let table = builder.build(&dest_table_name).await.infer_error()?;
Ok(Table::new(table))
})
}
}

View File

@@ -3,7 +3,6 @@
use std::{collections::HashMap, sync::Arc};
use crate::{
connection::Connection,
error::PythonErrorExt,
index::{extract_index_params, IndexConfig},
query::{Query, TakeQuery},
@@ -250,7 +249,7 @@ impl Table {
}
impl Table {
pub(crate) fn inner_ref(&self) -> PyResult<&LanceDbTable> {
fn inner_ref(&self) -> PyResult<&LanceDbTable> {
self.inner
.as_ref()
.ok_or_else(|| PyRuntimeError::new_err(format!("Table {} is closed", self.name)))
@@ -273,13 +272,6 @@ impl Table {
self.inner.take();
}
pub fn database(&self) -> PyResult<Connection> {
let inner = self.inner_ref()?.clone();
let inner_connection =
lancedb::Connection::new(inner.database().clone(), inner.embedding_registry().clone());
Ok(Connection::new(inner_connection))
}
pub fn schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {

View File

@@ -1,2 +1,2 @@
[toolchain]
channel = "1.90.0"
channel = "1.86.0"

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb"
version = "0.22.2"
version = "0.22.2-beta.0"
edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true
@@ -11,7 +11,6 @@ rust-version.workspace = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
ahash = { workspace = true }
arrow = { workspace = true }
arrow-array = { workspace = true }
arrow-data = { workspace = true }
@@ -25,16 +24,12 @@ datafusion-common.workspace = true
datafusion-execution.workspace = true
datafusion-expr.workspace = true
datafusion-physical-plan.workspace = true
datafusion.workspace = true
object_store = { workspace = true }
snafu = { workspace = true }
half = { workspace = true }
lazy_static.workspace = true
lance = { workspace = true }
lance-core = { workspace = true }
lance-datafusion.workspace = true
lance-datagen = { workspace = true }
lance-file = { workspace = true }
lance-io = { workspace = true }
lance-index = { workspace = true }
lance-table = { workspace = true }
@@ -51,13 +46,11 @@ bytes = "1"
futures.workspace = true
num-traits.workspace = true
url.workspace = true
rand.workspace = true
regex.workspace = true
serde = { version = "^1" }
serde_json = { version = "1" }
async-openai = { version = "0.20.0", optional = true }
serde_with = { version = "3.8.1" }
tempfile = "3.5.0"
aws-sdk-bedrockruntime = { version = "1.27.0", optional = true }
# For remote feature
reqwest = { version = "0.12.0", default-features = false, features = [
@@ -68,8 +61,9 @@ reqwest = { version = "0.12.0", default-features = false, features = [
"macos-system-configuration",
"stream",
], optional = true }
rand = { version = "0.9", features = ["small_rng"], optional = true }
http = { version = "1", optional = true } # Matching what is in reqwest
uuid = { version = "1.7.0", features = ["v4"] }
uuid = { version = "1.7.0", features = ["v4"], optional = true }
polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
polars = { version = ">=0.37,<0.40.0", optional = true }
hf-hub = { version = "0.4.1", optional = true, default-features = false, features = [
@@ -90,6 +84,7 @@ bytemuck_derive.workspace = true
[dev-dependencies]
anyhow = "1"
tempfile = "3.5.0"
rand = { version = "0.9", features = ["small_rng"] }
random_word = { version = "0.4.3", features = ["en"] }
uuid = { version = "1.7.0", features = ["v4"] }
walkdir = "2"
@@ -101,7 +96,6 @@ aws-smithy-runtime = { version = "1.9.1" }
datafusion.workspace = true
http-body = "1" # Matching reqwest
rstest = "0.23.0"
test-log = "0.2"
[features]
@@ -111,7 +105,7 @@ oss = ["lance/oss", "lance-io/oss"]
gcs = ["lance/gcp", "lance-io/gcp"]
azure = ["lance/azure", "lance-io/azure"]
dynamodb = ["lance/dynamodb", "aws"]
remote = ["dep:reqwest", "dep:http"]
remote = ["dep:reqwest", "dep:http", "dep:rand", "dep:uuid"]
fp16kernels = ["lance-linalg/fp16kernels"]
s3-test = []
bedrock = ["dep:aws-sdk-bedrockruntime"]

View File

@@ -7,7 +7,6 @@ pub use arrow_schema;
use datafusion_common::DataFusionError;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use futures::{Stream, StreamExt, TryStreamExt};
use lance_datagen::{BatchCount, BatchGeneratorBuilder, RowCount};
#[cfg(feature = "polars")]
use {crate::polars_arrow_convertors, polars::frame::ArrowChunk, polars::prelude::DataFrame};
@@ -162,26 +161,6 @@ impl IntoArrowStream for datafusion_physical_plan::SendableRecordBatchStream {
}
}
pub trait LanceDbDatagenExt {
fn into_ldb_stream(
self,
batch_size: RowCount,
num_batches: BatchCount,
) -> SendableRecordBatchStream;
}
impl LanceDbDatagenExt for BatchGeneratorBuilder {
fn into_ldb_stream(
self,
batch_size: RowCount,
num_batches: BatchCount,
) -> SendableRecordBatchStream {
let (stream, schema) = self.into_reader_stream(batch_size, num_batches);
let stream = stream.map_err(|err| Error::Arrow { source: err });
Box::pin(SimpleRecordBatchStream::new(stream, schema))
}
}
#[cfg(feature = "polars")]
/// An iterator of record batches formed from a Polars DataFrame.
pub struct PolarsDataFrameRecordBatchReader {

View File

@@ -19,7 +19,7 @@ use crate::database::listing::{
use crate::database::{
CloneTableRequest, CreateNamespaceRequest, CreateTableData, CreateTableMode,
CreateTableRequest, Database, DatabaseOptions, DropNamespaceRequest, ListNamespacesRequest,
OpenTableRequest, ReadConsistency, TableNamesRequest,
OpenTableRequest, TableNamesRequest,
};
use crate::embeddings::{
EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry, MemoryRegistry, WithEmbeddings,
@@ -152,7 +152,6 @@ impl CreateTableBuilder<true> {
let request = self.into_request()?;
Ok(Table::new_with_embedding_registry(
parent.create_table(request).await?,
parent,
embedding_registry,
))
}
@@ -212,9 +211,9 @@ impl CreateTableBuilder<false> {
/// Execute the create table operation
pub async fn execute(self) -> Result<Table> {
let parent = self.parent.clone();
let table = parent.create_table(self.request).await?;
Ok(Table::new(table, parent))
Ok(Table::new(
self.parent.clone().create_table(self.request).await?,
))
}
}
@@ -463,10 +462,8 @@ impl OpenTableBuilder {
/// Open the table
pub async fn execute(self) -> Result<Table> {
let table = self.parent.open_table(self.request).await?;
Ok(Table::new_with_embedding_registry(
table,
self.parent,
self.parent.clone().open_table(self.request).await?,
self.embedding_registry,
))
}
@@ -522,15 +519,16 @@ impl CloneTableBuilder {
/// Execute the clone operation
pub async fn execute(self) -> Result<Table> {
let parent = self.parent.clone();
let table = parent.clone_table(self.request).await?;
Ok(Table::new(table, parent))
Ok(Table::new(
self.parent.clone().clone_table(self.request).await?,
))
}
}
/// A connection to LanceDB
#[derive(Clone)]
pub struct Connection {
uri: String,
internal: Arc<dyn Database>,
embedding_registry: Arc<dyn EmbeddingRegistry>,
}
@@ -542,19 +540,9 @@ impl std::fmt::Display for Connection {
}
impl Connection {
pub fn new(
internal: Arc<dyn Database>,
embedding_registry: Arc<dyn EmbeddingRegistry>,
) -> Self {
Self {
internal,
embedding_registry,
}
}
/// Get the URI of the connection
pub fn uri(&self) -> &str {
self.internal.uri()
self.uri.as_str()
}
/// Get access to the underlying database
@@ -687,11 +675,6 @@ impl Connection {
.await
}
/// Get the read consistency of the connection
pub async fn read_consistency(&self) -> Result<ReadConsistency> {
self.internal.read_consistency().await
}
/// Drop a table in the database.
///
/// # Arguments
@@ -990,6 +973,7 @@ impl ConnectBuilder {
)?);
Ok(Connection {
internal,
uri: self.request.uri,
embedding_registry: self
.embedding_registry
.unwrap_or_else(|| Arc::new(MemoryRegistry::new())),
@@ -1012,6 +996,7 @@ impl ConnectBuilder {
let internal = Arc::new(ListingDatabase::connect_with_options(&self.request).await?);
Ok(Connection {
internal,
uri: self.request.uri,
embedding_registry: self
.embedding_registry
.unwrap_or_else(|| Arc::new(MemoryRegistry::new())),
@@ -1119,6 +1104,7 @@ impl ConnectNamespaceBuilder {
Ok(Connection {
internal,
uri: format!("namespace://{}", self.ns_impl),
embedding_registry: self
.embedding_registry
.unwrap_or_else(|| Arc::new(MemoryRegistry::new())),
@@ -1153,6 +1139,7 @@ mod test_utils {
let internal = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
Self {
internal,
uri: "db://test".to_string(),
embedding_registry: Arc::new(MemoryRegistry::new()),
}
}
@@ -1169,6 +1156,7 @@ mod test_utils {
));
Self {
internal,
uri: "db://test".to_string(),
embedding_registry: Arc::new(MemoryRegistry::new()),
}
}
@@ -1199,7 +1187,7 @@ mod tests {
#[tokio::test]
async fn test_connect() {
let tc = new_test_connection().await.unwrap();
assert_eq!(tc.connection.uri(), tc.uri);
assert_eq!(tc.connection.uri, tc.uri);
}
#[cfg(not(windows))]
@@ -1220,7 +1208,7 @@ mod tests {
.await
.unwrap();
assert_eq!(db.uri(), relative_uri.to_str().unwrap().to_string());
assert_eq!(db.uri, relative_uri.to_str().unwrap().to_string());
}
#[tokio::test]

View File

@@ -52,13 +52,13 @@ pub fn infer_vector_columns(
for field in reader.schema().fields() {
match field.data_type() {
DataType::FixedSizeList(sub_field, _) if sub_field.data_type().is_floating() => {
columns.push(field.name().clone());
columns.push(field.name().to_string());
}
DataType::List(sub_field) if sub_field.data_type().is_floating() && !strict => {
columns_to_infer.insert(field.name().clone(), None);
columns_to_infer.insert(field.name().to_string(), None);
}
DataType::LargeList(sub_field) if sub_field.data_type().is_floating() && !strict => {
columns_to_infer.insert(field.name().clone(), None);
columns_to_infer.insert(field.name().to_string(), None);
}
_ => {}
}

View File

@@ -16,7 +16,6 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use arrow_array::RecordBatchReader;
use async_trait::async_trait;
@@ -214,20 +213,6 @@ impl CloneTableRequest {
}
}
/// How long until a change is reflected from one Table instance to another
///
/// Tables are always internally consistent. If a write method is called on
/// a table instance it will be immediately visible in that same table instance.
pub enum ReadConsistency {
/// Changes will not be automatically propagated until the checkout_latest
/// method is called on the target table
Manual,
/// Changes will be propagated automatically within the given duration
Eventual(Duration),
/// Changes are immediately visible in target tables
Strong,
}
/// The `Database` trait defines the interface for database implementations.
///
/// A database is responsible for managing tables and their metadata.
@@ -235,10 +220,6 @@ pub enum ReadConsistency {
pub trait Database:
Send + Sync + std::any::Any + std::fmt::Debug + std::fmt::Display + 'static
{
/// Get the uri of the database
fn uri(&self) -> &str;
/// Get the read consistency of the database
async fn read_consistency(&self) -> Result<ReadConsistency>;
/// List immediate child namespace names in the given namespace
async fn list_namespaces(&self, request: ListNamespacesRequest) -> Result<Vec<String>>;
/// Create a new namespace

View File

@@ -17,7 +17,6 @@ use object_store::local::LocalFileSystem;
use snafu::ResultExt;
use crate::connection::ConnectRequest;
use crate::database::ReadConsistency;
use crate::error::{CreateDirSnafu, Error, Result};
use crate::io::object_store::MirroringObjectStoreWrapper;
use crate::table::NativeTable;
@@ -599,22 +598,6 @@ impl Database for ListingDatabase {
Ok(Vec::new())
}
fn uri(&self) -> &str {
&self.uri
}
async fn read_consistency(&self) -> Result<ReadConsistency> {
if let Some(read_consistency_inverval) = self.read_consistency_interval {
if read_consistency_inverval.is_zero() {
Ok(ReadConsistency::Strong)
} else {
Ok(ReadConsistency::Eventual(read_consistency_inverval))
}
} else {
Ok(ReadConsistency::Manual)
}
}
async fn create_namespace(&self, _request: CreateNamespaceRequest) -> Result<()> {
Err(Error::NotSupported {
message: "Namespace operations are not supported for listing database".into(),
@@ -735,9 +718,9 @@ impl Database for ListingDatabase {
.map_err(|e| Error::Lance { source: e })?;
let version_ref = match (request.source_version, request.source_tag) {
(Some(v), None) => Ok(Ref::Version(None, Some(v))),
(Some(v), None) => Ok(Ref::Version(v)),
(None, Some(tag)) => Ok(Ref::Tag(tag)),
(None, None) => Ok(Ref::Version(None, Some(source_dataset.version().version))),
(None, None) => Ok(Ref::Version(source_dataset.version().version)),
_ => Err(Error::InvalidInput {
message: "Cannot specify both source_version and source_tag".to_string(),
}),
@@ -1266,8 +1249,7 @@ mod tests {
)
.unwrap();
let db = Arc::new(db);
let source_table_obj = Table::new(source_table.clone(), db.clone());
let source_table_obj = Table::new(source_table.clone());
source_table_obj
.add(Box::new(arrow_array::RecordBatchIterator::new(
vec![Ok(batch2)],
@@ -1338,8 +1320,7 @@ mod tests {
.unwrap();
// Create a tag for the current version
let db = Arc::new(db);
let source_table_obj = Table::new(source_table.clone(), db.clone());
let source_table_obj = Table::new(source_table.clone());
let mut tags = source_table_obj.tags().await.unwrap();
tags.create("v1.0", source_table.version().await.unwrap())
.await
@@ -1355,7 +1336,7 @@ mod tests {
)
.unwrap();
let source_table_obj = Table::new(source_table.clone(), db.clone());
let source_table_obj = Table::new(source_table.clone());
source_table_obj
.add(Box::new(arrow_array::RecordBatchIterator::new(
vec![Ok(batch2)],
@@ -1451,8 +1432,7 @@ mod tests {
)
.unwrap();
let db = Arc::new(db);
let cloned_table_obj = Table::new(cloned_table.clone(), db.clone());
let cloned_table_obj = Table::new(cloned_table.clone());
cloned_table_obj
.add(Box::new(arrow_array::RecordBatchIterator::new(
vec![Ok(batch_clone)],
@@ -1472,7 +1452,7 @@ mod tests {
)
.unwrap();
let source_table_obj = Table::new(source_table.clone(), db);
let source_table_obj = Table::new(source_table.clone());
source_table_obj
.add(Box::new(arrow_array::RecordBatchIterator::new(
vec![Ok(batch_source)],
@@ -1515,7 +1495,6 @@ mod tests {
.unwrap();
// Add more data to create new versions
let db = Arc::new(db);
for i in 0..3 {
let batch = RecordBatch::try_new(
schema.clone(),
@@ -1523,7 +1502,7 @@ mod tests {
)
.unwrap();
let source_table_obj = Table::new(source_table.clone(), db.clone());
let source_table_obj = Table::new(source_table.clone());
source_table_obj
.add(Box::new(arrow_array::RecordBatchIterator::new(
vec![Ok(batch)],

View File

@@ -16,9 +16,9 @@ use lance_namespace::{
LanceNamespace,
};
use crate::connection::ConnectRequest;
use crate::database::listing::ListingDatabase;
use crate::error::{Error, Result};
use crate::{connection::ConnectRequest, database::ReadConsistency};
use super::{
BaseTable, CloneTableRequest, CreateNamespaceRequest as DbCreateNamespaceRequest,
@@ -36,8 +36,6 @@ pub struct LanceNamespaceDatabase {
read_consistency_interval: Option<std::time::Duration>,
// Optional session for object stores and caching
session: Option<Arc<lance::session::Session>>,
// database URI
uri: String,
}
impl LanceNamespaceDatabase {
@@ -59,7 +57,6 @@ impl LanceNamespaceDatabase {
storage_options,
read_consistency_interval,
session,
uri: format!("namespace://{}", ns_impl),
})
}
@@ -133,22 +130,6 @@ impl std::fmt::Display for LanceNamespaceDatabase {
#[async_trait]
impl Database for LanceNamespaceDatabase {
fn uri(&self) -> &str {
&self.uri
}
async fn read_consistency(&self) -> Result<ReadConsistency> {
if let Some(read_consistency_inverval) = self.read_consistency_interval {
if read_consistency_inverval.is_zero() {
Ok(ReadConsistency::Strong)
} else {
Ok(ReadConsistency::Eventual(read_consistency_inverval))
}
} else {
Ok(ReadConsistency::Manual)
}
}
async fn list_namespaces(&self, request: DbListNamespacesRequest) -> Result<Vec<String>> {
let ns_request = ListNamespacesRequest {
id: if request.namespace.is_empty() {
@@ -280,7 +261,7 @@ impl Database for LanceNamespaceDatabase {
return listing_db
.open_table(OpenTableRequest {
name: request.name.clone(),
namespace: vec![],
namespace: request.namespace.clone(),
index_cache_size: None,
lance_read_params: None,
})
@@ -324,14 +305,7 @@ impl Database for LanceNamespaceDatabase {
)
.await?;
let create_request = DbCreateTableRequest {
name: request.name,
namespace: vec![],
data: request.data,
mode: request.mode,
write_options: request.write_options,
};
listing_db.create_table(create_request).await
listing_db.create_table(request).await
}
async fn open_table(&self, request: OpenTableRequest) -> Result<Arc<dyn BaseTable>> {
@@ -358,13 +332,7 @@ impl Database for LanceNamespaceDatabase {
.create_listing_database(&request.name, &location, response.storage_options)
.await?;
let open_request = OpenTableRequest {
name: request.name.clone(),
namespace: vec![],
index_cache_size: request.index_cache_size,
lance_read_params: request.lance_read_params,
};
listing_db.open_table(open_request).await
listing_db.open_table(request).await
}
async fn clone_table(&self, _request: CloneTableRequest) -> Result<Arc<dyn BaseTable>> {

View File

@@ -1,7 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
pub mod permutation;
pub mod shuffle;
pub mod split;
pub mod util;

View File

@@ -1,294 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! Contains the [PermutationBuilder] to create a permutation "view" of an existing table.
//!
//! A permutation view can apply a filter, divide the data into splits, and shuffle the data.
//! The permutation table only stores the split ids and row ids. It is not a materialized copy of
//! the underlying data and can be very lightweight.
//!
//! Building a permutation table should be fairly quick and memory efficient, even for billions or
//! trillions of rows.
use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion_execution::{disk_manager::DiskManagerBuilder, runtime_env::RuntimeEnvBuilder};
use datafusion_expr::col;
use futures::TryStreamExt;
use lance_datafusion::exec::SessionContextExt;
use crate::{
arrow::{SendableRecordBatchStream, SendableRecordBatchStreamExt, SimpleRecordBatchStream},
dataloader::{
shuffle::{Shuffler, ShufflerConfig},
split::{SplitStrategy, Splitter, SPLIT_ID_COLUMN},
util::{rename_column, TemporaryDirectory},
},
query::{ExecutableQuery, QueryBase},
Connection, Error, Result, Table,
};
/// Configuration for creating a permutation table
#[derive(Debug, Default)]
pub struct PermutationConfig {
/// Splitting configuration
pub split_strategy: SplitStrategy,
/// Shuffle strategy
pub shuffle_strategy: ShuffleStrategy,
/// Optional filter to apply to the base table
pub filter: Option<String>,
/// Directory to use for temporary files
pub temp_dir: TemporaryDirectory,
}
/// Strategy for shuffling the data.
#[derive(Debug, Clone)]
pub enum ShuffleStrategy {
/// The data is randomly shuffled
///
/// A seed can be provided to make the shuffle deterministic.
///
/// If a clump size is provided, then data will be shuffled in small blocks of contiguous rows.
/// This decreases the overall randomization but can improve I/O performance when reading from
/// cloud storage.
///
/// For example, a clump size of 16 will means we will shuffle blocks of 16 contiguous rows. This
/// will mean 16x fewer IOPS but these 16 rows will always be close together and this can influence
/// the performance of the model. Note: shuffling within clumps can still be done at read time but
/// this will only provide a local shuffle and not a global shuffle.
Random {
seed: Option<u64>,
clump_size: Option<u64>,
},
/// The data is not shuffled
///
/// This is useful for debugging and testing.
None,
}
impl Default for ShuffleStrategy {
fn default() -> Self {
Self::None
}
}
/// Builder for creating a permutation table.
///
/// A permutation table is a table that stores split assignments and a shuffled order of rows. This
/// can be used to create a
pub struct PermutationBuilder {
config: PermutationConfig,
base_table: Table,
}
impl PermutationBuilder {
pub fn new(base_table: Table) -> Self {
Self {
config: PermutationConfig::default(),
base_table,
}
}
/// Configures the strategy for assigning rows to splits.
///
/// For example, it is common to create a test/train split of the data. Splits can also be used
/// to limit the number of rows. For example, to only use 10% of the data in a permutation you can
/// create a single split with 10% of the data.
///
/// Splits are _not_ required for parallel processing. A single split can be loaded in parallel across
/// multiple processes and multiple nodes.
///
/// The default is a single split that contains all rows.
pub fn with_split_strategy(mut self, split_strategy: SplitStrategy) -> Self {
self.config.split_strategy = split_strategy;
self
}
/// Configures the strategy for shuffling the data.
///
/// The default is to shuffle the data randomly at row-level granularity (no shard size) and
/// with a random seed.
pub fn with_shuffle_strategy(mut self, shuffle_strategy: ShuffleStrategy) -> Self {
self.config.shuffle_strategy = shuffle_strategy;
self
}
/// Configures a filter to apply to the base table.
///
/// Only rows matching the filter will be included in the permutation.
pub fn with_filter(mut self, filter: String) -> Self {
self.config.filter = Some(filter);
self
}
/// Configures the directory to use for temporary files.
///
/// The default is to use the operating system's default temporary directory.
pub fn with_temp_dir(mut self, temp_dir: TemporaryDirectory) -> Self {
self.config.temp_dir = temp_dir;
self
}
async fn sort_by_split_id(
&self,
data: SendableRecordBatchStream,
) -> Result<SendableRecordBatchStream> {
let ctx = SessionContext::new_with_config_rt(
SessionConfig::default(),
RuntimeEnvBuilder::new()
.with_memory_limit(100 * 1024 * 1024, 1.0)
.with_disk_manager_builder(
DiskManagerBuilder::default()
.with_mode(self.config.temp_dir.to_disk_manager_mode()),
)
.build_arc()
.unwrap(),
);
let df = ctx
.read_one_shot(data.into_df_stream())
.map_err(|e| Error::Other {
message: format!("Failed to setup sort by split id: {}", e),
source: Some(e.into()),
})?;
let df_stream = df
.sort_by(vec![col(SPLIT_ID_COLUMN)])
.map_err(|e| Error::Other {
message: format!("Failed to plan sort by split id: {}", e),
source: Some(e.into()),
})?
.execute_stream()
.await
.map_err(|e| Error::Other {
message: format!("Failed to sort by split id: {}", e),
source: Some(e.into()),
})?;
let schema = df_stream.schema();
let stream = df_stream.map_err(|e| Error::Other {
message: format!("Failed to execute sort by split id: {}", e),
source: Some(e.into()),
});
Ok(Box::pin(SimpleRecordBatchStream { schema, stream }))
}
/// Builds the permutation table and stores it in the given database.
pub async fn build(self, dest_table_name: &str) -> Result<Table> {
// First pass, apply filter and load row ids
let mut rows = self.base_table.query().with_row_id();
if let Some(filter) = &self.config.filter {
rows = rows.only_if(filter);
}
let splitter = Splitter::new(
self.config.temp_dir.clone(),
self.config.split_strategy.clone(),
);
let mut needs_sort = !splitter.orders_by_split_id();
// Might need to load additional columns to calculate splits (e.g. hash columns or calculated
// split id)
rows = splitter.project(rows);
let num_rows = self
.base_table
.count_rows(self.config.filter.clone())
.await? as u64;
// Apply splits
let rows = rows.execute().await?;
let split_data = splitter.apply(rows, num_rows).await?;
// Shuffle data if requested
let shuffled = match self.config.shuffle_strategy {
ShuffleStrategy::None => split_data,
ShuffleStrategy::Random { seed, clump_size } => {
let shuffler = Shuffler::new(ShufflerConfig {
seed,
clump_size,
temp_dir: self.config.temp_dir.clone(),
max_rows_per_file: 10 * 1024 * 1024,
});
shuffler.shuffle(split_data, num_rows).await?
}
};
// We want the final permutation to be sorted by the split id. If we shuffled or if
// the split was not assigned sequentially then we need to sort the data.
needs_sort |= !matches!(self.config.shuffle_strategy, ShuffleStrategy::None);
let sorted = if needs_sort {
self.sort_by_split_id(shuffled).await?
} else {
shuffled
};
// Rename _rowid to row_id
let renamed = rename_column(sorted, "_rowid", "row_id")?;
// Create permutation table
let conn = Connection::new(
self.base_table.database().clone(),
self.base_table.embedding_registry().clone(),
);
conn.create_table_streaming(dest_table_name, renamed)
.execute()
.await
}
}
#[cfg(test)]
mod tests {
use arrow::datatypes::Int32Type;
use lance_datagen::{BatchCount, RowCount};
use crate::{arrow::LanceDbDatagenExt, connect, dataloader::split::SplitSizes};
use super::*;
#[tokio::test]
async fn test_permutation_builder() {
let temp_dir = tempfile::tempdir().unwrap();
let db = connect(temp_dir.path().to_str().unwrap())
.execute()
.await
.unwrap();
let initial_data = lance_datagen::gen_batch()
.col("some_value", lance_datagen::array::step::<Int32Type>())
.into_ldb_stream(RowCount::from(100), BatchCount::from(10));
let data_table = db
.create_table_streaming("mytbl", initial_data)
.execute()
.await
.unwrap();
let permutation_table = PermutationBuilder::new(data_table)
.with_filter("some_value > 57".to_string())
.with_split_strategy(SplitStrategy::Random {
seed: Some(42),
sizes: SplitSizes::Percentages(vec![0.05, 0.30]),
})
.build("permutation")
.await
.unwrap();
// Potentially brittle seed-dependent values below
assert_eq!(permutation_table.count_rows(None).await.unwrap(), 330);
assert_eq!(
permutation_table
.count_rows(Some("split_id = 0".to_string()))
.await
.unwrap(),
47
);
assert_eq!(
permutation_table
.count_rows(Some("split_id = 1".to_string()))
.await
.unwrap(),
283
);
}
}

View File

@@ -1,475 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::{Arc, Mutex};
use arrow::compute::concat_batches;
use arrow_array::{RecordBatch, UInt64Array};
use futures::{StreamExt, TryStreamExt};
use lance::io::ObjectStore;
use lance_core::{cache::LanceCache, utils::futures::FinallyStreamExt};
use lance_encoding::decoder::DecoderPlugins;
use lance_file::v2::{
reader::{FileReader, FileReaderOptions},
writer::{FileWriter, FileWriterOptions},
};
use lance_index::scalar::IndexReader;
use lance_io::{
scheduler::{ScanScheduler, SchedulerConfig},
utils::CachedFileSize,
};
use rand::{seq::SliceRandom, Rng, RngCore};
use crate::{
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
dataloader::util::{non_crypto_rng, TemporaryDirectory},
Error, Result,
};
#[derive(Debug, Clone)]
pub struct ShufflerConfig {
/// An optional seed to make the shuffle deterministic
pub seed: Option<u64>,
/// The maximum number of rows to write to a single file
///
/// The shuffler will need to hold at least this many rows in memory. Setting this value
/// extremely large could cause the shuffler to use a lot of memory (depending on row size).
///
/// However, the shuffler will also need to hold total_num_rows / max_rows_per_file file
/// writers in memory. Each of these will consume some amount of data for column write buffers.
/// So setting this value too small could _also_ cause the shuffler to use a lot of memory and
/// open file handles.
pub max_rows_per_file: u64,
/// The temporary directory to use for writing files
pub temp_dir: TemporaryDirectory,
/// The size of the clumps to shuffle within
///
/// If a clump size is provided, then data will be shuffled in small blocks of contiguous rows.
/// This decreases the overall randomization but can improve I/O performance when reading from
/// cloud storage.
pub clump_size: Option<u64>,
}
impl Default for ShufflerConfig {
fn default() -> Self {
Self {
max_rows_per_file: 1024 * 1024,
seed: Option::default(),
temp_dir: TemporaryDirectory::default(),
clump_size: None,
}
}
}
/// A shuffler that can shuffle a stream of record batches
///
/// To do this the stream is consumed and written to temporary files. A new stream is returned
/// which returns the shuffled data from the temporary files.
///
/// If there are fewer than max_rows_per_file rows in the input stream, then the shuffler will not
/// write any files and will instead perform an in-memory shuffle.
///
/// The number of rows in the input stream must be known in advance.
pub struct Shuffler {
config: ShufflerConfig,
id: String,
}
impl Shuffler {
pub fn new(config: ShufflerConfig) -> Self {
let id = uuid::Uuid::new_v4().to_string();
Self { config, id }
}
/// Shuffles a single batch of data in memory
fn shuffle_batch(
batch: &RecordBatch,
rng: &mut dyn RngCore,
clump_size: u64,
) -> Result<RecordBatch> {
let num_clumps = (batch.num_rows() as u64).div_ceil(clump_size);
let mut indices = (0..num_clumps).collect::<Vec<_>>();
indices.shuffle(rng);
let indices = if clump_size == 1 {
UInt64Array::from(indices)
} else {
UInt64Array::from_iter_values(indices.iter().flat_map(|&clump_index| {
if clump_index == num_clumps - 1 {
clump_index * clump_size..batch.num_rows() as u64
} else {
clump_index * clump_size..(clump_index + 1) * clump_size
}
}))
};
Ok(arrow::compute::take_record_batch(batch, &indices)?)
}
async fn in_memory_shuffle(
&self,
data: SendableRecordBatchStream,
mut rng: Box<dyn RngCore + Send>,
) -> Result<SendableRecordBatchStream> {
let schema = data.schema();
let batches = data.try_collect::<Vec<_>>().await?;
let batch = concat_batches(&schema, &batches)?;
let shuffled = Self::shuffle_batch(&batch, &mut rng, self.config.clump_size.unwrap_or(1))?;
log::debug!("Shuffle job {}: in-memory shuffle complete", self.id);
Ok(Box::pin(SimpleRecordBatchStream::new(
futures::stream::once(async move { Ok(shuffled) }),
schema,
)))
}
async fn do_shuffle(
&self,
mut data: SendableRecordBatchStream,
num_rows: u64,
mut rng: Box<dyn RngCore + Send>,
) -> Result<SendableRecordBatchStream> {
let num_files = num_rows.div_ceil(self.config.max_rows_per_file);
let temp_dir = self.config.temp_dir.create_temp_dir()?;
let tmp_dir = temp_dir.path().to_path_buf();
let clump_size = self.config.clump_size.unwrap_or(1);
if clump_size == 0 {
return Err(Error::InvalidInput {
message: "clump size must be greater than 0".to_string(),
});
}
let object_store = ObjectStore::local();
let arrow_schema = data.schema();
let schema = lance::datatypes::Schema::try_from(arrow_schema.as_ref())?;
// Create file writers
let mut file_writers = Vec::with_capacity(num_files as usize);
for file_index in 0..num_files {
let path = tmp_dir.join(format!("shuffle_{}_{file_index}.lance", self.id));
let path =
object_store::path::Path::from_absolute_path(path).map_err(|err| Error::Other {
message: format!("Failed to create temporary file: {}", err),
source: None,
})?;
let object_writer = object_store.create(&path).await?;
let writer =
FileWriter::try_new(object_writer, schema.clone(), FileWriterOptions::default())?;
file_writers.push(writer);
}
let mut num_rows_seen = 0;
// Randomly distribute clumps to files
while let Some(batch) = data.try_next().await? {
num_rows_seen += batch.num_rows() as u64;
let is_last = num_rows_seen == num_rows;
if num_rows_seen > num_rows {
return Err(Error::Runtime {
message: format!("Expected {} rows but saw {} rows", num_rows, num_rows_seen),
});
}
// This is kind of an annoying limitation but if we allow runt clumps from batches then
// clumps will get unaligned and we will mess up the clumps when we do the in-memory
// shuffle step. If this is a problem we can probably figure out a better way to do this.
if !is_last && batch.num_rows() as u64 % clump_size != 0 {
return Err(Error::Runtime {
message: format!(
"Expected batch size ({}) to be divisible by clump size ({})",
batch.num_rows(),
clump_size
),
});
}
let num_clumps = (batch.num_rows() as u64).div_ceil(clump_size);
let mut batch_offsets_for_files =
vec![Vec::<u64>::with_capacity(batch.num_rows()); num_files as usize];
// Partition the batch randomly and write to the appropriate accumulator
for clump_offset in 0..num_clumps {
let clump_start = clump_offset * clump_size;
let num_rows_in_clump = clump_size.min(batch.num_rows() as u64 - clump_start);
let clump_end = clump_start + num_rows_in_clump;
let file_index = rng.random_range(0..num_files);
batch_offsets_for_files[file_index as usize].extend(clump_start..clump_end);
}
for (file_index, batch_offsets) in batch_offsets_for_files.into_iter().enumerate() {
if batch_offsets.is_empty() {
continue;
}
let indices = UInt64Array::from(batch_offsets);
let partition = arrow::compute::take_record_batch(&batch, &indices)?;
file_writers[file_index].write_batch(&partition).await?;
}
}
// Finish writing files
for (file_idx, mut writer) in file_writers.into_iter().enumerate() {
let num_written = writer.finish().await?;
log::debug!(
"Shuffle job {}: wrote {} rows to file {}",
self.id,
num_written,
file_idx
);
}
let scheduler_config = SchedulerConfig::max_bandwidth(&object_store);
let scan_scheduler = ScanScheduler::new(Arc::new(object_store), scheduler_config);
let job_id = self.id.clone();
let rng = Arc::new(Mutex::new(rng));
// Second pass, read each file as a single batch and shuffle
let stream = futures::stream::iter(0..num_files)
.then(move |file_index| {
let scan_scheduler = scan_scheduler.clone();
let rng = rng.clone();
let tmp_dir = tmp_dir.clone();
let job_id = job_id.clone();
async move {
let path = tmp_dir.join(format!("shuffle_{}_{file_index}.lance", job_id));
let path = object_store::path::Path::from_absolute_path(path).unwrap();
let file_scheduler = scan_scheduler
.open_file(&path, &CachedFileSize::unknown())
.await?;
let reader = FileReader::try_open(
file_scheduler,
None,
Arc::<DecoderPlugins>::default(),
&LanceCache::no_cache(),
FileReaderOptions::default(),
)
.await?;
// Need to read the entire file in a single batch for in-memory shuffling
let batch = reader.read_record_batch(0, reader.num_rows()).await?;
let mut rng = rng.lock().unwrap();
Self::shuffle_batch(&batch, &mut rng, clump_size)
}
})
.finally(move || drop(temp_dir))
.boxed();
Ok(Box::pin(SimpleRecordBatchStream::new(stream, arrow_schema)))
}
pub async fn shuffle(
self,
data: SendableRecordBatchStream,
num_rows: u64,
) -> Result<SendableRecordBatchStream> {
log::debug!(
"Shuffle job {}: shuffling {} rows and {} columns",
self.id,
num_rows,
data.schema().fields.len()
);
let rng = non_crypto_rng(&self.config.seed);
if num_rows < self.config.max_rows_per_file {
return self.in_memory_shuffle(data, rng).await;
}
self.do_shuffle(data, num_rows, rng).await
}
}
#[cfg(test)]
mod tests {
use crate::arrow::LanceDbDatagenExt;
use super::*;
use arrow::{array::AsArray, datatypes::Int32Type};
use datafusion::prelude::SessionContext;
use datafusion_expr::col;
use futures::TryStreamExt;
use lance_datagen::{BatchCount, BatchGeneratorBuilder, ByteCount, RowCount, Seed};
use rand::{rngs::SmallRng, SeedableRng};
fn test_gen() -> BatchGeneratorBuilder {
lance_datagen::gen_batch()
.with_seed(Seed::from(42))
.col("id", lance_datagen::array::step::<Int32Type>())
.col(
"name",
lance_datagen::array::rand_utf8(ByteCount::from(10), false),
)
}
fn create_test_batch(size: RowCount) -> RecordBatch {
test_gen().into_batch_rows(size).unwrap()
}
fn create_test_stream(
num_batches: BatchCount,
batch_size: RowCount,
) -> SendableRecordBatchStream {
test_gen().into_ldb_stream(batch_size, num_batches)
}
#[test]
fn test_shuffle_batch_deterministic() {
let batch = create_test_batch(RowCount::from(10));
let mut rng1 = SmallRng::seed_from_u64(42);
let mut rng2 = SmallRng::seed_from_u64(42);
let shuffled1 = Shuffler::shuffle_batch(&batch, &mut rng1, 1).unwrap();
let shuffled2 = Shuffler::shuffle_batch(&batch, &mut rng2, 1).unwrap();
// Same seed should produce same shuffle
assert_eq!(shuffled1, shuffled2);
}
#[test]
fn test_shuffle_with_clumps() {
let batch = create_test_batch(RowCount::from(10));
let mut rng = SmallRng::seed_from_u64(42);
let shuffled = Shuffler::shuffle_batch(&batch, &mut rng, 3).unwrap();
let values = shuffled.column(0).as_primitive::<Int32Type>();
let mut iter = values.into_iter().map(|o| o.unwrap());
let mut frag_seen = false;
let mut clumps_seen = 0;
while let Some(first) = iter.next() {
// 9 is the last value and not a full clump
if first != 9 {
// Otherwise we should have a full clump
let second = iter.next().unwrap();
let third = iter.next().unwrap();
assert_eq!(first + 1, second);
assert_eq!(first + 2, third);
clumps_seen += 1;
} else {
frag_seen = true;
}
}
assert_eq!(clumps_seen, 3);
assert!(frag_seen);
}
async fn sort_batch(batch: RecordBatch) -> RecordBatch {
let ctx = SessionContext::new();
let df = ctx.read_batch(batch).unwrap();
let sorted = df.sort_by(vec![col("id")]).unwrap();
let batches = sorted.collect().await.unwrap();
let schema = batches[0].schema();
concat_batches(&schema, &batches).unwrap()
}
#[tokio::test]
async fn test_shuffle_batch_preserves_data() {
let batch = create_test_batch(RowCount::from(100));
let mut rng = SmallRng::seed_from_u64(42);
let shuffled = Shuffler::shuffle_batch(&batch, &mut rng, 1).unwrap();
assert_ne!(shuffled, batch);
let sorted = sort_batch(shuffled).await;
assert_eq!(sorted, batch);
}
#[test]
fn test_shuffle_batch_empty() {
let batch = create_test_batch(RowCount::from(0));
let mut rng = SmallRng::seed_from_u64(42);
let shuffled = Shuffler::shuffle_batch(&batch, &mut rng, 1).unwrap();
assert_eq!(shuffled.num_rows(), 0);
}
#[tokio::test]
async fn test_in_memory_shuffle() {
let config = ShufflerConfig {
temp_dir: TemporaryDirectory::None,
..Default::default()
};
let shuffler = Shuffler::new(config);
let stream = create_test_stream(BatchCount::from(5), RowCount::from(20));
let result_stream = shuffler.shuffle(stream, 100).await.unwrap();
let result_batches: Vec<RecordBatch> = result_stream.try_collect().await.unwrap();
assert_eq!(result_batches.len(), 1);
let result_batch = result_batches.into_iter().next().unwrap();
let unshuffled_batches = create_test_stream(BatchCount::from(5), RowCount::from(20))
.try_collect::<Vec<_>>()
.await
.unwrap();
let schema = unshuffled_batches[0].schema();
let unshuffled_batch = concat_batches(&schema, &unshuffled_batches).unwrap();
let sorted = sort_batch(result_batch).await;
assert_eq!(unshuffled_batch, sorted);
}
#[tokio::test]
async fn test_external_shuffle() {
let config = ShufflerConfig {
max_rows_per_file: 100,
..Default::default()
};
let shuffler = Shuffler::new(config);
let stream = create_test_stream(BatchCount::from(5), RowCount::from(1000));
let result_stream = shuffler.shuffle(stream, 5000).await.unwrap();
let result_batches: Vec<RecordBatch> = result_stream.try_collect().await.unwrap();
let unshuffled_batches = create_test_stream(BatchCount::from(5), RowCount::from(1000))
.try_collect::<Vec<_>>()
.await
.unwrap();
let schema = unshuffled_batches[0].schema();
let unshuffled_batch = concat_batches(&schema, &unshuffled_batches).unwrap();
assert_eq!(result_batches.len(), 50);
let result_batch = concat_batches(&schema, &result_batches).unwrap();
let sorted = sort_batch(result_batch).await;
assert_eq!(unshuffled_batch, sorted);
}
#[test_log::test(tokio::test)]
async fn test_external_clump_shuffle() {
let config = ShufflerConfig {
max_rows_per_file: 100,
clump_size: Some(30),
..Default::default()
};
let shuffler = Shuffler::new(config);
// Batch size (900) must be multiple of clump size (30)
let stream = create_test_stream(BatchCount::from(5), RowCount::from(900));
let schema = stream.schema();
// Remove 10 rows from the last batch to simulate ending with partial clump
let mut batches = stream.try_collect::<Vec<_>>().await.unwrap();
let last_index = batches.len() - 1;
let sliced_last = batches[last_index].slice(0, 890);
batches[last_index] = sliced_last;
let stream = Box::pin(SimpleRecordBatchStream::new(
futures::stream::iter(batches).map(Ok).boxed(),
schema.clone(),
));
let result_stream = shuffler.shuffle(stream, 4490).await.unwrap();
let result_batches: Vec<RecordBatch> = result_stream.try_collect().await.unwrap();
let result_batch = concat_batches(&schema, &result_batches).unwrap();
let ids = result_batch.column(0).as_primitive::<Int32Type>();
let mut iter = ids.into_iter().map(|o| o.unwrap());
while let Some(first) = iter.next() {
let rows_left_in_clump = if first == 4470 { 19 } else { 29 };
let mut expected_next = first + 1;
for _ in 0..rows_left_in_clump {
let next = iter.next().unwrap();
assert_eq!(next, expected_next);
expected_next += 1;
}
}
}
}

View File

@@ -1,804 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::{
iter,
sync::{
atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
Arc,
},
};
use arrow_array::{Array, BooleanArray, RecordBatch, UInt64Array};
use arrow_schema::{DataType, Field, Schema};
use datafusion_common::hash_utils::create_hashes;
use futures::{StreamExt, TryStreamExt};
use lance::arrow::SchemaExt;
use crate::{
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
dataloader::{
shuffle::{Shuffler, ShufflerConfig},
util::TemporaryDirectory,
},
query::{Query, QueryBase, Select},
Error, Result,
};
pub const SPLIT_ID_COLUMN: &str = "split_id";
/// Strategy for assigning rows to splits
#[derive(Debug, Clone)]
pub enum SplitStrategy {
/// All rows will have split id 0
NoSplit,
/// Rows will be randomly assigned to splits
///
/// A seed can be provided to make the assignment deterministic.
Random {
seed: Option<u64>,
sizes: SplitSizes,
},
/// Rows will be assigned to splits based on the values in the specified columns.
///
/// This will ensure rows are always assigned to the same split if the given columns do not change.
///
/// The `split_weights` are used to determine the approximate number of rows in each split. This
/// controls how we divide up the u64 hash space. However, it does not guarantee any particular division
/// of rows. For example, if all rows have identical hash values then all rows will be assigned to the same split
/// regardless of the weights.
///
/// The `discard_weight` controls what percentage of rows should be throw away. For example, if you want your
/// first split to have ~5% of your rows and the second split to have ~10% of your rows then you would set
/// split_weights to [1, 2] and discard weight to 17 (or you could set split_weights to [5, 10] and discard_weight
/// to 85). If you set discard_weight to 0 then all rows will be assigned to a split.
Hash {
columns: Vec<String>,
split_weights: Vec<u64>,
discard_weight: u64,
},
/// Rows will be assigned to splits sequentially.
///
/// The first N1 rows are assigned to split 1, the next N2 rows are assigned to split 2, etc.
///
/// This is mainly useful for debugging and testing.
Sequential { sizes: SplitSizes },
/// Rows will be assigned to splits based on a calculation of one or more columns.
///
/// This is useful when the splits already exist in the base table.
///
/// The provided `calculation` should be an SQL statement that returns an integer value between
/// 0 and the number of splits - 1 (the number of splits is defined by the `splits` configuration).
///
/// If this strategy is used then the counts/percentages in the SplitSizes are ignored.
Calculated { calculation: String },
}
// The default is not to split the data
//
// All data will be assigned to a single split.
impl Default for SplitStrategy {
fn default() -> Self {
Self::NoSplit
}
}
impl SplitStrategy {
pub fn validate(&self, num_rows: u64) -> Result<()> {
match self {
Self::NoSplit => Ok(()),
Self::Random { sizes, .. } => sizes.validate(num_rows),
Self::Hash {
split_weights,
columns,
..
} => {
if columns.is_empty() {
return Err(Error::InvalidInput {
message: "Hash strategy requires at least one column".to_string(),
});
}
if split_weights.is_empty() {
return Err(Error::InvalidInput {
message: "Hash strategy requires at least one split weight".to_string(),
});
}
if split_weights.contains(&0) {
return Err(Error::InvalidInput {
message: "Split weights must be greater than 0".to_string(),
});
}
Ok(())
}
Self::Sequential { sizes } => sizes.validate(num_rows),
Self::Calculated { .. } => Ok(()),
}
}
}
pub struct Splitter {
temp_dir: TemporaryDirectory,
strategy: SplitStrategy,
}
impl Splitter {
pub fn new(temp_dir: TemporaryDirectory, strategy: SplitStrategy) -> Self {
Self { temp_dir, strategy }
}
fn sequential_split_id(
num_rows: u64,
split_sizes: &[u64],
split_index: &AtomicUsize,
counter_in_split: &AtomicU64,
exhausted: &AtomicBool,
) -> UInt64Array {
let mut split_ids = Vec::<u64>::with_capacity(num_rows as usize);
while split_ids.len() < num_rows as usize {
let split_id = split_index.load(Ordering::Relaxed);
let counter = counter_in_split.load(Ordering::Relaxed);
let split_size = split_sizes[split_id];
let remaining_in_split = split_size - counter;
let remaining_in_batch = num_rows - split_ids.len() as u64;
let mut done = false;
let rows_to_add = if remaining_in_batch < remaining_in_split {
counter_in_split.fetch_add(remaining_in_batch, Ordering::Relaxed);
remaining_in_batch
} else {
split_index.fetch_add(1, Ordering::Relaxed);
counter_in_split.store(0, Ordering::Relaxed);
if split_id == split_sizes.len() - 1 {
exhausted.store(true, Ordering::Relaxed);
done = true;
}
remaining_in_split
};
split_ids.extend(iter::repeat(split_id as u64).take(rows_to_add as usize));
if done {
// Quit early if we've run out of splits
break;
}
}
UInt64Array::from(split_ids)
}
async fn apply_sequential(
&self,
source: SendableRecordBatchStream,
num_rows: u64,
sizes: &SplitSizes,
) -> Result<SendableRecordBatchStream> {
let split_sizes = sizes.to_counts(num_rows);
let split_index = AtomicUsize::new(0);
let counter_in_split = AtomicU64::new(0);
let exhausted = AtomicBool::new(false);
let schema = source.schema();
let new_schema = Arc::new(schema.try_with_column(Field::new(
SPLIT_ID_COLUMN,
DataType::UInt64,
false,
))?);
let new_schema_clone = new_schema.clone();
let stream = source.filter_map(move |batch| {
let batch = match batch {
Ok(batch) => batch,
Err(e) => {
return std::future::ready(Some(Err(e)));
}
};
if exhausted.load(Ordering::Relaxed) {
return std::future::ready(None);
}
let split_ids = Self::sequential_split_id(
batch.num_rows() as u64,
&split_sizes,
&split_index,
&counter_in_split,
&exhausted,
);
let mut arrays = batch.columns().to_vec();
// This can happen if we exhaust all splits in the middle of a batch
if split_ids.len() < batch.num_rows() {
arrays = arrays
.iter()
.map(|arr| arr.slice(0, split_ids.len()))
.collect();
}
arrays.push(Arc::new(split_ids));
std::future::ready(Some(Ok(
RecordBatch::try_new(new_schema.clone(), arrays).unwrap()
)))
});
Ok(Box::pin(SimpleRecordBatchStream::new(
stream,
new_schema_clone,
)))
}
fn hash_split_id(batch: &RecordBatch, thresholds: &[u64], total_weight: u64) -> UInt64Array {
let arrays = batch
.columns()
.iter()
// Don't hash the last column which should always be the row id
.take(batch.columns().len() - 1)
.cloned()
.collect::<Vec<_>>();
let mut hashes = vec![0; batch.num_rows()];
let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0);
create_hashes(&arrays, &random_state, &mut hashes).unwrap();
// As an example, let's assume the weights are 1, 2. Our total weight is 3.
//
// Our thresholds are [1, 3]
// Our modulo output will be 0, 1, or 2.
//
// thresholds.binary_search(0) => Err(0) => 0
// thresholds.binary_search(1) => Ok(0) => 1
// thresholds.binary_search(2) => Err(1) => 1
let split_ids = hashes
.iter()
.map(|h| {
let h = h % total_weight;
let split_id = match thresholds.binary_search(&h) {
Ok(i) => (i + 1) as u64,
Err(i) => i as u64,
};
if split_id == thresholds.len() as u64 {
// If we're at the last threshold then we discard the row (indicated by setting
// the split_id to null)
None
} else {
Some(split_id)
}
})
.collect::<Vec<_>>();
UInt64Array::from(split_ids)
}
async fn apply_hash(
&self,
source: SendableRecordBatchStream,
weights: &[u64],
discard_weight: u64,
) -> Result<SendableRecordBatchStream> {
let row_id_index = source.schema().fields.len() - 1;
let new_schema = Arc::new(Schema::new(vec![
source.schema().field(row_id_index).clone(),
Field::new(SPLIT_ID_COLUMN, DataType::UInt64, false),
]));
let total_weight = weights.iter().sum::<u64>() + discard_weight;
// Thresholds are the cumulative sum of the weights
let mut offset = 0;
let thresholds = weights
.iter()
.map(|w| {
let value = offset + w;
offset = value;
value
})
.collect::<Vec<_>>();
let new_schema_clone = new_schema.clone();
let stream = source.map_ok(move |batch| {
let split_ids = Self::hash_split_id(&batch, &thresholds, total_weight);
if split_ids.null_count() > 0 {
let is_valid = split_ids.nulls().unwrap().inner();
let is_valid_mask = BooleanArray::new(is_valid.clone(), None);
let split_ids = arrow::compute::filter(&split_ids, &is_valid_mask).unwrap();
let row_ids = batch.column(row_id_index);
let row_ids = arrow::compute::filter(row_ids.as_ref(), &is_valid_mask).unwrap();
RecordBatch::try_new(new_schema.clone(), vec![row_ids, split_ids]).unwrap()
} else {
RecordBatch::try_new(
new_schema.clone(),
vec![batch.column(row_id_index).clone(), Arc::new(split_ids)],
)
.unwrap()
}
});
Ok(Box::pin(SimpleRecordBatchStream::new(
stream,
new_schema_clone,
)))
}
pub async fn apply(
&self,
source: SendableRecordBatchStream,
num_rows: u64,
) -> Result<SendableRecordBatchStream> {
self.strategy.validate(num_rows)?;
match &self.strategy {
// For consistency, even if no-split, we still give a split id column of all 0s
SplitStrategy::NoSplit => {
self.apply_sequential(source, num_rows, &SplitSizes::Counts(vec![num_rows]))
.await
}
SplitStrategy::Random { seed, sizes } => {
let shuffler = Shuffler::new(ShufflerConfig {
seed: *seed,
// In this case we are only shuffling row ids so we can use a large max_rows_per_file
max_rows_per_file: 10 * 1024 * 1024,
temp_dir: self.temp_dir.clone(),
clump_size: None,
});
let shuffled = shuffler.shuffle(source, num_rows).await?;
self.apply_sequential(shuffled, num_rows, sizes).await
}
SplitStrategy::Sequential { sizes } => {
self.apply_sequential(source, num_rows, sizes).await
}
// Nothing to do, split is calculated in projection
SplitStrategy::Calculated { .. } => Ok(source),
SplitStrategy::Hash {
split_weights,
discard_weight,
..
} => {
self.apply_hash(source, split_weights, *discard_weight)
.await
}
}
}
pub fn project(&self, query: Query) -> Query {
match &self.strategy {
SplitStrategy::Calculated { calculation } => query.select(Select::Dynamic(vec![(
SPLIT_ID_COLUMN.to_string(),
calculation.clone(),
)])),
SplitStrategy::Hash { columns, .. } => query.select(Select::Columns(columns.clone())),
_ => query,
}
}
pub fn orders_by_split_id(&self) -> bool {
match &self.strategy {
SplitStrategy::Hash { .. } | SplitStrategy::Calculated { .. } => true,
SplitStrategy::NoSplit
| SplitStrategy::Sequential { .. }
// It may be strange but for random we shuffle and then assign splits so the result is
// sorted by split id
| SplitStrategy::Random { .. } => false,
}
}
}
/// Split configuration - either percentages or absolute counts
///
/// If the percentages do not sum to 1.0 (or the counts do not sum to the total number of rows)
/// the remaining rows will not be included in the permutation.
///
/// The default implementation assigns all rows to a single split.
#[derive(Debug, Clone)]
pub enum SplitSizes {
/// Percentage splits (must sum to <= 1.0)
///
/// The number of rows in each split is the nearest integer to the percentage multiplied by
/// the total number of rows.
Percentages(Vec<f64>),
/// Absolute row counts per split
///
/// If the dataset doesn't contain enough matching rows to fill all splits then an error
/// will be raised.
Counts(Vec<u64>),
/// Divides data into a fixed number of splits
///
/// Will divide the data evenly.
///
/// If the number of rows is not divisible by the number of splits then the rows per split
/// is rounded down.
Fixed(u64),
}
impl Default for SplitSizes {
fn default() -> Self {
Self::Percentages(vec![1.0])
}
}
impl SplitSizes {
pub fn validate(&self, num_rows: u64) -> Result<()> {
match self {
Self::Percentages(percentages) => {
for percentage in percentages {
if *percentage < 0.0 || *percentage > 1.0 {
return Err(Error::InvalidInput {
message: "Split percentages must be between 0.0 and 1.0".to_string(),
});
}
if percentage * (num_rows as f64) < 1.0 {
return Err(Error::InvalidInput {
message: format!(
"One of the splits has {}% of {} rows which rounds to 0 rows",
percentage * 100.0,
num_rows
),
});
}
}
if percentages.iter().sum::<f64>() > 1.0 {
return Err(Error::InvalidInput {
message: "Split percentages must sum to 1.0 or less".to_string(),
});
}
}
Self::Counts(counts) => {
if counts.iter().sum::<u64>() > num_rows {
return Err(Error::InvalidInput {
message: format!(
"Split counts specified {} rows but only {} are available",
counts.iter().sum::<u64>(),
num_rows
),
});
}
if counts.contains(&0) {
return Err(Error::InvalidInput {
message: "Split counts must be greater than 0".to_string(),
});
}
}
Self::Fixed(num_splits) => {
if *num_splits > num_rows {
return Err(Error::InvalidInput {
message: format!(
"Split fixed config specified {} splits but only {} rows are available. Must have at least 1 row per split.",
*num_splits, num_rows
),
});
}
if (num_rows / num_splits) == 0 {
return Err(Error::InvalidInput {
message: format!(
"Split fixed config specified {} splits but only {} rows are available. Must have at least 1 row per split.",
*num_splits, num_rows
),
});
}
}
}
Ok(())
}
pub fn to_counts(&self, num_rows: u64) -> Vec<u64> {
let sizes = match self {
Self::Percentages(percentages) => {
let mut percentage_sum = 0.0_f64;
let mut counts = percentages
.iter()
.map(|p| {
let count = (p * (num_rows as f64)).round() as u64;
percentage_sum += p;
count
})
.collect::<Vec<_>>();
let sum = counts.iter().sum::<u64>();
let is_basically_one =
(num_rows as f64 - percentage_sum * num_rows as f64).abs() < 0.5;
// If the sum of percentages is close to 1.0 then rounding errors can add up
// to more or less than num_rows
//
// Drop items from buckets until we have the correct number of rows
let mut excess = sum as i64 - num_rows as i64;
let mut drop_idx = 0;
while excess > 0 {
if counts[drop_idx] > 0 {
counts[drop_idx] -= 1;
excess -= 1;
}
drop_idx += 1;
if drop_idx == counts.len() {
drop_idx = 0;
}
}
// On the other hand, if the percentages sum to ~1.0 then the we also shouldn't _lose_
// rows due to rounding errors
let mut add_idx = 0;
while is_basically_one && excess < 0 {
counts[add_idx] += 1;
add_idx += 1;
excess += 1;
if add_idx == counts.len() {
add_idx = 0;
}
}
counts
}
Self::Counts(counts) => counts.clone(),
Self::Fixed(num_splits) => {
let rows_per_split = num_rows / *num_splits;
vec![rows_per_split; *num_splits as usize]
}
};
assert!(sizes.iter().sum::<u64>() <= num_rows);
sizes
}
}
#[cfg(test)]
mod tests {
use crate::arrow::LanceDbDatagenExt;
use super::*;
use arrow::{
array::AsArray,
compute::concat_batches,
datatypes::{Int32Type, UInt64Type},
};
use arrow_array::Int32Array;
use futures::TryStreamExt;
use lance_datagen::{BatchCount, ByteCount, RowCount, Seed};
use std::sync::Arc;
const ID_COLUMN: &str = "id";
#[test]
fn test_split_sizes_percentages_validation() {
// Valid percentages
let sizes = SplitSizes::Percentages(vec![0.7, 0.3]);
assert!(sizes.validate(100).is_ok());
// Sum > 1.0
let sizes = SplitSizes::Percentages(vec![0.7, 0.4]);
assert!(sizes.validate(100).is_err());
// Negative percentage
let sizes = SplitSizes::Percentages(vec![-0.1, 0.5]);
assert!(sizes.validate(100).is_err());
// Percentage > 1.0
let sizes = SplitSizes::Percentages(vec![1.5]);
assert!(sizes.validate(100).is_err());
// Percentage rounds to 0 rows
let sizes = SplitSizes::Percentages(vec![0.001]);
assert!(sizes.validate(100).is_err());
}
#[test]
fn test_split_sizes_counts_validation() {
// Valid counts
let sizes = SplitSizes::Counts(vec![30, 70]);
assert!(sizes.validate(100).is_ok());
// Sum > num_rows
let sizes = SplitSizes::Counts(vec![60, 50]);
assert!(sizes.validate(100).is_err());
// Counts are 0
let sizes = SplitSizes::Counts(vec![0, 100]);
assert!(sizes.validate(100).is_err());
}
#[test]
fn test_split_sizes_fixed_validation() {
// Valid fixed splits
let sizes = SplitSizes::Fixed(5);
assert!(sizes.validate(100).is_ok());
// More splits than rows
let sizes = SplitSizes::Fixed(150);
assert!(sizes.validate(100).is_err());
}
#[test]
fn test_split_sizes_to_sizes_percentages() {
let sizes = SplitSizes::Percentages(vec![0.3, 0.7]);
let result = sizes.to_counts(100);
assert_eq!(result, vec![30, 70]);
// Test rounding
let sizes = SplitSizes::Percentages(vec![0.3, 0.41]);
let result = sizes.to_counts(70);
assert_eq!(result, vec![21, 29]);
}
#[test]
fn test_split_sizes_to_sizes_fixed() {
let sizes = SplitSizes::Fixed(4);
let result = sizes.to_counts(100);
assert_eq!(result, vec![25, 25, 25, 25]);
// Test with remainder
let sizes = SplitSizes::Fixed(3);
let result = sizes.to_counts(10);
assert_eq!(result, vec![3, 3, 3]);
}
fn test_data() -> SendableRecordBatchStream {
lance_datagen::gen_batch()
.with_seed(Seed::from(42))
.col(ID_COLUMN, lance_datagen::array::step::<Int32Type>())
.into_ldb_stream(RowCount::from(10), BatchCount::from(5))
}
async fn verify_splitter(
splitter: Splitter,
data: SendableRecordBatchStream,
num_rows: u64,
expected_split_sizes: &[u64],
row_ids_in_order: bool,
) {
let split_batches = splitter
.apply(data, num_rows)
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let schema = split_batches[0].schema();
let split_batch = concat_batches(&schema, &split_batches).unwrap();
let total_split_sizes = expected_split_sizes.iter().sum::<u64>();
assert_eq!(split_batch.num_rows(), total_split_sizes as usize);
let mut expected = Vec::with_capacity(total_split_sizes as usize);
for (i, size) in expected_split_sizes.iter().enumerate() {
expected.extend(iter::repeat(i as u64).take(*size as usize));
}
let expected = Arc::new(UInt64Array::from(expected)) as Arc<dyn Array>;
assert_eq!(&expected, split_batch.column(1));
let expected_row_ids =
Arc::new(Int32Array::from_iter_values(0..total_split_sizes as i32)) as Arc<dyn Array>;
if row_ids_in_order {
assert_eq!(&expected_row_ids, split_batch.column(0));
} else {
assert_ne!(&expected_row_ids, split_batch.column(0));
}
}
#[tokio::test]
async fn test_fixed_sequential_split() {
let splitter = Splitter::new(
// Sequential splitting doesn't need a temp dir
TemporaryDirectory::None,
SplitStrategy::Sequential {
sizes: SplitSizes::Fixed(3),
},
);
verify_splitter(splitter, test_data(), 50, &[16, 16, 16], true).await;
}
#[tokio::test]
async fn test_fixed_random_split() {
let splitter = Splitter::new(
TemporaryDirectory::None,
SplitStrategy::Random {
seed: Some(42),
sizes: SplitSizes::Fixed(3),
},
);
verify_splitter(splitter, test_data(), 50, &[16, 16, 16], false).await;
}
#[tokio::test]
async fn test_counts_sequential_split() {
let splitter = Splitter::new(
// Sequential splitting doesn't need a temp dir
TemporaryDirectory::None,
SplitStrategy::Sequential {
sizes: SplitSizes::Counts(vec![5, 15, 10]),
},
);
verify_splitter(splitter, test_data(), 50, &[5, 15, 10], true).await;
}
#[tokio::test]
async fn test_counts_random_split() {
let splitter = Splitter::new(
TemporaryDirectory::None,
SplitStrategy::Random {
seed: Some(42),
sizes: SplitSizes::Counts(vec![5, 15, 10]),
},
);
verify_splitter(splitter, test_data(), 50, &[5, 15, 10], false).await;
}
#[tokio::test]
async fn test_percentages_sequential_split() {
let splitter = Splitter::new(
// Sequential splitting doesn't need a temp dir
TemporaryDirectory::None,
SplitStrategy::Sequential {
sizes: SplitSizes::Percentages(vec![0.217, 0.168, 0.17]),
},
);
verify_splitter(splitter, test_data(), 50, &[11, 8, 9], true).await;
}
#[tokio::test]
async fn test_percentages_random_split() {
let splitter = Splitter::new(
TemporaryDirectory::None,
SplitStrategy::Random {
seed: Some(42),
sizes: SplitSizes::Percentages(vec![0.217, 0.168, 0.17]),
},
);
verify_splitter(splitter, test_data(), 50, &[11, 8, 9], false).await;
}
#[tokio::test]
async fn test_hash_split() {
let data = lance_datagen::gen_batch()
.with_seed(Seed::from(42))
.col(
"hash1",
lance_datagen::array::rand_utf8(ByteCount::from(10), false),
)
.col("hash2", lance_datagen::array::step::<Int32Type>())
.col(ID_COLUMN, lance_datagen::array::step::<Int32Type>())
.into_ldb_stream(RowCount::from(10), BatchCount::from(5));
let splitter = Splitter::new(
TemporaryDirectory::None,
SplitStrategy::Hash {
columns: vec!["hash1".to_string(), "hash2".to_string()],
split_weights: vec![1, 2],
discard_weight: 1,
},
);
let split_batches = splitter
.apply(data, 10)
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let schema = split_batches[0].schema();
let split_batch = concat_batches(&schema, &split_batches).unwrap();
// These assertions are all based on fixed seed in data generation but they match
// up roughly to what we expect (25% discarded, 25% in split 0, 50% in split 1)
// 14 rows (28%) are discarded because discard_weight is 1
assert_eq!(split_batch.num_rows(), 36);
assert_eq!(split_batch.num_columns(), 2);
let split_ids = split_batch.column(1).as_primitive::<UInt64Type>().values();
let num_in_split_0 = split_ids.iter().filter(|v| **v == 0).count();
let num_in_split_1 = split_ids.iter().filter(|v| **v == 1).count();
assert_eq!(num_in_split_0, 11); // 22%
assert_eq!(num_in_split_1, 25); // 50%
}
}

View File

@@ -1,98 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::{path::PathBuf, sync::Arc};
use arrow_array::RecordBatch;
use arrow_schema::{Fields, Schema};
use datafusion_execution::disk_manager::DiskManagerMode;
use futures::TryStreamExt;
use rand::{rngs::SmallRng, RngCore, SeedableRng};
use tempfile::TempDir;
use crate::{
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
Error, Result,
};
/// Directory to use for temporary files
#[derive(Debug, Clone, Default)]
pub enum TemporaryDirectory {
/// Use the operating system's default temporary directory (e.g. /tmp)
#[default]
OsDefault,
/// Use the specified directory (must be an absolute path)
Specific(PathBuf),
/// If spilling is required, then error out
None,
}
impl TemporaryDirectory {
pub fn create_temp_dir(&self) -> Result<TempDir> {
match self {
Self::OsDefault => tempfile::tempdir(),
Self::Specific(path) => tempfile::Builder::default().tempdir_in(path),
Self::None => {
return Err(Error::Runtime {
message: "No temporary directory was supplied and this operation requires spilling to disk".to_string(),
});
}
}
.map_err(|err| Error::Other {
message: "Failed to create temporary directory".to_string(),
source: Some(err.into()),
})
}
pub fn to_disk_manager_mode(&self) -> DiskManagerMode {
match self {
Self::OsDefault => DiskManagerMode::OsTmpDirectory,
Self::Specific(path) => DiskManagerMode::Directories(vec![path.clone()]),
Self::None => DiskManagerMode::Disabled,
}
}
}
pub fn non_crypto_rng(seed: &Option<u64>) -> Box<dyn RngCore + Send> {
Box::new(
seed.as_ref()
.map(|seed| SmallRng::seed_from_u64(*seed))
.unwrap_or_else(SmallRng::from_os_rng),
)
}
pub fn rename_column(
stream: SendableRecordBatchStream,
old_name: &str,
new_name: &str,
) -> Result<SendableRecordBatchStream> {
let schema = stream.schema();
let field_index = schema.index_of(old_name)?;
let new_fields = schema
.fields
.iter()
.cloned()
.enumerate()
.map(|(idx, f)| {
if idx == field_index {
Arc::new(f.as_ref().clone().with_name(new_name))
} else {
f
}
})
.collect::<Fields>();
let new_schema = Arc::new(Schema::new(new_fields).with_metadata(schema.metadata().clone()));
let new_schema_clone = new_schema.clone();
let renamed_stream = stream.and_then(move |batch| {
let renamed_batch =
RecordBatch::try_new(new_schema.clone(), batch.columns().to_vec()).map_err(Error::from);
std::future::ready(renamed_batch)
});
Ok(Box::pin(SimpleRecordBatchStream::new(
renamed_stream,
new_schema_clone,
)))
}

View File

@@ -194,7 +194,6 @@ pub mod arrow;
pub mod connection;
pub mod data;
pub mod database;
pub mod dataloader;
pub mod embeddings;
pub mod error;
pub mod index;

View File

@@ -16,7 +16,7 @@ use tokio::task::spawn_blocking;
use crate::database::{
CloneTableRequest, CreateNamespaceRequest, CreateTableData, CreateTableMode,
CreateTableRequest, Database, DatabaseOptions, DropNamespaceRequest, ListNamespacesRequest,
OpenTableRequest, ReadConsistency, TableNamesRequest,
OpenTableRequest, TableNamesRequest,
};
use crate::error::Result;
use crate::table::BaseTable;
@@ -189,7 +189,6 @@ struct ListTablesResponse {
pub struct RemoteDatabase<S: HttpSend = Sender> {
client: RestfulLanceDbClient<S>,
table_cache: Cache<String, Arc<RemoteTable<S>>>,
uri: String,
}
impl RemoteDatabase {
@@ -218,7 +217,6 @@ impl RemoteDatabase {
Ok(Self {
client,
table_cache,
uri: uri.to_owned(),
})
}
}
@@ -240,7 +238,6 @@ mod test_utils {
Self {
client,
table_cache: Cache::new(0),
uri: "http://localhost".to_string(),
}
}
@@ -253,7 +250,6 @@ mod test_utils {
Self {
client,
table_cache: Cache::new(0),
uri: "http://localhost".to_string(),
}
}
}
@@ -319,17 +315,6 @@ fn build_cache_key(name: &str, namespace: &[String]) -> String {
#[async_trait]
impl<S: HttpSend> Database for RemoteDatabase<S> {
fn uri(&self) -> &str {
&self.uri
}
async fn read_consistency(&self) -> Result<ReadConsistency> {
Err(Error::NotSupported {
message: "Getting the read consistency of a remote database is not yet supported"
.to_string(),
})
}
async fn table_names(&self, request: TableNamesRequest) -> Result<Vec<String>> {
let mut req = if !request.namespace.is_empty() {
let namespace_id =
@@ -662,7 +647,7 @@ impl From<StorageOptions> for RemoteOptions {
let mut filtered = HashMap::new();
for opt in supported_opts {
if let Some(v) = options.0.get(opt) {
filtered.insert(opt.to_string(), v.clone());
filtered.insert(opt.to_string(), v.to_string());
}
}
Self::new(filtered)

View File

@@ -50,7 +50,6 @@ use std::sync::Arc;
use crate::arrow::IntoArrow;
use crate::connection::NoData;
use crate::database::Database;
use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, MemoryRegistry};
use crate::error::{Error, Result};
use crate::index::vector::{suggested_num_partitions_for_hnsw, VectorIndex};
@@ -612,10 +611,9 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
/// A Table is a collection of strong typed Rows.
///
/// The type of the each row is defined in Apache Arrow [Schema].
#[derive(Clone, Debug)]
#[derive(Clone)]
pub struct Table {
inner: Arc<dyn BaseTable>,
database: Arc<dyn Database>,
embedding_registry: Arc<dyn EmbeddingRegistry>,
}
@@ -633,13 +631,11 @@ mod test_utils {
{
let inner = Arc::new(crate::remote::table::RemoteTable::new_mock(
name.into(),
handler.clone(),
handler,
None,
));
let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
Self {
inner,
database,
// Registry is unused.
embedding_registry: Arc::new(MemoryRegistry::new()),
}
@@ -655,13 +651,11 @@ mod test_utils {
{
let inner = Arc::new(crate::remote::table::RemoteTable::new_mock(
name.into(),
handler.clone(),
handler,
Some(version),
));
let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
Self {
inner,
database,
// Registry is unused.
embedding_registry: Arc::new(MemoryRegistry::new()),
}
@@ -676,10 +670,9 @@ impl std::fmt::Display for Table {
}
impl Table {
pub fn new(inner: Arc<dyn BaseTable>, database: Arc<dyn Database>) -> Self {
pub fn new(inner: Arc<dyn BaseTable>) -> Self {
Self {
inner,
database,
embedding_registry: Arc::new(MemoryRegistry::new()),
}
}
@@ -688,22 +681,12 @@ impl Table {
&self.inner
}
pub fn database(&self) -> &Arc<dyn Database> {
&self.database
}
pub fn embedding_registry(&self) -> &Arc<dyn EmbeddingRegistry> {
&self.embedding_registry
}
pub(crate) fn new_with_embedding_registry(
inner: Arc<dyn BaseTable>,
database: Arc<dyn Database>,
embedding_registry: Arc<dyn EmbeddingRegistry>,
) -> Self {
Self {
inner,
database,
embedding_registry,
}
}
@@ -1400,39 +1383,40 @@ impl Table {
}
pub struct NativeTags {
dataset: dataset::DatasetConsistencyWrapper,
inner: LanceTags,
}
#[async_trait]
impl Tags for NativeTags {
async fn list(&self) -> Result<HashMap<String, TagContents>> {
let dataset = self.dataset.get().await?;
Ok(dataset.tags().list().await?)
Ok(self.inner.list().await?)
}
async fn get_version(&self, tag: &str) -> Result<u64> {
let dataset = self.dataset.get().await?;
Ok(dataset.tags().get_version(tag).await?)
Ok(self.inner.get_version(tag).await?)
}
async fn create(&mut self, tag: &str, version: u64) -> Result<()> {
let dataset = self.dataset.get().await?;
dataset.tags().create(tag, version).await?;
self.inner.create(tag, version).await?;
Ok(())
}
async fn delete(&mut self, tag: &str) -> Result<()> {
let dataset = self.dataset.get().await?;
dataset.tags().delete(tag).await?;
self.inner.delete(tag).await?;
Ok(())
}
async fn update(&mut self, tag: &str, version: u64) -> Result<()> {
let dataset = self.dataset.get().await?;
dataset.tags().update(tag, version).await?;
self.inner.update(tag, version).await?;
Ok(())
}
}
impl From<NativeTable> for Table {
fn from(table: NativeTable) -> Self {
Self::new(Arc::new(table))
}
}
pub trait NativeTableExt {
/// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`].
fn as_native(&self) -> Option<&NativeTable>;
@@ -1796,13 +1780,13 @@ impl NativeTable {
BuiltinIndexType::BTree,
)))
} else {
Err(Error::InvalidInput {
return Err(Error::InvalidInput {
message: format!(
"there are no indices supported for the field `{}` with the data type {}",
field.name(),
field.data_type()
),
})?
});
}
}
Index::BTree(_) => {
@@ -2470,8 +2454,10 @@ impl BaseTable for NativeTable {
}
async fn tags(&self) -> Result<Box<dyn Tags + '_>> {
let dataset = self.dataset.get().await?;
Ok(Box::new(NativeTags {
dataset: self.dataset.clone(),
inner: dataset.tags.clone(),
}))
}

View File

@@ -125,10 +125,6 @@ impl ExecutionPlan for MetadataEraserExec {
fn partition_statistics(&self, partition: Option<usize>) -> DataFusionResult<Statistics> {
self.input.partition_statistics(partition)
}
fn supports_limit_pushdown(&self) -> bool {
true
}
}
#[derive(Debug)]
@@ -176,7 +172,7 @@ impl TableProvider for BaseTableAdapter {
if let Some(projection) = projection {
let field_names = projection
.iter()
.map(|i| self.schema.field(*i).name().clone())
.map(|i| self.schema.field(*i).name().to_string())
.collect();
query.select = Select::Columns(field_names);
}

View File

@@ -98,9 +98,8 @@ impl DatasetRef {
}
Self::TimeTravel { dataset, version } => {
let should_checkout = match &target_ref {
refs::Ref::Version(_, Some(target_ver)) => version != target_ver,
refs::Ref::Version(_, None) => true, // No specific version, always checkout
refs::Ref::Tag(_) => true, // Always checkout for tags
refs::Ref::Version(target_ver) => version != target_ver,
refs::Ref::Tag(_) => true, // Always checkout for tags
};
if should_checkout {

View File

@@ -39,7 +39,7 @@ impl PatchStoreParam for Option<ObjectStoreParams> {
let mut params = self.unwrap_or_default();
if params.object_store_wrapper.is_some() {
return Err(Error::Other {
message: "can not patch param because object store is already set".into(),
message: "can not patch param because object store is already set.".into(),
source: None,
});
}
@@ -174,7 +174,7 @@ pub(crate) fn default_vector_column(schema: &Schema, dim: Option<i32>) -> Result
),
})
} else {
Ok(candidates[0].clone())
Ok(candidates[0].to_string())
}
}