mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-24 22:09:58 +00:00
Compare commits
28 Commits
python-v0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
02d31ee412 | ||
|
|
308623577d | ||
|
|
8ee3ae378f | ||
|
|
3372a2aae0 | ||
|
|
4cfcd95320 | ||
|
|
a70ff04bc9 | ||
|
|
a9daa18be9 | ||
|
|
3f2e3986e9 | ||
|
|
bf55feb9b6 | ||
|
|
8f8e06a2da | ||
|
|
03eab0f091 | ||
|
|
143184c0ae | ||
|
|
dadb042978 | ||
|
|
5a19cf15a6 | ||
|
|
3dcec724b7 | ||
|
|
86a6bb9fcb | ||
|
|
b59d1007d3 | ||
|
|
56a16b1728 | ||
|
|
b7afed9beb | ||
|
|
5cbbaa2e4a | ||
|
|
1b6bd2498e | ||
|
|
285da9db1d | ||
|
|
ad8306c96b | ||
|
|
3594538509 | ||
|
|
917aabd077 | ||
|
|
5ec12c9971 | ||
|
|
d0ce489b21 | ||
|
|
d7e02c8181 |
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.22.2-beta.1"
|
||||
current_version = "0.22.3-beta.0"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
106
.github/workflows/codex-update-lance-dependency.yml
vendored
Normal file
106
.github/workflows/codex-update-lance-dependency.yml
vendored
Normal file
@@ -0,0 +1,106 @@
|
||||
name: Codex Update Lance Dependency
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
tag:
|
||||
description: "Tag name from Lance"
|
||||
required: true
|
||||
type: string
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: "Tag name from Lance"
|
||||
required: true
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
actions: read
|
||||
|
||||
jobs:
|
||||
update:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Show inputs
|
||||
run: |
|
||||
echo "tag = ${{ inputs.tag }}"
|
||||
|
||||
- name: Checkout Repo LanceDB
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: true
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 20
|
||||
|
||||
- name: Install Codex CLI
|
||||
run: npm install -g @openai/codex
|
||||
|
||||
- name: Install Rust toolchain
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
toolchain: stable
|
||||
components: clippy, rustfmt
|
||||
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y protobuf-compiler libssl-dev
|
||||
|
||||
- name: Install cargo-info
|
||||
run: cargo install cargo-info
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: python3 -m pip install --upgrade pip packaging
|
||||
|
||||
- name: Configure git user
|
||||
run: |
|
||||
git config user.name "lancedb automation"
|
||||
git config user.email "automation@lancedb.com"
|
||||
|
||||
- name: Configure Codex authentication
|
||||
env:
|
||||
CODEX_TOKEN_B64: ${{ secrets.CODEX_TOKEN }}
|
||||
run: |
|
||||
if [ -z "${CODEX_TOKEN_B64}" ]; then
|
||||
echo "Repository secret CODEX_TOKEN is not defined; skipping Codex execution."
|
||||
exit 1
|
||||
fi
|
||||
mkdir -p ~/.codex
|
||||
echo "${CODEX_TOKEN_B64}" | base64 --decode > ~/.codex/auth.json
|
||||
|
||||
- name: Run Codex to update Lance dependency
|
||||
env:
|
||||
TAG: ${{ inputs.tag }}
|
||||
GITHUB_TOKEN: ${{ github.token }}
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
VERSION="${TAG#refs/tags/}"
|
||||
VERSION="${VERSION#v}"
|
||||
BRANCH_NAME="codex/update-lance-${VERSION//[^a-zA-Z0-9]/-}"
|
||||
cat <<EOF >/tmp/codex-prompt.txt
|
||||
You are running inside the lancedb repository on a GitHub Actions runner. Update the Lance dependency to version ${VERSION} and prepare a pull request for maintainers to review.
|
||||
|
||||
Follow these steps exactly:
|
||||
1. Use script `ci/set_lance_version.py` to update Lance dependencies. The script already refreshes Cargo metadata, so allow it to finish even if it takes time.
|
||||
2. Run "cargo clippy --workspace --tests --all-features -- -D warnings". If diagnostics appear, fix them yourself and rerun clippy until it exits cleanly. Do not skip any warnings.
|
||||
3. After clippy succeeds, run "cargo fmt --all" to format the workspace.
|
||||
4. Ensure the repository is clean except for intentional changes. Inspect "git status --short" and "git diff" to confirm the dependency update and any required fixes.
|
||||
5. Create and switch to a new branch named "${BRANCH_NAME}" (replace any duplicated hyphens if necessary).
|
||||
6. Stage all relevant files with "git add -A". Commit using the message "chore: update lance dependency to v${VERSION}".
|
||||
7. Push the branch to origin. If the branch already exists, force-push your changes.
|
||||
8. Create a pull request targeting "main" with title "chore: update lance dependency to v${VERSION}". In the body, summarize the dependency bump, clippy/fmt verification, and link the triggering tag (${TAG}). Use the GitHub CLI if helpful.
|
||||
9. After creating the PR, display the PR URL, "git status --short", and a concise summary of the commands run and their results.
|
||||
|
||||
Constraints:
|
||||
- Use bash commands; avoid modifying GitHub workflow files other than through the scripted task above.
|
||||
- Do not merge the PR.
|
||||
- If any command fails, diagnose and fix the issue instead of aborting.
|
||||
EOF
|
||||
codex exec --dangerously-bypass-approvals-and-sandbox "$(cat /tmp/codex-prompt.txt)"
|
||||
770
Cargo.lock
generated
770
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
69
Cargo.toml
69
Cargo.toml
@@ -15,30 +15,36 @@ categories = ["database-implementations"]
|
||||
rust-version = "1.78.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.38.0", default-features = false, "features" = ["dynamodb"] }
|
||||
lance-io = { "version" = "=0.38.0", default-features = false }
|
||||
lance-index = "=0.38.0"
|
||||
lance-linalg = "=0.38.0"
|
||||
lance-table = "=0.38.0"
|
||||
lance-testing = "=0.38.0"
|
||||
lance-datafusion = "=0.38.0"
|
||||
lance-encoding = "=0.38.0"
|
||||
lance-namespace = "0.0.16"
|
||||
lance = { "version" = "=0.38.3-beta.6", default-features = false, "tag" = "v0.38.3-beta.6", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-core = { "version" = "=0.38.3-beta.6", "tag" = "v0.38.3-beta.6", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-datagen = { "version" = "=0.38.3-beta.6", "tag" = "v0.38.3-beta.6", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-file = { "version" = "=0.38.3-beta.6", "tag" = "v0.38.3-beta.6", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-io = { "version" = "=0.38.3-beta.6", default-features = false, "tag" = "v0.38.3-beta.6", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-index = { "version" = "=0.38.3-beta.6", "tag" = "v0.38.3-beta.6", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-linalg = { "version" = "=0.38.3-beta.6", "tag" = "v0.38.3-beta.6", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-namespace = { "version" = "=0.38.3-beta.6", "tag" = "v0.38.3-beta.6", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-namespace-impls = { "version" = "=0.38.3-beta.6", "tag" = "v0.38.3-beta.6", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-table = { "version" = "=0.38.3-beta.6", "tag" = "v0.38.3-beta.6", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-testing = { "version" = "=0.38.3-beta.6", "tag" = "v0.38.3-beta.6", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-datafusion = { "version" = "=0.38.3-beta.6", "tag" = "v0.38.3-beta.6", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-encoding = { "version" = "=0.38.3-beta.6", "tag" = "v0.38.3-beta.6", "git" = "https://github.com/lancedb/lance.git" }
|
||||
ahash = "0.8"
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "55.1", optional = false }
|
||||
arrow-array = "55.1"
|
||||
arrow-data = "55.1"
|
||||
arrow-ipc = "55.1"
|
||||
arrow-ord = "55.1"
|
||||
arrow-schema = "55.1"
|
||||
arrow-cast = "55.1"
|
||||
arrow = { version = "56.2", optional = false }
|
||||
arrow-array = "56.2"
|
||||
arrow-data = "56.2"
|
||||
arrow-ipc = "56.2"
|
||||
arrow-ord = "56.2"
|
||||
arrow-schema = "56.2"
|
||||
arrow-select = "56.2"
|
||||
arrow-cast = "56.2"
|
||||
async-trait = "0"
|
||||
datafusion = { version = "49.0", default-features = false }
|
||||
datafusion-catalog = "49.0"
|
||||
datafusion-common = { version = "49.0", default-features = false }
|
||||
datafusion-execution = "49.0"
|
||||
datafusion-expr = "49.0"
|
||||
datafusion-physical-plan = "49.0"
|
||||
datafusion = { version = "50.1", default-features = false }
|
||||
datafusion-catalog = "50.1"
|
||||
datafusion-common = { version = "50.1", default-features = false }
|
||||
datafusion-execution = "50.1"
|
||||
datafusion-expr = "50.1"
|
||||
datafusion-physical-plan = "50.1"
|
||||
env_logger = "0.11"
|
||||
half = { "version" = "2.6.0", default-features = false, features = [
|
||||
"num-traits",
|
||||
@@ -48,6 +54,7 @@ log = "0.4"
|
||||
moka = { version = "0.12", features = ["future"] }
|
||||
object_store = "0.12.0"
|
||||
pin-project = "1.0.7"
|
||||
rand = "0.9"
|
||||
snafu = "0.8"
|
||||
url = "2"
|
||||
num-traits = "0.2"
|
||||
@@ -55,20 +62,6 @@ regex = "1.10"
|
||||
lazy_static = "1"
|
||||
semver = "1.0.25"
|
||||
crunchy = "0.2.4"
|
||||
# Temporary pins to work around downstream issues
|
||||
# https://github.com/apache/arrow-rs/commit/2fddf85afcd20110ce783ed5b4cdeb82293da30b
|
||||
chrono = "=0.4.41"
|
||||
chrono = "0.4"
|
||||
# Workaround for: https://github.com/Lokathor/bytemuck/issues/306
|
||||
bytemuck_derive = ">=1.8.1, <1.9.0"
|
||||
|
||||
# This is only needed when we reference preview releases of lance
|
||||
# [patch.crates-io]
|
||||
# # Force to use the same lance version as the rest of the project to avoid duplicate dependencies
|
||||
# lance = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
|
||||
# lance-io = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
|
||||
# lance-index = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
|
||||
# lance-linalg = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
|
||||
# lance-table = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
|
||||
# lance-testing = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
|
||||
# lance-datafusion = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
|
||||
# lance-encoding = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
|
||||
bytemuck_derive = ">=1.8.1, <1.9.0"
|
||||
@@ -117,7 +117,7 @@ def update_cargo_toml(line_updater):
|
||||
lance_line = ""
|
||||
is_parsing_lance_line = False
|
||||
for line in lines:
|
||||
if line.startswith("lance") and not line.startswith("lance-namespace"):
|
||||
if line.startswith("lance"):
|
||||
# Check if this is a single-line or multi-line entry
|
||||
# Single-line entries either:
|
||||
# 1. End with } (complete inline table)
|
||||
@@ -183,10 +183,8 @@ def set_preview_version(version: str):
|
||||
|
||||
def line_updater(line: str) -> str:
|
||||
package_name = line.split("=", maxsplit=1)[0].strip()
|
||||
base_version = version.split("-")[0] # Get the base version without beta suffix
|
||||
|
||||
# Build config in desired order: version, default-features, features, tag, git
|
||||
config = {"version": f"={base_version}"}
|
||||
config = {"version": f"={version}"}
|
||||
|
||||
if extract_default_features(line):
|
||||
config["default-features"] = False
|
||||
|
||||
@@ -84,6 +84,7 @@ plugins:
|
||||
'examples.md': 'https://lancedb.com/docs/tutorials/'
|
||||
'concepts/vector_search.md': 'https://lancedb.com/docs/search/vector-search/'
|
||||
'troubleshooting.md': 'https://lancedb.com/docs/troubleshooting/'
|
||||
'guides/storage.md': 'https://lancedb.com/docs/storage/integrations'
|
||||
|
||||
|
||||
|
||||
@@ -402,4 +403,4 @@ extra:
|
||||
- icon: fontawesome/brands/x-twitter
|
||||
link: https://twitter.com/lancedb
|
||||
- icon: fontawesome/brands/linkedin
|
||||
link: https://www.linkedin.com/company/lancedb
|
||||
link: https://www.linkedin.com/company/lancedb
|
||||
|
||||
@@ -194,6 +194,37 @@ currently is also a memory intensive operation.
|
||||
|
||||
***
|
||||
|
||||
### ivfRq()
|
||||
|
||||
```ts
|
||||
static ivfRq(options?): Index
|
||||
```
|
||||
|
||||
Create an IvfRq index
|
||||
|
||||
IVF-RQ (RabitQ Quantization) compresses vectors using RabitQ quantization
|
||||
and organizes them into IVF partitions.
|
||||
|
||||
The compression scheme is called RabitQ quantization. Each dimension is quantized into a small number of bits.
|
||||
The parameters `num_bits` and `num_partitions` control this process, providing a tradeoff
|
||||
between index size (and thus search speed) and index accuracy.
|
||||
|
||||
The partitioning process is called IVF and the `num_partitions` parameter controls how
|
||||
many groups to create.
|
||||
|
||||
Note that training an IVF RQ index on a large dataset is a slow operation and
|
||||
currently is also a memory intensive operation.
|
||||
|
||||
#### Parameters
|
||||
|
||||
* **options?**: `Partial`<[`IvfRqOptions`](../interfaces/IvfRqOptions.md)>
|
||||
|
||||
#### Returns
|
||||
|
||||
[`Index`](Index.md)
|
||||
|
||||
***
|
||||
|
||||
### labelList()
|
||||
|
||||
```ts
|
||||
|
||||
220
docs/src/js/classes/PermutationBuilder.md
Normal file
220
docs/src/js/classes/PermutationBuilder.md
Normal file
@@ -0,0 +1,220 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / PermutationBuilder
|
||||
|
||||
# Class: PermutationBuilder
|
||||
|
||||
A PermutationBuilder for creating data permutations with splits, shuffling, and filtering.
|
||||
|
||||
This class provides a TypeScript wrapper around the native Rust PermutationBuilder,
|
||||
offering methods to configure data splits, shuffling, and filtering before executing
|
||||
the permutation to create a new table.
|
||||
|
||||
## Methods
|
||||
|
||||
### execute()
|
||||
|
||||
```ts
|
||||
execute(): Promise<Table>
|
||||
```
|
||||
|
||||
Execute the permutation and create the destination table.
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<[`Table`](Table.md)>
|
||||
|
||||
A Promise that resolves to the new Table instance
|
||||
|
||||
#### Example
|
||||
|
||||
```ts
|
||||
const permutationTable = await builder.execute();
|
||||
console.log(`Created table: ${permutationTable.name}`);
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
### filter()
|
||||
|
||||
```ts
|
||||
filter(filter): PermutationBuilder
|
||||
```
|
||||
|
||||
Configure filtering for the permutation.
|
||||
|
||||
#### Parameters
|
||||
|
||||
* **filter**: `string`
|
||||
SQL filter expression
|
||||
|
||||
#### Returns
|
||||
|
||||
[`PermutationBuilder`](PermutationBuilder.md)
|
||||
|
||||
A new PermutationBuilder instance
|
||||
|
||||
#### Example
|
||||
|
||||
```ts
|
||||
builder.filter("age > 18 AND status = 'active'");
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
### shuffle()
|
||||
|
||||
```ts
|
||||
shuffle(options): PermutationBuilder
|
||||
```
|
||||
|
||||
Configure shuffling for the permutation.
|
||||
|
||||
#### Parameters
|
||||
|
||||
* **options**: [`ShuffleOptions`](../interfaces/ShuffleOptions.md)
|
||||
Configuration for shuffling
|
||||
|
||||
#### Returns
|
||||
|
||||
[`PermutationBuilder`](PermutationBuilder.md)
|
||||
|
||||
A new PermutationBuilder instance
|
||||
|
||||
#### Example
|
||||
|
||||
```ts
|
||||
// Basic shuffle
|
||||
builder.shuffle({ seed: 42 });
|
||||
|
||||
// Shuffle with clump size
|
||||
builder.shuffle({ seed: 42, clumpSize: 10 });
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
### splitCalculated()
|
||||
|
||||
```ts
|
||||
splitCalculated(calculation): PermutationBuilder
|
||||
```
|
||||
|
||||
Configure calculated splits for the permutation.
|
||||
|
||||
#### Parameters
|
||||
|
||||
* **calculation**: `string`
|
||||
SQL expression for calculating splits
|
||||
|
||||
#### Returns
|
||||
|
||||
[`PermutationBuilder`](PermutationBuilder.md)
|
||||
|
||||
A new PermutationBuilder instance
|
||||
|
||||
#### Example
|
||||
|
||||
```ts
|
||||
builder.splitCalculated("user_id % 3");
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
### splitHash()
|
||||
|
||||
```ts
|
||||
splitHash(options): PermutationBuilder
|
||||
```
|
||||
|
||||
Configure hash-based splits for the permutation.
|
||||
|
||||
#### Parameters
|
||||
|
||||
* **options**: [`SplitHashOptions`](../interfaces/SplitHashOptions.md)
|
||||
Configuration for hash-based splitting
|
||||
|
||||
#### Returns
|
||||
|
||||
[`PermutationBuilder`](PermutationBuilder.md)
|
||||
|
||||
A new PermutationBuilder instance
|
||||
|
||||
#### Example
|
||||
|
||||
```ts
|
||||
builder.splitHash({
|
||||
columns: ["user_id"],
|
||||
splitWeights: [70, 30],
|
||||
discardWeight: 0
|
||||
});
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
### splitRandom()
|
||||
|
||||
```ts
|
||||
splitRandom(options): PermutationBuilder
|
||||
```
|
||||
|
||||
Configure random splits for the permutation.
|
||||
|
||||
#### Parameters
|
||||
|
||||
* **options**: [`SplitRandomOptions`](../interfaces/SplitRandomOptions.md)
|
||||
Configuration for random splitting
|
||||
|
||||
#### Returns
|
||||
|
||||
[`PermutationBuilder`](PermutationBuilder.md)
|
||||
|
||||
A new PermutationBuilder instance
|
||||
|
||||
#### Example
|
||||
|
||||
```ts
|
||||
// Split by ratios
|
||||
builder.splitRandom({ ratios: [0.7, 0.3], seed: 42 });
|
||||
|
||||
// Split by counts
|
||||
builder.splitRandom({ counts: [1000, 500], seed: 42 });
|
||||
|
||||
// Split with fixed size
|
||||
builder.splitRandom({ fixed: 100, seed: 42 });
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
### splitSequential()
|
||||
|
||||
```ts
|
||||
splitSequential(options): PermutationBuilder
|
||||
```
|
||||
|
||||
Configure sequential splits for the permutation.
|
||||
|
||||
#### Parameters
|
||||
|
||||
* **options**: [`SplitSequentialOptions`](../interfaces/SplitSequentialOptions.md)
|
||||
Configuration for sequential splitting
|
||||
|
||||
#### Returns
|
||||
|
||||
[`PermutationBuilder`](PermutationBuilder.md)
|
||||
|
||||
A new PermutationBuilder instance
|
||||
|
||||
#### Example
|
||||
|
||||
```ts
|
||||
// Split by ratios
|
||||
builder.splitSequential({ ratios: [0.8, 0.2] });
|
||||
|
||||
// Split by counts
|
||||
builder.splitSequential({ counts: [800, 200] });
|
||||
|
||||
// Split with fixed size
|
||||
builder.splitSequential({ fixed: 1000 });
|
||||
```
|
||||
@@ -343,6 +343,29 @@ This is useful for pagination.
|
||||
|
||||
***
|
||||
|
||||
### outputSchema()
|
||||
|
||||
```ts
|
||||
outputSchema(): Promise<Schema<any>>
|
||||
```
|
||||
|
||||
Returns the schema of the output that will be returned by this query.
|
||||
|
||||
This can be used to inspect the types and names of the columns that will be
|
||||
returned by the query before executing it.
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<`Schema`<`any`>>
|
||||
|
||||
An Arrow Schema describing the output columns.
|
||||
|
||||
#### Inherited from
|
||||
|
||||
`StandardQueryBase.outputSchema`
|
||||
|
||||
***
|
||||
|
||||
### select()
|
||||
|
||||
```ts
|
||||
|
||||
@@ -140,6 +140,25 @@ const plan = await table.query().nearestTo([0.5, 0.2]).explainPlan();
|
||||
|
||||
***
|
||||
|
||||
### outputSchema()
|
||||
|
||||
```ts
|
||||
outputSchema(): Promise<Schema<any>>
|
||||
```
|
||||
|
||||
Returns the schema of the output that will be returned by this query.
|
||||
|
||||
This can be used to inspect the types and names of the columns that will be
|
||||
returned by the query before executing it.
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<`Schema`<`any`>>
|
||||
|
||||
An Arrow Schema describing the output columns.
|
||||
|
||||
***
|
||||
|
||||
### select()
|
||||
|
||||
```ts
|
||||
|
||||
@@ -143,6 +143,29 @@ const plan = await table.query().nearestTo([0.5, 0.2]).explainPlan();
|
||||
|
||||
***
|
||||
|
||||
### outputSchema()
|
||||
|
||||
```ts
|
||||
outputSchema(): Promise<Schema<any>>
|
||||
```
|
||||
|
||||
Returns the schema of the output that will be returned by this query.
|
||||
|
||||
This can be used to inspect the types and names of the columns that will be
|
||||
returned by the query before executing it.
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<`Schema`<`any`>>
|
||||
|
||||
An Arrow Schema describing the output columns.
|
||||
|
||||
#### Inherited from
|
||||
|
||||
[`QueryBase`](QueryBase.md).[`outputSchema`](QueryBase.md#outputschema)
|
||||
|
||||
***
|
||||
|
||||
### select()
|
||||
|
||||
```ts
|
||||
|
||||
@@ -498,6 +498,29 @@ This is useful for pagination.
|
||||
|
||||
***
|
||||
|
||||
### outputSchema()
|
||||
|
||||
```ts
|
||||
outputSchema(): Promise<Schema<any>>
|
||||
```
|
||||
|
||||
Returns the schema of the output that will be returned by this query.
|
||||
|
||||
This can be used to inspect the types and names of the columns that will be
|
||||
returned by the query before executing it.
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<`Schema`<`any`>>
|
||||
|
||||
An Arrow Schema describing the output columns.
|
||||
|
||||
#### Inherited from
|
||||
|
||||
`StandardQueryBase.outputSchema`
|
||||
|
||||
***
|
||||
|
||||
### postfilter()
|
||||
|
||||
```ts
|
||||
|
||||
34
docs/src/js/functions/permutationBuilder.md
Normal file
34
docs/src/js/functions/permutationBuilder.md
Normal file
@@ -0,0 +1,34 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / permutationBuilder
|
||||
|
||||
# Function: permutationBuilder()
|
||||
|
||||
```ts
|
||||
function permutationBuilder(table): PermutationBuilder
|
||||
```
|
||||
|
||||
Create a permutation builder for the given table.
|
||||
|
||||
## Parameters
|
||||
|
||||
* **table**: [`Table`](../classes/Table.md)
|
||||
The source table to create a permutation from
|
||||
|
||||
## Returns
|
||||
|
||||
[`PermutationBuilder`](../classes/PermutationBuilder.md)
|
||||
|
||||
A PermutationBuilder instance
|
||||
|
||||
## Example
|
||||
|
||||
```ts
|
||||
const builder = permutationBuilder(sourceTable, "training_data")
|
||||
.splitRandom({ ratios: [0.8, 0.2], seed: 42 })
|
||||
.shuffle({ seed: 123 });
|
||||
|
||||
const trainingTable = await builder.execute();
|
||||
```
|
||||
@@ -28,6 +28,7 @@
|
||||
- [MultiMatchQuery](classes/MultiMatchQuery.md)
|
||||
- [NativeJsHeaderProvider](classes/NativeJsHeaderProvider.md)
|
||||
- [OAuthHeaderProvider](classes/OAuthHeaderProvider.md)
|
||||
- [PermutationBuilder](classes/PermutationBuilder.md)
|
||||
- [PhraseQuery](classes/PhraseQuery.md)
|
||||
- [Query](classes/Query.md)
|
||||
- [QueryBase](classes/QueryBase.md)
|
||||
@@ -68,6 +69,7 @@
|
||||
- [IndexStatistics](interfaces/IndexStatistics.md)
|
||||
- [IvfFlatOptions](interfaces/IvfFlatOptions.md)
|
||||
- [IvfPqOptions](interfaces/IvfPqOptions.md)
|
||||
- [IvfRqOptions](interfaces/IvfRqOptions.md)
|
||||
- [MergeResult](interfaces/MergeResult.md)
|
||||
- [OpenTableOptions](interfaces/OpenTableOptions.md)
|
||||
- [OptimizeOptions](interfaces/OptimizeOptions.md)
|
||||
@@ -75,6 +77,10 @@
|
||||
- [QueryExecutionOptions](interfaces/QueryExecutionOptions.md)
|
||||
- [RemovalStats](interfaces/RemovalStats.md)
|
||||
- [RetryConfig](interfaces/RetryConfig.md)
|
||||
- [ShuffleOptions](interfaces/ShuffleOptions.md)
|
||||
- [SplitHashOptions](interfaces/SplitHashOptions.md)
|
||||
- [SplitRandomOptions](interfaces/SplitRandomOptions.md)
|
||||
- [SplitSequentialOptions](interfaces/SplitSequentialOptions.md)
|
||||
- [TableNamesOptions](interfaces/TableNamesOptions.md)
|
||||
- [TableStatistics](interfaces/TableStatistics.md)
|
||||
- [TimeoutConfig](interfaces/TimeoutConfig.md)
|
||||
@@ -102,3 +108,4 @@
|
||||
- [connect](functions/connect.md)
|
||||
- [makeArrowTable](functions/makeArrowTable.md)
|
||||
- [packBits](functions/packBits.md)
|
||||
- [permutationBuilder](functions/permutationBuilder.md)
|
||||
|
||||
101
docs/src/js/interfaces/IvfRqOptions.md
Normal file
101
docs/src/js/interfaces/IvfRqOptions.md
Normal file
@@ -0,0 +1,101 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / IvfRqOptions
|
||||
|
||||
# Interface: IvfRqOptions
|
||||
|
||||
## Properties
|
||||
|
||||
### distanceType?
|
||||
|
||||
```ts
|
||||
optional distanceType: "l2" | "cosine" | "dot";
|
||||
```
|
||||
|
||||
Distance type to use to build the index.
|
||||
|
||||
Default value is "l2".
|
||||
|
||||
This is used when training the index to calculate the IVF partitions
|
||||
(vectors are grouped in partitions with similar vectors according to this
|
||||
distance type) and during quantization.
|
||||
|
||||
The distance type used to train an index MUST match the distance type used
|
||||
to search the index. Failure to do so will yield inaccurate results.
|
||||
|
||||
The following distance types are available:
|
||||
|
||||
"l2" - Euclidean distance.
|
||||
"cosine" - Cosine distance.
|
||||
"dot" - Dot product.
|
||||
|
||||
***
|
||||
|
||||
### maxIterations?
|
||||
|
||||
```ts
|
||||
optional maxIterations: number;
|
||||
```
|
||||
|
||||
Max iterations to train IVF kmeans.
|
||||
|
||||
When training an IVF index we use kmeans to calculate the partitions. This parameter
|
||||
controls how many iterations of kmeans to run.
|
||||
|
||||
The default value is 50.
|
||||
|
||||
***
|
||||
|
||||
### numBits?
|
||||
|
||||
```ts
|
||||
optional numBits: number;
|
||||
```
|
||||
|
||||
Number of bits per dimension for residual quantization.
|
||||
|
||||
This value controls how much each residual component is compressed. The more
|
||||
bits, the more accurate the index will be but the slower search. Typical values
|
||||
are small integers; the default is 1 bit per dimension.
|
||||
|
||||
***
|
||||
|
||||
### numPartitions?
|
||||
|
||||
```ts
|
||||
optional numPartitions: number;
|
||||
```
|
||||
|
||||
The number of IVF partitions to create.
|
||||
|
||||
This value should generally scale with the number of rows in the dataset.
|
||||
By default the number of partitions is the square root of the number of
|
||||
rows.
|
||||
|
||||
If this value is too large then the first part of the search (picking the
|
||||
right partition) will be slow. If this value is too small then the second
|
||||
part of the search (searching within a partition) will be slow.
|
||||
|
||||
***
|
||||
|
||||
### sampleRate?
|
||||
|
||||
```ts
|
||||
optional sampleRate: number;
|
||||
```
|
||||
|
||||
The number of vectors, per partition, to sample when training IVF kmeans.
|
||||
|
||||
When an IVF index is trained, we need to calculate partitions. These are groups
|
||||
of vectors that are similar to each other. To do this we use an algorithm called kmeans.
|
||||
|
||||
Running kmeans on a large dataset can be slow. To speed this up we run kmeans on a
|
||||
random sample of the data. This parameter controls the size of the sample. The total
|
||||
number of vectors used to train the index is `sample_rate * num_partitions`.
|
||||
|
||||
Increasing this value might improve the quality of the index but in most cases the
|
||||
default should be sufficient.
|
||||
|
||||
The default value is 256.
|
||||
23
docs/src/js/interfaces/ShuffleOptions.md
Normal file
23
docs/src/js/interfaces/ShuffleOptions.md
Normal file
@@ -0,0 +1,23 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / ShuffleOptions
|
||||
|
||||
# Interface: ShuffleOptions
|
||||
|
||||
## Properties
|
||||
|
||||
### clumpSize?
|
||||
|
||||
```ts
|
||||
optional clumpSize: number;
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
### seed?
|
||||
|
||||
```ts
|
||||
optional seed: number;
|
||||
```
|
||||
31
docs/src/js/interfaces/SplitHashOptions.md
Normal file
31
docs/src/js/interfaces/SplitHashOptions.md
Normal file
@@ -0,0 +1,31 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / SplitHashOptions
|
||||
|
||||
# Interface: SplitHashOptions
|
||||
|
||||
## Properties
|
||||
|
||||
### columns
|
||||
|
||||
```ts
|
||||
columns: string[];
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
### discardWeight?
|
||||
|
||||
```ts
|
||||
optional discardWeight: number;
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
### splitWeights
|
||||
|
||||
```ts
|
||||
splitWeights: number[];
|
||||
```
|
||||
39
docs/src/js/interfaces/SplitRandomOptions.md
Normal file
39
docs/src/js/interfaces/SplitRandomOptions.md
Normal file
@@ -0,0 +1,39 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / SplitRandomOptions
|
||||
|
||||
# Interface: SplitRandomOptions
|
||||
|
||||
## Properties
|
||||
|
||||
### counts?
|
||||
|
||||
```ts
|
||||
optional counts: number[];
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
### fixed?
|
||||
|
||||
```ts
|
||||
optional fixed: number;
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
### ratios?
|
||||
|
||||
```ts
|
||||
optional ratios: number[];
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
### seed?
|
||||
|
||||
```ts
|
||||
optional seed: number;
|
||||
```
|
||||
31
docs/src/js/interfaces/SplitSequentialOptions.md
Normal file
31
docs/src/js/interfaces/SplitSequentialOptions.md
Normal file
@@ -0,0 +1,31 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / SplitSequentialOptions
|
||||
|
||||
# Interface: SplitSequentialOptions
|
||||
|
||||
## Properties
|
||||
|
||||
### counts?
|
||||
|
||||
```ts
|
||||
optional counts: number[];
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
### fixed?
|
||||
|
||||
```ts
|
||||
optional fixed: number;
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
### ratios?
|
||||
|
||||
```ts
|
||||
optional ratios: number[];
|
||||
```
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.22.2-beta.1</version>
|
||||
<version>0.22.3-beta.0</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.22.2-beta.1</version>
|
||||
<version>0.22.3-beta.0</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.22.2-beta.1</version>
|
||||
<version>0.22.3-beta.0</version>
|
||||
<packaging>pom</packaging>
|
||||
<name>${project.artifactId}</name>
|
||||
<description>LanceDB Java SDK Parent POM</description>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "lancedb-nodejs"
|
||||
edition.workspace = true
|
||||
version = "0.22.2-beta.1"
|
||||
version = "0.22.3-beta.0"
|
||||
license.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
|
||||
227
nodejs/__test__/permutation.test.ts
Normal file
227
nodejs/__test__/permutation.test.ts
Normal file
@@ -0,0 +1,227 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import * as tmp from "tmp";
|
||||
import { Table, connect, permutationBuilder } from "../lancedb";
|
||||
import { makeArrowTable } from "../lancedb/arrow";
|
||||
|
||||
describe("PermutationBuilder", () => {
|
||||
let tmpDir: tmp.DirResult;
|
||||
let table: Table;
|
||||
|
||||
beforeEach(async () => {
|
||||
tmpDir = tmp.dirSync({ unsafeCleanup: true });
|
||||
const db = await connect(tmpDir.name);
|
||||
|
||||
// Create test data
|
||||
const data = makeArrowTable(
|
||||
[
|
||||
{ id: 1, value: 10 },
|
||||
{ id: 2, value: 20 },
|
||||
{ id: 3, value: 30 },
|
||||
{ id: 4, value: 40 },
|
||||
{ id: 5, value: 50 },
|
||||
{ id: 6, value: 60 },
|
||||
{ id: 7, value: 70 },
|
||||
{ id: 8, value: 80 },
|
||||
{ id: 9, value: 90 },
|
||||
{ id: 10, value: 100 },
|
||||
],
|
||||
{ vectorColumns: {} },
|
||||
);
|
||||
|
||||
table = await db.createTable("test_table", data);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
tmpDir.removeCallback();
|
||||
});
|
||||
|
||||
test("should create permutation builder", () => {
|
||||
const builder = permutationBuilder(table);
|
||||
expect(builder).toBeDefined();
|
||||
});
|
||||
|
||||
test("should execute basic permutation", async () => {
|
||||
const builder = permutationBuilder(table);
|
||||
const permutationTable = await builder.execute();
|
||||
|
||||
expect(permutationTable).toBeDefined();
|
||||
|
||||
const rowCount = await permutationTable.countRows();
|
||||
expect(rowCount).toBe(10);
|
||||
});
|
||||
|
||||
test("should create permutation with random splits", async () => {
|
||||
const builder = permutationBuilder(table).splitRandom({
|
||||
ratios: [1.0],
|
||||
seed: 42,
|
||||
});
|
||||
|
||||
const permutationTable = await builder.execute();
|
||||
const rowCount = await permutationTable.countRows();
|
||||
expect(rowCount).toBe(10);
|
||||
});
|
||||
|
||||
test("should create permutation with percentage splits", async () => {
|
||||
const builder = permutationBuilder(table).splitRandom({
|
||||
ratios: [0.3, 0.7],
|
||||
seed: 42,
|
||||
});
|
||||
|
||||
const permutationTable = await builder.execute();
|
||||
const rowCount = await permutationTable.countRows();
|
||||
expect(rowCount).toBe(10);
|
||||
|
||||
// Check split distribution
|
||||
const split0Count = await permutationTable.countRows("split_id = 0");
|
||||
const split1Count = await permutationTable.countRows("split_id = 1");
|
||||
|
||||
expect(split0Count).toBeGreaterThan(0);
|
||||
expect(split1Count).toBeGreaterThan(0);
|
||||
expect(split0Count + split1Count).toBe(10);
|
||||
});
|
||||
|
||||
test("should create permutation with count splits", async () => {
|
||||
const builder = permutationBuilder(table).splitRandom({
|
||||
counts: [3, 7],
|
||||
seed: 42,
|
||||
});
|
||||
|
||||
const permutationTable = await builder.execute();
|
||||
const rowCount = await permutationTable.countRows();
|
||||
expect(rowCount).toBe(10);
|
||||
|
||||
// Check split distribution
|
||||
const split0Count = await permutationTable.countRows("split_id = 0");
|
||||
const split1Count = await permutationTable.countRows("split_id = 1");
|
||||
|
||||
expect(split0Count).toBe(3);
|
||||
expect(split1Count).toBe(7);
|
||||
});
|
||||
|
||||
test("should create permutation with hash splits", async () => {
|
||||
const builder = permutationBuilder(table).splitHash({
|
||||
columns: ["id"],
|
||||
splitWeights: [50, 50],
|
||||
discardWeight: 0,
|
||||
});
|
||||
|
||||
const permutationTable = await builder.execute();
|
||||
const rowCount = await permutationTable.countRows();
|
||||
expect(rowCount).toBe(10);
|
||||
|
||||
// Check that splits exist
|
||||
const split0Count = await permutationTable.countRows("split_id = 0");
|
||||
const split1Count = await permutationTable.countRows("split_id = 1");
|
||||
|
||||
expect(split0Count).toBeGreaterThan(0);
|
||||
expect(split1Count).toBeGreaterThan(0);
|
||||
expect(split0Count + split1Count).toBe(10);
|
||||
});
|
||||
|
||||
test("should create permutation with sequential splits", async () => {
|
||||
const builder = permutationBuilder(table).splitSequential({
|
||||
ratios: [0.5, 0.5],
|
||||
});
|
||||
|
||||
const permutationTable = await builder.execute();
|
||||
const rowCount = await permutationTable.countRows();
|
||||
expect(rowCount).toBe(10);
|
||||
|
||||
// Check split distribution - sequential should give exactly 5 and 5
|
||||
const split0Count = await permutationTable.countRows("split_id = 0");
|
||||
const split1Count = await permutationTable.countRows("split_id = 1");
|
||||
|
||||
expect(split0Count).toBe(5);
|
||||
expect(split1Count).toBe(5);
|
||||
});
|
||||
|
||||
test("should create permutation with calculated splits", async () => {
|
||||
const builder = permutationBuilder(table).splitCalculated("id % 2");
|
||||
|
||||
const permutationTable = await builder.execute();
|
||||
const rowCount = await permutationTable.countRows();
|
||||
expect(rowCount).toBe(10);
|
||||
|
||||
// Check split distribution
|
||||
const split0Count = await permutationTable.countRows("split_id = 0");
|
||||
const split1Count = await permutationTable.countRows("split_id = 1");
|
||||
|
||||
expect(split0Count).toBeGreaterThan(0);
|
||||
expect(split1Count).toBeGreaterThan(0);
|
||||
expect(split0Count + split1Count).toBe(10);
|
||||
});
|
||||
|
||||
test("should create permutation with shuffle", async () => {
|
||||
const builder = permutationBuilder(table).shuffle({
|
||||
seed: 42,
|
||||
});
|
||||
|
||||
const permutationTable = await builder.execute();
|
||||
const rowCount = await permutationTable.countRows();
|
||||
expect(rowCount).toBe(10);
|
||||
});
|
||||
|
||||
test("should create permutation with shuffle and clump size", async () => {
|
||||
const builder = permutationBuilder(table).shuffle({
|
||||
seed: 42,
|
||||
clumpSize: 2,
|
||||
});
|
||||
|
||||
const permutationTable = await builder.execute();
|
||||
const rowCount = await permutationTable.countRows();
|
||||
expect(rowCount).toBe(10);
|
||||
});
|
||||
|
||||
test("should create permutation with filter", async () => {
|
||||
const builder = permutationBuilder(table).filter("value > 50");
|
||||
|
||||
const permutationTable = await builder.execute();
|
||||
const rowCount = await permutationTable.countRows();
|
||||
expect(rowCount).toBe(5); // Values 60, 70, 80, 90, 100
|
||||
});
|
||||
|
||||
test("should chain multiple operations", async () => {
|
||||
const builder = permutationBuilder(table)
|
||||
.filter("value <= 80")
|
||||
.splitRandom({ ratios: [0.5, 0.5], seed: 42 })
|
||||
.shuffle({ seed: 123 });
|
||||
|
||||
const permutationTable = await builder.execute();
|
||||
const rowCount = await permutationTable.countRows();
|
||||
expect(rowCount).toBe(8); // Values 10, 20, 30, 40, 50, 60, 70, 80
|
||||
|
||||
// Check split distribution
|
||||
const split0Count = await permutationTable.countRows("split_id = 0");
|
||||
const split1Count = await permutationTable.countRows("split_id = 1");
|
||||
|
||||
expect(split0Count).toBeGreaterThan(0);
|
||||
expect(split1Count).toBeGreaterThan(0);
|
||||
expect(split0Count + split1Count).toBe(8);
|
||||
});
|
||||
|
||||
test("should throw error for invalid split arguments", () => {
|
||||
const builder = permutationBuilder(table);
|
||||
|
||||
// Test no arguments provided
|
||||
expect(() => builder.splitRandom({})).toThrow(
|
||||
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
||||
);
|
||||
|
||||
// Test multiple arguments provided
|
||||
expect(() =>
|
||||
builder.splitRandom({ ratios: [0.5, 0.5], counts: [3, 7], seed: 42 }),
|
||||
).toThrow("Exactly one of 'ratios', 'counts', or 'fixed' must be provided");
|
||||
});
|
||||
|
||||
test("should throw error when builder is consumed", async () => {
|
||||
const builder = permutationBuilder(table);
|
||||
|
||||
// Execute once
|
||||
await builder.execute();
|
||||
|
||||
// Should throw error on second execution
|
||||
await expect(builder.execute()).rejects.toThrow("Builder already consumed");
|
||||
});
|
||||
});
|
||||
111
nodejs/__test__/query.test.ts
Normal file
111
nodejs/__test__/query.test.ts
Normal file
@@ -0,0 +1,111 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import * as tmp from "tmp";
|
||||
|
||||
import { type Table, connect } from "../lancedb";
|
||||
import {
|
||||
Field,
|
||||
FixedSizeList,
|
||||
Float32,
|
||||
Int64,
|
||||
Schema,
|
||||
Utf8,
|
||||
makeArrowTable,
|
||||
} from "../lancedb/arrow";
|
||||
import { Index } from "../lancedb/indices";
|
||||
|
||||
describe("Query outputSchema", () => {
|
||||
let tmpDir: tmp.DirResult;
|
||||
let table: Table;
|
||||
|
||||
beforeEach(async () => {
|
||||
tmpDir = tmp.dirSync({ unsafeCleanup: true });
|
||||
const db = await connect(tmpDir.name);
|
||||
|
||||
// Create table with explicit schema to ensure proper types
|
||||
const schema = new Schema([
|
||||
new Field("a", new Int64(), true),
|
||||
new Field("text", new Utf8(), true),
|
||||
new Field(
|
||||
"vec",
|
||||
new FixedSizeList(2, new Field("item", new Float32())),
|
||||
true,
|
||||
),
|
||||
]);
|
||||
|
||||
const data = makeArrowTable(
|
||||
[
|
||||
{ a: 1n, text: "foo", vec: [1, 2] },
|
||||
{ a: 2n, text: "bar", vec: [3, 4] },
|
||||
{ a: 3n, text: "baz", vec: [5, 6] },
|
||||
],
|
||||
{ schema },
|
||||
);
|
||||
table = await db.createTable("test", data);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
tmpDir.removeCallback();
|
||||
});
|
||||
|
||||
it("should return schema for plain query", async () => {
|
||||
const schema = await table.query().outputSchema();
|
||||
|
||||
expect(schema.fields.length).toBe(3);
|
||||
expect(schema.fields.map((f) => f.name)).toEqual(["a", "text", "vec"]);
|
||||
expect(schema.fields[0].type.toString()).toBe("Int64");
|
||||
expect(schema.fields[1].type.toString()).toBe("Utf8");
|
||||
});
|
||||
|
||||
it("should return schema with dynamic projection", async () => {
|
||||
const schema = await table.query().select({ bl: "a * 2" }).outputSchema();
|
||||
|
||||
expect(schema.fields.length).toBe(1);
|
||||
expect(schema.fields[0].name).toBe("bl");
|
||||
expect(schema.fields[0].type.toString()).toBe("Int64");
|
||||
});
|
||||
|
||||
it("should return schema for vector search with _distance column", async () => {
|
||||
const schema = await table
|
||||
.vectorSearch([1, 2])
|
||||
.select(["a"])
|
||||
.outputSchema();
|
||||
|
||||
expect(schema.fields.length).toBe(2);
|
||||
expect(schema.fields.map((f) => f.name)).toEqual(["a", "_distance"]);
|
||||
expect(schema.fields[0].type.toString()).toBe("Int64");
|
||||
expect(schema.fields[1].type.toString()).toBe("Float32");
|
||||
});
|
||||
|
||||
it("should return schema for FTS search", async () => {
|
||||
await table.createIndex("text", { config: Index.fts() });
|
||||
|
||||
const schema = await table
|
||||
.search("foo", "fts")
|
||||
.select(["a"])
|
||||
.outputSchema();
|
||||
|
||||
// FTS search includes _score column in addition to selected columns
|
||||
expect(schema.fields.length).toBe(2);
|
||||
expect(schema.fields.map((f) => f.name)).toContain("a");
|
||||
expect(schema.fields.map((f) => f.name)).toContain("_score");
|
||||
const aField = schema.fields.find((f) => f.name === "a");
|
||||
expect(aField?.type.toString()).toBe("Int64");
|
||||
});
|
||||
|
||||
it("should return schema for take query", async () => {
|
||||
const schema = await table.takeOffsets([0]).select(["text"]).outputSchema();
|
||||
|
||||
expect(schema.fields.length).toBe(1);
|
||||
expect(schema.fields[0].name).toBe("text");
|
||||
expect(schema.fields[0].type.toString()).toBe("Utf8");
|
||||
});
|
||||
|
||||
it("should return full schema when no select is specified", async () => {
|
||||
const schema = await table.query().outputSchema();
|
||||
|
||||
// Should return all columns
|
||||
expect(schema.fields.length).toBe(3);
|
||||
});
|
||||
});
|
||||
184
nodejs/__test__/sanitize.test.ts
Normal file
184
nodejs/__test__/sanitize.test.ts
Normal file
@@ -0,0 +1,184 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import * as arrow from "../lancedb/arrow";
|
||||
import { sanitizeField, sanitizeType } from "../lancedb/sanitize";
|
||||
|
||||
describe("sanitize", function () {
|
||||
describe("sanitizeType function", function () {
|
||||
it("should handle type objects", function () {
|
||||
const type = new arrow.Int32();
|
||||
const result = sanitizeType(type);
|
||||
|
||||
expect(result.typeId).toBe(arrow.Type.Int);
|
||||
expect((result as arrow.Int).bitWidth).toBe(32);
|
||||
expect((result as arrow.Int).isSigned).toBe(true);
|
||||
|
||||
const floatType = {
|
||||
typeId: 3, // Type.Float = 3
|
||||
precision: 2,
|
||||
toString: () => "Float",
|
||||
isFloat: true,
|
||||
isFixedWidth: true,
|
||||
};
|
||||
|
||||
const floatResult = sanitizeType(floatType);
|
||||
expect(floatResult).toBeInstanceOf(arrow.DataType);
|
||||
expect(floatResult.typeId).toBe(arrow.Type.Float);
|
||||
|
||||
const floatResult2 = sanitizeType({ ...floatType, typeId: () => 3 });
|
||||
expect(floatResult2).toBeInstanceOf(arrow.DataType);
|
||||
expect(floatResult2.typeId).toBe(arrow.Type.Float);
|
||||
});
|
||||
|
||||
const allTypeNameTestCases = [
|
||||
["null", new arrow.Null()],
|
||||
["binary", new arrow.Binary()],
|
||||
["utf8", new arrow.Utf8()],
|
||||
["bool", new arrow.Bool()],
|
||||
["int8", new arrow.Int8()],
|
||||
["int16", new arrow.Int16()],
|
||||
["int32", new arrow.Int32()],
|
||||
["int64", new arrow.Int64()],
|
||||
["uint8", new arrow.Uint8()],
|
||||
["uint16", new arrow.Uint16()],
|
||||
["uint32", new arrow.Uint32()],
|
||||
["uint64", new arrow.Uint64()],
|
||||
["float16", new arrow.Float16()],
|
||||
["float32", new arrow.Float32()],
|
||||
["float64", new arrow.Float64()],
|
||||
["datemillisecond", new arrow.DateMillisecond()],
|
||||
["dateday", new arrow.DateDay()],
|
||||
["timenanosecond", new arrow.TimeNanosecond()],
|
||||
["timemicrosecond", new arrow.TimeMicrosecond()],
|
||||
["timemillisecond", new arrow.TimeMillisecond()],
|
||||
["timesecond", new arrow.TimeSecond()],
|
||||
["intervaldaytime", new arrow.IntervalDayTime()],
|
||||
["intervalyearmonth", new arrow.IntervalYearMonth()],
|
||||
["durationnanosecond", new arrow.DurationNanosecond()],
|
||||
["durationmicrosecond", new arrow.DurationMicrosecond()],
|
||||
["durationmillisecond", new arrow.DurationMillisecond()],
|
||||
["durationsecond", new arrow.DurationSecond()],
|
||||
] as const;
|
||||
|
||||
it.each(allTypeNameTestCases)(
|
||||
'should map type name "%s" to %s',
|
||||
function (name, expected) {
|
||||
const result = sanitizeType(name);
|
||||
expect(result).toBeInstanceOf(expected.constructor);
|
||||
},
|
||||
);
|
||||
|
||||
const caseVariationTestCases = [
|
||||
["NULL", new arrow.Null()],
|
||||
["Utf8", new arrow.Utf8()],
|
||||
["FLOAT32", new arrow.Float32()],
|
||||
["DaTedAy", new arrow.DateDay()],
|
||||
] as const;
|
||||
|
||||
it.each(caseVariationTestCases)(
|
||||
'should be case insensitive for type name "%s" mapped to %s',
|
||||
function (name, expected) {
|
||||
const result = sanitizeType(name);
|
||||
expect(result).toBeInstanceOf(expected.constructor);
|
||||
},
|
||||
);
|
||||
|
||||
it("should throw error for unrecognized type name", function () {
|
||||
expect(() => sanitizeType("invalid_type")).toThrow(
|
||||
"Unrecognized type name in schema: invalid_type",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("sanitizeField function", function () {
|
||||
it("should handle field with string type name", function () {
|
||||
const field = sanitizeField({
|
||||
name: "string_field",
|
||||
type: "utf8",
|
||||
nullable: true,
|
||||
metadata: new Map([["key", "value"]]),
|
||||
});
|
||||
|
||||
expect(field).toBeInstanceOf(arrow.Field);
|
||||
expect(field.name).toBe("string_field");
|
||||
expect(field.type).toBeInstanceOf(arrow.Utf8);
|
||||
expect(field.nullable).toBe(true);
|
||||
expect(field.metadata?.get("key")).toBe("value");
|
||||
});
|
||||
|
||||
it("should handle field with type object", function () {
|
||||
const floatType = {
|
||||
typeId: 3, // Float
|
||||
precision: 32,
|
||||
};
|
||||
|
||||
const field = sanitizeField({
|
||||
name: "float_field",
|
||||
type: floatType,
|
||||
nullable: false,
|
||||
});
|
||||
|
||||
expect(field).toBeInstanceOf(arrow.Field);
|
||||
expect(field.name).toBe("float_field");
|
||||
expect(field.type).toBeInstanceOf(arrow.DataType);
|
||||
expect(field.type.typeId).toBe(arrow.Type.Float);
|
||||
expect((field.type as arrow.Float64).precision).toBe(32);
|
||||
expect(field.nullable).toBe(false);
|
||||
});
|
||||
|
||||
it("should handle field with direct Type instance", function () {
|
||||
const field = sanitizeField({
|
||||
name: "bool_field",
|
||||
type: new arrow.Bool(),
|
||||
nullable: true,
|
||||
});
|
||||
|
||||
expect(field).toBeInstanceOf(arrow.Field);
|
||||
expect(field.name).toBe("bool_field");
|
||||
expect(field.type).toBeInstanceOf(arrow.Bool);
|
||||
expect(field.nullable).toBe(true);
|
||||
});
|
||||
|
||||
it("should throw error for invalid field object", function () {
|
||||
expect(() =>
|
||||
sanitizeField({
|
||||
type: "int32",
|
||||
nullable: true,
|
||||
}),
|
||||
).toThrow(
|
||||
"The field passed in is missing a `type`/`name`/`nullable` property",
|
||||
);
|
||||
|
||||
// Invalid type
|
||||
expect(() =>
|
||||
sanitizeField({
|
||||
name: "invalid",
|
||||
type: { invalid: true },
|
||||
nullable: true,
|
||||
}),
|
||||
).toThrow("Expected a Type to have a typeId property");
|
||||
|
||||
// Invalid nullable
|
||||
expect(() =>
|
||||
sanitizeField({
|
||||
name: "invalid_nullable",
|
||||
type: "int32",
|
||||
nullable: "not a boolean",
|
||||
}),
|
||||
).toThrow("The field passed in had a non-boolean `nullable` property");
|
||||
});
|
||||
|
||||
it("should report error for invalid type name", function () {
|
||||
expect(() =>
|
||||
sanitizeField({
|
||||
name: "invalid_field",
|
||||
type: "invalid_type",
|
||||
nullable: true,
|
||||
}),
|
||||
).toThrow(
|
||||
"Unable to sanitize type for field: invalid_field due to error: Error: Unrecognized type name in schema: invalid_type",
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -10,7 +10,13 @@ import * as arrow16 from "apache-arrow-16";
|
||||
import * as arrow17 from "apache-arrow-17";
|
||||
import * as arrow18 from "apache-arrow-18";
|
||||
|
||||
import { MatchQuery, PhraseQuery, Table, connect } from "../lancedb";
|
||||
import {
|
||||
Connection,
|
||||
MatchQuery,
|
||||
PhraseQuery,
|
||||
Table,
|
||||
connect,
|
||||
} from "../lancedb";
|
||||
import {
|
||||
Table as ArrowTable,
|
||||
Field,
|
||||
@@ -21,6 +27,8 @@ import {
|
||||
Int64,
|
||||
List,
|
||||
Schema,
|
||||
SchemaLike,
|
||||
Type,
|
||||
Uint8,
|
||||
Utf8,
|
||||
makeArrowTable,
|
||||
@@ -853,6 +861,15 @@ describe("When creating an index", () => {
|
||||
});
|
||||
});
|
||||
|
||||
it("should be able to create IVF_RQ", async () => {
|
||||
await tbl.createIndex("vec", {
|
||||
config: Index.ivfRq({
|
||||
numPartitions: 10,
|
||||
numBits: 1,
|
||||
}),
|
||||
});
|
||||
});
|
||||
|
||||
it("should allow me to replace (or not) an existing index", async () => {
|
||||
await tbl.createIndex("id");
|
||||
// Default is replace=true
|
||||
@@ -2019,3 +2036,52 @@ describe("column name options", () => {
|
||||
expect(results2.length).toBe(10);
|
||||
});
|
||||
});
|
||||
|
||||
describe("when creating an empty table", () => {
|
||||
let con: Connection;
|
||||
beforeEach(async () => {
|
||||
const tmpDir = tmp.dirSync({ unsafeCleanup: true });
|
||||
con = await connect(tmpDir.name);
|
||||
});
|
||||
afterEach(() => {
|
||||
con.close();
|
||||
});
|
||||
|
||||
it("can create an empty table from an arrow Schema", async () => {
|
||||
const schema = new Schema([
|
||||
new Field("id", new Int64()),
|
||||
new Field("vector", new Float64()),
|
||||
]);
|
||||
const table = await con.createEmptyTable("test", schema);
|
||||
const actualSchema = await table.schema();
|
||||
expect(actualSchema.fields[0].type.typeId).toBe(Type.Int);
|
||||
expect((actualSchema.fields[0].type as Int64).bitWidth).toBe(64);
|
||||
expect(actualSchema.fields[1].type.typeId).toBe(Type.Float);
|
||||
expect((actualSchema.fields[1].type as Float64).precision).toBe(2);
|
||||
});
|
||||
|
||||
it("can create an empty table from schema that specifies field types by name", async () => {
|
||||
const schemaLike = {
|
||||
fields: [
|
||||
{
|
||||
name: "id",
|
||||
type: "int64",
|
||||
nullable: true,
|
||||
},
|
||||
{
|
||||
name: "vector",
|
||||
type: "float64",
|
||||
nullable: true,
|
||||
},
|
||||
],
|
||||
metadata: new Map(),
|
||||
names: ["id", "vector"],
|
||||
} satisfies SchemaLike;
|
||||
const table = await con.createEmptyTable("test", schemaLike);
|
||||
const actualSchema = await table.schema();
|
||||
expect(actualSchema.fields[0].type.typeId).toBe(Type.Int);
|
||||
expect((actualSchema.fields[0].type as Int64).bitWidth).toBe(64);
|
||||
expect(actualSchema.fields[1].type.typeId).toBe(Type.Float);
|
||||
expect((actualSchema.fields[1].type as Float64).precision).toBe(2);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -73,7 +73,7 @@ export type FieldLike =
|
||||
| {
|
||||
type: string;
|
||||
name: string;
|
||||
nullable?: boolean;
|
||||
nullable: boolean;
|
||||
metadata?: Map<string, string>;
|
||||
};
|
||||
|
||||
|
||||
@@ -43,6 +43,10 @@ export {
|
||||
DeleteResult,
|
||||
DropColumnsResult,
|
||||
UpdateResult,
|
||||
SplitRandomOptions,
|
||||
SplitHashOptions,
|
||||
SplitSequentialOptions,
|
||||
ShuffleOptions,
|
||||
} from "./native.js";
|
||||
|
||||
export {
|
||||
@@ -85,6 +89,7 @@ export {
|
||||
Index,
|
||||
IndexOptions,
|
||||
IvfPqOptions,
|
||||
IvfRqOptions,
|
||||
IvfFlatOptions,
|
||||
HnswPqOptions,
|
||||
HnswSqOptions,
|
||||
@@ -110,6 +115,7 @@ export {
|
||||
export { MergeInsertBuilder, WriteExecutionOptions } from "./merge";
|
||||
|
||||
export * as embedding from "./embedding";
|
||||
export { permutationBuilder, PermutationBuilder } from "./permutation";
|
||||
export * as rerankers from "./rerankers";
|
||||
export {
|
||||
SchemaLike,
|
||||
|
||||
@@ -112,6 +112,77 @@ export interface IvfPqOptions {
|
||||
sampleRate?: number;
|
||||
}
|
||||
|
||||
export interface IvfRqOptions {
|
||||
/**
|
||||
* The number of IVF partitions to create.
|
||||
*
|
||||
* This value should generally scale with the number of rows in the dataset.
|
||||
* By default the number of partitions is the square root of the number of
|
||||
* rows.
|
||||
*
|
||||
* If this value is too large then the first part of the search (picking the
|
||||
* right partition) will be slow. If this value is too small then the second
|
||||
* part of the search (searching within a partition) will be slow.
|
||||
*/
|
||||
numPartitions?: number;
|
||||
|
||||
/**
|
||||
* Number of bits per dimension for residual quantization.
|
||||
*
|
||||
* This value controls how much each residual component is compressed. The more
|
||||
* bits, the more accurate the index will be but the slower search. Typical values
|
||||
* are small integers; the default is 1 bit per dimension.
|
||||
*/
|
||||
numBits?: number;
|
||||
|
||||
/**
|
||||
* Distance type to use to build the index.
|
||||
*
|
||||
* Default value is "l2".
|
||||
*
|
||||
* This is used when training the index to calculate the IVF partitions
|
||||
* (vectors are grouped in partitions with similar vectors according to this
|
||||
* distance type) and during quantization.
|
||||
*
|
||||
* The distance type used to train an index MUST match the distance type used
|
||||
* to search the index. Failure to do so will yield inaccurate results.
|
||||
*
|
||||
* The following distance types are available:
|
||||
*
|
||||
* "l2" - Euclidean distance.
|
||||
* "cosine" - Cosine distance.
|
||||
* "dot" - Dot product.
|
||||
*/
|
||||
distanceType?: "l2" | "cosine" | "dot";
|
||||
|
||||
/**
|
||||
* Max iterations to train IVF kmeans.
|
||||
*
|
||||
* When training an IVF index we use kmeans to calculate the partitions. This parameter
|
||||
* controls how many iterations of kmeans to run.
|
||||
*
|
||||
* The default value is 50.
|
||||
*/
|
||||
maxIterations?: number;
|
||||
|
||||
/**
|
||||
* The number of vectors, per partition, to sample when training IVF kmeans.
|
||||
*
|
||||
* When an IVF index is trained, we need to calculate partitions. These are groups
|
||||
* of vectors that are similar to each other. To do this we use an algorithm called kmeans.
|
||||
*
|
||||
* Running kmeans on a large dataset can be slow. To speed this up we run kmeans on a
|
||||
* random sample of the data. This parameter controls the size of the sample. The total
|
||||
* number of vectors used to train the index is `sample_rate * num_partitions`.
|
||||
*
|
||||
* Increasing this value might improve the quality of the index but in most cases the
|
||||
* default should be sufficient.
|
||||
*
|
||||
* The default value is 256.
|
||||
*/
|
||||
sampleRate?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Options to create an `HNSW_PQ` index
|
||||
*/
|
||||
@@ -523,6 +594,35 @@ export class Index {
|
||||
options?.distanceType,
|
||||
options?.numPartitions,
|
||||
options?.numSubVectors,
|
||||
options?.numBits,
|
||||
options?.maxIterations,
|
||||
options?.sampleRate,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an IvfRq index
|
||||
*
|
||||
* IVF-RQ (RabitQ Quantization) compresses vectors using RabitQ quantization
|
||||
* and organizes them into IVF partitions.
|
||||
*
|
||||
* The compression scheme is called RabitQ quantization. Each dimension is quantized into a small number of bits.
|
||||
* The parameters `num_bits` and `num_partitions` control this process, providing a tradeoff
|
||||
* between index size (and thus search speed) and index accuracy.
|
||||
*
|
||||
* The partitioning process is called IVF and the `num_partitions` parameter controls how
|
||||
* many groups to create.
|
||||
*
|
||||
* Note that training an IVF RQ index on a large dataset is a slow operation and
|
||||
* currently is also a memory intensive operation.
|
||||
*/
|
||||
static ivfRq(options?: Partial<IvfRqOptions>) {
|
||||
return new Index(
|
||||
LanceDbIndex.ivfRq(
|
||||
options?.distanceType,
|
||||
options?.numPartitions,
|
||||
options?.numBits,
|
||||
options?.maxIterations,
|
||||
options?.sampleRate,
|
||||
),
|
||||
|
||||
183
nodejs/lancedb/permutation.ts
Normal file
183
nodejs/lancedb/permutation.ts
Normal file
@@ -0,0 +1,183 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import {
|
||||
PermutationBuilder as NativePermutationBuilder,
|
||||
Table as NativeTable,
|
||||
ShuffleOptions,
|
||||
SplitHashOptions,
|
||||
SplitRandomOptions,
|
||||
SplitSequentialOptions,
|
||||
permutationBuilder as nativePermutationBuilder,
|
||||
} from "./native.js";
|
||||
import { LocalTable, Table } from "./table";
|
||||
|
||||
/**
|
||||
* A PermutationBuilder for creating data permutations with splits, shuffling, and filtering.
|
||||
*
|
||||
* This class provides a TypeScript wrapper around the native Rust PermutationBuilder,
|
||||
* offering methods to configure data splits, shuffling, and filtering before executing
|
||||
* the permutation to create a new table.
|
||||
*/
|
||||
export class PermutationBuilder {
|
||||
private inner: NativePermutationBuilder;
|
||||
|
||||
/**
|
||||
* @hidden
|
||||
*/
|
||||
constructor(inner: NativePermutationBuilder) {
|
||||
this.inner = inner;
|
||||
}
|
||||
|
||||
/**
|
||||
* Configure random splits for the permutation.
|
||||
*
|
||||
* @param options - Configuration for random splitting
|
||||
* @returns A new PermutationBuilder instance
|
||||
* @example
|
||||
* ```ts
|
||||
* // Split by ratios
|
||||
* builder.splitRandom({ ratios: [0.7, 0.3], seed: 42 });
|
||||
*
|
||||
* // Split by counts
|
||||
* builder.splitRandom({ counts: [1000, 500], seed: 42 });
|
||||
*
|
||||
* // Split with fixed size
|
||||
* builder.splitRandom({ fixed: 100, seed: 42 });
|
||||
* ```
|
||||
*/
|
||||
splitRandom(options: SplitRandomOptions): PermutationBuilder {
|
||||
const newInner = this.inner.splitRandom(options);
|
||||
return new PermutationBuilder(newInner);
|
||||
}
|
||||
|
||||
/**
|
||||
* Configure hash-based splits for the permutation.
|
||||
*
|
||||
* @param options - Configuration for hash-based splitting
|
||||
* @returns A new PermutationBuilder instance
|
||||
* @example
|
||||
* ```ts
|
||||
* builder.splitHash({
|
||||
* columns: ["user_id"],
|
||||
* splitWeights: [70, 30],
|
||||
* discardWeight: 0
|
||||
* });
|
||||
* ```
|
||||
*/
|
||||
splitHash(options: SplitHashOptions): PermutationBuilder {
|
||||
const newInner = this.inner.splitHash(options);
|
||||
return new PermutationBuilder(newInner);
|
||||
}
|
||||
|
||||
/**
|
||||
* Configure sequential splits for the permutation.
|
||||
*
|
||||
* @param options - Configuration for sequential splitting
|
||||
* @returns A new PermutationBuilder instance
|
||||
* @example
|
||||
* ```ts
|
||||
* // Split by ratios
|
||||
* builder.splitSequential({ ratios: [0.8, 0.2] });
|
||||
*
|
||||
* // Split by counts
|
||||
* builder.splitSequential({ counts: [800, 200] });
|
||||
*
|
||||
* // Split with fixed size
|
||||
* builder.splitSequential({ fixed: 1000 });
|
||||
* ```
|
||||
*/
|
||||
splitSequential(options: SplitSequentialOptions): PermutationBuilder {
|
||||
const newInner = this.inner.splitSequential(options);
|
||||
return new PermutationBuilder(newInner);
|
||||
}
|
||||
|
||||
/**
|
||||
* Configure calculated splits for the permutation.
|
||||
*
|
||||
* @param calculation - SQL expression for calculating splits
|
||||
* @returns A new PermutationBuilder instance
|
||||
* @example
|
||||
* ```ts
|
||||
* builder.splitCalculated("user_id % 3");
|
||||
* ```
|
||||
*/
|
||||
splitCalculated(calculation: string): PermutationBuilder {
|
||||
const newInner = this.inner.splitCalculated(calculation);
|
||||
return new PermutationBuilder(newInner);
|
||||
}
|
||||
|
||||
/**
|
||||
* Configure shuffling for the permutation.
|
||||
*
|
||||
* @param options - Configuration for shuffling
|
||||
* @returns A new PermutationBuilder instance
|
||||
* @example
|
||||
* ```ts
|
||||
* // Basic shuffle
|
||||
* builder.shuffle({ seed: 42 });
|
||||
*
|
||||
* // Shuffle with clump size
|
||||
* builder.shuffle({ seed: 42, clumpSize: 10 });
|
||||
* ```
|
||||
*/
|
||||
shuffle(options: ShuffleOptions): PermutationBuilder {
|
||||
const newInner = this.inner.shuffle(options);
|
||||
return new PermutationBuilder(newInner);
|
||||
}
|
||||
|
||||
/**
|
||||
* Configure filtering for the permutation.
|
||||
*
|
||||
* @param filter - SQL filter expression
|
||||
* @returns A new PermutationBuilder instance
|
||||
* @example
|
||||
* ```ts
|
||||
* builder.filter("age > 18 AND status = 'active'");
|
||||
* ```
|
||||
*/
|
||||
filter(filter: string): PermutationBuilder {
|
||||
const newInner = this.inner.filter(filter);
|
||||
return new PermutationBuilder(newInner);
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute the permutation and create the destination table.
|
||||
*
|
||||
* @returns A Promise that resolves to the new Table instance
|
||||
* @example
|
||||
* ```ts
|
||||
* const permutationTable = await builder.execute();
|
||||
* console.log(`Created table: ${permutationTable.name}`);
|
||||
* ```
|
||||
*/
|
||||
async execute(): Promise<Table> {
|
||||
const nativeTable: NativeTable = await this.inner.execute();
|
||||
return new LocalTable(nativeTable);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a permutation builder for the given table.
|
||||
*
|
||||
* @param table - The source table to create a permutation from
|
||||
* @returns A PermutationBuilder instance
|
||||
* @example
|
||||
* ```ts
|
||||
* const builder = permutationBuilder(sourceTable, "training_data")
|
||||
* .splitRandom({ ratios: [0.8, 0.2], seed: 42 })
|
||||
* .shuffle({ seed: 123 });
|
||||
*
|
||||
* const trainingTable = await builder.execute();
|
||||
* ```
|
||||
*/
|
||||
export function permutationBuilder(table: Table): PermutationBuilder {
|
||||
// Extract the inner native table from the TypeScript wrapper
|
||||
const localTable = table as LocalTable;
|
||||
// Access inner through type assertion since it's private
|
||||
const nativeBuilder = nativePermutationBuilder(
|
||||
// biome-ignore lint/suspicious/noExplicitAny: need access to private variable
|
||||
(localTable as any).inner,
|
||||
);
|
||||
return new PermutationBuilder(nativeBuilder);
|
||||
}
|
||||
@@ -326,6 +326,25 @@ export class QueryBase<
|
||||
return this.inner.analyzePlan();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the schema of the output that will be returned by this query.
|
||||
*
|
||||
* This can be used to inspect the types and names of the columns that will be
|
||||
* returned by the query before executing it.
|
||||
*
|
||||
* @returns An Arrow Schema describing the output columns.
|
||||
*/
|
||||
async outputSchema(): Promise<import("./arrow").Schema> {
|
||||
let schemaBuffer: Buffer;
|
||||
if (this.inner instanceof Promise) {
|
||||
schemaBuffer = await this.inner.then((inner) => inner.outputSchema());
|
||||
} else {
|
||||
schemaBuffer = await this.inner.outputSchema();
|
||||
}
|
||||
const schema = tableFromIPC(schemaBuffer).schema;
|
||||
return schema;
|
||||
}
|
||||
}
|
||||
|
||||
export class StandardQueryBase<
|
||||
|
||||
@@ -326,6 +326,9 @@ export function sanitizeDictionary(typeLike: object) {
|
||||
|
||||
// biome-ignore lint/suspicious/noExplicitAny: skip
|
||||
export function sanitizeType(typeLike: unknown): DataType<any> {
|
||||
if (typeof typeLike === "string") {
|
||||
return dataTypeFromName(typeLike);
|
||||
}
|
||||
if (typeof typeLike !== "object" || typeLike === null) {
|
||||
throw Error("Expected a Type but object was null/undefined");
|
||||
}
|
||||
@@ -447,7 +450,7 @@ export function sanitizeType(typeLike: unknown): DataType<any> {
|
||||
case Type.DurationSecond:
|
||||
return new DurationSecond();
|
||||
default:
|
||||
throw new Error("Unrecoginized type id in schema: " + typeId);
|
||||
throw new Error("Unrecognized type id in schema: " + typeId);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -467,7 +470,15 @@ export function sanitizeField(fieldLike: unknown): Field {
|
||||
"The field passed in is missing a `type`/`name`/`nullable` property",
|
||||
);
|
||||
}
|
||||
const type = sanitizeType(fieldLike.type);
|
||||
let type: DataType;
|
||||
try {
|
||||
type = sanitizeType(fieldLike.type);
|
||||
} catch (error: unknown) {
|
||||
throw Error(
|
||||
`Unable to sanitize type for field: ${fieldLike.name} due to error: ${error}`,
|
||||
{ cause: error },
|
||||
);
|
||||
}
|
||||
const name = fieldLike.name;
|
||||
if (!(typeof name === "string")) {
|
||||
throw Error("The field passed in had a non-string `name` property");
|
||||
@@ -581,3 +592,46 @@ function sanitizeData(
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
const constructorsByTypeName = {
|
||||
null: () => new Null(),
|
||||
binary: () => new Binary(),
|
||||
utf8: () => new Utf8(),
|
||||
bool: () => new Bool(),
|
||||
int8: () => new Int8(),
|
||||
int16: () => new Int16(),
|
||||
int32: () => new Int32(),
|
||||
int64: () => new Int64(),
|
||||
uint8: () => new Uint8(),
|
||||
uint16: () => new Uint16(),
|
||||
uint32: () => new Uint32(),
|
||||
uint64: () => new Uint64(),
|
||||
float16: () => new Float16(),
|
||||
float32: () => new Float32(),
|
||||
float64: () => new Float64(),
|
||||
datemillisecond: () => new DateMillisecond(),
|
||||
dateday: () => new DateDay(),
|
||||
timenanosecond: () => new TimeNanosecond(),
|
||||
timemicrosecond: () => new TimeMicrosecond(),
|
||||
timemillisecond: () => new TimeMillisecond(),
|
||||
timesecond: () => new TimeSecond(),
|
||||
intervaldaytime: () => new IntervalDayTime(),
|
||||
intervalyearmonth: () => new IntervalYearMonth(),
|
||||
durationnanosecond: () => new DurationNanosecond(),
|
||||
durationmicrosecond: () => new DurationMicrosecond(),
|
||||
durationmillisecond: () => new DurationMillisecond(),
|
||||
durationsecond: () => new DurationSecond(),
|
||||
} as const;
|
||||
|
||||
type MappableTypeName = keyof typeof constructorsByTypeName;
|
||||
|
||||
export function dataTypeFromName(typeName: string): DataType {
|
||||
const normalizedTypeName = typeName.toLowerCase() as MappableTypeName;
|
||||
const _constructor = constructorsByTypeName[normalizedTypeName];
|
||||
|
||||
if (!_constructor) {
|
||||
throw new Error("Unrecognized type name in schema: " + typeName);
|
||||
}
|
||||
|
||||
return _constructor();
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-arm64",
|
||||
"version": "0.22.2-beta.1",
|
||||
"version": "0.22.3-beta.0",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.darwin-arm64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-x64",
|
||||
"version": "0.22.2-beta.1",
|
||||
"version": "0.22.3-beta.0",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.darwin-x64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||
"version": "0.22.2-beta.1",
|
||||
"version": "0.22.3-beta.0",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||
"version": "0.22.2-beta.1",
|
||||
"version": "0.22.3-beta.0",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||
"version": "0.22.2-beta.1",
|
||||
"version": "0.22.3-beta.0",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||
"version": "0.22.2-beta.1",
|
||||
"version": "0.22.3-beta.0",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||
"version": "0.22.2-beta.1",
|
||||
"version": "0.22.3-beta.0",
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||
"version": "0.22.2-beta.1",
|
||||
"version": "0.22.3-beta.0",
|
||||
"os": ["win32"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.win32-x64-msvc.node",
|
||||
|
||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.22.2-beta.1",
|
||||
"version": "0.22.3-beta.0",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.22.2-beta.1",
|
||||
"version": "0.22.3-beta.0",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
"ann"
|
||||
],
|
||||
"private": false,
|
||||
"version": "0.22.2-beta.1",
|
||||
"version": "0.22.3-beta.0",
|
||||
"main": "dist/index.js",
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
|
||||
@@ -6,6 +6,7 @@ use std::sync::Mutex;
|
||||
use lancedb::index::scalar::{BTreeIndexBuilder, FtsIndexBuilder};
|
||||
use lancedb::index::vector::{
|
||||
IvfFlatIndexBuilder, IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder,
|
||||
IvfRqIndexBuilder,
|
||||
};
|
||||
use lancedb::index::Index as LanceDbIndex;
|
||||
use napi_derive::napi;
|
||||
@@ -65,6 +66,36 @@ impl Index {
|
||||
})
|
||||
}
|
||||
|
||||
#[napi(factory)]
|
||||
pub fn ivf_rq(
|
||||
distance_type: Option<String>,
|
||||
num_partitions: Option<u32>,
|
||||
num_bits: Option<u32>,
|
||||
max_iterations: Option<u32>,
|
||||
sample_rate: Option<u32>,
|
||||
) -> napi::Result<Self> {
|
||||
let mut ivf_rq_builder = IvfRqIndexBuilder::default();
|
||||
if let Some(distance_type) = distance_type {
|
||||
let distance_type = parse_distance_type(distance_type)?;
|
||||
ivf_rq_builder = ivf_rq_builder.distance_type(distance_type);
|
||||
}
|
||||
if let Some(num_partitions) = num_partitions {
|
||||
ivf_rq_builder = ivf_rq_builder.num_partitions(num_partitions);
|
||||
}
|
||||
if let Some(num_bits) = num_bits {
|
||||
ivf_rq_builder = ivf_rq_builder.num_bits(num_bits);
|
||||
}
|
||||
if let Some(max_iterations) = max_iterations {
|
||||
ivf_rq_builder = ivf_rq_builder.max_iterations(max_iterations);
|
||||
}
|
||||
if let Some(sample_rate) = sample_rate {
|
||||
ivf_rq_builder = ivf_rq_builder.sample_rate(sample_rate);
|
||||
}
|
||||
Ok(Self {
|
||||
inner: Mutex::new(Some(LanceDbIndex::IvfRq(ivf_rq_builder))),
|
||||
})
|
||||
}
|
||||
|
||||
#[napi(factory)]
|
||||
pub fn ivf_flat(
|
||||
distance_type: Option<String>,
|
||||
|
||||
@@ -12,6 +12,7 @@ mod header;
|
||||
mod index;
|
||||
mod iterator;
|
||||
pub mod merge;
|
||||
pub mod permutation;
|
||||
mod query;
|
||||
pub mod remote;
|
||||
mod rerankers;
|
||||
|
||||
214
nodejs/src/permutation.rs
Normal file
214
nodejs/src/permutation.rs
Normal file
@@ -0,0 +1,214 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::{error::NapiErrorExt, table::Table};
|
||||
use lancedb::dataloader::{
|
||||
permutation::builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
|
||||
permutation::split::{SplitSizes, SplitStrategy},
|
||||
};
|
||||
use napi_derive::napi;
|
||||
|
||||
#[napi(object)]
|
||||
pub struct SplitRandomOptions {
|
||||
pub ratios: Option<Vec<f64>>,
|
||||
pub counts: Option<Vec<i64>>,
|
||||
pub fixed: Option<i64>,
|
||||
pub seed: Option<i64>,
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
pub struct SplitHashOptions {
|
||||
pub columns: Vec<String>,
|
||||
pub split_weights: Vec<i64>,
|
||||
pub discard_weight: Option<i64>,
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
pub struct SplitSequentialOptions {
|
||||
pub ratios: Option<Vec<f64>>,
|
||||
pub counts: Option<Vec<i64>>,
|
||||
pub fixed: Option<i64>,
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
pub struct ShuffleOptions {
|
||||
pub seed: Option<i64>,
|
||||
pub clump_size: Option<i64>,
|
||||
}
|
||||
|
||||
pub struct PermutationBuilderState {
|
||||
pub builder: Option<LancePermutationBuilder>,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub struct PermutationBuilder {
|
||||
state: Arc<Mutex<PermutationBuilderState>>,
|
||||
}
|
||||
|
||||
impl PermutationBuilder {
|
||||
pub fn new(builder: LancePermutationBuilder) -> Self {
|
||||
Self {
|
||||
state: Arc::new(Mutex::new(PermutationBuilderState {
|
||||
builder: Some(builder),
|
||||
})),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PermutationBuilder {
|
||||
fn modify(
|
||||
&self,
|
||||
func: impl FnOnce(LancePermutationBuilder) -> LancePermutationBuilder,
|
||||
) -> napi::Result<Self> {
|
||||
let mut state = self.state.lock().unwrap();
|
||||
let builder = state
|
||||
.builder
|
||||
.take()
|
||||
.ok_or_else(|| napi::Error::from_reason("Builder already consumed"))?;
|
||||
state.builder = Some(func(builder));
|
||||
Ok(Self {
|
||||
state: self.state.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl PermutationBuilder {
|
||||
/// Configure random splits
|
||||
#[napi]
|
||||
pub fn split_random(&self, options: SplitRandomOptions) -> napi::Result<Self> {
|
||||
// Check that exactly one split type is provided
|
||||
let split_args_count = [
|
||||
options.ratios.is_some(),
|
||||
options.counts.is_some(),
|
||||
options.fixed.is_some(),
|
||||
]
|
||||
.iter()
|
||||
.filter(|&&x| x)
|
||||
.count();
|
||||
|
||||
if split_args_count != 1 {
|
||||
return Err(napi::Error::from_reason(
|
||||
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
||||
));
|
||||
}
|
||||
|
||||
let sizes = if let Some(ratios) = options.ratios {
|
||||
SplitSizes::Percentages(ratios)
|
||||
} else if let Some(counts) = options.counts {
|
||||
SplitSizes::Counts(counts.into_iter().map(|c| c as u64).collect())
|
||||
} else if let Some(fixed) = options.fixed {
|
||||
SplitSizes::Fixed(fixed as u64)
|
||||
} else {
|
||||
unreachable!("One of the split arguments must be provided");
|
||||
};
|
||||
|
||||
let seed = options.seed.map(|s| s as u64);
|
||||
|
||||
self.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes }))
|
||||
}
|
||||
|
||||
/// Configure hash-based splits
|
||||
#[napi]
|
||||
pub fn split_hash(&self, options: SplitHashOptions) -> napi::Result<Self> {
|
||||
let split_weights = options
|
||||
.split_weights
|
||||
.into_iter()
|
||||
.map(|w| w as u64)
|
||||
.collect();
|
||||
let discard_weight = options.discard_weight.unwrap_or(0) as u64;
|
||||
|
||||
self.modify(|builder| {
|
||||
builder.with_split_strategy(SplitStrategy::Hash {
|
||||
columns: options.columns,
|
||||
split_weights,
|
||||
discard_weight,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
/// Configure sequential splits
|
||||
#[napi]
|
||||
pub fn split_sequential(&self, options: SplitSequentialOptions) -> napi::Result<Self> {
|
||||
// Check that exactly one split type is provided
|
||||
let split_args_count = [
|
||||
options.ratios.is_some(),
|
||||
options.counts.is_some(),
|
||||
options.fixed.is_some(),
|
||||
]
|
||||
.iter()
|
||||
.filter(|&&x| x)
|
||||
.count();
|
||||
|
||||
if split_args_count != 1 {
|
||||
return Err(napi::Error::from_reason(
|
||||
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
||||
));
|
||||
}
|
||||
|
||||
let sizes = if let Some(ratios) = options.ratios {
|
||||
SplitSizes::Percentages(ratios)
|
||||
} else if let Some(counts) = options.counts {
|
||||
SplitSizes::Counts(counts.into_iter().map(|c| c as u64).collect())
|
||||
} else if let Some(fixed) = options.fixed {
|
||||
SplitSizes::Fixed(fixed as u64)
|
||||
} else {
|
||||
unreachable!("One of the split arguments must be provided");
|
||||
};
|
||||
|
||||
self.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes }))
|
||||
}
|
||||
|
||||
/// Configure calculated splits
|
||||
#[napi]
|
||||
pub fn split_calculated(&self, calculation: String) -> napi::Result<Self> {
|
||||
self.modify(|builder| {
|
||||
builder.with_split_strategy(SplitStrategy::Calculated { calculation })
|
||||
})
|
||||
}
|
||||
|
||||
/// Configure shuffling
|
||||
#[napi]
|
||||
pub fn shuffle(&self, options: ShuffleOptions) -> napi::Result<Self> {
|
||||
let seed = options.seed.map(|s| s as u64);
|
||||
let clump_size = options.clump_size.map(|c| c as u64);
|
||||
|
||||
self.modify(|builder| {
|
||||
builder.with_shuffle_strategy(ShuffleStrategy::Random { seed, clump_size })
|
||||
})
|
||||
}
|
||||
|
||||
/// Configure filtering
|
||||
#[napi]
|
||||
pub fn filter(&self, filter: String) -> napi::Result<Self> {
|
||||
self.modify(|builder| builder.with_filter(filter))
|
||||
}
|
||||
|
||||
/// Execute the permutation builder and create the table
|
||||
#[napi]
|
||||
pub async fn execute(&self) -> napi::Result<Table> {
|
||||
let builder = {
|
||||
let mut state = self.state.lock().unwrap();
|
||||
state
|
||||
.builder
|
||||
.take()
|
||||
.ok_or_else(|| napi::Error::from_reason("Builder already consumed"))?
|
||||
};
|
||||
|
||||
let table = builder.build().await.default_error()?;
|
||||
Ok(Table::new(table))
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a permutation builder for the given table
|
||||
#[napi]
|
||||
pub fn permutation_builder(table: &crate::table::Table) -> napi::Result<PermutationBuilder> {
|
||||
use lancedb::dataloader::permutation::builder::PermutationBuilder as LancePermutationBuilder;
|
||||
|
||||
let inner_table = table.inner_ref()?.clone();
|
||||
let inner_builder = LancePermutationBuilder::new(inner_table);
|
||||
|
||||
Ok(PermutationBuilder::new(inner_builder))
|
||||
}
|
||||
@@ -22,7 +22,7 @@ use crate::error::NapiErrorExt;
|
||||
use crate::iterator::RecordBatchIterator;
|
||||
use crate::rerankers::Reranker;
|
||||
use crate::rerankers::RerankerCallbacks;
|
||||
use crate::util::parse_distance_type;
|
||||
use crate::util::{parse_distance_type, schema_to_buffer};
|
||||
|
||||
#[napi]
|
||||
pub struct Query {
|
||||
@@ -88,6 +88,12 @@ impl Query {
|
||||
self.inner = self.inner.clone().with_row_id();
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn output_schema(&self) -> napi::Result<Buffer> {
|
||||
let schema = self.inner.output_schema().await.default_error()?;
|
||||
schema_to_buffer(&schema)
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn execute(
|
||||
&self,
|
||||
@@ -273,6 +279,12 @@ impl VectorQuery {
|
||||
.rerank(Arc::new(Reranker::new(callbacks)));
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn output_schema(&self) -> napi::Result<Buffer> {
|
||||
let schema = self.inner.output_schema().await.default_error()?;
|
||||
schema_to_buffer(&schema)
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn execute(
|
||||
&self,
|
||||
@@ -346,6 +358,12 @@ impl TakeQuery {
|
||||
self.inner = self.inner.clone().with_row_id();
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn output_schema(&self) -> napi::Result<Buffer> {
|
||||
let schema = self.inner.output_schema().await.default_error()?;
|
||||
schema_to_buffer(&schema)
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn execute(
|
||||
&self,
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use arrow_ipc::writer::FileWriter;
|
||||
use lancedb::ipc::ipc_file_to_batches;
|
||||
use lancedb::table::{
|
||||
AddDataMode, ColumnAlteration as LanceColumnAlteration, Duration, NewColumnTransform,
|
||||
@@ -16,6 +15,7 @@ use crate::error::NapiErrorExt;
|
||||
use crate::index::Index;
|
||||
use crate::merge::NativeMergeInsertBuilder;
|
||||
use crate::query::{Query, TakeQuery, VectorQuery};
|
||||
use crate::util::schema_to_buffer;
|
||||
|
||||
#[napi]
|
||||
pub struct Table {
|
||||
@@ -26,7 +26,7 @@ pub struct Table {
|
||||
}
|
||||
|
||||
impl Table {
|
||||
fn inner_ref(&self) -> napi::Result<&LanceDbTable> {
|
||||
pub(crate) fn inner_ref(&self) -> napi::Result<&LanceDbTable> {
|
||||
self.inner
|
||||
.as_ref()
|
||||
.ok_or_else(|| napi::Error::from_reason(format!("Table {} is closed", self.name)))
|
||||
@@ -64,14 +64,7 @@ impl Table {
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn schema(&self) -> napi::Result<Buffer> {
|
||||
let schema = self.inner_ref()?.schema().await.default_error()?;
|
||||
let mut writer = FileWriter::try_new(vec![], &schema)
|
||||
.map_err(|e| napi::Error::from_reason(format!("Failed to create IPC file: {}", e)))?;
|
||||
writer
|
||||
.finish()
|
||||
.map_err(|e| napi::Error::from_reason(format!("Failed to finish IPC file: {}", e)))?;
|
||||
Ok(Buffer::from(writer.into_inner().map_err(|e| {
|
||||
napi::Error::from_reason(format!("Failed to get IPC file: {}", e))
|
||||
})?))
|
||||
schema_to_buffer(&schema)
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use arrow_ipc::writer::FileWriter;
|
||||
use arrow_schema::Schema;
|
||||
use lancedb::DistanceType;
|
||||
use napi::bindgen_prelude::Buffer;
|
||||
|
||||
pub fn parse_distance_type(distance_type: impl AsRef<str>) -> napi::Result<DistanceType> {
|
||||
match distance_type.as_ref().to_lowercase().as_str() {
|
||||
@@ -15,3 +18,15 @@ pub fn parse_distance_type(distance_type: impl AsRef<str>) -> napi::Result<Dista
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert an Arrow Schema to an Arrow IPC file buffer
|
||||
pub fn schema_to_buffer(schema: &Schema) -> napi::Result<Buffer> {
|
||||
let mut writer = FileWriter::try_new(vec![], schema)
|
||||
.map_err(|e| napi::Error::from_reason(format!("Failed to create IPC file: {}", e)))?;
|
||||
writer
|
||||
.finish()
|
||||
.map_err(|e| napi::Error::from_reason(format!("Failed to finish IPC file: {}", e)))?;
|
||||
Ok(Buffer::from(writer.into_inner().map_err(|e| {
|
||||
napi::Error::from_reason(format!("Failed to get IPC file: {}", e))
|
||||
})?))
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.25.2-beta.2"
|
||||
current_version = "0.25.3-beta.1"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
@@ -24,6 +24,19 @@ commit = true
|
||||
message = "Bump version: {current_version} → {new_version}"
|
||||
commit_args = ""
|
||||
|
||||
# Update Cargo.lock after version bump
|
||||
pre_commit_hooks = [
|
||||
"""
|
||||
cd python && cargo update -p lancedb-python
|
||||
if git diff --quiet ../Cargo.lock; then
|
||||
echo "Cargo.lock unchanged"
|
||||
else
|
||||
git add ../Cargo.lock
|
||||
echo "Updated and staged Cargo.lock"
|
||||
fi
|
||||
""",
|
||||
]
|
||||
|
||||
[tool.bumpversion.parts.pre_l]
|
||||
values = ["beta", "final"]
|
||||
optional_value = "final"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.25.2-beta.2"
|
||||
version = "0.25.3-beta.1"
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
license.workspace = true
|
||||
@@ -14,12 +14,12 @@ name = "_lancedb"
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
arrow = { version = "55.1", features = ["pyarrow"] }
|
||||
arrow = { version = "56.2", features = ["pyarrow"] }
|
||||
async-trait = "0.1"
|
||||
lancedb = { path = "../rust/lancedb", default-features = false }
|
||||
env_logger.workspace = true
|
||||
pyo3 = { version = "0.24", features = ["extension-module", "abi3-py39"] }
|
||||
pyo3-async-runtimes = { version = "0.24", features = [
|
||||
pyo3 = { version = "0.25", features = ["extension-module", "abi3-py39"] }
|
||||
pyo3-async-runtimes = { version = "0.25", features = [
|
||||
"attributes",
|
||||
"tokio-runtime",
|
||||
] }
|
||||
@@ -28,7 +28,7 @@ futures.workspace = true
|
||||
tokio = { version = "1.40", features = ["sync"] }
|
||||
|
||||
[build-dependencies]
|
||||
pyo3-build-config = { version = "0.24", features = [
|
||||
pyo3-build-config = { version = "0.25", features = [
|
||||
"extension-module",
|
||||
"abi3-py39",
|
||||
] }
|
||||
|
||||
@@ -5,7 +5,7 @@ dynamic = ["version"]
|
||||
dependencies = [
|
||||
"deprecation",
|
||||
"numpy",
|
||||
"overrides>=0.7",
|
||||
"overrides>=0.7; python_version<'3.12'",
|
||||
"packaging",
|
||||
"pyarrow>=16",
|
||||
"pydantic>=1.10",
|
||||
|
||||
@@ -123,6 +123,8 @@ class Table:
|
||||
@property
|
||||
def tags(self) -> Tags: ...
|
||||
def query(self) -> Query: ...
|
||||
def take_offsets(self, offsets: list[int]) -> TakeQuery: ...
|
||||
def take_row_ids(self, row_ids: list[int]) -> TakeQuery: ...
|
||||
def vector_search(self) -> VectorQuery: ...
|
||||
|
||||
class Tags:
|
||||
@@ -133,6 +135,7 @@ class Tags:
|
||||
async def update(self, tag: str, version: int): ...
|
||||
|
||||
class IndexConfig:
|
||||
name: str
|
||||
index_type: str
|
||||
columns: List[str]
|
||||
|
||||
@@ -164,6 +167,7 @@ class Query:
|
||||
def postfilter(self): ...
|
||||
def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ...
|
||||
def nearest_to_text(self, query: dict) -> FTSQuery: ...
|
||||
async def output_schema(self) -> pa.Schema: ...
|
||||
async def execute(
|
||||
self, max_batch_length: Optional[int], timeout: Optional[timedelta]
|
||||
) -> RecordBatchStream: ...
|
||||
@@ -171,6 +175,13 @@ class Query:
|
||||
async def analyze_plan(self) -> str: ...
|
||||
def to_query_request(self) -> PyQueryRequest: ...
|
||||
|
||||
class TakeQuery:
|
||||
def select(self, columns: List[str]): ...
|
||||
def with_row_id(self): ...
|
||||
async def output_schema(self) -> pa.Schema: ...
|
||||
async def execute(self) -> RecordBatchStream: ...
|
||||
def to_query_request(self) -> PyQueryRequest: ...
|
||||
|
||||
class FTSQuery:
|
||||
def where(self, filter: str): ...
|
||||
def select(self, columns: List[str]): ...
|
||||
@@ -182,12 +193,14 @@ class FTSQuery:
|
||||
def get_query(self) -> str: ...
|
||||
def add_query_vector(self, query_vec: pa.Array) -> None: ...
|
||||
def nearest_to(self, query_vec: pa.Array) -> HybridQuery: ...
|
||||
async def output_schema(self) -> pa.Schema: ...
|
||||
async def execute(
|
||||
self, max_batch_length: Optional[int], timeout: Optional[timedelta]
|
||||
) -> RecordBatchStream: ...
|
||||
def to_query_request(self) -> PyQueryRequest: ...
|
||||
|
||||
class VectorQuery:
|
||||
async def output_schema(self) -> pa.Schema: ...
|
||||
async def execute(self) -> RecordBatchStream: ...
|
||||
def where(self, filter: str): ...
|
||||
def select(self, columns: List[str]): ...
|
||||
@@ -295,3 +308,34 @@ class AlterColumnsResult:
|
||||
|
||||
class DropColumnsResult:
|
||||
version: int
|
||||
|
||||
class AsyncPermutationBuilder:
|
||||
def select(self, projections: Dict[str, str]) -> "AsyncPermutationBuilder": ...
|
||||
def split_random(
|
||||
self,
|
||||
*,
|
||||
ratios: Optional[List[float]] = None,
|
||||
counts: Optional[List[int]] = None,
|
||||
fixed: Optional[int] = None,
|
||||
seed: Optional[int] = None,
|
||||
) -> "AsyncPermutationBuilder": ...
|
||||
def split_hash(
|
||||
self, columns: List[str], split_weights: List[int], *, discard_weight: int = 0
|
||||
) -> "AsyncPermutationBuilder": ...
|
||||
def split_sequential(
|
||||
self,
|
||||
*,
|
||||
ratios: Optional[List[float]] = None,
|
||||
counts: Optional[List[int]] = None,
|
||||
fixed: Optional[int] = None,
|
||||
) -> "AsyncPermutationBuilder": ...
|
||||
def split_calculated(self, calculation: str) -> "AsyncPermutationBuilder": ...
|
||||
def shuffle(
|
||||
self, seed: Optional[int], clump_size: Optional[int]
|
||||
) -> "AsyncPermutationBuilder": ...
|
||||
def filter(self, filter: str) -> "AsyncPermutationBuilder": ...
|
||||
async def execute(self) -> Table: ...
|
||||
|
||||
def async_permutation_builder(
|
||||
table: Table, dest_table_name: str
|
||||
) -> AsyncPermutationBuilder: ...
|
||||
|
||||
@@ -5,11 +5,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Union
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
|
||||
class EnforceOverrides:
|
||||
pass
|
||||
else:
|
||||
from overrides import EnforceOverrides, override # type: ignore
|
||||
|
||||
from lancedb.embeddings.registry import EmbeddingFunctionRegistry
|
||||
from overrides import EnforceOverrides, override # type: ignore
|
||||
|
||||
from lancedb.common import data_to_reader, sanitize_uri, validate_schema
|
||||
from lancedb.background_loop import LOOP
|
||||
@@ -32,7 +41,6 @@ import deprecation
|
||||
if TYPE_CHECKING:
|
||||
import pyarrow as pa
|
||||
from .pydantic import LanceModel
|
||||
from datetime import timedelta
|
||||
|
||||
from ._lancedb import Connection as LanceDbConnection
|
||||
from .common import DATA, URI
|
||||
@@ -444,7 +452,12 @@ class LanceDBConnection(DBConnection):
|
||||
read_consistency_interval: Optional[timedelta] = None,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
session: Optional[Session] = None,
|
||||
_inner: Optional[LanceDbConnection] = None,
|
||||
):
|
||||
if _inner is not None:
|
||||
self._conn = _inner
|
||||
return
|
||||
|
||||
if not isinstance(uri, Path):
|
||||
scheme = get_uri_scheme(uri)
|
||||
is_local = isinstance(uri, Path) or scheme == "file"
|
||||
@@ -453,11 +466,6 @@ class LanceDBConnection(DBConnection):
|
||||
uri = Path(uri)
|
||||
uri = uri.expanduser().absolute()
|
||||
Path(uri).mkdir(parents=True, exist_ok=True)
|
||||
self._uri = str(uri)
|
||||
self._entered = False
|
||||
self.read_consistency_interval = read_consistency_interval
|
||||
self.storage_options = storage_options
|
||||
self.session = session
|
||||
|
||||
if read_consistency_interval is not None:
|
||||
read_consistency_interval_secs = read_consistency_interval.total_seconds()
|
||||
@@ -476,10 +484,32 @@ class LanceDBConnection(DBConnection):
|
||||
session,
|
||||
)
|
||||
|
||||
# TODO: It would be nice if we didn't store self.storage_options but it is
|
||||
# currently used by the LanceTable.to_lance method. This doesn't _really_
|
||||
# work because some paths like LanceDBConnection.from_inner will lose the
|
||||
# storage_options. Also, this class really shouldn't be holding any state
|
||||
# beyond _conn.
|
||||
self.storage_options = storage_options
|
||||
self._conn = AsyncConnection(LOOP.run(do_connect()))
|
||||
|
||||
@property
|
||||
def read_consistency_interval(self) -> Optional[timedelta]:
|
||||
return LOOP.run(self._conn.get_read_consistency_interval())
|
||||
|
||||
@property
|
||||
def session(self) -> Optional[Session]:
|
||||
return self._conn.session
|
||||
|
||||
@property
|
||||
def uri(self) -> str:
|
||||
return self._conn.uri
|
||||
|
||||
@classmethod
|
||||
def from_inner(cls, inner: LanceDbConnection):
|
||||
return cls(None, _inner=inner)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
val = f"{self.__class__.__name__}(uri={self._uri!r}"
|
||||
val = f"{self.__class__.__name__}(uri={self._conn.uri!r}"
|
||||
if self.read_consistency_interval is not None:
|
||||
val += f", read_consistency_interval={repr(self.read_consistency_interval)}"
|
||||
val += ")"
|
||||
@@ -489,6 +519,10 @@ class LanceDBConnection(DBConnection):
|
||||
conn = AsyncConnection(await lancedb_connect(self.uri))
|
||||
return await conn.table_names(start_after=start_after, limit=limit)
|
||||
|
||||
@property
|
||||
def _inner(self) -> LanceDbConnection:
|
||||
return self._conn._inner
|
||||
|
||||
@override
|
||||
def list_namespaces(
|
||||
self,
|
||||
@@ -848,6 +882,13 @@ class AsyncConnection(object):
|
||||
def uri(self) -> str:
|
||||
return self._inner.uri
|
||||
|
||||
async def get_read_consistency_interval(self) -> Optional[timedelta]:
|
||||
interval_secs = await self._inner.get_read_consistency_interval()
|
||||
if interval_secs is not None:
|
||||
return timedelta(seconds=interval_secs)
|
||||
else:
|
||||
return None
|
||||
|
||||
async def list_namespaces(
|
||||
self,
|
||||
namespace: List[str] = [],
|
||||
|
||||
@@ -3,9 +3,11 @@
|
||||
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import List, Union, Optional, Any
|
||||
from logging import warning
|
||||
from typing import List, Union, Optional, Any, Callable
|
||||
import numpy as np
|
||||
import io
|
||||
import warnings
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import EmbeddingFunction
|
||||
@@ -19,35 +21,52 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
An embedding function that uses the ColPali engine for
|
||||
multimodal multi-vector embeddings.
|
||||
|
||||
This embedding function supports ColQwen2.5 models, producing multivector outputs
|
||||
for both text and image inputs. The output embeddings are lists of vectors, each
|
||||
vector being 128-dimensional by default, represented as List[List[float]].
|
||||
This embedding function supports ColPali models, producing multivector outputs
|
||||
for both text and image inputs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name : str
|
||||
The name of the model to use (e.g., "Metric-AI/ColQwen2.5-3b-multilingual-v1.0")
|
||||
Supports models based on these engines:
|
||||
- ColPali: "vidore/colpali-v1.3" and others
|
||||
- ColQwen2.5: "Metric-AI/ColQwen2.5-3b-multilingual-v1.0" and others
|
||||
- ColQwen2: "vidore/colqwen2-v1.0" and others
|
||||
- ColSmol: "vidore/colSmol-256M" and others
|
||||
|
||||
device : str
|
||||
The device for inference (default "cuda:0").
|
||||
The device for inference (default "auto").
|
||||
dtype : str
|
||||
Data type for model weights (default "bfloat16").
|
||||
use_token_pooling : bool
|
||||
Whether to use token pooling to reduce embedding size (default True).
|
||||
DEPRECATED. Whether to use token pooling. Use `pooling_strategy` instead.
|
||||
pooling_strategy : str, optional
|
||||
The token pooling strategy to use, by default "hierarchical".
|
||||
- "hierarchical": Progressively pools tokens to reduce sequence length.
|
||||
- "lambda": A simpler pooling that uses a custom `pooling_func`.
|
||||
pooling_func: typing.Callable, optional
|
||||
A function to use for pooling when `pooling_strategy` is "lambda".
|
||||
pool_factor : int
|
||||
Factor to reduce sequence length if token pooling is enabled (default 2).
|
||||
quantization_config : Optional[BitsAndBytesConfig]
|
||||
Quantization configuration for the model. (default None, bitsandbytes needed)
|
||||
batch_size : int
|
||||
Batch size for processing inputs (default 2).
|
||||
offload_folder: str, optional
|
||||
Folder to offload model weights if using CPU offloading (default None). This is
|
||||
useful for large models that do not fit in memory.
|
||||
"""
|
||||
|
||||
model_name: str = "Metric-AI/ColQwen2.5-3b-multilingual-v1.0"
|
||||
device: str = "auto"
|
||||
dtype: str = "bfloat16"
|
||||
use_token_pooling: bool = True
|
||||
pooling_strategy: Optional[str] = "hierarchical"
|
||||
pooling_func: Optional[Any] = None
|
||||
pool_factor: int = 2
|
||||
quantization_config: Optional[Any] = None
|
||||
batch_size: int = 2
|
||||
offload_folder: Optional[str] = None
|
||||
|
||||
_model = None
|
||||
_processor = None
|
||||
@@ -56,15 +75,43 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
|
||||
if not self.use_token_pooling:
|
||||
warnings.warn(
|
||||
"use_token_pooling is deprecated, use pooling_strategy=None instead",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.pooling_strategy = None
|
||||
|
||||
if self.pooling_strategy == "lambda" and self.pooling_func is None:
|
||||
raise ValueError(
|
||||
"pooling_func must be provided when pooling_strategy is 'lambda'"
|
||||
)
|
||||
|
||||
device = self.device
|
||||
if device == "auto":
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
dtype = self.dtype
|
||||
if device == "mps" and dtype == "bfloat16":
|
||||
dtype = "float32" # Avoid NaNs on MPS
|
||||
|
||||
(
|
||||
self._model,
|
||||
self._processor,
|
||||
self._token_pooler,
|
||||
) = self._load_model(
|
||||
self.model_name,
|
||||
self.dtype,
|
||||
self.device,
|
||||
self.use_token_pooling,
|
||||
dtype,
|
||||
device,
|
||||
self.pooling_strategy,
|
||||
self.pooling_func,
|
||||
self.quantization_config,
|
||||
)
|
||||
|
||||
@@ -74,16 +121,26 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
model_name: str,
|
||||
dtype: str,
|
||||
device: str,
|
||||
use_token_pooling: bool,
|
||||
pooling_strategy: Optional[str],
|
||||
pooling_func: Optional[Callable],
|
||||
quantization_config: Optional[Any],
|
||||
):
|
||||
"""
|
||||
Initialize and cache the ColPali model, processor, and token pooler.
|
||||
"""
|
||||
if device.startswith("mps"):
|
||||
# warn some torch ops in late interaction architecture result in nans on mps
|
||||
warning(
|
||||
"MPS device detected. Some operations may result in NaNs. "
|
||||
"If you encounter issues, consider using 'cpu' or 'cuda' devices."
|
||||
)
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
transformers = attempt_import_or_raise("transformers", "transformers")
|
||||
colpali_engine = attempt_import_or_raise("colpali_engine", "colpali_engine")
|
||||
from colpali_engine.compression.token_pooling import HierarchicalTokenPooler
|
||||
from colpali_engine.compression.token_pooling import (
|
||||
HierarchicalTokenPooler,
|
||||
LambdaTokenPooler,
|
||||
)
|
||||
|
||||
if quantization_config is not None:
|
||||
if not isinstance(quantization_config, transformers.BitsAndBytesConfig):
|
||||
@@ -98,21 +155,45 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
else:
|
||||
torch_dtype = torch.float32
|
||||
|
||||
model = colpali_engine.models.ColQwen2_5.from_pretrained(
|
||||
model_class, processor_class = None, None
|
||||
model_name_lower = model_name.lower()
|
||||
if "colqwen2.5" in model_name_lower:
|
||||
model_class = colpali_engine.models.ColQwen2_5
|
||||
processor_class = colpali_engine.models.ColQwen2_5_Processor
|
||||
elif "colsmol" in model_name_lower or "colidefics3" in model_name_lower:
|
||||
model_class = colpali_engine.models.ColIdefics3
|
||||
processor_class = colpali_engine.models.ColIdefics3Processor
|
||||
elif "colqwen" in model_name_lower:
|
||||
model_class = colpali_engine.models.ColQwen2
|
||||
processor_class = colpali_engine.models.ColQwen2Processor
|
||||
elif "colpali" in model_name_lower:
|
||||
model_class = colpali_engine.models.ColPali
|
||||
processor_class = colpali_engine.models.ColPaliProcessor
|
||||
|
||||
if model_class is None:
|
||||
raise ValueError(f"Unsupported model: {model_name}")
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=device,
|
||||
quantization_config=quantization_config
|
||||
if quantization_config is not None
|
||||
else None,
|
||||
attn_implementation="flash_attention_2"
|
||||
if is_flash_attn_2_available()
|
||||
else None,
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
processor = colpali_engine.models.ColQwen2_5_Processor.from_pretrained(
|
||||
model_name
|
||||
)
|
||||
token_pooler = HierarchicalTokenPooler() if use_token_pooling else None
|
||||
model = model.to(device)
|
||||
model = model.to(torch_dtype) # Force cast after moving to device
|
||||
processor = processor_class.from_pretrained(model_name)
|
||||
|
||||
token_pooler = None
|
||||
if pooling_strategy == "hierarchical":
|
||||
token_pooler = HierarchicalTokenPooler()
|
||||
elif pooling_strategy == "lambda":
|
||||
token_pooler = LambdaTokenPooler(pool_func=pooling_func)
|
||||
|
||||
return model, processor, token_pooler
|
||||
|
||||
def ndims(self):
|
||||
@@ -128,7 +209,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
with torch.no_grad():
|
||||
query_embeddings = self._model(**batch_queries)
|
||||
|
||||
if self.use_token_pooling and self._token_pooler is not None:
|
||||
if self.pooling_strategy and self._token_pooler is not None:
|
||||
query_embeddings = self._token_pooler.pool_embeddings(
|
||||
query_embeddings,
|
||||
pool_factor=self.pool_factor,
|
||||
@@ -145,13 +226,20 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
Use token pooling if enabled.
|
||||
"""
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
if self.use_token_pooling and self._token_pooler is not None:
|
||||
embeddings = self._token_pooler.pool_embeddings(
|
||||
embeddings,
|
||||
pool_factor=self.pool_factor,
|
||||
padding=True,
|
||||
padding_side=self._processor.tokenizer.padding_side,
|
||||
)
|
||||
if self.pooling_strategy and self._token_pooler is not None:
|
||||
if self.pooling_strategy == "hierarchical":
|
||||
embeddings = self._token_pooler.pool_embeddings(
|
||||
embeddings,
|
||||
pool_factor=self.pool_factor,
|
||||
padding=True,
|
||||
padding_side=self._processor.tokenizer.padding_side,
|
||||
)
|
||||
elif self.pooling_strategy == "lambda":
|
||||
embeddings = self._token_pooler.pool_embeddings(
|
||||
embeddings,
|
||||
padding=True,
|
||||
padding_side=self._processor.tokenizer.padding_side,
|
||||
)
|
||||
|
||||
if isinstance(embeddings, torch.Tensor):
|
||||
tensors = embeddings.detach().cpu()
|
||||
@@ -179,6 +267,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
)
|
||||
with torch.no_grad():
|
||||
query_embeddings = self._model(**batch_queries)
|
||||
query_embeddings = torch.nan_to_num(query_embeddings)
|
||||
all_embeddings.extend(self._process_embeddings(query_embeddings))
|
||||
return all_embeddings
|
||||
|
||||
@@ -225,6 +314,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
)
|
||||
with torch.no_grad():
|
||||
image_embeddings = self._model(**batch_images)
|
||||
image_embeddings = torch.nan_to_num(image_embeddings)
|
||||
all_embeddings.extend(self._process_embeddings(image_embeddings))
|
||||
return all_embeddings
|
||||
|
||||
|
||||
@@ -605,9 +605,53 @@ class IvfPq:
|
||||
target_partition_size: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class IvfRq:
|
||||
"""Describes an IVF RQ Index
|
||||
|
||||
IVF-RQ (Residual Quantization) stores a compressed copy of each vector using
|
||||
residual quantization and organizes them into IVF partitions. Parameters
|
||||
largely mirror IVF-PQ for consistency.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
distance_type: str, default "l2"
|
||||
Distance metric used to train the index and for quantization.
|
||||
|
||||
The following distance types are available:
|
||||
|
||||
"l2" - Euclidean distance.
|
||||
"cosine" - Cosine distance.
|
||||
"dot" - Dot product.
|
||||
|
||||
num_partitions: int, default sqrt(num_rows)
|
||||
Number of IVF partitions to create.
|
||||
|
||||
num_bits: int, default 1
|
||||
Number of bits to encode each dimension.
|
||||
|
||||
max_iterations: int, default 50
|
||||
Max iterations to train kmeans when computing IVF partitions.
|
||||
|
||||
sample_rate: int, default 256
|
||||
Controls the number of training vectors: sample_rate * num_partitions.
|
||||
|
||||
target_partition_size, default is 8192
|
||||
Target size of each partition.
|
||||
"""
|
||||
|
||||
distance_type: Literal["l2", "cosine", "dot"] = "l2"
|
||||
num_partitions: Optional[int] = None
|
||||
num_bits: int = 1
|
||||
max_iterations: int = 50
|
||||
sample_rate: int = 256
|
||||
target_partition_size: Optional[int] = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BTree",
|
||||
"IvfPq",
|
||||
"IvfRq",
|
||||
"IvfFlat",
|
||||
"HnswPq",
|
||||
"HnswSq",
|
||||
|
||||
@@ -12,13 +12,18 @@ from __future__ import annotations
|
||||
|
||||
from typing import Dict, Iterable, List, Optional, Union
|
||||
import os
|
||||
import sys
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from overrides import override
|
||||
|
||||
from lancedb.db import DBConnection
|
||||
from lancedb.table import LanceTable, Table
|
||||
from lancedb.util import validate_table_name
|
||||
from lancedb.common import validate_schema
|
||||
from lancedb.table import sanitize_create_table
|
||||
from overrides import override
|
||||
|
||||
from lance_namespace import LanceNamespace, connect as namespace_connect
|
||||
from lance_namespace_urllib3_client.models import (
|
||||
|
||||
72
python/python/lancedb/permutation.py
Normal file
72
python/python/lancedb/permutation.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from ._lancedb import async_permutation_builder
|
||||
from .table import LanceTable
|
||||
from .background_loop import LOOP
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class PermutationBuilder:
|
||||
def __init__(self, table: LanceTable):
|
||||
self._async = async_permutation_builder(table)
|
||||
|
||||
def select(self, projections: dict[str, str]) -> "PermutationBuilder":
|
||||
self._async.select(projections)
|
||||
return self
|
||||
|
||||
def split_random(
|
||||
self,
|
||||
*,
|
||||
ratios: Optional[list[float]] = None,
|
||||
counts: Optional[list[int]] = None,
|
||||
fixed: Optional[int] = None,
|
||||
seed: Optional[int] = None,
|
||||
) -> "PermutationBuilder":
|
||||
self._async.split_random(ratios=ratios, counts=counts, fixed=fixed, seed=seed)
|
||||
return self
|
||||
|
||||
def split_hash(
|
||||
self,
|
||||
columns: list[str],
|
||||
split_weights: list[int],
|
||||
*,
|
||||
discard_weight: Optional[int] = None,
|
||||
) -> "PermutationBuilder":
|
||||
self._async.split_hash(columns, split_weights, discard_weight=discard_weight)
|
||||
return self
|
||||
|
||||
def split_sequential(
|
||||
self,
|
||||
*,
|
||||
ratios: Optional[list[float]] = None,
|
||||
counts: Optional[list[int]] = None,
|
||||
fixed: Optional[int] = None,
|
||||
) -> "PermutationBuilder":
|
||||
self._async.split_sequential(ratios=ratios, counts=counts, fixed=fixed)
|
||||
return self
|
||||
|
||||
def split_calculated(self, calculation: str) -> "PermutationBuilder":
|
||||
self._async.split_calculated(calculation)
|
||||
return self
|
||||
|
||||
def shuffle(
|
||||
self, *, seed: Optional[int] = None, clump_size: Optional[int] = None
|
||||
) -> "PermutationBuilder":
|
||||
self._async.shuffle(seed=seed, clump_size=clump_size)
|
||||
return self
|
||||
|
||||
def filter(self, filter: str) -> "PermutationBuilder":
|
||||
self._async.filter(filter)
|
||||
return self
|
||||
|
||||
def execute(self) -> LanceTable:
|
||||
async def do_execute():
|
||||
inner_tbl = await self._async.execute()
|
||||
return LanceTable.from_inner(inner_tbl)
|
||||
|
||||
return LOOP.run(do_execute())
|
||||
|
||||
|
||||
def permutation_builder(table: LanceTable) -> PermutationBuilder:
|
||||
return PermutationBuilder(table)
|
||||
@@ -1237,6 +1237,14 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
self._refine_factor = refine_factor
|
||||
return self
|
||||
|
||||
def output_schema(self) -> pa.Schema:
|
||||
"""
|
||||
Return the output schema for the query
|
||||
|
||||
This does not execute the query.
|
||||
"""
|
||||
return self._table._output_schema(self.to_query_object())
|
||||
|
||||
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
|
||||
"""
|
||||
Execute the query and return the results as an
|
||||
@@ -1452,6 +1460,14 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
offset=self._offset,
|
||||
)
|
||||
|
||||
def output_schema(self) -> pa.Schema:
|
||||
"""
|
||||
Return the output schema for the query
|
||||
|
||||
This does not execute the query.
|
||||
"""
|
||||
return self._table._output_schema(self.to_query_object())
|
||||
|
||||
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
|
||||
path, fs, exist = self._table._get_fts_index_path()
|
||||
if exist:
|
||||
@@ -1595,6 +1611,10 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
||||
offset=self._offset,
|
||||
)
|
||||
|
||||
def output_schema(self) -> pa.Schema:
|
||||
query = self.to_query_object()
|
||||
return self._table._output_schema(query)
|
||||
|
||||
def to_batches(
|
||||
self, /, batch_size: Optional[int] = None, timeout: Optional[timedelta] = None
|
||||
) -> pa.RecordBatchReader:
|
||||
@@ -2238,6 +2258,14 @@ class AsyncQueryBase(object):
|
||||
)
|
||||
)
|
||||
|
||||
async def output_schema(self) -> pa.Schema:
|
||||
"""
|
||||
Return the output schema for the query
|
||||
|
||||
This does not execute the query.
|
||||
"""
|
||||
return await self._inner.output_schema()
|
||||
|
||||
async def to_arrow(self, timeout: Optional[timedelta] = None) -> pa.Table:
|
||||
"""
|
||||
Execute the query and collect the results into an Apache Arrow Table.
|
||||
@@ -3193,6 +3221,14 @@ class BaseQueryBuilder(object):
|
||||
self._inner.with_row_id()
|
||||
return self
|
||||
|
||||
def output_schema(self) -> pa.Schema:
|
||||
"""
|
||||
Return the output schema for the query
|
||||
|
||||
This does not execute the query.
|
||||
"""
|
||||
return LOOP.run(self._inner.output_schema())
|
||||
|
||||
def to_batches(
|
||||
self,
|
||||
*,
|
||||
|
||||
@@ -5,15 +5,20 @@
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import sys
|
||||
from typing import Any, Dict, Iterable, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
import warnings
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from overrides import override
|
||||
|
||||
# Remove this import to fix circular dependency
|
||||
# from lancedb import connect_async
|
||||
from lancedb.remote import ClientConfig
|
||||
import pyarrow as pa
|
||||
from overrides import override
|
||||
|
||||
from ..common import DATA
|
||||
from ..db import DBConnection, LOOP
|
||||
|
||||
@@ -114,7 +114,7 @@ class RemoteTable(Table):
|
||||
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST", "scalar"] = "scalar",
|
||||
*,
|
||||
replace: bool = False,
|
||||
wait_timeout: timedelta = None,
|
||||
wait_timeout: Optional[timedelta] = None,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""Creates a scalar index
|
||||
@@ -153,7 +153,7 @@ class RemoteTable(Table):
|
||||
column: str,
|
||||
*,
|
||||
replace: bool = False,
|
||||
wait_timeout: timedelta = None,
|
||||
wait_timeout: Optional[timedelta] = None,
|
||||
with_position: bool = False,
|
||||
# tokenizer configs:
|
||||
base_tokenizer: str = "simple",
|
||||
@@ -436,6 +436,9 @@ class RemoteTable(Table):
|
||||
def _analyze_plan(self, query: Query) -> str:
|
||||
return LOOP.run(self._table._analyze_plan(query))
|
||||
|
||||
def _output_schema(self, query: Query) -> pa.Schema:
|
||||
return LOOP.run(self._table._output_schema(query))
|
||||
|
||||
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
||||
"""Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder]
|
||||
that can be used to create a "merge insert" operation.
|
||||
|
||||
@@ -44,7 +44,7 @@ import numpy as np
|
||||
|
||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
|
||||
from .index import BTree, IvfFlat, IvfPq, Bitmap, IvfRq, LabelList, HnswPq, HnswSq, FTS
|
||||
from .merge import LanceMergeInsertBuilder
|
||||
from .pydantic import LanceModel, model_to_dict
|
||||
from .query import (
|
||||
@@ -74,6 +74,7 @@ from .index import lang_mapping
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .db import LanceDBConnection
|
||||
from ._lancedb import (
|
||||
Table as LanceDBTable,
|
||||
OptimizeStats,
|
||||
@@ -88,7 +89,6 @@ if TYPE_CHECKING:
|
||||
MergeResult,
|
||||
UpdateResult,
|
||||
)
|
||||
from .db import LanceDBConnection
|
||||
from .index import IndexConfig
|
||||
import pandas
|
||||
import PIL
|
||||
@@ -1248,6 +1248,9 @@ class Table(ABC):
|
||||
@abstractmethod
|
||||
def _analyze_plan(self, query: Query) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def _output_schema(self, query: Query) -> pa.Schema: ...
|
||||
|
||||
@abstractmethod
|
||||
def _do_merge(
|
||||
self,
|
||||
@@ -1707,22 +1710,38 @@ class LanceTable(Table):
|
||||
namespace: List[str] = [],
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
_async: AsyncTable = None,
|
||||
):
|
||||
self._conn = connection
|
||||
self._namespace = namespace
|
||||
self._table = LOOP.run(
|
||||
connection._conn.open_table(
|
||||
name,
|
||||
namespace=namespace,
|
||||
storage_options=storage_options,
|
||||
index_cache_size=index_cache_size,
|
||||
if _async is not None:
|
||||
self._table = _async
|
||||
else:
|
||||
self._table = LOOP.run(
|
||||
connection._conn.open_table(
|
||||
name,
|
||||
namespace=namespace,
|
||||
storage_options=storage_options,
|
||||
index_cache_size=index_cache_size,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._table.name
|
||||
|
||||
@classmethod
|
||||
def from_inner(cls, tbl: LanceDBTable):
|
||||
from .db import LanceDBConnection
|
||||
|
||||
async_tbl = AsyncTable(tbl)
|
||||
conn = LanceDBConnection.from_inner(tbl.database())
|
||||
return cls(
|
||||
conn,
|
||||
async_tbl.name,
|
||||
_async=async_tbl,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def open(cls, db, name, *, namespace: List[str] = [], **kwargs):
|
||||
tbl = cls(db, name, namespace=namespace, **kwargs)
|
||||
@@ -1991,7 +2010,7 @@ class LanceTable(Table):
|
||||
index_cache_size: Optional[int] = None,
|
||||
num_bits: int = 8,
|
||||
index_type: Literal[
|
||||
"IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"
|
||||
"IVF_FLAT", "IVF_PQ", "IVF_RQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"
|
||||
] = "IVF_PQ",
|
||||
max_iterations: int = 50,
|
||||
sample_rate: int = 256,
|
||||
@@ -2039,6 +2058,15 @@ class LanceTable(Table):
|
||||
sample_rate=sample_rate,
|
||||
target_partition_size=target_partition_size,
|
||||
)
|
||||
elif index_type == "IVF_RQ":
|
||||
config = IvfRq(
|
||||
distance_type=metric,
|
||||
num_partitions=num_partitions,
|
||||
num_bits=num_bits,
|
||||
max_iterations=max_iterations,
|
||||
sample_rate=sample_rate,
|
||||
target_partition_size=target_partition_size,
|
||||
)
|
||||
elif index_type == "IVF_HNSW_PQ":
|
||||
config = HnswPq(
|
||||
distance_type=metric,
|
||||
@@ -2736,6 +2764,9 @@ class LanceTable(Table):
|
||||
def _analyze_plan(self, query: Query) -> str:
|
||||
return LOOP.run(self._table._analyze_plan(query))
|
||||
|
||||
def _output_schema(self, query: Query) -> pa.Schema:
|
||||
return LOOP.run(self._table._output_schema(query))
|
||||
|
||||
def _do_merge(
|
||||
self,
|
||||
merge: LanceMergeInsertBuilder,
|
||||
@@ -2747,6 +2778,10 @@ class LanceTable(Table):
|
||||
self._table._do_merge(merge, new_data, on_bad_vectors, fill_value)
|
||||
)
|
||||
|
||||
@property
|
||||
def _inner(self) -> LanceDBTable:
|
||||
return self._table._inner
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.21.0",
|
||||
current_version=__version__,
|
||||
@@ -3330,7 +3365,7 @@ class AsyncTable:
|
||||
*,
|
||||
replace: Optional[bool] = None,
|
||||
config: Optional[
|
||||
Union[IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS]
|
||||
Union[IvfFlat, IvfPq, IvfRq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS]
|
||||
] = None,
|
||||
wait_timeout: Optional[timedelta] = None,
|
||||
name: Optional[str] = None,
|
||||
@@ -3369,11 +3404,12 @@ class AsyncTable:
|
||||
"""
|
||||
if config is not None:
|
||||
if not isinstance(
|
||||
config, (IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS)
|
||||
config,
|
||||
(IvfFlat, IvfPq, IvfRq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS),
|
||||
):
|
||||
raise TypeError(
|
||||
"config must be an instance of IvfPq, HnswPq, HnswSq, BTree,"
|
||||
" Bitmap, LabelList, or FTS"
|
||||
"config must be an instance of IvfPq, IvfRq, HnswPq, HnswSq, BTree,"
|
||||
" Bitmap, LabelList, or FTS, but got " + str(type(config))
|
||||
)
|
||||
try:
|
||||
await self._inner.create_index(
|
||||
@@ -3888,6 +3924,10 @@ class AsyncTable:
|
||||
async_query = self._sync_query_to_async(query)
|
||||
return await async_query.analyze_plan()
|
||||
|
||||
async def _output_schema(self, query: Query) -> pa.Schema:
|
||||
async_query = self._sync_query_to_async(query)
|
||||
return await async_query.output_schema()
|
||||
|
||||
async def _do_merge(
|
||||
self,
|
||||
merge: LanceMergeInsertBuilder,
|
||||
|
||||
@@ -18,10 +18,17 @@ AddMode = Literal["append", "overwrite"]
|
||||
CreateMode = Literal["create", "overwrite"]
|
||||
|
||||
# Index type literals
|
||||
VectorIndexType = Literal["IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"]
|
||||
VectorIndexType = Literal["IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ", "IVF_RQ"]
|
||||
ScalarIndexType = Literal["BTREE", "BITMAP", "LABEL_LIST"]
|
||||
IndexType = Literal[
|
||||
"IVF_PQ", "IVF_HNSW_PQ", "IVF_HNSW_SQ", "FTS", "BTREE", "BITMAP", "LABEL_LIST"
|
||||
"IVF_PQ",
|
||||
"IVF_HNSW_PQ",
|
||||
"IVF_HNSW_SQ",
|
||||
"FTS",
|
||||
"BTREE",
|
||||
"BITMAP",
|
||||
"LABEL_LIST",
|
||||
"IVF_RQ",
|
||||
]
|
||||
|
||||
# Tokenizer literals
|
||||
|
||||
@@ -656,6 +656,106 @@ def test_colpali(tmp_path):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
importlib.util.find_spec("colpali_engine") is None,
|
||||
reason="colpali_engine not installed",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[
|
||||
"vidore/colSmol-256M",
|
||||
"vidore/colqwen2.5-v0.2",
|
||||
"vidore/colpali-v1.3",
|
||||
"vidore/colqwen2-v1.0",
|
||||
],
|
||||
)
|
||||
def test_colpali_models(tmp_path, model_name):
|
||||
import requests
|
||||
from lancedb.pydantic import LanceModel
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = get_registry()
|
||||
func = registry.get("colpali").create(model_name=model_name)
|
||||
|
||||
class MediaItems(LanceModel):
|
||||
text: str
|
||||
image_uri: str = func.SourceField()
|
||||
image_bytes: bytes = func.SourceField()
|
||||
image_vectors: MultiVector(func.ndims()) = func.VectorField()
|
||||
|
||||
table = db.create_table(f"media_{model_name.replace('/', '_')}", schema=MediaItems)
|
||||
|
||||
texts = [
|
||||
"a cute cat playing with yarn",
|
||||
]
|
||||
|
||||
uris = [
|
||||
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||
]
|
||||
|
||||
image_bytes = [requests.get(uri).content for uri in uris]
|
||||
|
||||
table.add(
|
||||
pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes})
|
||||
)
|
||||
|
||||
image_results = (
|
||||
table.search("fluffy companion", vector_column_name="image_vectors")
|
||||
.limit(1)
|
||||
.to_pydantic(MediaItems)[0]
|
||||
)
|
||||
assert "cat" in image_results.text.lower() or "puppy" in image_results.text.lower()
|
||||
|
||||
first_row = table.to_arrow().to_pylist()[0]
|
||||
assert len(first_row["image_vectors"]) > 1, "Should have multiple image vectors"
|
||||
assert len(first_row["image_vectors"][0]) == func.ndims(), (
|
||||
"Vector dimension mismatch"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
importlib.util.find_spec("colpali_engine") is None,
|
||||
reason="colpali_engine not installed",
|
||||
)
|
||||
def test_colpali_pooling(tmp_path):
|
||||
registry = get_registry()
|
||||
model_name = "vidore/colSmol-256M"
|
||||
test_sentence = "a test sentence for pooling"
|
||||
|
||||
# 1. Get embeddings with no pooling
|
||||
func_no_pool = registry.get("colpali").create(
|
||||
model_name=model_name, pooling_strategy=None
|
||||
)
|
||||
unpooled_embeddings = func_no_pool.generate_text_embeddings([test_sentence])[0]
|
||||
original_length = len(unpooled_embeddings)
|
||||
assert original_length > 1
|
||||
|
||||
# 2. Test hierarchical pooling
|
||||
func_hierarchical = registry.get("colpali").create(
|
||||
model_name=model_name, pooling_strategy="hierarchical", pool_factor=2
|
||||
)
|
||||
hierarchical_embeddings = func_hierarchical.generate_text_embeddings(
|
||||
[test_sentence]
|
||||
)[0]
|
||||
expected_hierarchical_length = (original_length + 1) // 2
|
||||
assert len(hierarchical_embeddings) == expected_hierarchical_length
|
||||
|
||||
# 3. Test lambda pooling
|
||||
def simple_pool_func(tensor):
|
||||
return tensor[::2]
|
||||
|
||||
func_lambda = registry.get("colpali").create(
|
||||
model_name=model_name,
|
||||
pooling_strategy="lambda",
|
||||
pooling_func=simple_pool_func,
|
||||
)
|
||||
lambda_embeddings = func_lambda.generate_text_embeddings([test_sentence])[0]
|
||||
expected_lambda_length = (original_length + 1) // 2
|
||||
assert len(lambda_embeddings) == expected_lambda_length
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_siglip(tmp_path, test_images, query_image_bytes):
|
||||
from PIL import Image
|
||||
|
||||
@@ -8,7 +8,17 @@ import pyarrow as pa
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from lancedb import AsyncConnection, AsyncTable, connect_async
|
||||
from lancedb.index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
|
||||
from lancedb.index import (
|
||||
BTree,
|
||||
IvfFlat,
|
||||
IvfPq,
|
||||
IvfRq,
|
||||
Bitmap,
|
||||
LabelList,
|
||||
HnswPq,
|
||||
HnswSq,
|
||||
FTS,
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@@ -195,6 +205,16 @@ async def test_create_4bit_ivfpq_index(some_table: AsyncTable):
|
||||
assert stats.loss >= 0.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_ivfrq_index(some_table: AsyncTable):
|
||||
await some_table.create_index("vector", config=IvfRq(num_bits=1))
|
||||
indices = await some_table.list_indices()
|
||||
assert len(indices) == 1
|
||||
assert indices[0].index_type == "IvfRq"
|
||||
assert indices[0].columns == ["vector"]
|
||||
assert indices[0].name == "vector_idx"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_hnswpq_index(some_table: AsyncTable):
|
||||
await some_table.create_index("vector", config=HnswPq(num_partitions=10))
|
||||
|
||||
462
python/python/tests/test_permutation.py
Normal file
462
python/python/tests/test_permutation.py
Normal file
@@ -0,0 +1,462 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
from lancedb.permutation import permutation_builder
|
||||
|
||||
|
||||
def test_split_random_ratios(mem_db):
|
||||
"""Test random splitting with ratios."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"x": range(100), "y": range(100)})
|
||||
)
|
||||
permutation_tbl = permutation_builder(tbl).split_random(ratios=[0.3, 0.7]).execute()
|
||||
|
||||
# Check that the table was created and has data
|
||||
assert permutation_tbl.count_rows() == 100
|
||||
|
||||
# Check that split_id column exists and has correct values
|
||||
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||
split_ids = data["split_id"]
|
||||
assert set(split_ids) == {0, 1}
|
||||
|
||||
# Check approximate split sizes (allowing for rounding)
|
||||
split_0_count = split_ids.count(0)
|
||||
split_1_count = split_ids.count(1)
|
||||
assert 25 <= split_0_count <= 35 # ~30% ± tolerance
|
||||
assert 65 <= split_1_count <= 75 # ~70% ± tolerance
|
||||
|
||||
|
||||
def test_split_random_counts(mem_db):
|
||||
"""Test random splitting with absolute counts."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"x": range(100), "y": range(100)})
|
||||
)
|
||||
permutation_tbl = permutation_builder(tbl).split_random(counts=[20, 30]).execute()
|
||||
|
||||
# Check that we have exactly the requested counts
|
||||
assert permutation_tbl.count_rows() == 50
|
||||
|
||||
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||
split_ids = data["split_id"]
|
||||
assert split_ids.count(0) == 20
|
||||
assert split_ids.count(1) == 30
|
||||
|
||||
|
||||
def test_split_random_fixed(mem_db):
|
||||
"""Test random splitting with fixed number of splits."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"x": range(100), "y": range(100)})
|
||||
)
|
||||
permutation_tbl = permutation_builder(tbl).split_random(fixed=4).execute()
|
||||
|
||||
# Check that we have 4 splits with 25 rows each
|
||||
assert permutation_tbl.count_rows() == 100
|
||||
|
||||
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||
split_ids = data["split_id"]
|
||||
assert set(split_ids) == {0, 1, 2, 3}
|
||||
|
||||
for split_id in range(4):
|
||||
assert split_ids.count(split_id) == 25
|
||||
|
||||
|
||||
def test_split_random_with_seed(mem_db):
|
||||
"""Test that seeded random splits are reproducible."""
|
||||
tbl = mem_db.create_table("test_table", pa.table({"x": range(50), "y": range(50)}))
|
||||
|
||||
# Create two identical permutations with same seed
|
||||
perm1 = permutation_builder(tbl).split_random(ratios=[0.6, 0.4], seed=42).execute()
|
||||
|
||||
perm2 = permutation_builder(tbl).split_random(ratios=[0.6, 0.4], seed=42).execute()
|
||||
|
||||
# Results should be identical
|
||||
data1 = perm1.search(None).to_arrow().to_pydict()
|
||||
data2 = perm2.search(None).to_arrow().to_pydict()
|
||||
|
||||
assert data1["row_id"] == data2["row_id"]
|
||||
assert data1["split_id"] == data2["split_id"]
|
||||
|
||||
|
||||
def test_split_hash(mem_db):
|
||||
"""Test hash-based splitting."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table",
|
||||
pa.table(
|
||||
{
|
||||
"id": range(100),
|
||||
"category": (["A", "B", "C"] * 34)[:100], # Repeating pattern
|
||||
"value": range(100),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl)
|
||||
.split_hash(["category"], [1, 1], discard_weight=0)
|
||||
.execute()
|
||||
)
|
||||
|
||||
# Should have all 100 rows (no discard)
|
||||
assert permutation_tbl.count_rows() == 100
|
||||
|
||||
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||
split_ids = data["split_id"]
|
||||
assert set(split_ids) == {0, 1}
|
||||
|
||||
# Verify that each split has roughly 50 rows (allowing for hash variance)
|
||||
split_0_count = split_ids.count(0)
|
||||
split_1_count = split_ids.count(1)
|
||||
assert 30 <= split_0_count <= 70 # ~50 ± 20 tolerance for hash distribution
|
||||
assert 30 <= split_1_count <= 70 # ~50 ± 20 tolerance for hash distribution
|
||||
|
||||
# Hash splits should be deterministic - same category should go to same split
|
||||
# Let's verify by creating another permutation and checking consistency
|
||||
perm2 = (
|
||||
permutation_builder(tbl)
|
||||
.split_hash(["category"], [1, 1], discard_weight=0)
|
||||
.execute()
|
||||
)
|
||||
|
||||
data2 = perm2.search(None).to_arrow().to_pydict()
|
||||
assert data["split_id"] == data2["split_id"] # Should be identical
|
||||
|
||||
|
||||
def test_split_hash_with_discard(mem_db):
|
||||
"""Test hash-based splitting with discard weight."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table",
|
||||
pa.table({"id": range(100), "category": ["A", "B"] * 50, "value": range(100)}),
|
||||
)
|
||||
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl)
|
||||
.split_hash(["category"], [1, 1], discard_weight=2) # Should discard ~50%
|
||||
.execute()
|
||||
)
|
||||
|
||||
# Should have fewer than 100 rows due to discard
|
||||
row_count = permutation_tbl.count_rows()
|
||||
assert row_count < 100
|
||||
assert row_count > 0 # But not empty
|
||||
|
||||
|
||||
def test_split_sequential(mem_db):
|
||||
"""Test sequential splitting."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"x": range(100), "y": range(100)})
|
||||
)
|
||||
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl).split_sequential(counts=[30, 40]).execute()
|
||||
)
|
||||
|
||||
assert permutation_tbl.count_rows() == 70
|
||||
|
||||
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||
row_ids = data["row_id"]
|
||||
split_ids = data["split_id"]
|
||||
|
||||
# Sequential should maintain order
|
||||
assert row_ids == sorted(row_ids)
|
||||
|
||||
# First 30 should be split 0, next 40 should be split 1
|
||||
assert split_ids[:30] == [0] * 30
|
||||
assert split_ids[30:] == [1] * 40
|
||||
|
||||
|
||||
def test_split_calculated(mem_db):
|
||||
"""Test calculated splitting."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"id": range(100), "value": range(100)})
|
||||
)
|
||||
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl)
|
||||
.split_calculated("id % 3") # Split based on id modulo 3
|
||||
.execute()
|
||||
)
|
||||
|
||||
assert permutation_tbl.count_rows() == 100
|
||||
|
||||
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||
row_ids = data["row_id"]
|
||||
split_ids = data["split_id"]
|
||||
|
||||
# Verify the calculation: each row's split_id should equal row_id % 3
|
||||
for i, (row_id, split_id) in enumerate(zip(row_ids, split_ids)):
|
||||
assert split_id == row_id % 3
|
||||
|
||||
|
||||
def test_split_error_cases(mem_db):
|
||||
"""Test error handling for invalid split parameters."""
|
||||
tbl = mem_db.create_table("test_table", pa.table({"x": range(10), "y": range(10)}))
|
||||
|
||||
# Test split_random with no parameters
|
||||
with pytest.raises(Exception):
|
||||
permutation_builder(tbl).split_random().execute()
|
||||
|
||||
# Test split_random with multiple parameters
|
||||
with pytest.raises(Exception):
|
||||
permutation_builder(tbl).split_random(
|
||||
ratios=[0.5, 0.5], counts=[5, 5]
|
||||
).execute()
|
||||
|
||||
# Test split_sequential with no parameters
|
||||
with pytest.raises(Exception):
|
||||
permutation_builder(tbl).split_sequential().execute()
|
||||
|
||||
# Test split_sequential with multiple parameters
|
||||
with pytest.raises(Exception):
|
||||
permutation_builder(tbl).split_sequential(ratios=[0.5, 0.5], fixed=2).execute()
|
||||
|
||||
|
||||
def test_shuffle_no_seed(mem_db):
|
||||
"""Test shuffling without a seed."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"id": range(100), "value": range(100)})
|
||||
)
|
||||
|
||||
# Create a permutation with shuffling (no seed)
|
||||
permutation_tbl = permutation_builder(tbl).shuffle().execute()
|
||||
|
||||
assert permutation_tbl.count_rows() == 100
|
||||
|
||||
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||
row_ids = data["row_id"]
|
||||
|
||||
# Row IDs should not be in sequential order due to shuffling
|
||||
# This is probabilistic but with 100 rows, it's extremely unlikely they'd stay
|
||||
# in order
|
||||
assert row_ids != list(range(100))
|
||||
|
||||
|
||||
def test_shuffle_with_seed(mem_db):
|
||||
"""Test that shuffling with a seed is reproducible."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"id": range(50), "value": range(50)})
|
||||
)
|
||||
|
||||
# Create two identical permutations with same shuffle seed
|
||||
perm1 = permutation_builder(tbl).shuffle(seed=42).execute()
|
||||
|
||||
perm2 = permutation_builder(tbl).shuffle(seed=42).execute()
|
||||
|
||||
# Results should be identical due to same seed
|
||||
data1 = perm1.search(None).to_arrow().to_pydict()
|
||||
data2 = perm2.search(None).to_arrow().to_pydict()
|
||||
|
||||
assert data1["row_id"] == data2["row_id"]
|
||||
assert data1["split_id"] == data2["split_id"]
|
||||
|
||||
|
||||
def test_shuffle_with_clump_size(mem_db):
|
||||
"""Test shuffling with clump size."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"id": range(100), "value": range(100)})
|
||||
)
|
||||
|
||||
# Create a permutation with shuffling using clumps
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl)
|
||||
.shuffle(clump_size=10) # 10-row clumps
|
||||
.execute()
|
||||
)
|
||||
|
||||
assert permutation_tbl.count_rows() == 100
|
||||
|
||||
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||
row_ids = data["row_id"]
|
||||
|
||||
for i in range(10):
|
||||
start = row_ids[i * 10]
|
||||
assert row_ids[i * 10 : (i + 1) * 10] == list(range(start, start + 10))
|
||||
|
||||
|
||||
def test_shuffle_different_seeds(mem_db):
|
||||
"""Test that different seeds produce different shuffle orders."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"id": range(50), "value": range(50)})
|
||||
)
|
||||
|
||||
# Create two permutations with different shuffle seeds
|
||||
perm1 = permutation_builder(tbl).split_random(fixed=2).shuffle(seed=42).execute()
|
||||
|
||||
perm2 = permutation_builder(tbl).split_random(fixed=2).shuffle(seed=123).execute()
|
||||
|
||||
# Results should be different due to different seeds
|
||||
data1 = perm1.search(None).to_arrow().to_pydict()
|
||||
data2 = perm2.search(None).to_arrow().to_pydict()
|
||||
|
||||
# Row order should be different
|
||||
assert data1["row_id"] != data2["row_id"]
|
||||
|
||||
|
||||
def test_shuffle_combined_with_splits(mem_db):
|
||||
"""Test shuffling combined with different split strategies."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table",
|
||||
pa.table(
|
||||
{
|
||||
"id": range(100),
|
||||
"category": (["A", "B", "C"] * 34)[:100],
|
||||
"value": range(100),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
# Test shuffle with random splits
|
||||
perm_random = (
|
||||
permutation_builder(tbl)
|
||||
.split_random(ratios=[0.6, 0.4], seed=42)
|
||||
.shuffle(seed=123, clump_size=None)
|
||||
.execute()
|
||||
)
|
||||
|
||||
# Test shuffle with hash splits
|
||||
perm_hash = (
|
||||
permutation_builder(tbl)
|
||||
.split_hash(["category"], [1, 1], discard_weight=0)
|
||||
.shuffle(seed=456, clump_size=5)
|
||||
.execute()
|
||||
)
|
||||
|
||||
# Test shuffle with sequential splits
|
||||
perm_sequential = (
|
||||
permutation_builder(tbl)
|
||||
.split_sequential(counts=[40, 35])
|
||||
.shuffle(seed=789, clump_size=None)
|
||||
.execute()
|
||||
)
|
||||
|
||||
# Verify all permutations work and have expected properties
|
||||
assert perm_random.count_rows() == 100
|
||||
assert perm_hash.count_rows() == 100
|
||||
assert perm_sequential.count_rows() == 75
|
||||
|
||||
# Verify shuffle affected the order
|
||||
data_random = perm_random.search(None).to_arrow().to_pydict()
|
||||
data_sequential = perm_sequential.search(None).to_arrow().to_pydict()
|
||||
|
||||
assert data_random["row_id"] != list(range(100))
|
||||
assert data_sequential["row_id"] != list(range(75))
|
||||
|
||||
|
||||
def test_no_shuffle_maintains_order(mem_db):
|
||||
"""Test that not calling shuffle maintains the original order."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"id": range(50), "value": range(50)})
|
||||
)
|
||||
|
||||
# Create permutation without shuffle (should maintain some order)
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl)
|
||||
.split_sequential(counts=[25, 25]) # Sequential maintains order
|
||||
.execute()
|
||||
)
|
||||
|
||||
assert permutation_tbl.count_rows() == 50
|
||||
|
||||
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||
row_ids = data["row_id"]
|
||||
|
||||
# With sequential splits and no shuffle, should maintain order
|
||||
assert row_ids == list(range(50))
|
||||
|
||||
|
||||
def test_filter_basic(mem_db):
|
||||
"""Test basic filtering functionality."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"id": range(100), "value": range(100, 200)})
|
||||
)
|
||||
|
||||
# Filter to only include rows where id < 50
|
||||
permutation_tbl = permutation_builder(tbl).filter("id < 50").execute()
|
||||
|
||||
assert permutation_tbl.count_rows() == 50
|
||||
|
||||
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||
row_ids = data["row_id"]
|
||||
|
||||
# All row_ids should be less than 50
|
||||
assert all(row_id < 50 for row_id in row_ids)
|
||||
|
||||
|
||||
def test_filter_with_splits(mem_db):
|
||||
"""Test filtering combined with split strategies."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table",
|
||||
pa.table(
|
||||
{
|
||||
"id": range(100),
|
||||
"category": (["A", "B", "C"] * 34)[:100],
|
||||
"value": range(100),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
# Filter to only category A and B, then split
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl)
|
||||
.filter("category IN ('A', 'B')")
|
||||
.split_random(ratios=[0.5, 0.5])
|
||||
.execute()
|
||||
)
|
||||
|
||||
# Should have fewer than 100 rows due to filtering
|
||||
row_count = permutation_tbl.count_rows()
|
||||
assert row_count == 67
|
||||
|
||||
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||
categories = data["category"]
|
||||
|
||||
# All categories should be A or B
|
||||
assert all(cat in ["A", "B"] for cat in categories)
|
||||
|
||||
|
||||
def test_filter_with_shuffle(mem_db):
|
||||
"""Test filtering combined with shuffling."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table",
|
||||
pa.table(
|
||||
{
|
||||
"id": range(100),
|
||||
"category": (["A", "B", "C", "D"] * 25)[:100],
|
||||
"value": range(100),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
# Filter and shuffle
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl)
|
||||
.filter("category IN ('A', 'C')")
|
||||
.shuffle(seed=42)
|
||||
.execute()
|
||||
)
|
||||
|
||||
row_count = permutation_tbl.count_rows()
|
||||
assert row_count == 50 # Should have 50 rows (A and C categories)
|
||||
|
||||
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||
row_ids = data["row_id"]
|
||||
|
||||
assert row_ids != sorted(row_ids)
|
||||
|
||||
|
||||
def test_filter_empty_result(mem_db):
|
||||
"""Test filtering that results in empty set."""
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"id": range(10), "value": range(10)})
|
||||
)
|
||||
|
||||
# Filter that matches nothing
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl)
|
||||
.filter("value > 100") # No values > 100 in our data
|
||||
.execute()
|
||||
)
|
||||
|
||||
assert permutation_tbl.count_rows() == 0
|
||||
@@ -1298,6 +1298,79 @@ async def test_query_serialization_async(table_async: AsyncTable):
|
||||
)
|
||||
|
||||
|
||||
def test_query_schema(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
tbl = db.create_table(
|
||||
"test",
|
||||
pa.table(
|
||||
{
|
||||
"a": [1, 2, 3],
|
||||
"text": ["a", "b", "c"],
|
||||
"vec": pa.array(
|
||||
[[1, 2], [3, 4], [5, 6]], pa.list_(pa.float32(), list_size=2)
|
||||
),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
assert tbl.search(None).output_schema() == pa.schema(
|
||||
{
|
||||
"a": pa.int64(),
|
||||
"text": pa.string(),
|
||||
"vec": pa.list_(pa.float32(), list_size=2),
|
||||
}
|
||||
)
|
||||
assert tbl.search(None).select({"bl": "a * 2"}).output_schema() == pa.schema(
|
||||
{"bl": pa.int64()}
|
||||
)
|
||||
assert tbl.search([1, 2]).select(["a"]).output_schema() == pa.schema(
|
||||
{"a": pa.int64(), "_distance": pa.float32()}
|
||||
)
|
||||
assert tbl.search("blah").select(["a"]).output_schema() == pa.schema(
|
||||
{"a": pa.int64()}
|
||||
)
|
||||
assert tbl.take_offsets([0]).select(["text"]).output_schema() == pa.schema(
|
||||
{"text": pa.string()}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_schema_async(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
tbl = await db.create_table(
|
||||
"test",
|
||||
pa.table(
|
||||
{
|
||||
"a": [1, 2, 3],
|
||||
"text": ["a", "b", "c"],
|
||||
"vec": pa.array(
|
||||
[[1, 2], [3, 4], [5, 6]], pa.list_(pa.float32(), list_size=2)
|
||||
),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
assert await tbl.query().output_schema() == pa.schema(
|
||||
{
|
||||
"a": pa.int64(),
|
||||
"text": pa.string(),
|
||||
"vec": pa.list_(pa.float32(), list_size=2),
|
||||
}
|
||||
)
|
||||
assert await tbl.query().select({"bl": "a * 2"}).output_schema() == pa.schema(
|
||||
{"bl": pa.int64()}
|
||||
)
|
||||
assert await tbl.vector_search([1, 2]).select(["a"]).output_schema() == pa.schema(
|
||||
{"a": pa.int64(), "_distance": pa.float32()}
|
||||
)
|
||||
assert await (await tbl.search("blah")).select(["a"]).output_schema() == pa.schema(
|
||||
{"a": pa.int64()}
|
||||
)
|
||||
assert await tbl.take_offsets([0]).select(["text"]).output_schema() == pa.schema(
|
||||
{"text": pa.string()}
|
||||
)
|
||||
|
||||
|
||||
def test_query_timeout(tmp_path):
|
||||
# Use local directory instead of memory:// to add a bit of latency to
|
||||
# operations so a timeout of zero will trigger exceptions.
|
||||
|
||||
@@ -4,7 +4,10 @@
|
||||
use std::{collections::HashMap, sync::Arc, time::Duration};
|
||||
|
||||
use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow};
|
||||
use lancedb::{connection::Connection as LanceConnection, database::CreateTableMode};
|
||||
use lancedb::{
|
||||
connection::Connection as LanceConnection,
|
||||
database::{CreateTableMode, ReadConsistency},
|
||||
};
|
||||
use pyo3::{
|
||||
exceptions::{PyRuntimeError, PyValueError},
|
||||
pyclass, pyfunction, pymethods, Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
|
||||
@@ -23,7 +26,7 @@ impl Connection {
|
||||
Self { inner: Some(inner) }
|
||||
}
|
||||
|
||||
fn get_inner(&self) -> PyResult<&LanceConnection> {
|
||||
pub(crate) fn get_inner(&self) -> PyResult<&LanceConnection> {
|
||||
self.inner
|
||||
.as_ref()
|
||||
.ok_or_else(|| PyRuntimeError::new_err("Connection is closed"))
|
||||
@@ -63,6 +66,18 @@ impl Connection {
|
||||
self.get_inner().map(|inner| inner.uri().to_string())
|
||||
}
|
||||
|
||||
#[pyo3(signature = ())]
|
||||
pub fn get_read_consistency_interval(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
Ok(match inner.read_consistency().await.infer_error()? {
|
||||
ReadConsistency::Manual => None,
|
||||
ReadConsistency::Eventual(duration) => Some(duration.as_secs_f64()),
|
||||
ReadConsistency::Strong => Some(0.0_f64),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (namespace=vec![], start_after=None, limit=None))]
|
||||
pub fn table_names(
|
||||
self_: PyRef<'_, Self>,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use lancedb::index::vector::IvfFlatIndexBuilder;
|
||||
use lancedb::index::vector::{IvfFlatIndexBuilder, IvfRqIndexBuilder};
|
||||
use lancedb::index::{
|
||||
scalar::{BTreeIndexBuilder, FtsIndexBuilder},
|
||||
vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
|
||||
@@ -87,6 +87,22 @@ pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<Lance
|
||||
}
|
||||
Ok(LanceDbIndex::IvfPq(ivf_pq_builder))
|
||||
},
|
||||
"IvfRq" => {
|
||||
let params = source.extract::<IvfRqParams>()?;
|
||||
let distance_type = parse_distance_type(params.distance_type)?;
|
||||
let mut ivf_rq_builder = IvfRqIndexBuilder::default()
|
||||
.distance_type(distance_type)
|
||||
.max_iterations(params.max_iterations)
|
||||
.sample_rate(params.sample_rate)
|
||||
.num_bits(params.num_bits);
|
||||
if let Some(num_partitions) = params.num_partitions {
|
||||
ivf_rq_builder = ivf_rq_builder.num_partitions(num_partitions);
|
||||
}
|
||||
if let Some(target_partition_size) = params.target_partition_size {
|
||||
ivf_rq_builder = ivf_rq_builder.target_partition_size(target_partition_size);
|
||||
}
|
||||
Ok(LanceDbIndex::IvfRq(ivf_rq_builder))
|
||||
},
|
||||
"HnswPq" => {
|
||||
let params = source.extract::<IvfHnswPqParams>()?;
|
||||
let distance_type = parse_distance_type(params.distance_type)?;
|
||||
@@ -170,6 +186,16 @@ struct IvfPqParams {
|
||||
target_partition_size: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(FromPyObject)]
|
||||
struct IvfRqParams {
|
||||
distance_type: String,
|
||||
num_partitions: Option<u32>,
|
||||
num_bits: u32,
|
||||
max_iterations: u32,
|
||||
sample_rate: u32,
|
||||
target_partition_size: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(FromPyObject)]
|
||||
struct IvfHnswPqParams {
|
||||
distance_type: String,
|
||||
|
||||
@@ -5,6 +5,7 @@ use arrow::RecordBatchStream;
|
||||
use connection::{connect, Connection};
|
||||
use env_logger::Env;
|
||||
use index::IndexConfig;
|
||||
use permutation::PyAsyncPermutationBuilder;
|
||||
use pyo3::{
|
||||
pymodule,
|
||||
types::{PyModule, PyModuleMethods},
|
||||
@@ -22,6 +23,7 @@ pub mod connection;
|
||||
pub mod error;
|
||||
pub mod header;
|
||||
pub mod index;
|
||||
pub mod permutation;
|
||||
pub mod query;
|
||||
pub mod session;
|
||||
pub mod table;
|
||||
@@ -49,7 +51,9 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<DeleteResult>()?;
|
||||
m.add_class::<DropColumnsResult>()?;
|
||||
m.add_class::<UpdateResult>()?;
|
||||
m.add_class::<PyAsyncPermutationBuilder>()?;
|
||||
m.add_function(wrap_pyfunction!(connect, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(permutation::async_permutation_builder, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;
|
||||
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
||||
Ok(())
|
||||
|
||||
170
python/src/permutation.rs
Normal file
170
python/src/permutation.rs
Normal file
@@ -0,0 +1,170 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::{error::PythonErrorExt, table::Table};
|
||||
use lancedb::dataloader::{
|
||||
permutation::builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
|
||||
permutation::split::{SplitSizes, SplitStrategy},
|
||||
};
|
||||
use pyo3::{
|
||||
exceptions::PyRuntimeError, pyclass, pymethods, types::PyAnyMethods, Bound, PyAny, PyRefMut,
|
||||
PyResult,
|
||||
};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
/// Create a permutation builder for the given table
|
||||
#[pyo3::pyfunction]
|
||||
pub fn async_permutation_builder(table: Bound<'_, PyAny>) -> PyResult<PyAsyncPermutationBuilder> {
|
||||
let table = table.getattr("_inner")?.downcast_into::<Table>()?;
|
||||
let inner_table = table.borrow().inner_ref()?.clone();
|
||||
let inner_builder = LancePermutationBuilder::new(inner_table);
|
||||
|
||||
Ok(PyAsyncPermutationBuilder {
|
||||
state: Arc::new(Mutex::new(PyAsyncPermutationBuilderState {
|
||||
builder: Some(inner_builder),
|
||||
})),
|
||||
})
|
||||
}
|
||||
|
||||
struct PyAsyncPermutationBuilderState {
|
||||
builder: Option<LancePermutationBuilder>,
|
||||
}
|
||||
|
||||
#[pyclass(name = "AsyncPermutationBuilder")]
|
||||
pub struct PyAsyncPermutationBuilder {
|
||||
state: Arc<Mutex<PyAsyncPermutationBuilderState>>,
|
||||
}
|
||||
|
||||
impl PyAsyncPermutationBuilder {
|
||||
fn modify(
|
||||
&self,
|
||||
func: impl FnOnce(LancePermutationBuilder) -> LancePermutationBuilder,
|
||||
) -> PyResult<Self> {
|
||||
let mut state = self.state.lock().unwrap();
|
||||
let builder = state
|
||||
.builder
|
||||
.take()
|
||||
.ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?;
|
||||
state.builder = Some(func(builder));
|
||||
Ok(Self {
|
||||
state: self.state.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyAsyncPermutationBuilder {
|
||||
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None, seed=None))]
|
||||
pub fn split_random(
|
||||
slf: PyRefMut<'_, Self>,
|
||||
ratios: Option<Vec<f64>>,
|
||||
counts: Option<Vec<u64>>,
|
||||
fixed: Option<u64>,
|
||||
seed: Option<u64>,
|
||||
) -> PyResult<Self> {
|
||||
// Check that exactly one split type is provided
|
||||
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
|
||||
.iter()
|
||||
.filter(|&&x| x)
|
||||
.count();
|
||||
|
||||
if split_args_count != 1 {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
||||
));
|
||||
}
|
||||
|
||||
let sizes = if let Some(ratios) = ratios {
|
||||
SplitSizes::Percentages(ratios)
|
||||
} else if let Some(counts) = counts {
|
||||
SplitSizes::Counts(counts)
|
||||
} else if let Some(fixed) = fixed {
|
||||
SplitSizes::Fixed(fixed)
|
||||
} else {
|
||||
unreachable!("One of the split arguments must be provided");
|
||||
};
|
||||
|
||||
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes }))
|
||||
}
|
||||
|
||||
#[pyo3(signature = (columns, split_weights, *, discard_weight=0))]
|
||||
pub fn split_hash(
|
||||
slf: PyRefMut<'_, Self>,
|
||||
columns: Vec<String>,
|
||||
split_weights: Vec<u64>,
|
||||
discard_weight: u64,
|
||||
) -> PyResult<Self> {
|
||||
slf.modify(|builder| {
|
||||
builder.with_split_strategy(SplitStrategy::Hash {
|
||||
columns,
|
||||
split_weights,
|
||||
discard_weight,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None))]
|
||||
pub fn split_sequential(
|
||||
slf: PyRefMut<'_, Self>,
|
||||
ratios: Option<Vec<f64>>,
|
||||
counts: Option<Vec<u64>>,
|
||||
fixed: Option<u64>,
|
||||
) -> PyResult<Self> {
|
||||
// Check that exactly one split type is provided
|
||||
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
|
||||
.iter()
|
||||
.filter(|&&x| x)
|
||||
.count();
|
||||
|
||||
if split_args_count != 1 {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
||||
));
|
||||
}
|
||||
|
||||
let sizes = if let Some(ratios) = ratios {
|
||||
SplitSizes::Percentages(ratios)
|
||||
} else if let Some(counts) = counts {
|
||||
SplitSizes::Counts(counts)
|
||||
} else if let Some(fixed) = fixed {
|
||||
SplitSizes::Fixed(fixed)
|
||||
} else {
|
||||
unreachable!("One of the split arguments must be provided");
|
||||
};
|
||||
|
||||
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes }))
|
||||
}
|
||||
|
||||
pub fn split_calculated(slf: PyRefMut<'_, Self>, calculation: String) -> PyResult<Self> {
|
||||
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Calculated { calculation }))
|
||||
}
|
||||
|
||||
pub fn shuffle(
|
||||
slf: PyRefMut<'_, Self>,
|
||||
seed: Option<u64>,
|
||||
clump_size: Option<u64>,
|
||||
) -> PyResult<Self> {
|
||||
slf.modify(|builder| {
|
||||
builder.with_shuffle_strategy(ShuffleStrategy::Random { seed, clump_size })
|
||||
})
|
||||
}
|
||||
|
||||
pub fn filter(slf: PyRefMut<'_, Self>, filter: String) -> PyResult<Self> {
|
||||
slf.modify(|builder| builder.with_filter(filter))
|
||||
}
|
||||
|
||||
pub fn execute(slf: PyRefMut<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let mut state = slf.state.lock().unwrap();
|
||||
let builder = state
|
||||
.builder
|
||||
.take()
|
||||
.ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?;
|
||||
|
||||
future_into_py(slf.py(), async move {
|
||||
let table = builder.build().await.infer_error()?;
|
||||
Ok(Table::new(table))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,7 @@ use arrow::array::Array;
|
||||
use arrow::array::ArrayData;
|
||||
use arrow::pyarrow::FromPyArrow;
|
||||
use arrow::pyarrow::IntoPyArrow;
|
||||
use arrow::pyarrow::ToPyArrow;
|
||||
use lancedb::index::scalar::{
|
||||
BooleanQuery, BoostQuery, FtsQuery, FullTextSearchQuery, MatchQuery, MultiMatchQuery, Occur,
|
||||
Operator, PhraseQuery,
|
||||
@@ -30,6 +31,7 @@ use pyo3::IntoPyObject;
|
||||
use pyo3::PyAny;
|
||||
use pyo3::PyRef;
|
||||
use pyo3::PyResult;
|
||||
use pyo3::Python;
|
||||
use pyo3::{exceptions::PyRuntimeError, FromPyObject};
|
||||
use pyo3::{
|
||||
exceptions::{PyNotImplementedError, PyValueError},
|
||||
@@ -445,6 +447,15 @@ impl Query {
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = ())]
|
||||
pub fn output_schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let schema = inner.output_schema().await.infer_error()?;
|
||||
Python::with_gil(|py| schema.to_pyarrow(py))
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (max_batch_length=None, timeout=None))]
|
||||
pub fn execute(
|
||||
self_: PyRef<'_, Self>,
|
||||
@@ -515,6 +526,15 @@ impl TakeQuery {
|
||||
self.inner = self.inner.clone().with_row_id();
|
||||
}
|
||||
|
||||
#[pyo3(signature = ())]
|
||||
pub fn output_schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let schema = inner.output_schema().await.infer_error()?;
|
||||
Python::with_gil(|py| schema.to_pyarrow(py))
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (max_batch_length=None, timeout=None))]
|
||||
pub fn execute(
|
||||
self_: PyRef<'_, Self>,
|
||||
@@ -601,6 +621,15 @@ impl FTSQuery {
|
||||
self.inner = self.inner.clone().postfilter();
|
||||
}
|
||||
|
||||
#[pyo3(signature = ())]
|
||||
pub fn output_schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let schema = inner.output_schema().await.infer_error()?;
|
||||
Python::with_gil(|py| schema.to_pyarrow(py))
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (max_batch_length=None, timeout=None))]
|
||||
pub fn execute(
|
||||
self_: PyRef<'_, Self>,
|
||||
@@ -771,6 +800,15 @@ impl VectorQuery {
|
||||
self.inner = self.inner.clone().bypass_vector_index()
|
||||
}
|
||||
|
||||
#[pyo3(signature = ())]
|
||||
pub fn output_schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let schema = inner.output_schema().await.infer_error()?;
|
||||
Python::with_gil(|py| schema.to_pyarrow(py))
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (max_batch_length=None, timeout=None))]
|
||||
pub fn execute(
|
||||
self_: PyRef<'_, Self>,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use crate::{
|
||||
connection::Connection,
|
||||
error::PythonErrorExt,
|
||||
index::{extract_index_params, IndexConfig},
|
||||
query::{Query, TakeQuery},
|
||||
@@ -249,7 +250,7 @@ impl Table {
|
||||
}
|
||||
|
||||
impl Table {
|
||||
fn inner_ref(&self) -> PyResult<&LanceDbTable> {
|
||||
pub(crate) fn inner_ref(&self) -> PyResult<&LanceDbTable> {
|
||||
self.inner
|
||||
.as_ref()
|
||||
.ok_or_else(|| PyRuntimeError::new_err(format!("Table {} is closed", self.name)))
|
||||
@@ -272,6 +273,13 @@ impl Table {
|
||||
self.inner.take();
|
||||
}
|
||||
|
||||
pub fn database(&self) -> PyResult<Connection> {
|
||||
let inner = self.inner_ref()?.clone();
|
||||
let inner_connection =
|
||||
lancedb::Connection::new(inner.database().clone(), inner.embedding_registry().clone());
|
||||
Ok(Connection::new(inner_connection))
|
||||
}
|
||||
|
||||
pub fn schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
[toolchain]
|
||||
channel = "1.86.0"
|
||||
channel = "1.90.0"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb"
|
||||
version = "0.22.2-beta.1"
|
||||
version = "0.22.3-beta.0"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
@@ -11,10 +11,12 @@ rust-version.workspace = true
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
[dependencies]
|
||||
ahash = { workspace = true }
|
||||
arrow = { workspace = true }
|
||||
arrow-array = { workspace = true }
|
||||
arrow-data = { workspace = true }
|
||||
arrow-schema = { workspace = true }
|
||||
arrow-select = { workspace = true }
|
||||
arrow-ord = { workspace = true }
|
||||
arrow-cast = { workspace = true }
|
||||
arrow-ipc.workspace = true
|
||||
@@ -24,12 +26,16 @@ datafusion-common.workspace = true
|
||||
datafusion-execution.workspace = true
|
||||
datafusion-expr.workspace = true
|
||||
datafusion-physical-plan.workspace = true
|
||||
datafusion.workspace = true
|
||||
object_store = { workspace = true }
|
||||
snafu = { workspace = true }
|
||||
half = { workspace = true }
|
||||
lazy_static.workspace = true
|
||||
lance = { workspace = true }
|
||||
lance-core = { workspace = true }
|
||||
lance-datafusion.workspace = true
|
||||
lance-datagen = { workspace = true }
|
||||
lance-file = { workspace = true }
|
||||
lance-io = { workspace = true }
|
||||
lance-index = { workspace = true }
|
||||
lance-table = { workspace = true }
|
||||
@@ -37,6 +43,7 @@ lance-linalg = { workspace = true }
|
||||
lance-testing = { workspace = true }
|
||||
lance-encoding = { workspace = true }
|
||||
lance-namespace = { workspace = true }
|
||||
lance-namespace-impls = { workspace = true, features = ["dir", "rest"] }
|
||||
moka = { workspace = true }
|
||||
pin-project = { workspace = true }
|
||||
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
||||
@@ -46,11 +53,13 @@ bytes = "1"
|
||||
futures.workspace = true
|
||||
num-traits.workspace = true
|
||||
url.workspace = true
|
||||
rand.workspace = true
|
||||
regex.workspace = true
|
||||
serde = { version = "^1" }
|
||||
serde_json = { version = "1" }
|
||||
async-openai = { version = "0.20.0", optional = true }
|
||||
serde_with = { version = "3.8.1" }
|
||||
tempfile = "3.5.0"
|
||||
aws-sdk-bedrockruntime = { version = "1.27.0", optional = true }
|
||||
# For remote feature
|
||||
reqwest = { version = "0.12.0", default-features = false, features = [
|
||||
@@ -61,9 +70,8 @@ reqwest = { version = "0.12.0", default-features = false, features = [
|
||||
"macos-system-configuration",
|
||||
"stream",
|
||||
], optional = true }
|
||||
rand = { version = "0.9", features = ["small_rng"], optional = true }
|
||||
http = { version = "1", optional = true } # Matching what is in reqwest
|
||||
uuid = { version = "1.7.0", features = ["v4"], optional = true }
|
||||
uuid = { version = "1.7.0", features = ["v4"] }
|
||||
polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
|
||||
polars = { version = ">=0.37,<0.40.0", optional = true }
|
||||
hf-hub = { version = "0.4.1", optional = true, default-features = false, features = [
|
||||
@@ -84,7 +92,6 @@ bytemuck_derive.workspace = true
|
||||
[dev-dependencies]
|
||||
anyhow = "1"
|
||||
tempfile = "3.5.0"
|
||||
rand = { version = "0.9", features = ["small_rng"] }
|
||||
random_word = { version = "0.4.3", features = ["en"] }
|
||||
uuid = { version = "1.7.0", features = ["v4"] }
|
||||
walkdir = "2"
|
||||
@@ -96,6 +103,7 @@ aws-smithy-runtime = { version = "1.9.1" }
|
||||
datafusion.workspace = true
|
||||
http-body = "1" # Matching reqwest
|
||||
rstest = "0.23.0"
|
||||
test-log = "0.2"
|
||||
|
||||
|
||||
[features]
|
||||
@@ -105,7 +113,7 @@ oss = ["lance/oss", "lance-io/oss"]
|
||||
gcs = ["lance/gcp", "lance-io/gcp"]
|
||||
azure = ["lance/azure", "lance-io/azure"]
|
||||
dynamodb = ["lance/dynamodb", "aws"]
|
||||
remote = ["dep:reqwest", "dep:http", "dep:rand", "dep:uuid"]
|
||||
remote = ["dep:reqwest", "dep:http"]
|
||||
fp16kernels = ["lance-linalg/fp16kernels"]
|
||||
s3-test = []
|
||||
bedrock = ["dep:aws-sdk-bedrockruntime"]
|
||||
|
||||
@@ -7,6 +7,7 @@ pub use arrow_schema;
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
|
||||
use futures::{Stream, StreamExt, TryStreamExt};
|
||||
use lance_datagen::{BatchCount, BatchGeneratorBuilder, RowCount};
|
||||
|
||||
#[cfg(feature = "polars")]
|
||||
use {crate::polars_arrow_convertors, polars::frame::ArrowChunk, polars::prelude::DataFrame};
|
||||
@@ -161,6 +162,26 @@ impl IntoArrowStream for datafusion_physical_plan::SendableRecordBatchStream {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait LanceDbDatagenExt {
|
||||
fn into_ldb_stream(
|
||||
self,
|
||||
batch_size: RowCount,
|
||||
num_batches: BatchCount,
|
||||
) -> SendableRecordBatchStream;
|
||||
}
|
||||
|
||||
impl LanceDbDatagenExt for BatchGeneratorBuilder {
|
||||
fn into_ldb_stream(
|
||||
self,
|
||||
batch_size: RowCount,
|
||||
num_batches: BatchCount,
|
||||
) -> SendableRecordBatchStream {
|
||||
let (stream, schema) = self.into_reader_stream(batch_size, num_batches);
|
||||
let stream = stream.map_err(|err| Error::Arrow { source: err });
|
||||
Box::pin(SimpleRecordBatchStream::new(stream, schema))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "polars")]
|
||||
/// An iterator of record batches formed from a Polars DataFrame.
|
||||
pub struct PolarsDataFrameRecordBatchReader {
|
||||
|
||||
@@ -19,7 +19,7 @@ use crate::database::listing::{
|
||||
use crate::database::{
|
||||
CloneTableRequest, CreateNamespaceRequest, CreateTableData, CreateTableMode,
|
||||
CreateTableRequest, Database, DatabaseOptions, DropNamespaceRequest, ListNamespacesRequest,
|
||||
OpenTableRequest, TableNamesRequest,
|
||||
OpenTableRequest, ReadConsistency, TableNamesRequest,
|
||||
};
|
||||
use crate::embeddings::{
|
||||
EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry, MemoryRegistry, WithEmbeddings,
|
||||
@@ -152,6 +152,7 @@ impl CreateTableBuilder<true> {
|
||||
let request = self.into_request()?;
|
||||
Ok(Table::new_with_embedding_registry(
|
||||
parent.create_table(request).await?,
|
||||
parent,
|
||||
embedding_registry,
|
||||
))
|
||||
}
|
||||
@@ -211,9 +212,9 @@ impl CreateTableBuilder<false> {
|
||||
|
||||
/// Execute the create table operation
|
||||
pub async fn execute(self) -> Result<Table> {
|
||||
Ok(Table::new(
|
||||
self.parent.clone().create_table(self.request).await?,
|
||||
))
|
||||
let parent = self.parent.clone();
|
||||
let table = parent.create_table(self.request).await?;
|
||||
Ok(Table::new(table, parent))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -462,8 +463,10 @@ impl OpenTableBuilder {
|
||||
|
||||
/// Open the table
|
||||
pub async fn execute(self) -> Result<Table> {
|
||||
let table = self.parent.open_table(self.request).await?;
|
||||
Ok(Table::new_with_embedding_registry(
|
||||
self.parent.clone().open_table(self.request).await?,
|
||||
table,
|
||||
self.parent,
|
||||
self.embedding_registry,
|
||||
))
|
||||
}
|
||||
@@ -519,16 +522,15 @@ impl CloneTableBuilder {
|
||||
|
||||
/// Execute the clone operation
|
||||
pub async fn execute(self) -> Result<Table> {
|
||||
Ok(Table::new(
|
||||
self.parent.clone().clone_table(self.request).await?,
|
||||
))
|
||||
let parent = self.parent.clone();
|
||||
let table = parent.clone_table(self.request).await?;
|
||||
Ok(Table::new(table, parent))
|
||||
}
|
||||
}
|
||||
|
||||
/// A connection to LanceDB
|
||||
#[derive(Clone)]
|
||||
pub struct Connection {
|
||||
uri: String,
|
||||
internal: Arc<dyn Database>,
|
||||
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
||||
}
|
||||
@@ -540,9 +542,19 @@ impl std::fmt::Display for Connection {
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
pub fn new(
|
||||
internal: Arc<dyn Database>,
|
||||
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
||||
) -> Self {
|
||||
Self {
|
||||
internal,
|
||||
embedding_registry,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the URI of the connection
|
||||
pub fn uri(&self) -> &str {
|
||||
self.uri.as_str()
|
||||
self.internal.uri()
|
||||
}
|
||||
|
||||
/// Get access to the underlying database
|
||||
@@ -675,6 +687,11 @@ impl Connection {
|
||||
.await
|
||||
}
|
||||
|
||||
/// Get the read consistency of the connection
|
||||
pub async fn read_consistency(&self) -> Result<ReadConsistency> {
|
||||
self.internal.read_consistency().await
|
||||
}
|
||||
|
||||
/// Drop a table in the database.
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -973,7 +990,6 @@ impl ConnectBuilder {
|
||||
)?);
|
||||
Ok(Connection {
|
||||
internal,
|
||||
uri: self.request.uri,
|
||||
embedding_registry: self
|
||||
.embedding_registry
|
||||
.unwrap_or_else(|| Arc::new(MemoryRegistry::new())),
|
||||
@@ -996,7 +1012,6 @@ impl ConnectBuilder {
|
||||
let internal = Arc::new(ListingDatabase::connect_with_options(&self.request).await?);
|
||||
Ok(Connection {
|
||||
internal,
|
||||
uri: self.request.uri,
|
||||
embedding_registry: self
|
||||
.embedding_registry
|
||||
.unwrap_or_else(|| Arc::new(MemoryRegistry::new())),
|
||||
@@ -1104,7 +1119,6 @@ impl ConnectNamespaceBuilder {
|
||||
|
||||
Ok(Connection {
|
||||
internal,
|
||||
uri: format!("namespace://{}", self.ns_impl),
|
||||
embedding_registry: self
|
||||
.embedding_registry
|
||||
.unwrap_or_else(|| Arc::new(MemoryRegistry::new())),
|
||||
@@ -1139,7 +1153,6 @@ mod test_utils {
|
||||
let internal = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
|
||||
Self {
|
||||
internal,
|
||||
uri: "db://test".to_string(),
|
||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||
}
|
||||
}
|
||||
@@ -1156,7 +1169,6 @@ mod test_utils {
|
||||
));
|
||||
Self {
|
||||
internal,
|
||||
uri: "db://test".to_string(),
|
||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||
}
|
||||
}
|
||||
@@ -1170,7 +1182,7 @@ mod tests {
|
||||
use crate::database::listing::{ListingDatabaseOptions, NewTableConfig};
|
||||
use crate::query::QueryBase;
|
||||
use crate::query::{ExecutableQuery, QueryExecutionOptions};
|
||||
use crate::test_connection::test_utils::new_test_connection;
|
||||
use crate::test_utils::connection::new_test_connection;
|
||||
use arrow::compute::concat_batches;
|
||||
use arrow_array::RecordBatchReader;
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
@@ -1187,7 +1199,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_connect() {
|
||||
let tc = new_test_connection().await.unwrap();
|
||||
assert_eq!(tc.connection.uri, tc.uri);
|
||||
assert_eq!(tc.connection.uri(), tc.uri);
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
@@ -1208,7 +1220,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(db.uri, relative_uri.to_str().unwrap().to_string());
|
||||
assert_eq!(db.uri(), relative_uri.to_str().unwrap().to_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use arrow_array::RecordBatchReader;
|
||||
use async_trait::async_trait;
|
||||
@@ -213,6 +214,20 @@ impl CloneTableRequest {
|
||||
}
|
||||
}
|
||||
|
||||
/// How long until a change is reflected from one Table instance to another
|
||||
///
|
||||
/// Tables are always internally consistent. If a write method is called on
|
||||
/// a table instance it will be immediately visible in that same table instance.
|
||||
pub enum ReadConsistency {
|
||||
/// Changes will not be automatically propagated until the checkout_latest
|
||||
/// method is called on the target table
|
||||
Manual,
|
||||
/// Changes will be propagated automatically within the given duration
|
||||
Eventual(Duration),
|
||||
/// Changes are immediately visible in target tables
|
||||
Strong,
|
||||
}
|
||||
|
||||
/// The `Database` trait defines the interface for database implementations.
|
||||
///
|
||||
/// A database is responsible for managing tables and their metadata.
|
||||
@@ -220,6 +235,10 @@ impl CloneTableRequest {
|
||||
pub trait Database:
|
||||
Send + Sync + std::any::Any + std::fmt::Debug + std::fmt::Display + 'static
|
||||
{
|
||||
/// Get the uri of the database
|
||||
fn uri(&self) -> &str;
|
||||
/// Get the read consistency of the database
|
||||
async fn read_consistency(&self) -> Result<ReadConsistency>;
|
||||
/// List immediate child namespace names in the given namespace
|
||||
async fn list_namespaces(&self, request: ListNamespacesRequest) -> Result<Vec<String>>;
|
||||
/// Create a new namespace
|
||||
|
||||
@@ -17,6 +17,7 @@ use object_store::local::LocalFileSystem;
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::connection::ConnectRequest;
|
||||
use crate::database::ReadConsistency;
|
||||
use crate::error::{CreateDirSnafu, Error, Result};
|
||||
use crate::io::object_store::MirroringObjectStoreWrapper;
|
||||
use crate::table::NativeTable;
|
||||
@@ -598,6 +599,22 @@ impl Database for ListingDatabase {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
fn uri(&self) -> &str {
|
||||
&self.uri
|
||||
}
|
||||
|
||||
async fn read_consistency(&self) -> Result<ReadConsistency> {
|
||||
if let Some(read_consistency_inverval) = self.read_consistency_interval {
|
||||
if read_consistency_inverval.is_zero() {
|
||||
Ok(ReadConsistency::Strong)
|
||||
} else {
|
||||
Ok(ReadConsistency::Eventual(read_consistency_inverval))
|
||||
}
|
||||
} else {
|
||||
Ok(ReadConsistency::Manual)
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_namespace(&self, _request: CreateNamespaceRequest) -> Result<()> {
|
||||
Err(Error::NotSupported {
|
||||
message: "Namespace operations are not supported for listing database".into(),
|
||||
@@ -1249,7 +1266,8 @@ mod tests {
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let source_table_obj = Table::new(source_table.clone());
|
||||
let db = Arc::new(db);
|
||||
let source_table_obj = Table::new(source_table.clone(), db.clone());
|
||||
source_table_obj
|
||||
.add(Box::new(arrow_array::RecordBatchIterator::new(
|
||||
vec![Ok(batch2)],
|
||||
@@ -1320,7 +1338,8 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
// Create a tag for the current version
|
||||
let source_table_obj = Table::new(source_table.clone());
|
||||
let db = Arc::new(db);
|
||||
let source_table_obj = Table::new(source_table.clone(), db.clone());
|
||||
let mut tags = source_table_obj.tags().await.unwrap();
|
||||
tags.create("v1.0", source_table.version().await.unwrap())
|
||||
.await
|
||||
@@ -1336,7 +1355,7 @@ mod tests {
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let source_table_obj = Table::new(source_table.clone());
|
||||
let source_table_obj = Table::new(source_table.clone(), db.clone());
|
||||
source_table_obj
|
||||
.add(Box::new(arrow_array::RecordBatchIterator::new(
|
||||
vec![Ok(batch2)],
|
||||
@@ -1432,7 +1451,8 @@ mod tests {
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let cloned_table_obj = Table::new(cloned_table.clone());
|
||||
let db = Arc::new(db);
|
||||
let cloned_table_obj = Table::new(cloned_table.clone(), db.clone());
|
||||
cloned_table_obj
|
||||
.add(Box::new(arrow_array::RecordBatchIterator::new(
|
||||
vec![Ok(batch_clone)],
|
||||
@@ -1452,7 +1472,7 @@ mod tests {
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let source_table_obj = Table::new(source_table.clone());
|
||||
let source_table_obj = Table::new(source_table.clone(), db);
|
||||
source_table_obj
|
||||
.add(Box::new(arrow_array::RecordBatchIterator::new(
|
||||
vec![Ok(batch_source)],
|
||||
@@ -1495,6 +1515,7 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
// Add more data to create new versions
|
||||
let db = Arc::new(db);
|
||||
for i in 0..3 {
|
||||
let batch = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
@@ -1502,7 +1523,7 @@ mod tests {
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let source_table_obj = Table::new(source_table.clone());
|
||||
let source_table_obj = Table::new(source_table.clone(), db.clone());
|
||||
source_table_obj
|
||||
.add(Box::new(arrow_array::RecordBatchIterator::new(
|
||||
vec![Ok(batch)],
|
||||
|
||||
@@ -8,17 +8,17 @@ use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use lance_namespace::{
|
||||
connect as connect_namespace,
|
||||
models::{
|
||||
CreateEmptyTableRequest, CreateNamespaceRequest, DescribeTableRequest,
|
||||
DropNamespaceRequest, DropTableRequest, ListNamespacesRequest, ListTablesRequest,
|
||||
},
|
||||
LanceNamespace,
|
||||
};
|
||||
use lance_namespace_impls::connect::connect as connect_namespace;
|
||||
|
||||
use crate::connection::ConnectRequest;
|
||||
use crate::database::listing::ListingDatabase;
|
||||
use crate::error::{Error, Result};
|
||||
use crate::{connection::ConnectRequest, database::ReadConsistency};
|
||||
|
||||
use super::{
|
||||
BaseTable, CloneTableRequest, CreateNamespaceRequest as DbCreateNamespaceRequest,
|
||||
@@ -36,6 +36,8 @@ pub struct LanceNamespaceDatabase {
|
||||
read_consistency_interval: Option<std::time::Duration>,
|
||||
// Optional session for object stores and caching
|
||||
session: Option<Arc<lance::session::Session>>,
|
||||
// database URI
|
||||
uri: String,
|
||||
}
|
||||
|
||||
impl LanceNamespaceDatabase {
|
||||
@@ -57,6 +59,7 @@ impl LanceNamespaceDatabase {
|
||||
storage_options,
|
||||
read_consistency_interval,
|
||||
session,
|
||||
uri: format!("namespace://{}", ns_impl),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -130,6 +133,22 @@ impl std::fmt::Display for LanceNamespaceDatabase {
|
||||
|
||||
#[async_trait]
|
||||
impl Database for LanceNamespaceDatabase {
|
||||
fn uri(&self) -> &str {
|
||||
&self.uri
|
||||
}
|
||||
|
||||
async fn read_consistency(&self) -> Result<ReadConsistency> {
|
||||
if let Some(read_consistency_inverval) = self.read_consistency_interval {
|
||||
if read_consistency_inverval.is_zero() {
|
||||
Ok(ReadConsistency::Strong)
|
||||
} else {
|
||||
Ok(ReadConsistency::Eventual(read_consistency_inverval))
|
||||
}
|
||||
} else {
|
||||
Ok(ReadConsistency::Manual)
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_namespaces(&self, request: DbListNamespacesRequest) -> Result<Vec<String>> {
|
||||
let ns_request = ListNamespacesRequest {
|
||||
id: if request.namespace.is_empty() {
|
||||
@@ -261,7 +280,7 @@ impl Database for LanceNamespaceDatabase {
|
||||
return listing_db
|
||||
.open_table(OpenTableRequest {
|
||||
name: request.name.clone(),
|
||||
namespace: request.namespace.clone(),
|
||||
namespace: vec![],
|
||||
index_cache_size: None,
|
||||
lance_read_params: None,
|
||||
})
|
||||
@@ -305,7 +324,14 @@ impl Database for LanceNamespaceDatabase {
|
||||
)
|
||||
.await?;
|
||||
|
||||
listing_db.create_table(request).await
|
||||
let create_request = DbCreateTableRequest {
|
||||
name: request.name,
|
||||
namespace: vec![],
|
||||
data: request.data,
|
||||
mode: request.mode,
|
||||
write_options: request.write_options,
|
||||
};
|
||||
listing_db.create_table(create_request).await
|
||||
}
|
||||
|
||||
async fn open_table(&self, request: OpenTableRequest) -> Result<Arc<dyn BaseTable>> {
|
||||
@@ -332,7 +358,13 @@ impl Database for LanceNamespaceDatabase {
|
||||
.create_listing_database(&request.name, &location, response.storage_options)
|
||||
.await?;
|
||||
|
||||
listing_db.open_table(request).await
|
||||
let open_request = OpenTableRequest {
|
||||
name: request.name.clone(),
|
||||
namespace: vec![],
|
||||
index_cache_size: request.index_cache_size,
|
||||
lance_read_params: request.lance_read_params,
|
||||
};
|
||||
listing_db.open_table(open_request).await
|
||||
}
|
||||
|
||||
async fn clone_table(&self, _request: CloneTableRequest) -> Result<Arc<dyn BaseTable>> {
|
||||
|
||||
4
rust/lancedb/src/dataloader.rs
Normal file
4
rust/lancedb/src/dataloader.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
pub mod permutation;
|
||||
18
rust/lancedb/src/dataloader/permutation.rs
Normal file
18
rust/lancedb/src/dataloader/permutation.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! Contains the [PermutationBuilder] to create a permutation "view" of an existing table.
|
||||
//!
|
||||
//! A permutation view can apply a filter, divide the data into splits, and shuffle the data.
|
||||
//! The permutation table only stores the split ids and row ids. It is not a materialized copy of
|
||||
//! the underlying data and can be very lightweight.
|
||||
//!
|
||||
//! Building a permutation table should be fairly quick (it is an O(N) operation where N is
|
||||
//! the number of rows in the base table) and memory efficient, even for billions or trillions
|
||||
//! of rows.
|
||||
|
||||
pub mod builder;
|
||||
pub mod reader;
|
||||
pub mod shuffle;
|
||||
pub mod split;
|
||||
pub mod util;
|
||||
326
rust/lancedb/src/dataloader/permutation/builder.rs
Normal file
326
rust/lancedb/src/dataloader/permutation/builder.rs
Normal file
@@ -0,0 +1,326 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion::prelude::{SessionConfig, SessionContext};
|
||||
use datafusion_execution::{disk_manager::DiskManagerBuilder, runtime_env::RuntimeEnvBuilder};
|
||||
use datafusion_expr::col;
|
||||
use futures::TryStreamExt;
|
||||
use lance_core::ROW_ID;
|
||||
use lance_datafusion::exec::SessionContextExt;
|
||||
|
||||
use crate::{
|
||||
arrow::{SendableRecordBatchStream, SendableRecordBatchStreamExt, SimpleRecordBatchStream},
|
||||
connect,
|
||||
database::{CreateTableData, CreateTableRequest, Database},
|
||||
dataloader::permutation::{
|
||||
shuffle::{Shuffler, ShufflerConfig},
|
||||
split::{SplitStrategy, Splitter, SPLIT_ID_COLUMN},
|
||||
util::{rename_column, TemporaryDirectory},
|
||||
},
|
||||
query::{ExecutableQuery, QueryBase},
|
||||
Error, Result, Table,
|
||||
};
|
||||
|
||||
pub const SRC_ROW_ID_COL: &str = "row_id";
|
||||
|
||||
/// Where to store the permutation table
|
||||
#[derive(Debug, Clone, Default)]
|
||||
enum PermutationDestination {
|
||||
/// The permutation table is a temporary table in memory
|
||||
#[default]
|
||||
Temporary,
|
||||
/// The permutation table is a permanent table in a database
|
||||
Permanent(Arc<dyn Database>, String),
|
||||
}
|
||||
|
||||
/// Configuration for creating a permutation table
|
||||
#[derive(Debug, Default)]
|
||||
pub struct PermutationConfig {
|
||||
/// Splitting configuration
|
||||
split_strategy: SplitStrategy,
|
||||
/// Shuffle strategy
|
||||
shuffle_strategy: ShuffleStrategy,
|
||||
/// Optional filter to apply to the base table
|
||||
filter: Option<String>,
|
||||
/// Directory to use for temporary files
|
||||
temp_dir: TemporaryDirectory,
|
||||
/// Destination
|
||||
destination: PermutationDestination,
|
||||
}
|
||||
|
||||
/// Strategy for shuffling the data.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ShuffleStrategy {
|
||||
/// The data is randomly shuffled
|
||||
///
|
||||
/// A seed can be provided to make the shuffle deterministic.
|
||||
///
|
||||
/// If a clump size is provided, then data will be shuffled in small blocks of contiguous rows.
|
||||
/// This decreases the overall randomization but can improve I/O performance when reading from
|
||||
/// cloud storage.
|
||||
///
|
||||
/// For example, a clump size of 16 will means we will shuffle blocks of 16 contiguous rows. This
|
||||
/// will mean 16x fewer IOPS but these 16 rows will always be close together and this can influence
|
||||
/// the performance of the model. Note: shuffling within clumps can still be done at read time but
|
||||
/// this will only provide a local shuffle and not a global shuffle.
|
||||
Random {
|
||||
seed: Option<u64>,
|
||||
clump_size: Option<u64>,
|
||||
},
|
||||
/// The data is not shuffled
|
||||
///
|
||||
/// This is useful for debugging and testing.
|
||||
None,
|
||||
}
|
||||
|
||||
impl Default for ShuffleStrategy {
|
||||
fn default() -> Self {
|
||||
Self::None
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for creating a permutation table.
|
||||
///
|
||||
/// A permutation table is a table that stores split assignments and a shuffled order of rows. This
|
||||
/// can be used to create a permutation reader that reads rows in the order defined by the permutation.
|
||||
///
|
||||
/// The permutation table is not a materialized copy of the underlying data and can be very lightweight.
|
||||
/// It is not a view of the underlying data and is not a copy of the data. It is a separate table that
|
||||
/// stores just row id and split id.
|
||||
pub struct PermutationBuilder {
|
||||
config: PermutationConfig,
|
||||
base_table: Table,
|
||||
}
|
||||
|
||||
impl PermutationBuilder {
|
||||
pub fn new(base_table: Table) -> Self {
|
||||
Self {
|
||||
config: PermutationConfig::default(),
|
||||
base_table,
|
||||
}
|
||||
}
|
||||
|
||||
/// Configures the strategy for assigning rows to splits.
|
||||
///
|
||||
/// For example, it is common to create a test/train split of the data. Splits can also be used
|
||||
/// to limit the number of rows. For example, to only use 10% of the data in a permutation you can
|
||||
/// create a single split with 10% of the data.
|
||||
///
|
||||
/// Splits are _not_ required for parallel processing. A single split can be loaded in parallel across
|
||||
/// multiple processes and multiple nodes.
|
||||
///
|
||||
/// The default is a single split that contains all rows.
|
||||
pub fn with_split_strategy(mut self, split_strategy: SplitStrategy) -> Self {
|
||||
self.config.split_strategy = split_strategy;
|
||||
self
|
||||
}
|
||||
|
||||
/// Configures the strategy for shuffling the data.
|
||||
///
|
||||
/// The default is to shuffle the data randomly at row-level granularity (no clump size) and
|
||||
/// with a random seed.
|
||||
pub fn with_shuffle_strategy(mut self, shuffle_strategy: ShuffleStrategy) -> Self {
|
||||
self.config.shuffle_strategy = shuffle_strategy;
|
||||
self
|
||||
}
|
||||
|
||||
/// Configures a filter to apply to the base table.
|
||||
///
|
||||
/// Only rows matching the filter will be included in the permutation.
|
||||
pub fn with_filter(mut self, filter: String) -> Self {
|
||||
self.config.filter = Some(filter);
|
||||
self
|
||||
}
|
||||
|
||||
/// Configures the directory to use for temporary files.
|
||||
///
|
||||
/// The default is to use the operating system's default temporary directory.
|
||||
pub fn with_temp_dir(mut self, temp_dir: TemporaryDirectory) -> Self {
|
||||
self.config.temp_dir = temp_dir;
|
||||
self
|
||||
}
|
||||
|
||||
/// Stores the permutation as a table in a database
|
||||
///
|
||||
/// By default, the permutation is stored in memory. If this method is called then
|
||||
/// the permutation will be stored as a table in the given database.
|
||||
pub fn persist(mut self, database: Arc<dyn Database>, table_name: String) -> Self {
|
||||
self.config.destination = PermutationDestination::Permanent(database, table_name);
|
||||
self
|
||||
}
|
||||
|
||||
async fn sort_by_split_id(
|
||||
&self,
|
||||
data: SendableRecordBatchStream,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let ctx = SessionContext::new_with_config_rt(
|
||||
SessionConfig::default(),
|
||||
RuntimeEnvBuilder::new()
|
||||
.with_memory_limit(100 * 1024 * 1024, 1.0)
|
||||
.with_disk_manager_builder(
|
||||
DiskManagerBuilder::default()
|
||||
.with_mode(self.config.temp_dir.to_disk_manager_mode()),
|
||||
)
|
||||
.build_arc()
|
||||
.unwrap(),
|
||||
);
|
||||
let df = ctx
|
||||
.read_one_shot(data.into_df_stream())
|
||||
.map_err(|e| Error::Other {
|
||||
message: format!("Failed to setup sort by split id: {}", e),
|
||||
source: Some(e.into()),
|
||||
})?;
|
||||
let df_stream = df
|
||||
.sort_by(vec![col(SPLIT_ID_COLUMN)])
|
||||
.map_err(|e| Error::Other {
|
||||
message: format!("Failed to plan sort by split id: {}", e),
|
||||
source: Some(e.into()),
|
||||
})?
|
||||
.execute_stream()
|
||||
.await
|
||||
.map_err(|e| Error::Other {
|
||||
message: format!("Failed to sort by split id: {}", e),
|
||||
source: Some(e.into()),
|
||||
})?;
|
||||
|
||||
let schema = df_stream.schema();
|
||||
let stream = df_stream.map_err(|e| Error::Other {
|
||||
message: format!("Failed to execute sort by split id: {}", e),
|
||||
source: Some(e.into()),
|
||||
});
|
||||
Ok(Box::pin(SimpleRecordBatchStream { schema, stream }))
|
||||
}
|
||||
|
||||
/// Builds the permutation table and stores it in the given database.
|
||||
pub async fn build(self) -> Result<Table> {
|
||||
// First pass, apply filter and load row ids
|
||||
let mut rows = self.base_table.query().with_row_id();
|
||||
|
||||
if let Some(filter) = &self.config.filter {
|
||||
rows = rows.only_if(filter);
|
||||
}
|
||||
|
||||
let splitter = Splitter::new(
|
||||
self.config.temp_dir.clone(),
|
||||
self.config.split_strategy.clone(),
|
||||
);
|
||||
|
||||
let mut needs_sort = !splitter.orders_by_split_id();
|
||||
|
||||
// Might need to load additional columns to calculate splits (e.g. hash columns or calculated
|
||||
// split id)
|
||||
rows = splitter.project(rows);
|
||||
|
||||
let num_rows = self
|
||||
.base_table
|
||||
.count_rows(self.config.filter.clone())
|
||||
.await? as u64;
|
||||
|
||||
// Apply splits
|
||||
let rows = rows.execute().await?;
|
||||
let split_data = splitter.apply(rows, num_rows).await?;
|
||||
|
||||
// Shuffle data if requested
|
||||
let shuffled = match self.config.shuffle_strategy {
|
||||
ShuffleStrategy::None => split_data,
|
||||
ShuffleStrategy::Random { seed, clump_size } => {
|
||||
let shuffler = Shuffler::new(ShufflerConfig {
|
||||
seed,
|
||||
clump_size,
|
||||
temp_dir: self.config.temp_dir.clone(),
|
||||
max_rows_per_file: 10 * 1024 * 1024,
|
||||
});
|
||||
shuffler.shuffle(split_data, num_rows).await?
|
||||
}
|
||||
};
|
||||
|
||||
// We want the final permutation to be sorted by the split id. If we shuffled or if
|
||||
// the split was not assigned sequentially then we need to sort the data.
|
||||
needs_sort |= !matches!(self.config.shuffle_strategy, ShuffleStrategy::None);
|
||||
|
||||
let sorted = if needs_sort {
|
||||
self.sort_by_split_id(shuffled).await?
|
||||
} else {
|
||||
shuffled
|
||||
};
|
||||
|
||||
// Rename _rowid to row_id
|
||||
let renamed = rename_column(sorted, ROW_ID, SRC_ROW_ID_COL)?;
|
||||
|
||||
let (name, database) = match &self.config.destination {
|
||||
PermutationDestination::Permanent(database, table_name) => {
|
||||
(table_name.as_str(), database.clone())
|
||||
}
|
||||
PermutationDestination::Temporary => {
|
||||
let conn = connect("memory:///").execute().await?;
|
||||
("permutation", conn.database().clone())
|
||||
}
|
||||
};
|
||||
|
||||
let create_table_request =
|
||||
CreateTableRequest::new(name.to_string(), CreateTableData::StreamingData(renamed));
|
||||
|
||||
let table = database.create_table(create_table_request).await?;
|
||||
Ok(Table::new(table, database))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use arrow::datatypes::Int32Type;
|
||||
use lance_datagen::{BatchCount, RowCount};
|
||||
|
||||
use crate::{arrow::LanceDbDatagenExt, connect, dataloader::permutation::split::SplitSizes};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_permutation_builder() {
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
|
||||
let db = connect(temp_dir.path().to_str().unwrap())
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let initial_data = lance_datagen::gen_batch()
|
||||
.col("some_value", lance_datagen::array::step::<Int32Type>())
|
||||
.into_ldb_stream(RowCount::from(100), BatchCount::from(10));
|
||||
let data_table = db
|
||||
.create_table_streaming("mytbl", initial_data)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let permutation_table = PermutationBuilder::new(data_table.clone())
|
||||
.with_filter("some_value > 57".to_string())
|
||||
.with_split_strategy(SplitStrategy::Random {
|
||||
seed: Some(42),
|
||||
sizes: SplitSizes::Percentages(vec![0.05, 0.30]),
|
||||
})
|
||||
.build()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
println!("permutation_table: {:?}", permutation_table);
|
||||
|
||||
// Potentially brittle seed-dependent values below
|
||||
assert_eq!(permutation_table.count_rows(None).await.unwrap(), 330);
|
||||
assert_eq!(
|
||||
permutation_table
|
||||
.count_rows(Some("split_id = 0".to_string()))
|
||||
.await
|
||||
.unwrap(),
|
||||
47
|
||||
);
|
||||
assert_eq!(
|
||||
permutation_table
|
||||
.count_rows(Some("split_id = 1".to_string()))
|
||||
.await
|
||||
.unwrap(),
|
||||
283
|
||||
);
|
||||
}
|
||||
}
|
||||
384
rust/lancedb/src/dataloader/permutation/reader.rs
Normal file
384
rust/lancedb/src/dataloader/permutation/reader.rs
Normal file
@@ -0,0 +1,384 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! Row ID-based views for LanceDB tables
|
||||
//!
|
||||
//! This module provides functionality for creating views that are based on specific row IDs.
|
||||
//! The `IdView` allows you to create a virtual table that contains only
|
||||
//! the rows from a source table that correspond to row IDs stored in a separate table.
|
||||
|
||||
use crate::arrow::{SendableRecordBatchStream, SimpleRecordBatchStream};
|
||||
use crate::dataloader::permutation::builder::SRC_ROW_ID_COL;
|
||||
use crate::dataloader::permutation::split::SPLIT_ID_COLUMN;
|
||||
use crate::error::Error;
|
||||
use crate::query::{QueryExecutionOptions, QueryFilter, QueryRequest, Select};
|
||||
use crate::table::{AnyQuery, BaseTable};
|
||||
use crate::Result;
|
||||
use arrow::array::AsArray;
|
||||
use arrow::datatypes::UInt64Type;
|
||||
use arrow_array::{RecordBatch, UInt64Array};
|
||||
use futures::{StreamExt, TryStreamExt};
|
||||
use lance::arrow::RecordBatchExt;
|
||||
use lance::dataset::scanner::DatasetRecordBatchStream;
|
||||
use lance::error::LanceOptionExt;
|
||||
use lance_core::ROW_ID;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Reads a permutation of a source table based on row IDs stored in a separate table
|
||||
pub struct PermutationReader {
|
||||
base_table: Arc<dyn BaseTable>,
|
||||
permutation_table: Arc<dyn BaseTable>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for PermutationReader {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"PermutationReader(base={}, permutation={})",
|
||||
self.base_table.name(),
|
||||
self.permutation_table.name(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl PermutationReader {
|
||||
/// Create a new PermutationReader
|
||||
pub async fn try_new(
|
||||
base_table: Arc<dyn BaseTable>,
|
||||
permutation_table: Arc<dyn BaseTable>,
|
||||
) -> Result<Self> {
|
||||
let schema = permutation_table.schema().await?;
|
||||
if schema.column_with_name(SRC_ROW_ID_COL).is_none() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Permutation table must contain a column named row_id".to_string(),
|
||||
});
|
||||
}
|
||||
if schema.column_with_name(SPLIT_ID_COLUMN).is_none() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Permutation table must contain a column named split_id".to_string(),
|
||||
});
|
||||
}
|
||||
Ok(Self {
|
||||
base_table,
|
||||
permutation_table,
|
||||
})
|
||||
}
|
||||
|
||||
fn is_sorted_already<'a, T: Iterator<Item = &'a u64>>(iter: T) -> bool {
|
||||
for (expected, idx) in iter.enumerate() {
|
||||
if *idx != expected as u64 {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
async fn load_batch(
|
||||
base_table: &Arc<dyn BaseTable>,
|
||||
row_ids: RecordBatch,
|
||||
selection: Select,
|
||||
has_row_id: bool,
|
||||
) -> Result<RecordBatch> {
|
||||
let num_rows = row_ids.num_rows();
|
||||
let row_ids = row_ids
|
||||
.column(0)
|
||||
.as_primitive_opt::<UInt64Type>()
|
||||
.expect_ok()?
|
||||
.values();
|
||||
|
||||
let filter = format!(
|
||||
"_rowid in ({})",
|
||||
row_ids
|
||||
.iter()
|
||||
.map(|o| o.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
);
|
||||
|
||||
let base_query = QueryRequest {
|
||||
filter: Some(QueryFilter::Sql(filter)),
|
||||
select: selection,
|
||||
with_row_id: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut data = base_table
|
||||
.query(
|
||||
&AnyQuery::Query(base_query),
|
||||
QueryExecutionOptions {
|
||||
max_batch_length: num_rows as u32,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let Some(batch) = data.try_next().await? else {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Base table returned no batches".to_string(),
|
||||
});
|
||||
};
|
||||
if data.try_next().await?.is_some() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Base table returned more than one batch".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if batch.num_rows() != num_rows {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Base table returned different number of rows than the number of row IDs"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// There is no guarantee the result order will match the order provided
|
||||
// so may need to restore order
|
||||
let actual_row_ids = batch
|
||||
.column_by_name(ROW_ID)
|
||||
.expect_ok()?
|
||||
.as_primitive_opt::<UInt64Type>()
|
||||
.expect_ok()?
|
||||
.values();
|
||||
|
||||
// Map from row id to order in batch, used to restore original ordering
|
||||
let ordering = actual_row_ids
|
||||
.iter()
|
||||
.copied()
|
||||
.enumerate()
|
||||
.map(|(i, o)| (o, i as u64))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let desired_idx_order = row_ids
|
||||
.iter()
|
||||
.map(|o| ordering.get(o).copied().expect_ok().map_err(Error::from))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let ordered_batch = if Self::is_sorted_already(desired_idx_order.iter()) {
|
||||
// Fast path if already sorted, important as data may be large and
|
||||
// re-ordering could be expensive
|
||||
batch
|
||||
} else {
|
||||
let desired_idx_order = UInt64Array::from(desired_idx_order);
|
||||
|
||||
arrow_select::take::take_record_batch(&batch, &desired_idx_order)?
|
||||
};
|
||||
|
||||
if has_row_id {
|
||||
Ok(ordered_batch)
|
||||
} else {
|
||||
// The user didn't ask for row id, we needed it for ordering the data, but now we drop it
|
||||
Ok(ordered_batch.drop_column(ROW_ID)?)
|
||||
}
|
||||
}
|
||||
|
||||
async fn row_ids_to_batches(
|
||||
base_table: Arc<dyn BaseTable>,
|
||||
row_ids: DatasetRecordBatchStream,
|
||||
selection: Select,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let has_row_id = Self::has_row_id(&selection)?;
|
||||
let mut stream = row_ids
|
||||
.map_err(Error::from)
|
||||
.try_filter_map(move |batch| {
|
||||
let selection = selection.clone();
|
||||
let base_table = base_table.clone();
|
||||
async move {
|
||||
Self::load_batch(&base_table, batch, selection, has_row_id)
|
||||
.await
|
||||
.map(Some)
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
|
||||
// Need to read out first batch to get schema
|
||||
let Some(first_batch) = stream.try_next().await? else {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Permutation was empty".to_string(),
|
||||
});
|
||||
};
|
||||
let schema = first_batch.schema();
|
||||
|
||||
let stream = futures::stream::once(std::future::ready(Ok(first_batch))).chain(stream);
|
||||
|
||||
Ok(Box::pin(SimpleRecordBatchStream::new(stream, schema)))
|
||||
}
|
||||
|
||||
fn has_row_id(selection: &Select) -> Result<bool> {
|
||||
match selection {
|
||||
Select::All => {
|
||||
// _rowid is a system column and is not included in Select::All
|
||||
Ok(false)
|
||||
}
|
||||
Select::Columns(columns) => Ok(columns.contains(&ROW_ID.to_string())),
|
||||
Select::Dynamic(columns) => {
|
||||
for column in columns {
|
||||
if column.0 == ROW_ID {
|
||||
if column.1 == ROW_ID {
|
||||
return Ok(true);
|
||||
} else {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"Dynamic column {} cannot be used to select _rowid",
|
||||
column.1
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read_split(
|
||||
&self,
|
||||
split: u64,
|
||||
selection: Select,
|
||||
execution_options: QueryExecutionOptions,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let row_ids = self
|
||||
.permutation_table
|
||||
.query(
|
||||
&AnyQuery::Query(QueryRequest {
|
||||
select: Select::Columns(vec![SRC_ROW_ID_COL.to_string()]),
|
||||
filter: Some(QueryFilter::Sql(format!("{} = {}", SPLIT_ID_COLUMN, split))),
|
||||
..Default::default()
|
||||
}),
|
||||
execution_options,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Self::row_ids_to_batches(self.base_table.clone(), row_ids, selection).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use arrow::datatypes::Int32Type;
|
||||
use arrow_array::{ArrowPrimitiveType, RecordBatch, UInt64Array};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use lance_datagen::{BatchCount, RowCount};
|
||||
use rand::seq::SliceRandom;
|
||||
|
||||
use crate::{
|
||||
arrow::SendableRecordBatchStream,
|
||||
query::{ExecutableQuery, QueryBase},
|
||||
test_utils::datagen::{virtual_table, LanceDbDatagenExt},
|
||||
Table,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
async fn collect_from_stream<T: ArrowPrimitiveType>(
|
||||
mut stream: SendableRecordBatchStream,
|
||||
column: &str,
|
||||
) -> Vec<T::Native> {
|
||||
let mut row_ids = Vec::new();
|
||||
while let Some(batch) = stream.try_next().await.unwrap() {
|
||||
let col_idx = batch.schema().index_of(column).unwrap();
|
||||
row_ids.extend(batch.column(col_idx).as_primitive::<T>().values().to_vec());
|
||||
}
|
||||
row_ids
|
||||
}
|
||||
|
||||
async fn collect_column<T: ArrowPrimitiveType>(table: &Table, column: &str) -> Vec<T::Native> {
|
||||
collect_from_stream::<T>(
|
||||
table
|
||||
.query()
|
||||
.select(Select::Columns(vec![column.to_string()]))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap(),
|
||||
column,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_permutation_reader() {
|
||||
let base_table = lance_datagen::gen_batch()
|
||||
.col("idx", lance_datagen::array::step::<Int32Type>())
|
||||
.col("other_col", lance_datagen::array::step::<UInt64Type>())
|
||||
.into_mem_table("tbl", RowCount::from(9), BatchCount::from(1))
|
||||
.await;
|
||||
|
||||
let mut row_ids = collect_column::<UInt64Type>(&base_table, "_rowid").await;
|
||||
row_ids.shuffle(&mut rand::rng());
|
||||
// Put the last two rows in split 1
|
||||
let split_ids = UInt64Array::from_iter_values(
|
||||
std::iter::repeat_n(0, row_ids.len() - 2).chain(std::iter::repeat_n(1, 2)),
|
||||
);
|
||||
let permutation_batch = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![
|
||||
Field::new("row_id", DataType::UInt64, false),
|
||||
Field::new(SPLIT_ID_COLUMN, DataType::UInt64, false),
|
||||
])),
|
||||
vec![
|
||||
Arc::new(UInt64Array::from(row_ids.clone())),
|
||||
Arc::new(split_ids),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let row_ids_table = virtual_table("row_ids", &permutation_batch).await;
|
||||
|
||||
let reader = PermutationReader::try_new(
|
||||
base_table.base_table().clone(),
|
||||
row_ids_table.base_table().clone(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Read split 0
|
||||
let mut stream = reader
|
||||
.read_split(
|
||||
0,
|
||||
Select::All,
|
||||
QueryExecutionOptions {
|
||||
max_batch_length: 3,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(stream.schema(), base_table.schema().await.unwrap());
|
||||
|
||||
let check_batch = async |stream: &mut SendableRecordBatchStream,
|
||||
expected_values: &[u64]| {
|
||||
let batch = stream.try_next().await.unwrap().unwrap();
|
||||
assert_eq!(batch.num_rows(), expected_values.len());
|
||||
assert_eq!(
|
||||
batch.column(0).as_primitive::<Int32Type>().values(),
|
||||
&expected_values
|
||||
.iter()
|
||||
.map(|o| *o as i32)
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
assert_eq!(
|
||||
batch.column(1).as_primitive::<UInt64Type>().values(),
|
||||
&expected_values
|
||||
);
|
||||
};
|
||||
|
||||
check_batch(&mut stream, &row_ids[0..3]).await;
|
||||
check_batch(&mut stream, &row_ids[3..6]).await;
|
||||
check_batch(&mut stream, &row_ids[6..7]).await;
|
||||
assert!(stream.try_next().await.unwrap().is_none());
|
||||
|
||||
// Read split 1
|
||||
let mut stream = reader
|
||||
.read_split(
|
||||
1,
|
||||
Select::All,
|
||||
QueryExecutionOptions {
|
||||
max_batch_length: 3,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
check_batch(&mut stream, &row_ids[7..9]).await;
|
||||
assert!(stream.try_next().await.unwrap().is_none());
|
||||
}
|
||||
}
|
||||
475
rust/lancedb/src/dataloader/permutation/shuffle.rs
Normal file
475
rust/lancedb/src/dataloader/permutation/shuffle.rs
Normal file
@@ -0,0 +1,475 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use arrow::compute::concat_batches;
|
||||
use arrow_array::{RecordBatch, UInt64Array};
|
||||
use futures::{StreamExt, TryStreamExt};
|
||||
use lance::io::ObjectStore;
|
||||
use lance_core::{cache::LanceCache, utils::futures::FinallyStreamExt};
|
||||
use lance_encoding::decoder::DecoderPlugins;
|
||||
use lance_file::v2::{
|
||||
reader::{FileReader, FileReaderOptions},
|
||||
writer::{FileWriter, FileWriterOptions},
|
||||
};
|
||||
use lance_index::scalar::IndexReader;
|
||||
use lance_io::{
|
||||
scheduler::{ScanScheduler, SchedulerConfig},
|
||||
utils::CachedFileSize,
|
||||
};
|
||||
use rand::{seq::SliceRandom, Rng, RngCore};
|
||||
|
||||
use crate::{
|
||||
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
|
||||
dataloader::permutation::util::{non_crypto_rng, TemporaryDirectory},
|
||||
Error, Result,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ShufflerConfig {
|
||||
/// An optional seed to make the shuffle deterministic
|
||||
pub seed: Option<u64>,
|
||||
/// The maximum number of rows to write to a single file
|
||||
///
|
||||
/// The shuffler will need to hold at least this many rows in memory. Setting this value
|
||||
/// extremely large could cause the shuffler to use a lot of memory (depending on row size).
|
||||
///
|
||||
/// However, the shuffler will also need to hold total_num_rows / max_rows_per_file file
|
||||
/// writers in memory. Each of these will consume some amount of data for column write buffers.
|
||||
/// So setting this value too small could _also_ cause the shuffler to use a lot of memory and
|
||||
/// open file handles.
|
||||
pub max_rows_per_file: u64,
|
||||
/// The temporary directory to use for writing files
|
||||
pub temp_dir: TemporaryDirectory,
|
||||
/// The size of the clumps to shuffle within
|
||||
///
|
||||
/// If a clump size is provided, then data will be shuffled in small blocks of contiguous rows.
|
||||
/// This decreases the overall randomization but can improve I/O performance when reading from
|
||||
/// cloud storage.
|
||||
pub clump_size: Option<u64>,
|
||||
}
|
||||
|
||||
impl Default for ShufflerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_rows_per_file: 1024 * 1024,
|
||||
seed: Option::default(),
|
||||
temp_dir: TemporaryDirectory::default(),
|
||||
clump_size: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A shuffler that can shuffle a stream of record batches
|
||||
///
|
||||
/// To do this the stream is consumed and written to temporary files. A new stream is returned
|
||||
/// which returns the shuffled data from the temporary files.
|
||||
///
|
||||
/// If there are fewer than max_rows_per_file rows in the input stream, then the shuffler will not
|
||||
/// write any files and will instead perform an in-memory shuffle.
|
||||
///
|
||||
/// The number of rows in the input stream must be known in advance.
|
||||
pub struct Shuffler {
|
||||
config: ShufflerConfig,
|
||||
id: String,
|
||||
}
|
||||
|
||||
impl Shuffler {
|
||||
pub fn new(config: ShufflerConfig) -> Self {
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
Self { config, id }
|
||||
}
|
||||
|
||||
/// Shuffles a single batch of data in memory
|
||||
fn shuffle_batch(
|
||||
batch: &RecordBatch,
|
||||
rng: &mut dyn RngCore,
|
||||
clump_size: u64,
|
||||
) -> Result<RecordBatch> {
|
||||
let num_clumps = (batch.num_rows() as u64).div_ceil(clump_size);
|
||||
let mut indices = (0..num_clumps).collect::<Vec<_>>();
|
||||
indices.shuffle(rng);
|
||||
let indices = if clump_size == 1 {
|
||||
UInt64Array::from(indices)
|
||||
} else {
|
||||
UInt64Array::from_iter_values(indices.iter().flat_map(|&clump_index| {
|
||||
if clump_index == num_clumps - 1 {
|
||||
clump_index * clump_size..batch.num_rows() as u64
|
||||
} else {
|
||||
clump_index * clump_size..(clump_index + 1) * clump_size
|
||||
}
|
||||
}))
|
||||
};
|
||||
Ok(arrow::compute::take_record_batch(batch, &indices)?)
|
||||
}
|
||||
|
||||
async fn in_memory_shuffle(
|
||||
&self,
|
||||
data: SendableRecordBatchStream,
|
||||
mut rng: Box<dyn RngCore + Send>,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let schema = data.schema();
|
||||
let batches = data.try_collect::<Vec<_>>().await?;
|
||||
let batch = concat_batches(&schema, &batches)?;
|
||||
let shuffled = Self::shuffle_batch(&batch, &mut rng, self.config.clump_size.unwrap_or(1))?;
|
||||
log::debug!("Shuffle job {}: in-memory shuffle complete", self.id);
|
||||
Ok(Box::pin(SimpleRecordBatchStream::new(
|
||||
futures::stream::once(async move { Ok(shuffled) }),
|
||||
schema,
|
||||
)))
|
||||
}
|
||||
|
||||
async fn do_shuffle(
|
||||
&self,
|
||||
mut data: SendableRecordBatchStream,
|
||||
num_rows: u64,
|
||||
mut rng: Box<dyn RngCore + Send>,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let num_files = num_rows.div_ceil(self.config.max_rows_per_file);
|
||||
|
||||
let temp_dir = self.config.temp_dir.create_temp_dir()?;
|
||||
let tmp_dir = temp_dir.path().to_path_buf();
|
||||
|
||||
let clump_size = self.config.clump_size.unwrap_or(1);
|
||||
if clump_size == 0 {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "clump size must be greater than 0".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let object_store = ObjectStore::local();
|
||||
let arrow_schema = data.schema();
|
||||
let schema = lance::datatypes::Schema::try_from(arrow_schema.as_ref())?;
|
||||
|
||||
// Create file writers
|
||||
let mut file_writers = Vec::with_capacity(num_files as usize);
|
||||
for file_index in 0..num_files {
|
||||
let path = tmp_dir.join(format!("shuffle_{}_{file_index}.lance", self.id));
|
||||
let path =
|
||||
object_store::path::Path::from_absolute_path(path).map_err(|err| Error::Other {
|
||||
message: format!("Failed to create temporary file: {}", err),
|
||||
source: None,
|
||||
})?;
|
||||
let object_writer = object_store.create(&path).await?;
|
||||
let writer =
|
||||
FileWriter::try_new(object_writer, schema.clone(), FileWriterOptions::default())?;
|
||||
file_writers.push(writer);
|
||||
}
|
||||
|
||||
let mut num_rows_seen = 0;
|
||||
|
||||
// Randomly distribute clumps to files
|
||||
while let Some(batch) = data.try_next().await? {
|
||||
num_rows_seen += batch.num_rows() as u64;
|
||||
let is_last = num_rows_seen == num_rows;
|
||||
if num_rows_seen > num_rows {
|
||||
return Err(Error::Runtime {
|
||||
message: format!("Expected {} rows but saw {} rows", num_rows, num_rows_seen),
|
||||
});
|
||||
}
|
||||
// This is kind of an annoying limitation but if we allow runt clumps from batches then
|
||||
// clumps will get unaligned and we will mess up the clumps when we do the in-memory
|
||||
// shuffle step. If this is a problem we can probably figure out a better way to do this.
|
||||
if !is_last && batch.num_rows() as u64 % clump_size != 0 {
|
||||
return Err(Error::Runtime {
|
||||
message: format!(
|
||||
"Expected batch size ({}) to be divisible by clump size ({})",
|
||||
batch.num_rows(),
|
||||
clump_size
|
||||
),
|
||||
});
|
||||
}
|
||||
let num_clumps = (batch.num_rows() as u64).div_ceil(clump_size);
|
||||
let mut batch_offsets_for_files =
|
||||
vec![Vec::<u64>::with_capacity(batch.num_rows()); num_files as usize];
|
||||
// Partition the batch randomly and write to the appropriate accumulator
|
||||
for clump_offset in 0..num_clumps {
|
||||
let clump_start = clump_offset * clump_size;
|
||||
let num_rows_in_clump = clump_size.min(batch.num_rows() as u64 - clump_start);
|
||||
let clump_end = clump_start + num_rows_in_clump;
|
||||
let file_index = rng.random_range(0..num_files);
|
||||
batch_offsets_for_files[file_index as usize].extend(clump_start..clump_end);
|
||||
}
|
||||
for (file_index, batch_offsets) in batch_offsets_for_files.into_iter().enumerate() {
|
||||
if batch_offsets.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let indices = UInt64Array::from(batch_offsets);
|
||||
let partition = arrow::compute::take_record_batch(&batch, &indices)?;
|
||||
file_writers[file_index].write_batch(&partition).await?;
|
||||
}
|
||||
}
|
||||
|
||||
// Finish writing files
|
||||
for (file_idx, mut writer) in file_writers.into_iter().enumerate() {
|
||||
let num_written = writer.finish().await?;
|
||||
log::debug!(
|
||||
"Shuffle job {}: wrote {} rows to file {}",
|
||||
self.id,
|
||||
num_written,
|
||||
file_idx
|
||||
);
|
||||
}
|
||||
|
||||
let scheduler_config = SchedulerConfig::max_bandwidth(&object_store);
|
||||
let scan_scheduler = ScanScheduler::new(Arc::new(object_store), scheduler_config);
|
||||
let job_id = self.id.clone();
|
||||
let rng = Arc::new(Mutex::new(rng));
|
||||
|
||||
// Second pass, read each file as a single batch and shuffle
|
||||
let stream = futures::stream::iter(0..num_files)
|
||||
.then(move |file_index| {
|
||||
let scan_scheduler = scan_scheduler.clone();
|
||||
let rng = rng.clone();
|
||||
let tmp_dir = tmp_dir.clone();
|
||||
let job_id = job_id.clone();
|
||||
async move {
|
||||
let path = tmp_dir.join(format!("shuffle_{}_{file_index}.lance", job_id));
|
||||
let path = object_store::path::Path::from_absolute_path(path).unwrap();
|
||||
let file_scheduler = scan_scheduler
|
||||
.open_file(&path, &CachedFileSize::unknown())
|
||||
.await?;
|
||||
let reader = FileReader::try_open(
|
||||
file_scheduler,
|
||||
None,
|
||||
Arc::<DecoderPlugins>::default(),
|
||||
&LanceCache::no_cache(),
|
||||
FileReaderOptions::default(),
|
||||
)
|
||||
.await?;
|
||||
// Need to read the entire file in a single batch for in-memory shuffling
|
||||
let batch = reader.read_record_batch(0, reader.num_rows()).await?;
|
||||
let mut rng = rng.lock().unwrap();
|
||||
Self::shuffle_batch(&batch, &mut rng, clump_size)
|
||||
}
|
||||
})
|
||||
.finally(move || drop(temp_dir))
|
||||
.boxed();
|
||||
|
||||
Ok(Box::pin(SimpleRecordBatchStream::new(stream, arrow_schema)))
|
||||
}
|
||||
|
||||
pub async fn shuffle(
|
||||
self,
|
||||
data: SendableRecordBatchStream,
|
||||
num_rows: u64,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
log::debug!(
|
||||
"Shuffle job {}: shuffling {} rows and {} columns",
|
||||
self.id,
|
||||
num_rows,
|
||||
data.schema().fields.len()
|
||||
);
|
||||
let rng = non_crypto_rng(&self.config.seed);
|
||||
|
||||
if num_rows < self.config.max_rows_per_file {
|
||||
return self.in_memory_shuffle(data, rng).await;
|
||||
}
|
||||
|
||||
self.do_shuffle(data, num_rows, rng).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::arrow::LanceDbDatagenExt;
|
||||
|
||||
use super::*;
|
||||
use arrow::{array::AsArray, datatypes::Int32Type};
|
||||
use datafusion::prelude::SessionContext;
|
||||
use datafusion_expr::col;
|
||||
use futures::TryStreamExt;
|
||||
use lance_datagen::{BatchCount, BatchGeneratorBuilder, ByteCount, RowCount, Seed};
|
||||
use rand::{rngs::SmallRng, SeedableRng};
|
||||
|
||||
fn test_gen() -> BatchGeneratorBuilder {
|
||||
lance_datagen::gen_batch()
|
||||
.with_seed(Seed::from(42))
|
||||
.col("id", lance_datagen::array::step::<Int32Type>())
|
||||
.col(
|
||||
"name",
|
||||
lance_datagen::array::rand_utf8(ByteCount::from(10), false),
|
||||
)
|
||||
}
|
||||
|
||||
fn create_test_batch(size: RowCount) -> RecordBatch {
|
||||
test_gen().into_batch_rows(size).unwrap()
|
||||
}
|
||||
|
||||
fn create_test_stream(
|
||||
num_batches: BatchCount,
|
||||
batch_size: RowCount,
|
||||
) -> SendableRecordBatchStream {
|
||||
test_gen().into_ldb_stream(batch_size, num_batches)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shuffle_batch_deterministic() {
|
||||
let batch = create_test_batch(RowCount::from(10));
|
||||
let mut rng1 = SmallRng::seed_from_u64(42);
|
||||
let mut rng2 = SmallRng::seed_from_u64(42);
|
||||
|
||||
let shuffled1 = Shuffler::shuffle_batch(&batch, &mut rng1, 1).unwrap();
|
||||
let shuffled2 = Shuffler::shuffle_batch(&batch, &mut rng2, 1).unwrap();
|
||||
|
||||
// Same seed should produce same shuffle
|
||||
assert_eq!(shuffled1, shuffled2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shuffle_with_clumps() {
|
||||
let batch = create_test_batch(RowCount::from(10));
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
let shuffled = Shuffler::shuffle_batch(&batch, &mut rng, 3).unwrap();
|
||||
let values = shuffled.column(0).as_primitive::<Int32Type>();
|
||||
|
||||
let mut iter = values.into_iter().map(|o| o.unwrap());
|
||||
let mut frag_seen = false;
|
||||
let mut clumps_seen = 0;
|
||||
while let Some(first) = iter.next() {
|
||||
// 9 is the last value and not a full clump
|
||||
if first != 9 {
|
||||
// Otherwise we should have a full clump
|
||||
let second = iter.next().unwrap();
|
||||
let third = iter.next().unwrap();
|
||||
assert_eq!(first + 1, second);
|
||||
assert_eq!(first + 2, third);
|
||||
clumps_seen += 1;
|
||||
} else {
|
||||
frag_seen = true;
|
||||
}
|
||||
}
|
||||
assert_eq!(clumps_seen, 3);
|
||||
assert!(frag_seen);
|
||||
}
|
||||
|
||||
async fn sort_batch(batch: RecordBatch) -> RecordBatch {
|
||||
let ctx = SessionContext::new();
|
||||
let df = ctx.read_batch(batch).unwrap();
|
||||
let sorted = df.sort_by(vec![col("id")]).unwrap();
|
||||
let batches = sorted.collect().await.unwrap();
|
||||
let schema = batches[0].schema();
|
||||
concat_batches(&schema, &batches).unwrap()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_shuffle_batch_preserves_data() {
|
||||
let batch = create_test_batch(RowCount::from(100));
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
|
||||
let shuffled = Shuffler::shuffle_batch(&batch, &mut rng, 1).unwrap();
|
||||
|
||||
assert_ne!(shuffled, batch);
|
||||
|
||||
let sorted = sort_batch(shuffled).await;
|
||||
|
||||
assert_eq!(sorted, batch);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shuffle_batch_empty() {
|
||||
let batch = create_test_batch(RowCount::from(0));
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
|
||||
let shuffled = Shuffler::shuffle_batch(&batch, &mut rng, 1).unwrap();
|
||||
assert_eq!(shuffled.num_rows(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_in_memory_shuffle() {
|
||||
let config = ShufflerConfig {
|
||||
temp_dir: TemporaryDirectory::None,
|
||||
..Default::default()
|
||||
};
|
||||
let shuffler = Shuffler::new(config);
|
||||
|
||||
let stream = create_test_stream(BatchCount::from(5), RowCount::from(20));
|
||||
|
||||
let result_stream = shuffler.shuffle(stream, 100).await.unwrap();
|
||||
let result_batches: Vec<RecordBatch> = result_stream.try_collect().await.unwrap();
|
||||
|
||||
assert_eq!(result_batches.len(), 1);
|
||||
let result_batch = result_batches.into_iter().next().unwrap();
|
||||
|
||||
let unshuffled_batches = create_test_stream(BatchCount::from(5), RowCount::from(20))
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
let schema = unshuffled_batches[0].schema();
|
||||
let unshuffled_batch = concat_batches(&schema, &unshuffled_batches).unwrap();
|
||||
|
||||
let sorted = sort_batch(result_batch).await;
|
||||
|
||||
assert_eq!(unshuffled_batch, sorted);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_external_shuffle() {
|
||||
let config = ShufflerConfig {
|
||||
max_rows_per_file: 100,
|
||||
..Default::default()
|
||||
};
|
||||
let shuffler = Shuffler::new(config);
|
||||
|
||||
let stream = create_test_stream(BatchCount::from(5), RowCount::from(1000));
|
||||
|
||||
let result_stream = shuffler.shuffle(stream, 5000).await.unwrap();
|
||||
let result_batches: Vec<RecordBatch> = result_stream.try_collect().await.unwrap();
|
||||
|
||||
let unshuffled_batches = create_test_stream(BatchCount::from(5), RowCount::from(1000))
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
let schema = unshuffled_batches[0].schema();
|
||||
let unshuffled_batch = concat_batches(&schema, &unshuffled_batches).unwrap();
|
||||
|
||||
assert_eq!(result_batches.len(), 50);
|
||||
let result_batch = concat_batches(&schema, &result_batches).unwrap();
|
||||
|
||||
let sorted = sort_batch(result_batch).await;
|
||||
|
||||
assert_eq!(unshuffled_batch, sorted);
|
||||
}
|
||||
|
||||
#[test_log::test(tokio::test)]
|
||||
async fn test_external_clump_shuffle() {
|
||||
let config = ShufflerConfig {
|
||||
max_rows_per_file: 100,
|
||||
clump_size: Some(30),
|
||||
..Default::default()
|
||||
};
|
||||
let shuffler = Shuffler::new(config);
|
||||
|
||||
// Batch size (900) must be multiple of clump size (30)
|
||||
let stream = create_test_stream(BatchCount::from(5), RowCount::from(900));
|
||||
let schema = stream.schema();
|
||||
|
||||
// Remove 10 rows from the last batch to simulate ending with partial clump
|
||||
let mut batches = stream.try_collect::<Vec<_>>().await.unwrap();
|
||||
let last_index = batches.len() - 1;
|
||||
let sliced_last = batches[last_index].slice(0, 890);
|
||||
batches[last_index] = sliced_last;
|
||||
|
||||
let stream = Box::pin(SimpleRecordBatchStream::new(
|
||||
futures::stream::iter(batches).map(Ok).boxed(),
|
||||
schema.clone(),
|
||||
));
|
||||
|
||||
let result_stream = shuffler.shuffle(stream, 4490).await.unwrap();
|
||||
let result_batches: Vec<RecordBatch> = result_stream.try_collect().await.unwrap();
|
||||
let result_batch = concat_batches(&schema, &result_batches).unwrap();
|
||||
|
||||
let ids = result_batch.column(0).as_primitive::<Int32Type>();
|
||||
let mut iter = ids.into_iter().map(|o| o.unwrap());
|
||||
while let Some(first) = iter.next() {
|
||||
let rows_left_in_clump = if first == 4470 { 19 } else { 29 };
|
||||
let mut expected_next = first + 1;
|
||||
for _ in 0..rows_left_in_clump {
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next, expected_next);
|
||||
expected_next += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
804
rust/lancedb/src/dataloader/permutation/split.rs
Normal file
804
rust/lancedb/src/dataloader/permutation/split.rs
Normal file
@@ -0,0 +1,804 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::{
|
||||
iter,
|
||||
sync::{
|
||||
atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use arrow_array::{Array, BooleanArray, RecordBatch, UInt64Array};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use datafusion_common::hash_utils::create_hashes;
|
||||
use futures::{StreamExt, TryStreamExt};
|
||||
use lance::arrow::SchemaExt;
|
||||
|
||||
use crate::{
|
||||
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
|
||||
dataloader::{
|
||||
permutation::shuffle::{Shuffler, ShufflerConfig},
|
||||
permutation::util::TemporaryDirectory,
|
||||
},
|
||||
query::{Query, QueryBase, Select},
|
||||
Error, Result,
|
||||
};
|
||||
|
||||
pub const SPLIT_ID_COLUMN: &str = "split_id";
|
||||
|
||||
/// Strategy for assigning rows to splits
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum SplitStrategy {
|
||||
/// All rows will have split id 0
|
||||
NoSplit,
|
||||
/// Rows will be randomly assigned to splits
|
||||
///
|
||||
/// A seed can be provided to make the assignment deterministic.
|
||||
Random {
|
||||
seed: Option<u64>,
|
||||
sizes: SplitSizes,
|
||||
},
|
||||
/// Rows will be assigned to splits based on the values in the specified columns.
|
||||
///
|
||||
/// This will ensure rows are always assigned to the same split if the given columns do not change.
|
||||
///
|
||||
/// The `split_weights` are used to determine the approximate number of rows in each split. This
|
||||
/// controls how we divide up the u64 hash space. However, it does not guarantee any particular division
|
||||
/// of rows. For example, if all rows have identical hash values then all rows will be assigned to the same split
|
||||
/// regardless of the weights.
|
||||
///
|
||||
/// The `discard_weight` controls what percentage of rows should be throw away. For example, if you want your
|
||||
/// first split to have ~5% of your rows and the second split to have ~10% of your rows then you would set
|
||||
/// split_weights to [1, 2] and discard weight to 17 (or you could set split_weights to [5, 10] and discard_weight
|
||||
/// to 85). If you set discard_weight to 0 then all rows will be assigned to a split.
|
||||
Hash {
|
||||
columns: Vec<String>,
|
||||
split_weights: Vec<u64>,
|
||||
discard_weight: u64,
|
||||
},
|
||||
/// Rows will be assigned to splits sequentially.
|
||||
///
|
||||
/// The first N1 rows are assigned to split 1, the next N2 rows are assigned to split 2, etc.
|
||||
///
|
||||
/// This is mainly useful for debugging and testing.
|
||||
Sequential { sizes: SplitSizes },
|
||||
/// Rows will be assigned to splits based on a calculation of one or more columns.
|
||||
///
|
||||
/// This is useful when the splits already exist in the base table.
|
||||
///
|
||||
/// The provided `calculation` should be an SQL statement that returns an integer value between
|
||||
/// 0 and the number of splits - 1 (the number of splits is defined by the `splits` configuration).
|
||||
///
|
||||
/// If this strategy is used then the counts/percentages in the SplitSizes are ignored.
|
||||
Calculated { calculation: String },
|
||||
}
|
||||
|
||||
// The default is not to split the data
|
||||
//
|
||||
// All data will be assigned to a single split.
|
||||
impl Default for SplitStrategy {
|
||||
fn default() -> Self {
|
||||
Self::NoSplit
|
||||
}
|
||||
}
|
||||
|
||||
impl SplitStrategy {
|
||||
pub fn validate(&self, num_rows: u64) -> Result<()> {
|
||||
match self {
|
||||
Self::NoSplit => Ok(()),
|
||||
Self::Random { sizes, .. } => sizes.validate(num_rows),
|
||||
Self::Hash {
|
||||
split_weights,
|
||||
columns,
|
||||
..
|
||||
} => {
|
||||
if columns.is_empty() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Hash strategy requires at least one column".to_string(),
|
||||
});
|
||||
}
|
||||
if split_weights.is_empty() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Hash strategy requires at least one split weight".to_string(),
|
||||
});
|
||||
}
|
||||
if split_weights.contains(&0) {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Split weights must be greater than 0".to_string(),
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Self::Sequential { sizes } => sizes.validate(num_rows),
|
||||
Self::Calculated { .. } => Ok(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Splitter {
|
||||
temp_dir: TemporaryDirectory,
|
||||
strategy: SplitStrategy,
|
||||
}
|
||||
|
||||
impl Splitter {
|
||||
pub fn new(temp_dir: TemporaryDirectory, strategy: SplitStrategy) -> Self {
|
||||
Self { temp_dir, strategy }
|
||||
}
|
||||
|
||||
fn sequential_split_id(
|
||||
num_rows: u64,
|
||||
split_sizes: &[u64],
|
||||
split_index: &AtomicUsize,
|
||||
counter_in_split: &AtomicU64,
|
||||
exhausted: &AtomicBool,
|
||||
) -> UInt64Array {
|
||||
let mut split_ids = Vec::<u64>::with_capacity(num_rows as usize);
|
||||
|
||||
while split_ids.len() < num_rows as usize {
|
||||
let split_id = split_index.load(Ordering::Relaxed);
|
||||
let counter = counter_in_split.load(Ordering::Relaxed);
|
||||
|
||||
let split_size = split_sizes[split_id];
|
||||
let remaining_in_split = split_size - counter;
|
||||
|
||||
let remaining_in_batch = num_rows - split_ids.len() as u64;
|
||||
|
||||
let mut done = false;
|
||||
let rows_to_add = if remaining_in_batch < remaining_in_split {
|
||||
counter_in_split.fetch_add(remaining_in_batch, Ordering::Relaxed);
|
||||
remaining_in_batch
|
||||
} else {
|
||||
split_index.fetch_add(1, Ordering::Relaxed);
|
||||
counter_in_split.store(0, Ordering::Relaxed);
|
||||
if split_id == split_sizes.len() - 1 {
|
||||
exhausted.store(true, Ordering::Relaxed);
|
||||
done = true;
|
||||
}
|
||||
remaining_in_split
|
||||
};
|
||||
|
||||
split_ids.extend(iter::repeat(split_id as u64).take(rows_to_add as usize));
|
||||
if done {
|
||||
// Quit early if we've run out of splits
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
UInt64Array::from(split_ids)
|
||||
}
|
||||
|
||||
async fn apply_sequential(
|
||||
&self,
|
||||
source: SendableRecordBatchStream,
|
||||
num_rows: u64,
|
||||
sizes: &SplitSizes,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let split_sizes = sizes.to_counts(num_rows);
|
||||
|
||||
let split_index = AtomicUsize::new(0);
|
||||
let counter_in_split = AtomicU64::new(0);
|
||||
let exhausted = AtomicBool::new(false);
|
||||
|
||||
let schema = source.schema();
|
||||
|
||||
let new_schema = Arc::new(schema.try_with_column(Field::new(
|
||||
SPLIT_ID_COLUMN,
|
||||
DataType::UInt64,
|
||||
false,
|
||||
))?);
|
||||
|
||||
let new_schema_clone = new_schema.clone();
|
||||
let stream = source.filter_map(move |batch| {
|
||||
let batch = match batch {
|
||||
Ok(batch) => batch,
|
||||
Err(e) => {
|
||||
return std::future::ready(Some(Err(e)));
|
||||
}
|
||||
};
|
||||
|
||||
if exhausted.load(Ordering::Relaxed) {
|
||||
return std::future::ready(None);
|
||||
}
|
||||
|
||||
let split_ids = Self::sequential_split_id(
|
||||
batch.num_rows() as u64,
|
||||
&split_sizes,
|
||||
&split_index,
|
||||
&counter_in_split,
|
||||
&exhausted,
|
||||
);
|
||||
|
||||
let mut arrays = batch.columns().to_vec();
|
||||
// This can happen if we exhaust all splits in the middle of a batch
|
||||
if split_ids.len() < batch.num_rows() {
|
||||
arrays = arrays
|
||||
.iter()
|
||||
.map(|arr| arr.slice(0, split_ids.len()))
|
||||
.collect();
|
||||
}
|
||||
arrays.push(Arc::new(split_ids));
|
||||
|
||||
std::future::ready(Some(Ok(
|
||||
RecordBatch::try_new(new_schema.clone(), arrays).unwrap()
|
||||
)))
|
||||
});
|
||||
|
||||
Ok(Box::pin(SimpleRecordBatchStream::new(
|
||||
stream,
|
||||
new_schema_clone,
|
||||
)))
|
||||
}
|
||||
|
||||
fn hash_split_id(batch: &RecordBatch, thresholds: &[u64], total_weight: u64) -> UInt64Array {
|
||||
let arrays = batch
|
||||
.columns()
|
||||
.iter()
|
||||
// Don't hash the last column which should always be the row id
|
||||
.take(batch.columns().len() - 1)
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
let mut hashes = vec![0; batch.num_rows()];
|
||||
let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0);
|
||||
create_hashes(&arrays, &random_state, &mut hashes).unwrap();
|
||||
// As an example, let's assume the weights are 1, 2. Our total weight is 3.
|
||||
//
|
||||
// Our thresholds are [1, 3]
|
||||
// Our modulo output will be 0, 1, or 2.
|
||||
//
|
||||
// thresholds.binary_search(0) => Err(0) => 0
|
||||
// thresholds.binary_search(1) => Ok(0) => 1
|
||||
// thresholds.binary_search(2) => Err(1) => 1
|
||||
let split_ids = hashes
|
||||
.iter()
|
||||
.map(|h| {
|
||||
let h = h % total_weight;
|
||||
let split_id = match thresholds.binary_search(&h) {
|
||||
Ok(i) => (i + 1) as u64,
|
||||
Err(i) => i as u64,
|
||||
};
|
||||
if split_id == thresholds.len() as u64 {
|
||||
// If we're at the last threshold then we discard the row (indicated by setting
|
||||
// the split_id to null)
|
||||
None
|
||||
} else {
|
||||
Some(split_id)
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
UInt64Array::from(split_ids)
|
||||
}
|
||||
|
||||
async fn apply_hash(
|
||||
&self,
|
||||
source: SendableRecordBatchStream,
|
||||
weights: &[u64],
|
||||
discard_weight: u64,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let row_id_index = source.schema().fields.len() - 1;
|
||||
let new_schema = Arc::new(Schema::new(vec![
|
||||
source.schema().field(row_id_index).clone(),
|
||||
Field::new(SPLIT_ID_COLUMN, DataType::UInt64, false),
|
||||
]));
|
||||
|
||||
let total_weight = weights.iter().sum::<u64>() + discard_weight;
|
||||
// Thresholds are the cumulative sum of the weights
|
||||
let mut offset = 0;
|
||||
let thresholds = weights
|
||||
.iter()
|
||||
.map(|w| {
|
||||
let value = offset + w;
|
||||
offset = value;
|
||||
value
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let new_schema_clone = new_schema.clone();
|
||||
let stream = source.map_ok(move |batch| {
|
||||
let split_ids = Self::hash_split_id(&batch, &thresholds, total_weight);
|
||||
|
||||
if split_ids.null_count() > 0 {
|
||||
let is_valid = split_ids.nulls().unwrap().inner();
|
||||
let is_valid_mask = BooleanArray::new(is_valid.clone(), None);
|
||||
let split_ids = arrow::compute::filter(&split_ids, &is_valid_mask).unwrap();
|
||||
let row_ids = batch.column(row_id_index);
|
||||
let row_ids = arrow::compute::filter(row_ids.as_ref(), &is_valid_mask).unwrap();
|
||||
RecordBatch::try_new(new_schema.clone(), vec![row_ids, split_ids]).unwrap()
|
||||
} else {
|
||||
RecordBatch::try_new(
|
||||
new_schema.clone(),
|
||||
vec![batch.column(row_id_index).clone(), Arc::new(split_ids)],
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Box::pin(SimpleRecordBatchStream::new(
|
||||
stream,
|
||||
new_schema_clone,
|
||||
)))
|
||||
}
|
||||
|
||||
pub async fn apply(
|
||||
&self,
|
||||
source: SendableRecordBatchStream,
|
||||
num_rows: u64,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
self.strategy.validate(num_rows)?;
|
||||
|
||||
match &self.strategy {
|
||||
// For consistency, even if no-split, we still give a split id column of all 0s
|
||||
SplitStrategy::NoSplit => {
|
||||
self.apply_sequential(source, num_rows, &SplitSizes::Counts(vec![num_rows]))
|
||||
.await
|
||||
}
|
||||
SplitStrategy::Random { seed, sizes } => {
|
||||
let shuffler = Shuffler::new(ShufflerConfig {
|
||||
seed: *seed,
|
||||
// In this case we are only shuffling row ids so we can use a large max_rows_per_file
|
||||
max_rows_per_file: 10 * 1024 * 1024,
|
||||
temp_dir: self.temp_dir.clone(),
|
||||
clump_size: None,
|
||||
});
|
||||
|
||||
let shuffled = shuffler.shuffle(source, num_rows).await?;
|
||||
|
||||
self.apply_sequential(shuffled, num_rows, sizes).await
|
||||
}
|
||||
SplitStrategy::Sequential { sizes } => {
|
||||
self.apply_sequential(source, num_rows, sizes).await
|
||||
}
|
||||
// Nothing to do, split is calculated in projection
|
||||
SplitStrategy::Calculated { .. } => Ok(source),
|
||||
SplitStrategy::Hash {
|
||||
split_weights,
|
||||
discard_weight,
|
||||
..
|
||||
} => {
|
||||
self.apply_hash(source, split_weights, *discard_weight)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn project(&self, query: Query) -> Query {
|
||||
match &self.strategy {
|
||||
SplitStrategy::Calculated { calculation } => query.select(Select::Dynamic(vec![(
|
||||
SPLIT_ID_COLUMN.to_string(),
|
||||
calculation.clone(),
|
||||
)])),
|
||||
SplitStrategy::Hash { columns, .. } => query.select(Select::Columns(columns.clone())),
|
||||
_ => query,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn orders_by_split_id(&self) -> bool {
|
||||
match &self.strategy {
|
||||
SplitStrategy::Hash { .. } | SplitStrategy::Calculated { .. } => true,
|
||||
SplitStrategy::NoSplit
|
||||
| SplitStrategy::Sequential { .. }
|
||||
// It may be strange but for random we shuffle and then assign splits so the result is
|
||||
// sorted by split id
|
||||
| SplitStrategy::Random { .. } => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Split configuration - either percentages or absolute counts
|
||||
///
|
||||
/// If the percentages do not sum to 1.0 (or the counts do not sum to the total number of rows)
|
||||
/// the remaining rows will not be included in the permutation.
|
||||
///
|
||||
/// The default implementation assigns all rows to a single split.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum SplitSizes {
|
||||
/// Percentage splits (must sum to <= 1.0)
|
||||
///
|
||||
/// The number of rows in each split is the nearest integer to the percentage multiplied by
|
||||
/// the total number of rows.
|
||||
Percentages(Vec<f64>),
|
||||
/// Absolute row counts per split
|
||||
///
|
||||
/// If the dataset doesn't contain enough matching rows to fill all splits then an error
|
||||
/// will be raised.
|
||||
Counts(Vec<u64>),
|
||||
/// Divides data into a fixed number of splits
|
||||
///
|
||||
/// Will divide the data evenly.
|
||||
///
|
||||
/// If the number of rows is not divisible by the number of splits then the rows per split
|
||||
/// is rounded down.
|
||||
Fixed(u64),
|
||||
}
|
||||
|
||||
impl Default for SplitSizes {
|
||||
fn default() -> Self {
|
||||
Self::Percentages(vec![1.0])
|
||||
}
|
||||
}
|
||||
|
||||
impl SplitSizes {
|
||||
pub fn validate(&self, num_rows: u64) -> Result<()> {
|
||||
match self {
|
||||
Self::Percentages(percentages) => {
|
||||
for percentage in percentages {
|
||||
if *percentage < 0.0 || *percentage > 1.0 {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Split percentages must be between 0.0 and 1.0".to_string(),
|
||||
});
|
||||
}
|
||||
if percentage * (num_rows as f64) < 1.0 {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"One of the splits has {}% of {} rows which rounds to 0 rows",
|
||||
percentage * 100.0,
|
||||
num_rows
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
if percentages.iter().sum::<f64>() > 1.0 {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Split percentages must sum to 1.0 or less".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
Self::Counts(counts) => {
|
||||
if counts.iter().sum::<u64>() > num_rows {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"Split counts specified {} rows but only {} are available",
|
||||
counts.iter().sum::<u64>(),
|
||||
num_rows
|
||||
),
|
||||
});
|
||||
}
|
||||
if counts.contains(&0) {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Split counts must be greater than 0".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
Self::Fixed(num_splits) => {
|
||||
if *num_splits > num_rows {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"Split fixed config specified {} splits but only {} rows are available. Must have at least 1 row per split.",
|
||||
*num_splits, num_rows
|
||||
),
|
||||
});
|
||||
}
|
||||
if (num_rows / num_splits) == 0 {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"Split fixed config specified {} splits but only {} rows are available. Must have at least 1 row per split.",
|
||||
*num_splits, num_rows
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn to_counts(&self, num_rows: u64) -> Vec<u64> {
|
||||
let sizes = match self {
|
||||
Self::Percentages(percentages) => {
|
||||
let mut percentage_sum = 0.0_f64;
|
||||
let mut counts = percentages
|
||||
.iter()
|
||||
.map(|p| {
|
||||
let count = (p * (num_rows as f64)).round() as u64;
|
||||
percentage_sum += p;
|
||||
count
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let sum = counts.iter().sum::<u64>();
|
||||
|
||||
let is_basically_one =
|
||||
(num_rows as f64 - percentage_sum * num_rows as f64).abs() < 0.5;
|
||||
|
||||
// If the sum of percentages is close to 1.0 then rounding errors can add up
|
||||
// to more or less than num_rows
|
||||
//
|
||||
// Drop items from buckets until we have the correct number of rows
|
||||
let mut excess = sum as i64 - num_rows as i64;
|
||||
let mut drop_idx = 0;
|
||||
while excess > 0 {
|
||||
if counts[drop_idx] > 0 {
|
||||
counts[drop_idx] -= 1;
|
||||
excess -= 1;
|
||||
}
|
||||
drop_idx += 1;
|
||||
if drop_idx == counts.len() {
|
||||
drop_idx = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// On the other hand, if the percentages sum to ~1.0 then the we also shouldn't _lose_
|
||||
// rows due to rounding errors
|
||||
let mut add_idx = 0;
|
||||
while is_basically_one && excess < 0 {
|
||||
counts[add_idx] += 1;
|
||||
add_idx += 1;
|
||||
excess += 1;
|
||||
if add_idx == counts.len() {
|
||||
add_idx = 0;
|
||||
}
|
||||
}
|
||||
|
||||
counts
|
||||
}
|
||||
Self::Counts(counts) => counts.clone(),
|
||||
Self::Fixed(num_splits) => {
|
||||
let rows_per_split = num_rows / *num_splits;
|
||||
vec![rows_per_split; *num_splits as usize]
|
||||
}
|
||||
};
|
||||
|
||||
assert!(sizes.iter().sum::<u64>() <= num_rows);
|
||||
|
||||
sizes
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::arrow::LanceDbDatagenExt;
|
||||
|
||||
use super::*;
|
||||
use arrow::{
|
||||
array::AsArray,
|
||||
compute::concat_batches,
|
||||
datatypes::{Int32Type, UInt64Type},
|
||||
};
|
||||
use arrow_array::Int32Array;
|
||||
use futures::TryStreamExt;
|
||||
use lance_datagen::{BatchCount, ByteCount, RowCount, Seed};
|
||||
use std::sync::Arc;
|
||||
|
||||
const ID_COLUMN: &str = "id";
|
||||
|
||||
#[test]
|
||||
fn test_split_sizes_percentages_validation() {
|
||||
// Valid percentages
|
||||
let sizes = SplitSizes::Percentages(vec![0.7, 0.3]);
|
||||
assert!(sizes.validate(100).is_ok());
|
||||
|
||||
// Sum > 1.0
|
||||
let sizes = SplitSizes::Percentages(vec![0.7, 0.4]);
|
||||
assert!(sizes.validate(100).is_err());
|
||||
|
||||
// Negative percentage
|
||||
let sizes = SplitSizes::Percentages(vec![-0.1, 0.5]);
|
||||
assert!(sizes.validate(100).is_err());
|
||||
|
||||
// Percentage > 1.0
|
||||
let sizes = SplitSizes::Percentages(vec![1.5]);
|
||||
assert!(sizes.validate(100).is_err());
|
||||
|
||||
// Percentage rounds to 0 rows
|
||||
let sizes = SplitSizes::Percentages(vec![0.001]);
|
||||
assert!(sizes.validate(100).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_split_sizes_counts_validation() {
|
||||
// Valid counts
|
||||
let sizes = SplitSizes::Counts(vec![30, 70]);
|
||||
assert!(sizes.validate(100).is_ok());
|
||||
|
||||
// Sum > num_rows
|
||||
let sizes = SplitSizes::Counts(vec![60, 50]);
|
||||
assert!(sizes.validate(100).is_err());
|
||||
|
||||
// Counts are 0
|
||||
let sizes = SplitSizes::Counts(vec![0, 100]);
|
||||
assert!(sizes.validate(100).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_split_sizes_fixed_validation() {
|
||||
// Valid fixed splits
|
||||
let sizes = SplitSizes::Fixed(5);
|
||||
assert!(sizes.validate(100).is_ok());
|
||||
|
||||
// More splits than rows
|
||||
let sizes = SplitSizes::Fixed(150);
|
||||
assert!(sizes.validate(100).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_split_sizes_to_sizes_percentages() {
|
||||
let sizes = SplitSizes::Percentages(vec![0.3, 0.7]);
|
||||
let result = sizes.to_counts(100);
|
||||
assert_eq!(result, vec![30, 70]);
|
||||
|
||||
// Test rounding
|
||||
let sizes = SplitSizes::Percentages(vec![0.3, 0.41]);
|
||||
let result = sizes.to_counts(70);
|
||||
assert_eq!(result, vec![21, 29]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_split_sizes_to_sizes_fixed() {
|
||||
let sizes = SplitSizes::Fixed(4);
|
||||
let result = sizes.to_counts(100);
|
||||
assert_eq!(result, vec![25, 25, 25, 25]);
|
||||
|
||||
// Test with remainder
|
||||
let sizes = SplitSizes::Fixed(3);
|
||||
let result = sizes.to_counts(10);
|
||||
assert_eq!(result, vec![3, 3, 3]);
|
||||
}
|
||||
|
||||
fn test_data() -> SendableRecordBatchStream {
|
||||
lance_datagen::gen_batch()
|
||||
.with_seed(Seed::from(42))
|
||||
.col(ID_COLUMN, lance_datagen::array::step::<Int32Type>())
|
||||
.into_ldb_stream(RowCount::from(10), BatchCount::from(5))
|
||||
}
|
||||
|
||||
async fn verify_splitter(
|
||||
splitter: Splitter,
|
||||
data: SendableRecordBatchStream,
|
||||
num_rows: u64,
|
||||
expected_split_sizes: &[u64],
|
||||
row_ids_in_order: bool,
|
||||
) {
|
||||
let split_batches = splitter
|
||||
.apply(data, num_rows)
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let schema = split_batches[0].schema();
|
||||
let split_batch = concat_batches(&schema, &split_batches).unwrap();
|
||||
|
||||
let total_split_sizes = expected_split_sizes.iter().sum::<u64>();
|
||||
|
||||
assert_eq!(split_batch.num_rows(), total_split_sizes as usize);
|
||||
let mut expected = Vec::with_capacity(total_split_sizes as usize);
|
||||
for (i, size) in expected_split_sizes.iter().enumerate() {
|
||||
expected.extend(iter::repeat(i as u64).take(*size as usize));
|
||||
}
|
||||
let expected = Arc::new(UInt64Array::from(expected)) as Arc<dyn Array>;
|
||||
|
||||
assert_eq!(&expected, split_batch.column(1));
|
||||
|
||||
let expected_row_ids =
|
||||
Arc::new(Int32Array::from_iter_values(0..total_split_sizes as i32)) as Arc<dyn Array>;
|
||||
if row_ids_in_order {
|
||||
assert_eq!(&expected_row_ids, split_batch.column(0));
|
||||
} else {
|
||||
assert_ne!(&expected_row_ids, split_batch.column(0));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fixed_sequential_split() {
|
||||
let splitter = Splitter::new(
|
||||
// Sequential splitting doesn't need a temp dir
|
||||
TemporaryDirectory::None,
|
||||
SplitStrategy::Sequential {
|
||||
sizes: SplitSizes::Fixed(3),
|
||||
},
|
||||
);
|
||||
|
||||
verify_splitter(splitter, test_data(), 50, &[16, 16, 16], true).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fixed_random_split() {
|
||||
let splitter = Splitter::new(
|
||||
TemporaryDirectory::None,
|
||||
SplitStrategy::Random {
|
||||
seed: Some(42),
|
||||
sizes: SplitSizes::Fixed(3),
|
||||
},
|
||||
);
|
||||
|
||||
verify_splitter(splitter, test_data(), 50, &[16, 16, 16], false).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_counts_sequential_split() {
|
||||
let splitter = Splitter::new(
|
||||
// Sequential splitting doesn't need a temp dir
|
||||
TemporaryDirectory::None,
|
||||
SplitStrategy::Sequential {
|
||||
sizes: SplitSizes::Counts(vec![5, 15, 10]),
|
||||
},
|
||||
);
|
||||
|
||||
verify_splitter(splitter, test_data(), 50, &[5, 15, 10], true).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_counts_random_split() {
|
||||
let splitter = Splitter::new(
|
||||
TemporaryDirectory::None,
|
||||
SplitStrategy::Random {
|
||||
seed: Some(42),
|
||||
sizes: SplitSizes::Counts(vec![5, 15, 10]),
|
||||
},
|
||||
);
|
||||
|
||||
verify_splitter(splitter, test_data(), 50, &[5, 15, 10], false).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_percentages_sequential_split() {
|
||||
let splitter = Splitter::new(
|
||||
// Sequential splitting doesn't need a temp dir
|
||||
TemporaryDirectory::None,
|
||||
SplitStrategy::Sequential {
|
||||
sizes: SplitSizes::Percentages(vec![0.217, 0.168, 0.17]),
|
||||
},
|
||||
);
|
||||
|
||||
verify_splitter(splitter, test_data(), 50, &[11, 8, 9], true).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_percentages_random_split() {
|
||||
let splitter = Splitter::new(
|
||||
TemporaryDirectory::None,
|
||||
SplitStrategy::Random {
|
||||
seed: Some(42),
|
||||
sizes: SplitSizes::Percentages(vec![0.217, 0.168, 0.17]),
|
||||
},
|
||||
);
|
||||
|
||||
verify_splitter(splitter, test_data(), 50, &[11, 8, 9], false).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hash_split() {
|
||||
let data = lance_datagen::gen_batch()
|
||||
.with_seed(Seed::from(42))
|
||||
.col(
|
||||
"hash1",
|
||||
lance_datagen::array::rand_utf8(ByteCount::from(10), false),
|
||||
)
|
||||
.col("hash2", lance_datagen::array::step::<Int32Type>())
|
||||
.col(ID_COLUMN, lance_datagen::array::step::<Int32Type>())
|
||||
.into_ldb_stream(RowCount::from(10), BatchCount::from(5));
|
||||
|
||||
let splitter = Splitter::new(
|
||||
TemporaryDirectory::None,
|
||||
SplitStrategy::Hash {
|
||||
columns: vec!["hash1".to_string(), "hash2".to_string()],
|
||||
split_weights: vec![1, 2],
|
||||
discard_weight: 1,
|
||||
},
|
||||
);
|
||||
|
||||
let split_batches = splitter
|
||||
.apply(data, 10)
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let schema = split_batches[0].schema();
|
||||
let split_batch = concat_batches(&schema, &split_batches).unwrap();
|
||||
|
||||
// These assertions are all based on fixed seed in data generation but they match
|
||||
// up roughly to what we expect (25% discarded, 25% in split 0, 50% in split 1)
|
||||
|
||||
// 14 rows (28%) are discarded because discard_weight is 1
|
||||
assert_eq!(split_batch.num_rows(), 36);
|
||||
assert_eq!(split_batch.num_columns(), 2);
|
||||
|
||||
let split_ids = split_batch.column(1).as_primitive::<UInt64Type>().values();
|
||||
let num_in_split_0 = split_ids.iter().filter(|v| **v == 0).count();
|
||||
let num_in_split_1 = split_ids.iter().filter(|v| **v == 1).count();
|
||||
|
||||
assert_eq!(num_in_split_0, 11); // 22%
|
||||
assert_eq!(num_in_split_1, 25); // 50%
|
||||
}
|
||||
}
|
||||
98
rust/lancedb/src/dataloader/permutation/util.rs
Normal file
98
rust/lancedb/src/dataloader/permutation/util.rs
Normal file
@@ -0,0 +1,98 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
|
||||
use arrow_array::RecordBatch;
|
||||
use arrow_schema::{Fields, Schema};
|
||||
use datafusion_execution::disk_manager::DiskManagerMode;
|
||||
use futures::TryStreamExt;
|
||||
use rand::{rngs::SmallRng, RngCore, SeedableRng};
|
||||
use tempfile::TempDir;
|
||||
|
||||
use crate::{
|
||||
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
|
||||
Error, Result,
|
||||
};
|
||||
|
||||
/// Directory to use for temporary files
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub enum TemporaryDirectory {
|
||||
/// Use the operating system's default temporary directory (e.g. /tmp)
|
||||
#[default]
|
||||
OsDefault,
|
||||
/// Use the specified directory (must be an absolute path)
|
||||
Specific(PathBuf),
|
||||
/// If spilling is required, then error out
|
||||
None,
|
||||
}
|
||||
|
||||
impl TemporaryDirectory {
|
||||
pub fn create_temp_dir(&self) -> Result<TempDir> {
|
||||
match self {
|
||||
Self::OsDefault => tempfile::tempdir(),
|
||||
Self::Specific(path) => tempfile::Builder::default().tempdir_in(path),
|
||||
Self::None => {
|
||||
return Err(Error::Runtime {
|
||||
message: "No temporary directory was supplied and this operation requires spilling to disk".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
.map_err(|err| Error::Other {
|
||||
message: "Failed to create temporary directory".to_string(),
|
||||
source: Some(err.into()),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn to_disk_manager_mode(&self) -> DiskManagerMode {
|
||||
match self {
|
||||
Self::OsDefault => DiskManagerMode::OsTmpDirectory,
|
||||
Self::Specific(path) => DiskManagerMode::Directories(vec![path.clone()]),
|
||||
Self::None => DiskManagerMode::Disabled,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn non_crypto_rng(seed: &Option<u64>) -> Box<dyn RngCore + Send> {
|
||||
Box::new(
|
||||
seed.as_ref()
|
||||
.map(|seed| SmallRng::seed_from_u64(*seed))
|
||||
.unwrap_or_else(SmallRng::from_os_rng),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn rename_column(
|
||||
stream: SendableRecordBatchStream,
|
||||
old_name: &str,
|
||||
new_name: &str,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let schema = stream.schema();
|
||||
let field_index = schema.index_of(old_name)?;
|
||||
|
||||
let new_fields = schema
|
||||
.fields
|
||||
.iter()
|
||||
.cloned()
|
||||
.enumerate()
|
||||
.map(|(idx, f)| {
|
||||
if idx == field_index {
|
||||
Arc::new(f.as_ref().clone().with_name(new_name))
|
||||
} else {
|
||||
f
|
||||
}
|
||||
})
|
||||
.collect::<Fields>();
|
||||
let new_schema = Arc::new(Schema::new(new_fields).with_metadata(schema.metadata().clone()));
|
||||
let new_schema_clone = new_schema.clone();
|
||||
|
||||
let renamed_stream = stream.and_then(move |batch| {
|
||||
let renamed_batch =
|
||||
RecordBatch::try_new(new_schema.clone(), batch.columns().to_vec()).map_err(Error::from);
|
||||
std::future::ready(renamed_batch)
|
||||
});
|
||||
|
||||
Ok(Box::pin(SimpleRecordBatchStream::new(
|
||||
renamed_stream,
|
||||
new_schema_clone,
|
||||
)))
|
||||
}
|
||||
@@ -8,6 +8,7 @@ use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use vector::IvfFlatIndexBuilder;
|
||||
|
||||
use crate::index::vector::IvfRqIndexBuilder;
|
||||
use crate::{table::BaseTable, DistanceType, Error, Result};
|
||||
|
||||
use self::{
|
||||
@@ -53,6 +54,9 @@ pub enum Index {
|
||||
/// IVF index with Product Quantization
|
||||
IvfPq(IvfPqIndexBuilder),
|
||||
|
||||
/// IVF index with RabitQ Quantization
|
||||
IvfRq(IvfRqIndexBuilder),
|
||||
|
||||
/// IVF-HNSW index with Product Quantization
|
||||
/// It is a variant of the HNSW algorithm that uses product quantization to compress the vectors.
|
||||
IvfHnswPq(IvfHnswPqIndexBuilder),
|
||||
@@ -275,6 +279,8 @@ pub enum IndexType {
|
||||
IvfFlat,
|
||||
#[serde(alias = "IVF_PQ")]
|
||||
IvfPq,
|
||||
#[serde(alias = "IVF_RQ")]
|
||||
IvfRq,
|
||||
#[serde(alias = "IVF_HNSW_PQ")]
|
||||
IvfHnswPq,
|
||||
#[serde(alias = "IVF_HNSW_SQ")]
|
||||
@@ -296,6 +302,7 @@ impl std::fmt::Display for IndexType {
|
||||
match self {
|
||||
Self::IvfFlat => write!(f, "IVF_FLAT"),
|
||||
Self::IvfPq => write!(f, "IVF_PQ"),
|
||||
Self::IvfRq => write!(f, "IVF_RQ"),
|
||||
Self::IvfHnswPq => write!(f, "IVF_HNSW_PQ"),
|
||||
Self::IvfHnswSq => write!(f, "IVF_HNSW_SQ"),
|
||||
Self::BTree => write!(f, "BTREE"),
|
||||
@@ -317,6 +324,7 @@ impl std::str::FromStr for IndexType {
|
||||
"FTS" | "INVERTED" => Ok(Self::FTS),
|
||||
"IVF_FLAT" => Ok(Self::IvfFlat),
|
||||
"IVF_PQ" => Ok(Self::IvfPq),
|
||||
"IVF_RQ" => Ok(Self::IvfRq),
|
||||
"IVF_HNSW_PQ" => Ok(Self::IvfHnswPq),
|
||||
"IVF_HNSW_SQ" => Ok(Self::IvfHnswSq),
|
||||
_ => Err(Error::InvalidInput {
|
||||
|
||||
@@ -291,6 +291,52 @@ pub(crate) fn suggested_num_sub_vectors(dim: u32) -> u32 {
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for an IVF RQ index.
|
||||
///
|
||||
/// This index stores a compressed (quantized) copy of every vector. Each dimension
|
||||
/// is quantized into a small number of bits.
|
||||
/// The parameters `num_bits` control this process, providing a tradeoff
|
||||
/// between index size (and thus search speed) and index accuracy.
|
||||
///
|
||||
/// The partitioning process is called IVF and the `num_partitions` parameter controls how
|
||||
/// many groups to create.
|
||||
///
|
||||
/// Note that training an IVF RQ index on a large dataset is a slow operation and
|
||||
/// currently is also a memory intensive operation.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IvfRqIndexBuilder {
|
||||
// IVF
|
||||
pub(crate) distance_type: DistanceType,
|
||||
pub(crate) num_partitions: Option<u32>,
|
||||
pub(crate) num_bits: Option<u32>,
|
||||
pub(crate) sample_rate: u32,
|
||||
pub(crate) max_iterations: u32,
|
||||
pub(crate) target_partition_size: Option<u32>,
|
||||
}
|
||||
|
||||
impl Default for IvfRqIndexBuilder {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
distance_type: DistanceType::L2,
|
||||
num_partitions: None,
|
||||
num_bits: None,
|
||||
sample_rate: 256,
|
||||
max_iterations: 50,
|
||||
target_partition_size: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IvfRqIndexBuilder {
|
||||
impl_distance_type_setter!();
|
||||
impl_ivf_params_setter!();
|
||||
|
||||
pub fn num_bits(mut self, num_bits: u32) -> Self {
|
||||
self.num_bits = Some(num_bits);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for an IVF HNSW PQ index.
|
||||
///
|
||||
/// This index is a combination of IVF and HNSW.
|
||||
|
||||
@@ -194,6 +194,7 @@ pub mod arrow;
|
||||
pub mod connection;
|
||||
pub mod data;
|
||||
pub mod database;
|
||||
pub mod dataloader;
|
||||
pub mod embeddings;
|
||||
pub mod error;
|
||||
pub mod index;
|
||||
@@ -206,7 +207,8 @@ pub mod query;
|
||||
pub mod remote;
|
||||
pub mod rerankers;
|
||||
pub mod table;
|
||||
pub mod test_connection;
|
||||
#[cfg(test)]
|
||||
pub mod test_utils;
|
||||
pub mod utils;
|
||||
|
||||
use std::fmt::Display;
|
||||
|
||||
@@ -6,10 +6,10 @@ use std::{future::Future, time::Duration};
|
||||
|
||||
use arrow::compute::concat_batches;
|
||||
use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array};
|
||||
use arrow_schema::DataType;
|
||||
use arrow_schema::{DataType, SchemaRef};
|
||||
use datafusion_expr::Expr;
|
||||
use datafusion_physical_plan::ExecutionPlan;
|
||||
use futures::{stream, try_join, FutureExt, TryStreamExt};
|
||||
use futures::{stream, try_join, FutureExt, TryFutureExt, TryStreamExt};
|
||||
use half::f16;
|
||||
use lance::{
|
||||
arrow::RecordBatchExt,
|
||||
@@ -582,16 +582,40 @@ pub trait ExecutableQuery {
|
||||
options: QueryExecutionOptions,
|
||||
) -> impl Future<Output = Result<SendableRecordBatchStream>> + Send;
|
||||
|
||||
/// Explain the plan for a query
|
||||
///
|
||||
/// This will create a string representation of the plan that will be used to
|
||||
/// execute the query. This will not execute the query.
|
||||
///
|
||||
/// This function can be used to get an understanding of what work will be done by the query
|
||||
/// and is useful for debugging query performance.
|
||||
fn explain_plan(&self, verbose: bool) -> impl Future<Output = Result<String>> + Send;
|
||||
|
||||
/// Execute the query and display the runtime metrics
|
||||
///
|
||||
/// This shows the same plan as [`ExecutableQuery::explain_plan`] but includes runtime metrics.
|
||||
///
|
||||
/// This function will actually execute the query in order to get the runtime metrics.
|
||||
fn analyze_plan(&self) -> impl Future<Output = Result<String>> + Send {
|
||||
self.analyze_plan_with_options(QueryExecutionOptions::default())
|
||||
}
|
||||
|
||||
/// Execute the query and display the runtime metrics
|
||||
///
|
||||
/// This is the same as [`ExecutableQuery::analyze_plan`] but allows for specifying the execution options.
|
||||
fn analyze_plan_with_options(
|
||||
&self,
|
||||
options: QueryExecutionOptions,
|
||||
) -> impl Future<Output = Result<String>> + Send;
|
||||
|
||||
/// Return the output schema for data returned by the query without actually executing the query
|
||||
///
|
||||
/// This can be useful when the selection for a query is built dynamically as it is not always
|
||||
/// obvious what the output schema will be.
|
||||
fn output_schema(&self) -> impl Future<Output = Result<SchemaRef>> + Send {
|
||||
self.create_plan(QueryExecutionOptions::default())
|
||||
.and_then(|plan| std::future::ready(Ok(plan.schema())))
|
||||
}
|
||||
}
|
||||
|
||||
/// A query filter that can be applied to a query
|
||||
@@ -1505,6 +1529,16 @@ mod tests {
|
||||
.query()
|
||||
.limit(10)
|
||||
.select(Select::dynamic(&[("id2", "id * 2"), ("id", "id")]));
|
||||
|
||||
let schema = query.output_schema().await.unwrap();
|
||||
assert_eq!(
|
||||
schema,
|
||||
Arc::new(ArrowSchema::new(vec![
|
||||
ArrowField::new("id2", DataType::Int32, true),
|
||||
ArrowField::new("id", DataType::Int32, true),
|
||||
]))
|
||||
);
|
||||
|
||||
let result = query.execute().await;
|
||||
let mut batches = result
|
||||
.expect("should have result")
|
||||
|
||||
@@ -16,7 +16,7 @@ use tokio::task::spawn_blocking;
|
||||
use crate::database::{
|
||||
CloneTableRequest, CreateNamespaceRequest, CreateTableData, CreateTableMode,
|
||||
CreateTableRequest, Database, DatabaseOptions, DropNamespaceRequest, ListNamespacesRequest,
|
||||
OpenTableRequest, TableNamesRequest,
|
||||
OpenTableRequest, ReadConsistency, TableNamesRequest,
|
||||
};
|
||||
use crate::error::Result;
|
||||
use crate::table::BaseTable;
|
||||
@@ -189,6 +189,7 @@ struct ListTablesResponse {
|
||||
pub struct RemoteDatabase<S: HttpSend = Sender> {
|
||||
client: RestfulLanceDbClient<S>,
|
||||
table_cache: Cache<String, Arc<RemoteTable<S>>>,
|
||||
uri: String,
|
||||
}
|
||||
|
||||
impl RemoteDatabase {
|
||||
@@ -217,6 +218,7 @@ impl RemoteDatabase {
|
||||
Ok(Self {
|
||||
client,
|
||||
table_cache,
|
||||
uri: uri.to_owned(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -238,6 +240,7 @@ mod test_utils {
|
||||
Self {
|
||||
client,
|
||||
table_cache: Cache::new(0),
|
||||
uri: "http://localhost".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -250,6 +253,7 @@ mod test_utils {
|
||||
Self {
|
||||
client,
|
||||
table_cache: Cache::new(0),
|
||||
uri: "http://localhost".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -315,6 +319,17 @@ fn build_cache_key(name: &str, namespace: &[String]) -> String {
|
||||
|
||||
#[async_trait]
|
||||
impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
fn uri(&self) -> &str {
|
||||
&self.uri
|
||||
}
|
||||
|
||||
async fn read_consistency(&self) -> Result<ReadConsistency> {
|
||||
Err(Error::NotSupported {
|
||||
message: "Getting the read consistency of a remote database is not yet supported"
|
||||
.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn table_names(&self, request: TableNamesRequest) -> Result<Vec<String>> {
|
||||
let mut req = if !request.namespace.is_empty() {
|
||||
let namespace_id =
|
||||
|
||||
@@ -50,6 +50,7 @@ use std::sync::Arc;
|
||||
|
||||
use crate::arrow::IntoArrow;
|
||||
use crate::connection::NoData;
|
||||
use crate::database::Database;
|
||||
use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, MemoryRegistry};
|
||||
use crate::error::{Error, Result};
|
||||
use crate::index::vector::{suggested_num_partitions_for_hnsw, VectorIndex};
|
||||
@@ -510,6 +511,9 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
|
||||
/// Get the namespace of the table.
|
||||
fn namespace(&self) -> &[String];
|
||||
/// Get the id of the table
|
||||
///
|
||||
/// This is the namespace of the table concatenated with the name
|
||||
/// separated by a dot (".")
|
||||
fn id(&self) -> &str;
|
||||
/// Get the arrow [Schema] of the table.
|
||||
async fn schema(&self) -> Result<SchemaRef>;
|
||||
@@ -611,9 +615,10 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
|
||||
/// A Table is a collection of strong typed Rows.
|
||||
///
|
||||
/// The type of the each row is defined in Apache Arrow [Schema].
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Table {
|
||||
inner: Arc<dyn BaseTable>,
|
||||
database: Arc<dyn Database>,
|
||||
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
||||
}
|
||||
|
||||
@@ -631,11 +636,13 @@ mod test_utils {
|
||||
{
|
||||
let inner = Arc::new(crate::remote::table::RemoteTable::new_mock(
|
||||
name.into(),
|
||||
handler,
|
||||
handler.clone(),
|
||||
None,
|
||||
));
|
||||
let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
|
||||
Self {
|
||||
inner,
|
||||
database,
|
||||
// Registry is unused.
|
||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||
}
|
||||
@@ -651,11 +658,13 @@ mod test_utils {
|
||||
{
|
||||
let inner = Arc::new(crate::remote::table::RemoteTable::new_mock(
|
||||
name.into(),
|
||||
handler,
|
||||
handler.clone(),
|
||||
Some(version),
|
||||
));
|
||||
let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
|
||||
Self {
|
||||
inner,
|
||||
database,
|
||||
// Registry is unused.
|
||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||
}
|
||||
@@ -670,9 +679,10 @@ impl std::fmt::Display for Table {
|
||||
}
|
||||
|
||||
impl Table {
|
||||
pub fn new(inner: Arc<dyn BaseTable>) -> Self {
|
||||
pub fn new(inner: Arc<dyn BaseTable>, database: Arc<dyn Database>) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
database,
|
||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||
}
|
||||
}
|
||||
@@ -681,12 +691,22 @@ impl Table {
|
||||
&self.inner
|
||||
}
|
||||
|
||||
pub fn database(&self) -> &Arc<dyn Database> {
|
||||
&self.database
|
||||
}
|
||||
|
||||
pub fn embedding_registry(&self) -> &Arc<dyn EmbeddingRegistry> {
|
||||
&self.embedding_registry
|
||||
}
|
||||
|
||||
pub(crate) fn new_with_embedding_registry(
|
||||
inner: Arc<dyn BaseTable>,
|
||||
database: Arc<dyn Database>,
|
||||
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
database,
|
||||
embedding_registry,
|
||||
}
|
||||
}
|
||||
@@ -1416,12 +1436,6 @@ impl Tags for NativeTags {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<NativeTable> for Table {
|
||||
fn from(table: NativeTable) -> Self {
|
||||
Self::new(Arc::new(table))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait NativeTableExt {
|
||||
/// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`].
|
||||
fn as_native(&self) -> Option<&NativeTable>;
|
||||
@@ -1843,6 +1857,18 @@ impl NativeTable {
|
||||
);
|
||||
Ok(Box::new(lance_idx_params))
|
||||
}
|
||||
Index::IvfRq(index) => {
|
||||
Self::validate_index_type(field, "IVF RQ", supported_vector_data_type)?;
|
||||
let num_partitions = self
|
||||
.get_num_partitions(index.num_partitions, false, None)
|
||||
.await?;
|
||||
let lance_idx_params = VectorIndexParams::ivf_rq(
|
||||
num_partitions as usize,
|
||||
index.num_bits.unwrap_or(1) as u8,
|
||||
index.distance_type.into(),
|
||||
);
|
||||
Ok(Box::new(lance_idx_params))
|
||||
}
|
||||
Index::IvfHnswPq(index) => {
|
||||
Self::validate_index_type(field, "IVF HNSW PQ", supported_vector_data_type)?;
|
||||
let dim = Self::get_vector_dimension(field)?;
|
||||
@@ -1912,9 +1938,11 @@ impl NativeTable {
|
||||
Index::Bitmap(_) => IndexType::Bitmap,
|
||||
Index::LabelList(_) => IndexType::LabelList,
|
||||
Index::FTS(_) => IndexType::Inverted,
|
||||
Index::IvfFlat(_) | Index::IvfPq(_) | Index::IvfHnswPq(_) | Index::IvfHnswSq(_) => {
|
||||
IndexType::Vector
|
||||
}
|
||||
Index::IvfFlat(_)
|
||||
| Index::IvfPq(_)
|
||||
| Index::IvfRq(_)
|
||||
| Index::IvfHnswPq(_)
|
||||
| Index::IvfHnswSq(_) => IndexType::Vector,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -125,6 +125,10 @@ impl ExecutionPlan for MetadataEraserExec {
|
||||
fn partition_statistics(&self, partition: Option<usize>) -> DataFusionResult<Statistics> {
|
||||
self.input.partition_statistics(partition)
|
||||
}
|
||||
|
||||
fn supports_limit_pushdown(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
||||
@@ -1,126 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! Functions for testing connections.
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test_utils {
|
||||
use regex::Regex;
|
||||
use std::env;
|
||||
use std::io::{BufRead, BufReader};
|
||||
use std::process::{Child, ChildStdout, Command, Stdio};
|
||||
|
||||
use crate::{connect, Connection};
|
||||
use anyhow::{bail, Result};
|
||||
use tempfile::{tempdir, TempDir};
|
||||
|
||||
pub struct TestConnection {
|
||||
pub uri: String,
|
||||
pub connection: Connection,
|
||||
_temp_dir: Option<TempDir>,
|
||||
_process: Option<TestProcess>,
|
||||
}
|
||||
|
||||
struct TestProcess {
|
||||
child: Child,
|
||||
}
|
||||
|
||||
impl Drop for TestProcess {
|
||||
#[allow(unused_must_use)]
|
||||
fn drop(&mut self) {
|
||||
self.child.kill();
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn new_test_connection() -> Result<TestConnection> {
|
||||
match env::var("CREATE_LANCEDB_TEST_CONNECTION_SCRIPT") {
|
||||
Ok(script_path) => new_remote_connection(&script_path).await,
|
||||
Err(_e) => new_local_connection().await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn new_remote_connection(script_path: &str) -> Result<TestConnection> {
|
||||
let temp_dir = tempdir()?;
|
||||
let data_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let child_result = Command::new(script_path)
|
||||
.stdin(Stdio::null())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.arg(data_path.clone())
|
||||
.spawn();
|
||||
if child_result.is_err() {
|
||||
bail!(format!(
|
||||
"Unable to run {}: {:?}",
|
||||
script_path,
|
||||
child_result.err()
|
||||
));
|
||||
}
|
||||
let mut process = TestProcess {
|
||||
child: child_result.unwrap(),
|
||||
};
|
||||
let stdout = BufReader::new(process.child.stdout.take().unwrap());
|
||||
let port = read_process_port(stdout)?;
|
||||
let uri = "db://test";
|
||||
let host_override = format!("http://localhost:{}", port);
|
||||
let connection = create_new_connection(uri, &host_override).await?;
|
||||
Ok(TestConnection {
|
||||
uri: uri.to_string(),
|
||||
connection,
|
||||
_temp_dir: Some(temp_dir),
|
||||
_process: Some(process),
|
||||
})
|
||||
}
|
||||
|
||||
fn read_process_port(mut stdout: BufReader<ChildStdout>) -> Result<String> {
|
||||
let mut line = String::new();
|
||||
let re = Regex::new(r"Query node now listening on 0.0.0.0:(.*)").unwrap();
|
||||
loop {
|
||||
let result = stdout.read_line(&mut line);
|
||||
if let Err(err) = result {
|
||||
bail!(format!(
|
||||
"read_process_port: error while reading from process output: {}",
|
||||
err
|
||||
));
|
||||
} else if result.unwrap() == 0 {
|
||||
bail!("read_process_port: hit EOF before reading port from process output.");
|
||||
}
|
||||
if re.is_match(&line) {
|
||||
let caps = re.captures(&line).unwrap();
|
||||
return Ok(caps[1].to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
async fn create_new_connection(
|
||||
uri: &str,
|
||||
host_override: &str,
|
||||
) -> crate::error::Result<Connection> {
|
||||
connect(uri)
|
||||
.region("us-east-1")
|
||||
.api_key("sk_localtest")
|
||||
.host_override(host_override)
|
||||
.execute()
|
||||
.await
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "remote"))]
|
||||
async fn create_new_connection(
|
||||
_uri: &str,
|
||||
_host_override: &str,
|
||||
) -> crate::error::Result<Connection> {
|
||||
panic!("remote feature not supported");
|
||||
}
|
||||
|
||||
async fn new_local_connection() -> Result<TestConnection> {
|
||||
let temp_dir = tempdir()?;
|
||||
let uri = temp_dir.path().to_str().unwrap();
|
||||
let connection = connect(uri).execute().await?;
|
||||
Ok(TestConnection {
|
||||
uri: uri.to_string(),
|
||||
connection,
|
||||
_temp_dir: Some(temp_dir),
|
||||
_process: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
5
rust/lancedb/src/test_utils.rs
Normal file
5
rust/lancedb/src/test_utils.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
pub mod connection;
|
||||
pub mod datagen;
|
||||
120
rust/lancedb/src/test_utils/connection.rs
Normal file
120
rust/lancedb/src/test_utils/connection.rs
Normal file
@@ -0,0 +1,120 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! Functions for testing connections.
|
||||
|
||||
use regex::Regex;
|
||||
use std::env;
|
||||
use std::io::{BufRead, BufReader};
|
||||
use std::process::{Child, ChildStdout, Command, Stdio};
|
||||
|
||||
use crate::{connect, Connection};
|
||||
use anyhow::{bail, Result};
|
||||
use tempfile::{tempdir, TempDir};
|
||||
|
||||
pub struct TestConnection {
|
||||
pub uri: String,
|
||||
pub connection: Connection,
|
||||
_temp_dir: Option<TempDir>,
|
||||
_process: Option<TestProcess>,
|
||||
}
|
||||
|
||||
struct TestProcess {
|
||||
child: Child,
|
||||
}
|
||||
|
||||
impl Drop for TestProcess {
|
||||
#[allow(unused_must_use)]
|
||||
fn drop(&mut self) {
|
||||
self.child.kill();
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn new_test_connection() -> Result<TestConnection> {
|
||||
match env::var("CREATE_LANCEDB_TEST_CONNECTION_SCRIPT") {
|
||||
Ok(script_path) => new_remote_connection(&script_path).await,
|
||||
Err(_e) => new_local_connection().await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn new_remote_connection(script_path: &str) -> Result<TestConnection> {
|
||||
let temp_dir = tempdir()?;
|
||||
let data_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let child_result = Command::new(script_path)
|
||||
.stdin(Stdio::null())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.arg(data_path.clone())
|
||||
.spawn();
|
||||
if child_result.is_err() {
|
||||
bail!(format!(
|
||||
"Unable to run {}: {:?}",
|
||||
script_path,
|
||||
child_result.err()
|
||||
));
|
||||
}
|
||||
let mut process = TestProcess {
|
||||
child: child_result.unwrap(),
|
||||
};
|
||||
let stdout = BufReader::new(process.child.stdout.take().unwrap());
|
||||
let port = read_process_port(stdout)?;
|
||||
let uri = "db://test";
|
||||
let host_override = format!("http://localhost:{}", port);
|
||||
let connection = create_new_connection(uri, &host_override).await?;
|
||||
Ok(TestConnection {
|
||||
uri: uri.to_string(),
|
||||
connection,
|
||||
_temp_dir: Some(temp_dir),
|
||||
_process: Some(process),
|
||||
})
|
||||
}
|
||||
|
||||
fn read_process_port(mut stdout: BufReader<ChildStdout>) -> Result<String> {
|
||||
let mut line = String::new();
|
||||
let re = Regex::new(r"Query node now listening on 0.0.0.0:(.*)").unwrap();
|
||||
loop {
|
||||
let result = stdout.read_line(&mut line);
|
||||
if let Err(err) = result {
|
||||
bail!(format!(
|
||||
"read_process_port: error while reading from process output: {}",
|
||||
err
|
||||
));
|
||||
} else if result.unwrap() == 0 {
|
||||
bail!("read_process_port: hit EOF before reading port from process output.");
|
||||
}
|
||||
if re.is_match(&line) {
|
||||
let caps = re.captures(&line).unwrap();
|
||||
return Ok(caps[1].to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
async fn create_new_connection(uri: &str, host_override: &str) -> crate::error::Result<Connection> {
|
||||
connect(uri)
|
||||
.region("us-east-1")
|
||||
.api_key("sk_localtest")
|
||||
.host_override(host_override)
|
||||
.execute()
|
||||
.await
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "remote"))]
|
||||
async fn create_new_connection(
|
||||
_uri: &str,
|
||||
_host_override: &str,
|
||||
) -> crate::error::Result<Connection> {
|
||||
panic!("remote feature not supported");
|
||||
}
|
||||
|
||||
async fn new_local_connection() -> Result<TestConnection> {
|
||||
let temp_dir = tempdir()?;
|
||||
let uri = temp_dir.path().to_str().unwrap();
|
||||
let connection = connect(uri).execute().await?;
|
||||
Ok(TestConnection {
|
||||
uri: uri.to_string(),
|
||||
connection,
|
||||
_temp_dir: Some(temp_dir),
|
||||
_process: None,
|
||||
})
|
||||
}
|
||||
55
rust/lancedb/src/test_utils/datagen.rs
Normal file
55
rust/lancedb/src/test_utils/datagen.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use arrow_array::RecordBatch;
|
||||
use futures::TryStreamExt;
|
||||
use lance_datagen::{BatchCount, BatchGeneratorBuilder, RowCount};
|
||||
|
||||
use crate::{
|
||||
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
|
||||
connect, Error, Table,
|
||||
};
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait LanceDbDatagenExt {
|
||||
async fn into_mem_table(
|
||||
self,
|
||||
table_name: &str,
|
||||
rows_per_batch: RowCount,
|
||||
num_batches: BatchCount,
|
||||
) -> Table;
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl LanceDbDatagenExt for BatchGeneratorBuilder {
|
||||
async fn into_mem_table(
|
||||
self,
|
||||
table_name: &str,
|
||||
rows_per_batch: RowCount,
|
||||
num_batches: BatchCount,
|
||||
) -> Table {
|
||||
let (stream, schema) = self.into_reader_stream(rows_per_batch, num_batches);
|
||||
let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream::new(
|
||||
stream.map_err(Error::from),
|
||||
schema,
|
||||
));
|
||||
let db = connect("memory:///").execute().await.unwrap();
|
||||
db.create_table_streaming(table_name, stream)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn virtual_table(name: &str, values: &RecordBatch) -> Table {
|
||||
let schema = values.schema();
|
||||
let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream::new(
|
||||
futures::stream::once(std::future::ready(Ok(values.clone()))),
|
||||
schema,
|
||||
));
|
||||
let db = connect("memory:///").execute().await.unwrap();
|
||||
db.create_table_streaming(name, stream)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
Reference in New Issue
Block a user