Compare commits

...

28 Commits

Author SHA1 Message Date
Lance Release
02d31ee412 Bump version: 0.25.3-beta.0 → 0.25.3-beta.1 2025-10-19 23:40:45 +00:00
github-actions[bot]
308623577d chore: update lance dependency to v0.38.3-beta.6 (#2731)
## Summary
- bump Lance dependencies across the workspace to v0.38.3-beta.6
- verified the workspace with cargo clippy --workspace --tests
--all-features -D warnings
- formatted the workspace with cargo fmt --all

## Reference
- https://github.com/lancedb/lance/releases/tag/v0.38.3-beta.6

Co-authored-by: lancedb automation <automation@lancedb.com>
2025-10-19 14:26:20 -07:00
Jack Ye
8ee3ae378f chore: use lance-namespace in lance main repo (#2729)
This fully fixes the duplicated lance version issue without the need of
a patch section in Cargo
2025-10-17 22:01:20 -07:00
github-actions[bot]
3372a2aae0 chore: update lance dependency to v0.38.3-beta.5 (#2726)
## Summary
- update Lance dependencies to v0.38.3-beta.4 via
ci/set_lance_version.py
- refresh Cargo.lock for the preview release

## Testing
- cargo clippy --workspace --tests --all-features -- -D warnings
- cargo fmt --all

Triggered by tag:
[v0.38.3-beta.4](https://github.com/lancedb/lance/releases/tag/v0.38.3-beta.4)

Co-authored-by: Jack Ye <yezhaoqin@gmail.com>
2025-10-17 15:17:16 -07:00
Weston Pace
4cfcd95320 feat: add a permutation reader that can read a permutation view (#2712)
This adds a rust permutation builder. In the next PR I will have python
bindings and integration with pytorch.
2025-10-17 05:00:23 -07:00
Xuanwo
a70ff04bc9 ci: polish prompt to make codex happy work (#2724)
Chang a bit of prompts to make codex happy.

Signed-off-by: Xuanwo <github@xuanwo.io>
2025-10-17 17:54:19 +08:00
Xuanwo
a9daa18be9 feat: using codex to auto upgrade lance (#2723)
This PR will add an action that allow codex to auto upgrade lance.

---

**This PR was primarily authored with Codex using GPT-5-Codex and then
hand-reviewed by me. I AM responsible for every change made in this PR.
I aimed to keep it aligned with our goals, though I may have missed
minor issues. Please flag anything that feels off, I'll fix it
quickly.**

Signed-off-by: Xuanwo <github@xuanwo.io>
2025-10-17 17:21:16 +08:00
Ayush Chaurasia
3f2e3986e9 feat: expand support for multivector colpali models and enchancements (#2719) 2025-10-17 14:36:32 +05:30
Rudi Floren
bf55feb9b6 feat: remove dynamodb default dependency (#2720)
`dynamodb` pulls in aws-* crates even if not used.

You can enable the `dynamodb` feature for lancedb to enable it for
lance.

Closes #2718
2025-10-16 10:54:06 -07:00
Weston Pace
8f8e06a2da feat: add output_schema method to queries (#2717)
This is a helper utility I need for some of my data loader work. It
makes it easy to see the output schema even when a `select` has been
applied.
2025-10-14 05:13:28 -07:00
Lance Release
03eab0f091 Bump version: 0.22.2 → 0.22.3-beta.0 2025-10-14 02:25:58 +00:00
Lance Release
143184c0ae Bump version: 0.25.2 → 0.25.3-beta.0 2025-10-14 02:25:16 +00:00
Jack Ye
dadb042978 feat: bump lance to 0.38.3-beta.2 and rust to 1.90.0 (#2714) 2025-10-10 14:02:41 -07:00
Weston Pace
5a19cf15a6 feat: a utility for creating "permutation views" (#2552)
I'm working on a lancedb version of pytorch data loading (and hopefully
addressing https://github.com/lancedb/lance/issues/3727).

However, rather than rely on pytorch for everything I'm moving some of
the things that pytorch does into rust. This gives us more control over
data loading (e.g. using shards or a hash-based split) and it allows
permutations to be persistent. In particular I hope to be able to:

* Create a persistent permutation
* This permutation can handle splits, filtering, shuffling, and sharding
* Create a rust data loader that can read a permutation (one or more
splits), or a subset of a permutation (for DDP)
* Create a python data loader that delegates to the rust data loader

Eventually create integrations for other data loading libraries,
including rust & node
2025-10-09 18:07:31 -07:00
Will Jones
3dcec724b7 chore: loosen pin on chrono (#2710)
Fixes #2709
2025-10-09 14:23:56 -07:00
LuQQiu
86a6bb9fcb chore: supports limit push down through MetadataEraserExec (#2679)
For limit to sucessfully push down to FilteredReadExec
https://github.com/lancedb/lance/pull/4795/
2025-10-09 09:33:38 -07:00
BubbleCal
b59d1007d3 feat(index): add IVF_RQ index type (#2687)
this expose IVF_RQ (RabitQ quantization) index type to lancedb

---------

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2025-10-09 15:46:18 +08:00
Lance Release
56a16b1728 Bump version: 0.22.2-beta.3 → 0.22.2 2025-10-08 18:13:08 +00:00
Lance Release
b7afed9beb Bump version: 0.22.2-beta.2 → 0.22.2-beta.3 2025-10-08 18:12:23 +00:00
Lance Release
5cbbaa2e4a Bump version: 0.25.2-beta.3 → 0.25.2 2025-10-08 18:11:45 +00:00
Lance Release
1b6bd2498e Bump version: 0.25.2-beta.2 → 0.25.2-beta.3 2025-10-08 18:11:45 +00:00
Jack Ye
285da9db1d feat: upgrade lance to 0.38.2 (#2705) 2025-10-08 09:59:28 -07:00
Ayush Chaurasia
ad8306c96b docs: add custom redirect for storage page (#2706)
Expand the custom redirection links list to include storage page
2025-10-08 21:35:48 +05:30
Wyatt Alt
3594538509 fix: add name to index config and fix create_index typing (#2660)
Co-authored-by: Mark McCaskey <markm@harvey.ai>
2025-10-08 04:41:30 -07:00
Tom LaMarre
917aabd077 fix(node): support specifying arrow field types by name (#2704)
The [`FieldLike` type in
arrow.ts](5ec12c9971/nodejs/lancedb/arrow.ts (L71-L78))
can have a `type: string` property, but before this change, actually
trying to create a table that has a schema that specifies field types by
name results in an error:

```
Error: Expected a Type but object was null/undefined
```

This change adds support for mapping some type name strings to arrow
`DataType`s, so that passing `FieldLike`s with a `type: string` property
to `sanitizeField` does not throw an error.

The type names that can be passed are upper/lowercase variations of the
keys of the `constructorsByTypeName` object. This does not support
mapping types that need parameters, such as timestamps which need
timezones.

With this, it is possible to create empty tables from `SchemaLike`
objects without instantiating arrow types, e.g.:

```
    import { SchemaLike } from "../lancedb/arrow"
    // ...
    const schemaLike = {
      fields: [
        {
          name: "id",
          type: "int64",
          nullable: true,
        },
        {
          name: "vector",
          type: "float64",
          nullable: true,
        },
      ],
    // ...
    } satisfies SchemaLike;
    const table = await con.createEmptyTable("test", schemaLike);
 ```

This change also makes `FieldLike.nullable` required since the `sanitizeField` function throws if it is undefined.
2025-10-08 04:40:06 -07:00
Jack Ye
5ec12c9971 fix: federated database should not pass namesapce to listing database (#2702)
Fixes error that when converting a federated database operation to a
listing database operation, the namespace parameter is no longer correct
and should be dropped.

Note that with the testing infra we have today, we don't have a good way
to test these changes. I will do a quick follow up on
https://github.com/lancedb/lancedb/issues/2701 but would be great to get
this in first to resolve the related issues.
2025-10-06 14:12:41 -07:00
Ed Rogers
d0ce489b21 fix: use stdlib override when possible (#2699)
## Description of changes

Fixes #2698  

This PR uses
[`typing.override`](https://docs.python.org/3/library/typing.html#typing.override)
in favor of the [`overrides`](https://pypi.org/project/overrides/)
dependency when possible. As of Python 3.12, the standard library offers
`typing.override` to perform a static check on overridden methods.

### Motivation

Currently, `overrides` is incompatible with Python 3.14. As a result,
any package that attempts to import `overrides` using Python 3.14+ will
raise an `AttributeError`. An
[issue](https://github.com/mkorpela/overrides/issues/127) has been
raised and a [pull
request](https://github.com/mkorpela/overrides/pull/133) has been
submitted to the GitHub repo for the `overrides` project. But the
maintainer has been unresponsive.

To ensure readiness for Python 3.14, this package (and any other package
directly depending on `overrides`) should consider using
`typing.override` instead.

### Impact

The standard library added `typing.override` as of 3.12. As a result,
this change will affect only users of Python 3.12+. Previous versions
will continue to rely on `overrides`. Notably, the standard library
implementation is slightly different than that of `overrides`. A
thorough discussion of those differences is shown in [PEP
698](https://peps.python.org/pep-0698/), and it is also summarized
nicely by the maintainer of `overrides`
[here](https://github.com/mkorpela/overrides/issues/126#issuecomment-2401327116).

There are 2 main ways that switching from `overrides` to
`typing.override` will have an impact on developers of this repo.
1. `typing.override` does not implement any runtime checking. Instead,
it provides information to type checkers.
2. The stdlib does not provide a mixin class to enforce override
decorators on child classes. (Their reasoning for this is explained in
[the PEP](https://peps.python.org/pep-0698/).) This PR disables that
behavior entirely by replacing the `EnforceOverrides`.
2025-10-06 11:23:20 -07:00
Lance Release
d7e02c8181 Bump version: 0.22.2-beta.1 → 0.22.2-beta.2 2025-10-06 18:10:40 +00:00
98 changed files with 6374 additions and 684 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.22.2-beta.1"
current_version = "0.22.3-beta.0"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -0,0 +1,106 @@
name: Codex Update Lance Dependency
on:
workflow_call:
inputs:
tag:
description: "Tag name from Lance"
required: true
type: string
workflow_dispatch:
inputs:
tag:
description: "Tag name from Lance"
required: true
type: string
permissions:
contents: write
pull-requests: write
actions: read
jobs:
update:
runs-on: ubuntu-latest
steps:
- name: Show inputs
run: |
echo "tag = ${{ inputs.tag }}"
- name: Checkout Repo LanceDB
uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: true
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: 20
- name: Install Codex CLI
run: npm install -g @openai/codex
- name: Install Rust toolchain
uses: dtolnay/rust-toolchain@stable
with:
toolchain: stable
components: clippy, rustfmt
- name: Install system dependencies
run: |
sudo apt-get update
sudo apt-get install -y protobuf-compiler libssl-dev
- name: Install cargo-info
run: cargo install cargo-info
- name: Install Python dependencies
run: python3 -m pip install --upgrade pip packaging
- name: Configure git user
run: |
git config user.name "lancedb automation"
git config user.email "automation@lancedb.com"
- name: Configure Codex authentication
env:
CODEX_TOKEN_B64: ${{ secrets.CODEX_TOKEN }}
run: |
if [ -z "${CODEX_TOKEN_B64}" ]; then
echo "Repository secret CODEX_TOKEN is not defined; skipping Codex execution."
exit 1
fi
mkdir -p ~/.codex
echo "${CODEX_TOKEN_B64}" | base64 --decode > ~/.codex/auth.json
- name: Run Codex to update Lance dependency
env:
TAG: ${{ inputs.tag }}
GITHUB_TOKEN: ${{ github.token }}
GH_TOKEN: ${{ github.token }}
run: |
set -euo pipefail
VERSION="${TAG#refs/tags/}"
VERSION="${VERSION#v}"
BRANCH_NAME="codex/update-lance-${VERSION//[^a-zA-Z0-9]/-}"
cat <<EOF >/tmp/codex-prompt.txt
You are running inside the lancedb repository on a GitHub Actions runner. Update the Lance dependency to version ${VERSION} and prepare a pull request for maintainers to review.
Follow these steps exactly:
1. Use script `ci/set_lance_version.py` to update Lance dependencies. The script already refreshes Cargo metadata, so allow it to finish even if it takes time.
2. Run "cargo clippy --workspace --tests --all-features -- -D warnings". If diagnostics appear, fix them yourself and rerun clippy until it exits cleanly. Do not skip any warnings.
3. After clippy succeeds, run "cargo fmt --all" to format the workspace.
4. Ensure the repository is clean except for intentional changes. Inspect "git status --short" and "git diff" to confirm the dependency update and any required fixes.
5. Create and switch to a new branch named "${BRANCH_NAME}" (replace any duplicated hyphens if necessary).
6. Stage all relevant files with "git add -A". Commit using the message "chore: update lance dependency to v${VERSION}".
7. Push the branch to origin. If the branch already exists, force-push your changes.
8. Create a pull request targeting "main" with title "chore: update lance dependency to v${VERSION}". In the body, summarize the dependency bump, clippy/fmt verification, and link the triggering tag (${TAG}). Use the GitHub CLI if helpful.
9. After creating the PR, display the PR URL, "git status --short", and a concise summary of the commands run and their results.
Constraints:
- Use bash commands; avoid modifying GitHub workflow files other than through the scripted task above.
- Do not merge the PR.
- If any command fails, diagnose and fix the issue instead of aborting.
EOF
codex exec --dangerously-bypass-approvals-and-sandbox "$(cat /tmp/codex-prompt.txt)"

770
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -15,30 +15,36 @@ categories = ["database-implementations"]
rust-version = "1.78.0"
[workspace.dependencies]
lance = { "version" = "=0.38.0", default-features = false, "features" = ["dynamodb"] }
lance-io = { "version" = "=0.38.0", default-features = false }
lance-index = "=0.38.0"
lance-linalg = "=0.38.0"
lance-table = "=0.38.0"
lance-testing = "=0.38.0"
lance-datafusion = "=0.38.0"
lance-encoding = "=0.38.0"
lance-namespace = "0.0.16"
lance = { "version" = "=0.38.3-beta.6", default-features = false, "tag" = "v0.38.3-beta.6", "git" = "https://github.com/lancedb/lance.git" }
lance-core = { "version" = "=0.38.3-beta.6", "tag" = "v0.38.3-beta.6", "git" = "https://github.com/lancedb/lance.git" }
lance-datagen = { "version" = "=0.38.3-beta.6", "tag" = "v0.38.3-beta.6", "git" = "https://github.com/lancedb/lance.git" }
lance-file = { "version" = "=0.38.3-beta.6", "tag" = "v0.38.3-beta.6", "git" = "https://github.com/lancedb/lance.git" }
lance-io = { "version" = "=0.38.3-beta.6", default-features = false, "tag" = "v0.38.3-beta.6", "git" = "https://github.com/lancedb/lance.git" }
lance-index = { "version" = "=0.38.3-beta.6", "tag" = "v0.38.3-beta.6", "git" = "https://github.com/lancedb/lance.git" }
lance-linalg = { "version" = "=0.38.3-beta.6", "tag" = "v0.38.3-beta.6", "git" = "https://github.com/lancedb/lance.git" }
lance-namespace = { "version" = "=0.38.3-beta.6", "tag" = "v0.38.3-beta.6", "git" = "https://github.com/lancedb/lance.git" }
lance-namespace-impls = { "version" = "=0.38.3-beta.6", "tag" = "v0.38.3-beta.6", "git" = "https://github.com/lancedb/lance.git" }
lance-table = { "version" = "=0.38.3-beta.6", "tag" = "v0.38.3-beta.6", "git" = "https://github.com/lancedb/lance.git" }
lance-testing = { "version" = "=0.38.3-beta.6", "tag" = "v0.38.3-beta.6", "git" = "https://github.com/lancedb/lance.git" }
lance-datafusion = { "version" = "=0.38.3-beta.6", "tag" = "v0.38.3-beta.6", "git" = "https://github.com/lancedb/lance.git" }
lance-encoding = { "version" = "=0.38.3-beta.6", "tag" = "v0.38.3-beta.6", "git" = "https://github.com/lancedb/lance.git" }
ahash = "0.8"
# Note that this one does not include pyarrow
arrow = { version = "55.1", optional = false }
arrow-array = "55.1"
arrow-data = "55.1"
arrow-ipc = "55.1"
arrow-ord = "55.1"
arrow-schema = "55.1"
arrow-cast = "55.1"
arrow = { version = "56.2", optional = false }
arrow-array = "56.2"
arrow-data = "56.2"
arrow-ipc = "56.2"
arrow-ord = "56.2"
arrow-schema = "56.2"
arrow-select = "56.2"
arrow-cast = "56.2"
async-trait = "0"
datafusion = { version = "49.0", default-features = false }
datafusion-catalog = "49.0"
datafusion-common = { version = "49.0", default-features = false }
datafusion-execution = "49.0"
datafusion-expr = "49.0"
datafusion-physical-plan = "49.0"
datafusion = { version = "50.1", default-features = false }
datafusion-catalog = "50.1"
datafusion-common = { version = "50.1", default-features = false }
datafusion-execution = "50.1"
datafusion-expr = "50.1"
datafusion-physical-plan = "50.1"
env_logger = "0.11"
half = { "version" = "2.6.0", default-features = false, features = [
"num-traits",
@@ -48,6 +54,7 @@ log = "0.4"
moka = { version = "0.12", features = ["future"] }
object_store = "0.12.0"
pin-project = "1.0.7"
rand = "0.9"
snafu = "0.8"
url = "2"
num-traits = "0.2"
@@ -55,20 +62,6 @@ regex = "1.10"
lazy_static = "1"
semver = "1.0.25"
crunchy = "0.2.4"
# Temporary pins to work around downstream issues
# https://github.com/apache/arrow-rs/commit/2fddf85afcd20110ce783ed5b4cdeb82293da30b
chrono = "=0.4.41"
chrono = "0.4"
# Workaround for: https://github.com/Lokathor/bytemuck/issues/306
bytemuck_derive = ">=1.8.1, <1.9.0"
# This is only needed when we reference preview releases of lance
# [patch.crates-io]
# # Force to use the same lance version as the rest of the project to avoid duplicate dependencies
# lance = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
# lance-io = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
# lance-index = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
# lance-linalg = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
# lance-table = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
# lance-testing = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
# lance-datafusion = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
# lance-encoding = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
bytemuck_derive = ">=1.8.1, <1.9.0"

View File

@@ -117,7 +117,7 @@ def update_cargo_toml(line_updater):
lance_line = ""
is_parsing_lance_line = False
for line in lines:
if line.startswith("lance") and not line.startswith("lance-namespace"):
if line.startswith("lance"):
# Check if this is a single-line or multi-line entry
# Single-line entries either:
# 1. End with } (complete inline table)
@@ -183,10 +183,8 @@ def set_preview_version(version: str):
def line_updater(line: str) -> str:
package_name = line.split("=", maxsplit=1)[0].strip()
base_version = version.split("-")[0] # Get the base version without beta suffix
# Build config in desired order: version, default-features, features, tag, git
config = {"version": f"={base_version}"}
config = {"version": f"={version}"}
if extract_default_features(line):
config["default-features"] = False

View File

@@ -84,6 +84,7 @@ plugins:
'examples.md': 'https://lancedb.com/docs/tutorials/'
'concepts/vector_search.md': 'https://lancedb.com/docs/search/vector-search/'
'troubleshooting.md': 'https://lancedb.com/docs/troubleshooting/'
'guides/storage.md': 'https://lancedb.com/docs/storage/integrations'
@@ -402,4 +403,4 @@ extra:
- icon: fontawesome/brands/x-twitter
link: https://twitter.com/lancedb
- icon: fontawesome/brands/linkedin
link: https://www.linkedin.com/company/lancedb
link: https://www.linkedin.com/company/lancedb

View File

@@ -194,6 +194,37 @@ currently is also a memory intensive operation.
***
### ivfRq()
```ts
static ivfRq(options?): Index
```
Create an IvfRq index
IVF-RQ (RabitQ Quantization) compresses vectors using RabitQ quantization
and organizes them into IVF partitions.
The compression scheme is called RabitQ quantization. Each dimension is quantized into a small number of bits.
The parameters `num_bits` and `num_partitions` control this process, providing a tradeoff
between index size (and thus search speed) and index accuracy.
The partitioning process is called IVF and the `num_partitions` parameter controls how
many groups to create.
Note that training an IVF RQ index on a large dataset is a slow operation and
currently is also a memory intensive operation.
#### Parameters
* **options?**: `Partial`&lt;[`IvfRqOptions`](../interfaces/IvfRqOptions.md)&gt;
#### Returns
[`Index`](Index.md)
***
### labelList()
```ts

View File

@@ -0,0 +1,220 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / PermutationBuilder
# Class: PermutationBuilder
A PermutationBuilder for creating data permutations with splits, shuffling, and filtering.
This class provides a TypeScript wrapper around the native Rust PermutationBuilder,
offering methods to configure data splits, shuffling, and filtering before executing
the permutation to create a new table.
## Methods
### execute()
```ts
execute(): Promise<Table>
```
Execute the permutation and create the destination table.
#### Returns
`Promise`&lt;[`Table`](Table.md)&gt;
A Promise that resolves to the new Table instance
#### Example
```ts
const permutationTable = await builder.execute();
console.log(`Created table: ${permutationTable.name}`);
```
***
### filter()
```ts
filter(filter): PermutationBuilder
```
Configure filtering for the permutation.
#### Parameters
* **filter**: `string`
SQL filter expression
#### Returns
[`PermutationBuilder`](PermutationBuilder.md)
A new PermutationBuilder instance
#### Example
```ts
builder.filter("age > 18 AND status = 'active'");
```
***
### shuffle()
```ts
shuffle(options): PermutationBuilder
```
Configure shuffling for the permutation.
#### Parameters
* **options**: [`ShuffleOptions`](../interfaces/ShuffleOptions.md)
Configuration for shuffling
#### Returns
[`PermutationBuilder`](PermutationBuilder.md)
A new PermutationBuilder instance
#### Example
```ts
// Basic shuffle
builder.shuffle({ seed: 42 });
// Shuffle with clump size
builder.shuffle({ seed: 42, clumpSize: 10 });
```
***
### splitCalculated()
```ts
splitCalculated(calculation): PermutationBuilder
```
Configure calculated splits for the permutation.
#### Parameters
* **calculation**: `string`
SQL expression for calculating splits
#### Returns
[`PermutationBuilder`](PermutationBuilder.md)
A new PermutationBuilder instance
#### Example
```ts
builder.splitCalculated("user_id % 3");
```
***
### splitHash()
```ts
splitHash(options): PermutationBuilder
```
Configure hash-based splits for the permutation.
#### Parameters
* **options**: [`SplitHashOptions`](../interfaces/SplitHashOptions.md)
Configuration for hash-based splitting
#### Returns
[`PermutationBuilder`](PermutationBuilder.md)
A new PermutationBuilder instance
#### Example
```ts
builder.splitHash({
columns: ["user_id"],
splitWeights: [70, 30],
discardWeight: 0
});
```
***
### splitRandom()
```ts
splitRandom(options): PermutationBuilder
```
Configure random splits for the permutation.
#### Parameters
* **options**: [`SplitRandomOptions`](../interfaces/SplitRandomOptions.md)
Configuration for random splitting
#### Returns
[`PermutationBuilder`](PermutationBuilder.md)
A new PermutationBuilder instance
#### Example
```ts
// Split by ratios
builder.splitRandom({ ratios: [0.7, 0.3], seed: 42 });
// Split by counts
builder.splitRandom({ counts: [1000, 500], seed: 42 });
// Split with fixed size
builder.splitRandom({ fixed: 100, seed: 42 });
```
***
### splitSequential()
```ts
splitSequential(options): PermutationBuilder
```
Configure sequential splits for the permutation.
#### Parameters
* **options**: [`SplitSequentialOptions`](../interfaces/SplitSequentialOptions.md)
Configuration for sequential splitting
#### Returns
[`PermutationBuilder`](PermutationBuilder.md)
A new PermutationBuilder instance
#### Example
```ts
// Split by ratios
builder.splitSequential({ ratios: [0.8, 0.2] });
// Split by counts
builder.splitSequential({ counts: [800, 200] });
// Split with fixed size
builder.splitSequential({ fixed: 1000 });
```

View File

@@ -343,6 +343,29 @@ This is useful for pagination.
***
### outputSchema()
```ts
outputSchema(): Promise<Schema<any>>
```
Returns the schema of the output that will be returned by this query.
This can be used to inspect the types and names of the columns that will be
returned by the query before executing it.
#### Returns
`Promise`&lt;`Schema`&lt;`any`&gt;&gt;
An Arrow Schema describing the output columns.
#### Inherited from
`StandardQueryBase.outputSchema`
***
### select()
```ts

View File

@@ -140,6 +140,25 @@ const plan = await table.query().nearestTo([0.5, 0.2]).explainPlan();
***
### outputSchema()
```ts
outputSchema(): Promise<Schema<any>>
```
Returns the schema of the output that will be returned by this query.
This can be used to inspect the types and names of the columns that will be
returned by the query before executing it.
#### Returns
`Promise`&lt;`Schema`&lt;`any`&gt;&gt;
An Arrow Schema describing the output columns.
***
### select()
```ts

View File

@@ -143,6 +143,29 @@ const plan = await table.query().nearestTo([0.5, 0.2]).explainPlan();
***
### outputSchema()
```ts
outputSchema(): Promise<Schema<any>>
```
Returns the schema of the output that will be returned by this query.
This can be used to inspect the types and names of the columns that will be
returned by the query before executing it.
#### Returns
`Promise`&lt;`Schema`&lt;`any`&gt;&gt;
An Arrow Schema describing the output columns.
#### Inherited from
[`QueryBase`](QueryBase.md).[`outputSchema`](QueryBase.md#outputschema)
***
### select()
```ts

View File

@@ -498,6 +498,29 @@ This is useful for pagination.
***
### outputSchema()
```ts
outputSchema(): Promise<Schema<any>>
```
Returns the schema of the output that will be returned by this query.
This can be used to inspect the types and names of the columns that will be
returned by the query before executing it.
#### Returns
`Promise`&lt;`Schema`&lt;`any`&gt;&gt;
An Arrow Schema describing the output columns.
#### Inherited from
`StandardQueryBase.outputSchema`
***
### postfilter()
```ts

View File

@@ -0,0 +1,34 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / permutationBuilder
# Function: permutationBuilder()
```ts
function permutationBuilder(table): PermutationBuilder
```
Create a permutation builder for the given table.
## Parameters
* **table**: [`Table`](../classes/Table.md)
The source table to create a permutation from
## Returns
[`PermutationBuilder`](../classes/PermutationBuilder.md)
A PermutationBuilder instance
## Example
```ts
const builder = permutationBuilder(sourceTable, "training_data")
.splitRandom({ ratios: [0.8, 0.2], seed: 42 })
.shuffle({ seed: 123 });
const trainingTable = await builder.execute();
```

View File

@@ -28,6 +28,7 @@
- [MultiMatchQuery](classes/MultiMatchQuery.md)
- [NativeJsHeaderProvider](classes/NativeJsHeaderProvider.md)
- [OAuthHeaderProvider](classes/OAuthHeaderProvider.md)
- [PermutationBuilder](classes/PermutationBuilder.md)
- [PhraseQuery](classes/PhraseQuery.md)
- [Query](classes/Query.md)
- [QueryBase](classes/QueryBase.md)
@@ -68,6 +69,7 @@
- [IndexStatistics](interfaces/IndexStatistics.md)
- [IvfFlatOptions](interfaces/IvfFlatOptions.md)
- [IvfPqOptions](interfaces/IvfPqOptions.md)
- [IvfRqOptions](interfaces/IvfRqOptions.md)
- [MergeResult](interfaces/MergeResult.md)
- [OpenTableOptions](interfaces/OpenTableOptions.md)
- [OptimizeOptions](interfaces/OptimizeOptions.md)
@@ -75,6 +77,10 @@
- [QueryExecutionOptions](interfaces/QueryExecutionOptions.md)
- [RemovalStats](interfaces/RemovalStats.md)
- [RetryConfig](interfaces/RetryConfig.md)
- [ShuffleOptions](interfaces/ShuffleOptions.md)
- [SplitHashOptions](interfaces/SplitHashOptions.md)
- [SplitRandomOptions](interfaces/SplitRandomOptions.md)
- [SplitSequentialOptions](interfaces/SplitSequentialOptions.md)
- [TableNamesOptions](interfaces/TableNamesOptions.md)
- [TableStatistics](interfaces/TableStatistics.md)
- [TimeoutConfig](interfaces/TimeoutConfig.md)
@@ -102,3 +108,4 @@
- [connect](functions/connect.md)
- [makeArrowTable](functions/makeArrowTable.md)
- [packBits](functions/packBits.md)
- [permutationBuilder](functions/permutationBuilder.md)

View File

@@ -0,0 +1,101 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / IvfRqOptions
# Interface: IvfRqOptions
## Properties
### distanceType?
```ts
optional distanceType: "l2" | "cosine" | "dot";
```
Distance type to use to build the index.
Default value is "l2".
This is used when training the index to calculate the IVF partitions
(vectors are grouped in partitions with similar vectors according to this
distance type) and during quantization.
The distance type used to train an index MUST match the distance type used
to search the index. Failure to do so will yield inaccurate results.
The following distance types are available:
"l2" - Euclidean distance.
"cosine" - Cosine distance.
"dot" - Dot product.
***
### maxIterations?
```ts
optional maxIterations: number;
```
Max iterations to train IVF kmeans.
When training an IVF index we use kmeans to calculate the partitions. This parameter
controls how many iterations of kmeans to run.
The default value is 50.
***
### numBits?
```ts
optional numBits: number;
```
Number of bits per dimension for residual quantization.
This value controls how much each residual component is compressed. The more
bits, the more accurate the index will be but the slower search. Typical values
are small integers; the default is 1 bit per dimension.
***
### numPartitions?
```ts
optional numPartitions: number;
```
The number of IVF partitions to create.
This value should generally scale with the number of rows in the dataset.
By default the number of partitions is the square root of the number of
rows.
If this value is too large then the first part of the search (picking the
right partition) will be slow. If this value is too small then the second
part of the search (searching within a partition) will be slow.
***
### sampleRate?
```ts
optional sampleRate: number;
```
The number of vectors, per partition, to sample when training IVF kmeans.
When an IVF index is trained, we need to calculate partitions. These are groups
of vectors that are similar to each other. To do this we use an algorithm called kmeans.
Running kmeans on a large dataset can be slow. To speed this up we run kmeans on a
random sample of the data. This parameter controls the size of the sample. The total
number of vectors used to train the index is `sample_rate * num_partitions`.
Increasing this value might improve the quality of the index but in most cases the
default should be sufficient.
The default value is 256.

View File

@@ -0,0 +1,23 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / ShuffleOptions
# Interface: ShuffleOptions
## Properties
### clumpSize?
```ts
optional clumpSize: number;
```
***
### seed?
```ts
optional seed: number;
```

View File

@@ -0,0 +1,31 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / SplitHashOptions
# Interface: SplitHashOptions
## Properties
### columns
```ts
columns: string[];
```
***
### discardWeight?
```ts
optional discardWeight: number;
```
***
### splitWeights
```ts
splitWeights: number[];
```

View File

@@ -0,0 +1,39 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / SplitRandomOptions
# Interface: SplitRandomOptions
## Properties
### counts?
```ts
optional counts: number[];
```
***
### fixed?
```ts
optional fixed: number;
```
***
### ratios?
```ts
optional ratios: number[];
```
***
### seed?
```ts
optional seed: number;
```

View File

@@ -0,0 +1,31 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / SplitSequentialOptions
# Interface: SplitSequentialOptions
## Properties
### counts?
```ts
optional counts: number[];
```
***
### fixed?
```ts
optional fixed: number;
```
***
### ratios?
```ts
optional ratios: number[];
```

View File

@@ -8,7 +8,7 @@
<parent>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.22.2-beta.1</version>
<version>0.22.3-beta.0</version>
<relativePath>../pom.xml</relativePath>
</parent>

View File

@@ -8,7 +8,7 @@
<parent>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.22.2-beta.1</version>
<version>0.22.3-beta.0</version>
<relativePath>../pom.xml</relativePath>
</parent>

View File

@@ -6,7 +6,7 @@
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.22.2-beta.1</version>
<version>0.22.3-beta.0</version>
<packaging>pom</packaging>
<name>${project.artifactId}</name>
<description>LanceDB Java SDK Parent POM</description>

View File

@@ -1,7 +1,7 @@
[package]
name = "lancedb-nodejs"
edition.workspace = true
version = "0.22.2-beta.1"
version = "0.22.3-beta.0"
license.workspace = true
description.workspace = true
repository.workspace = true

View File

@@ -0,0 +1,227 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
import * as tmp from "tmp";
import { Table, connect, permutationBuilder } from "../lancedb";
import { makeArrowTable } from "../lancedb/arrow";
describe("PermutationBuilder", () => {
let tmpDir: tmp.DirResult;
let table: Table;
beforeEach(async () => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
const db = await connect(tmpDir.name);
// Create test data
const data = makeArrowTable(
[
{ id: 1, value: 10 },
{ id: 2, value: 20 },
{ id: 3, value: 30 },
{ id: 4, value: 40 },
{ id: 5, value: 50 },
{ id: 6, value: 60 },
{ id: 7, value: 70 },
{ id: 8, value: 80 },
{ id: 9, value: 90 },
{ id: 10, value: 100 },
],
{ vectorColumns: {} },
);
table = await db.createTable("test_table", data);
});
afterEach(() => {
tmpDir.removeCallback();
});
test("should create permutation builder", () => {
const builder = permutationBuilder(table);
expect(builder).toBeDefined();
});
test("should execute basic permutation", async () => {
const builder = permutationBuilder(table);
const permutationTable = await builder.execute();
expect(permutationTable).toBeDefined();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
});
test("should create permutation with random splits", async () => {
const builder = permutationBuilder(table).splitRandom({
ratios: [1.0],
seed: 42,
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
});
test("should create permutation with percentage splits", async () => {
const builder = permutationBuilder(table).splitRandom({
ratios: [0.3, 0.7],
seed: 42,
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Check split distribution
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(10);
});
test("should create permutation with count splits", async () => {
const builder = permutationBuilder(table).splitRandom({
counts: [3, 7],
seed: 42,
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Check split distribution
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBe(3);
expect(split1Count).toBe(7);
});
test("should create permutation with hash splits", async () => {
const builder = permutationBuilder(table).splitHash({
columns: ["id"],
splitWeights: [50, 50],
discardWeight: 0,
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Check that splits exist
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(10);
});
test("should create permutation with sequential splits", async () => {
const builder = permutationBuilder(table).splitSequential({
ratios: [0.5, 0.5],
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Check split distribution - sequential should give exactly 5 and 5
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBe(5);
expect(split1Count).toBe(5);
});
test("should create permutation with calculated splits", async () => {
const builder = permutationBuilder(table).splitCalculated("id % 2");
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Check split distribution
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(10);
});
test("should create permutation with shuffle", async () => {
const builder = permutationBuilder(table).shuffle({
seed: 42,
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
});
test("should create permutation with shuffle and clump size", async () => {
const builder = permutationBuilder(table).shuffle({
seed: 42,
clumpSize: 2,
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
});
test("should create permutation with filter", async () => {
const builder = permutationBuilder(table).filter("value > 50");
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(5); // Values 60, 70, 80, 90, 100
});
test("should chain multiple operations", async () => {
const builder = permutationBuilder(table)
.filter("value <= 80")
.splitRandom({ ratios: [0.5, 0.5], seed: 42 })
.shuffle({ seed: 123 });
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(8); // Values 10, 20, 30, 40, 50, 60, 70, 80
// Check split distribution
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(8);
});
test("should throw error for invalid split arguments", () => {
const builder = permutationBuilder(table);
// Test no arguments provided
expect(() => builder.splitRandom({})).toThrow(
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
);
// Test multiple arguments provided
expect(() =>
builder.splitRandom({ ratios: [0.5, 0.5], counts: [3, 7], seed: 42 }),
).toThrow("Exactly one of 'ratios', 'counts', or 'fixed' must be provided");
});
test("should throw error when builder is consumed", async () => {
const builder = permutationBuilder(table);
// Execute once
await builder.execute();
// Should throw error on second execution
await expect(builder.execute()).rejects.toThrow("Builder already consumed");
});
});

View File

@@ -0,0 +1,111 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
import * as tmp from "tmp";
import { type Table, connect } from "../lancedb";
import {
Field,
FixedSizeList,
Float32,
Int64,
Schema,
Utf8,
makeArrowTable,
} from "../lancedb/arrow";
import { Index } from "../lancedb/indices";
describe("Query outputSchema", () => {
let tmpDir: tmp.DirResult;
let table: Table;
beforeEach(async () => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
const db = await connect(tmpDir.name);
// Create table with explicit schema to ensure proper types
const schema = new Schema([
new Field("a", new Int64(), true),
new Field("text", new Utf8(), true),
new Field(
"vec",
new FixedSizeList(2, new Field("item", new Float32())),
true,
),
]);
const data = makeArrowTable(
[
{ a: 1n, text: "foo", vec: [1, 2] },
{ a: 2n, text: "bar", vec: [3, 4] },
{ a: 3n, text: "baz", vec: [5, 6] },
],
{ schema },
);
table = await db.createTable("test", data);
});
afterEach(() => {
tmpDir.removeCallback();
});
it("should return schema for plain query", async () => {
const schema = await table.query().outputSchema();
expect(schema.fields.length).toBe(3);
expect(schema.fields.map((f) => f.name)).toEqual(["a", "text", "vec"]);
expect(schema.fields[0].type.toString()).toBe("Int64");
expect(schema.fields[1].type.toString()).toBe("Utf8");
});
it("should return schema with dynamic projection", async () => {
const schema = await table.query().select({ bl: "a * 2" }).outputSchema();
expect(schema.fields.length).toBe(1);
expect(schema.fields[0].name).toBe("bl");
expect(schema.fields[0].type.toString()).toBe("Int64");
});
it("should return schema for vector search with _distance column", async () => {
const schema = await table
.vectorSearch([1, 2])
.select(["a"])
.outputSchema();
expect(schema.fields.length).toBe(2);
expect(schema.fields.map((f) => f.name)).toEqual(["a", "_distance"]);
expect(schema.fields[0].type.toString()).toBe("Int64");
expect(schema.fields[1].type.toString()).toBe("Float32");
});
it("should return schema for FTS search", async () => {
await table.createIndex("text", { config: Index.fts() });
const schema = await table
.search("foo", "fts")
.select(["a"])
.outputSchema();
// FTS search includes _score column in addition to selected columns
expect(schema.fields.length).toBe(2);
expect(schema.fields.map((f) => f.name)).toContain("a");
expect(schema.fields.map((f) => f.name)).toContain("_score");
const aField = schema.fields.find((f) => f.name === "a");
expect(aField?.type.toString()).toBe("Int64");
});
it("should return schema for take query", async () => {
const schema = await table.takeOffsets([0]).select(["text"]).outputSchema();
expect(schema.fields.length).toBe(1);
expect(schema.fields[0].name).toBe("text");
expect(schema.fields[0].type.toString()).toBe("Utf8");
});
it("should return full schema when no select is specified", async () => {
const schema = await table.query().outputSchema();
// Should return all columns
expect(schema.fields.length).toBe(3);
});
});

View File

@@ -0,0 +1,184 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
import * as arrow from "../lancedb/arrow";
import { sanitizeField, sanitizeType } from "../lancedb/sanitize";
describe("sanitize", function () {
describe("sanitizeType function", function () {
it("should handle type objects", function () {
const type = new arrow.Int32();
const result = sanitizeType(type);
expect(result.typeId).toBe(arrow.Type.Int);
expect((result as arrow.Int).bitWidth).toBe(32);
expect((result as arrow.Int).isSigned).toBe(true);
const floatType = {
typeId: 3, // Type.Float = 3
precision: 2,
toString: () => "Float",
isFloat: true,
isFixedWidth: true,
};
const floatResult = sanitizeType(floatType);
expect(floatResult).toBeInstanceOf(arrow.DataType);
expect(floatResult.typeId).toBe(arrow.Type.Float);
const floatResult2 = sanitizeType({ ...floatType, typeId: () => 3 });
expect(floatResult2).toBeInstanceOf(arrow.DataType);
expect(floatResult2.typeId).toBe(arrow.Type.Float);
});
const allTypeNameTestCases = [
["null", new arrow.Null()],
["binary", new arrow.Binary()],
["utf8", new arrow.Utf8()],
["bool", new arrow.Bool()],
["int8", new arrow.Int8()],
["int16", new arrow.Int16()],
["int32", new arrow.Int32()],
["int64", new arrow.Int64()],
["uint8", new arrow.Uint8()],
["uint16", new arrow.Uint16()],
["uint32", new arrow.Uint32()],
["uint64", new arrow.Uint64()],
["float16", new arrow.Float16()],
["float32", new arrow.Float32()],
["float64", new arrow.Float64()],
["datemillisecond", new arrow.DateMillisecond()],
["dateday", new arrow.DateDay()],
["timenanosecond", new arrow.TimeNanosecond()],
["timemicrosecond", new arrow.TimeMicrosecond()],
["timemillisecond", new arrow.TimeMillisecond()],
["timesecond", new arrow.TimeSecond()],
["intervaldaytime", new arrow.IntervalDayTime()],
["intervalyearmonth", new arrow.IntervalYearMonth()],
["durationnanosecond", new arrow.DurationNanosecond()],
["durationmicrosecond", new arrow.DurationMicrosecond()],
["durationmillisecond", new arrow.DurationMillisecond()],
["durationsecond", new arrow.DurationSecond()],
] as const;
it.each(allTypeNameTestCases)(
'should map type name "%s" to %s',
function (name, expected) {
const result = sanitizeType(name);
expect(result).toBeInstanceOf(expected.constructor);
},
);
const caseVariationTestCases = [
["NULL", new arrow.Null()],
["Utf8", new arrow.Utf8()],
["FLOAT32", new arrow.Float32()],
["DaTedAy", new arrow.DateDay()],
] as const;
it.each(caseVariationTestCases)(
'should be case insensitive for type name "%s" mapped to %s',
function (name, expected) {
const result = sanitizeType(name);
expect(result).toBeInstanceOf(expected.constructor);
},
);
it("should throw error for unrecognized type name", function () {
expect(() => sanitizeType("invalid_type")).toThrow(
"Unrecognized type name in schema: invalid_type",
);
});
});
describe("sanitizeField function", function () {
it("should handle field with string type name", function () {
const field = sanitizeField({
name: "string_field",
type: "utf8",
nullable: true,
metadata: new Map([["key", "value"]]),
});
expect(field).toBeInstanceOf(arrow.Field);
expect(field.name).toBe("string_field");
expect(field.type).toBeInstanceOf(arrow.Utf8);
expect(field.nullable).toBe(true);
expect(field.metadata?.get("key")).toBe("value");
});
it("should handle field with type object", function () {
const floatType = {
typeId: 3, // Float
precision: 32,
};
const field = sanitizeField({
name: "float_field",
type: floatType,
nullable: false,
});
expect(field).toBeInstanceOf(arrow.Field);
expect(field.name).toBe("float_field");
expect(field.type).toBeInstanceOf(arrow.DataType);
expect(field.type.typeId).toBe(arrow.Type.Float);
expect((field.type as arrow.Float64).precision).toBe(32);
expect(field.nullable).toBe(false);
});
it("should handle field with direct Type instance", function () {
const field = sanitizeField({
name: "bool_field",
type: new arrow.Bool(),
nullable: true,
});
expect(field).toBeInstanceOf(arrow.Field);
expect(field.name).toBe("bool_field");
expect(field.type).toBeInstanceOf(arrow.Bool);
expect(field.nullable).toBe(true);
});
it("should throw error for invalid field object", function () {
expect(() =>
sanitizeField({
type: "int32",
nullable: true,
}),
).toThrow(
"The field passed in is missing a `type`/`name`/`nullable` property",
);
// Invalid type
expect(() =>
sanitizeField({
name: "invalid",
type: { invalid: true },
nullable: true,
}),
).toThrow("Expected a Type to have a typeId property");
// Invalid nullable
expect(() =>
sanitizeField({
name: "invalid_nullable",
type: "int32",
nullable: "not a boolean",
}),
).toThrow("The field passed in had a non-boolean `nullable` property");
});
it("should report error for invalid type name", function () {
expect(() =>
sanitizeField({
name: "invalid_field",
type: "invalid_type",
nullable: true,
}),
).toThrow(
"Unable to sanitize type for field: invalid_field due to error: Error: Unrecognized type name in schema: invalid_type",
);
});
});
});

View File

@@ -10,7 +10,13 @@ import * as arrow16 from "apache-arrow-16";
import * as arrow17 from "apache-arrow-17";
import * as arrow18 from "apache-arrow-18";
import { MatchQuery, PhraseQuery, Table, connect } from "../lancedb";
import {
Connection,
MatchQuery,
PhraseQuery,
Table,
connect,
} from "../lancedb";
import {
Table as ArrowTable,
Field,
@@ -21,6 +27,8 @@ import {
Int64,
List,
Schema,
SchemaLike,
Type,
Uint8,
Utf8,
makeArrowTable,
@@ -853,6 +861,15 @@ describe("When creating an index", () => {
});
});
it("should be able to create IVF_RQ", async () => {
await tbl.createIndex("vec", {
config: Index.ivfRq({
numPartitions: 10,
numBits: 1,
}),
});
});
it("should allow me to replace (or not) an existing index", async () => {
await tbl.createIndex("id");
// Default is replace=true
@@ -2019,3 +2036,52 @@ describe("column name options", () => {
expect(results2.length).toBe(10);
});
});
describe("when creating an empty table", () => {
let con: Connection;
beforeEach(async () => {
const tmpDir = tmp.dirSync({ unsafeCleanup: true });
con = await connect(tmpDir.name);
});
afterEach(() => {
con.close();
});
it("can create an empty table from an arrow Schema", async () => {
const schema = new Schema([
new Field("id", new Int64()),
new Field("vector", new Float64()),
]);
const table = await con.createEmptyTable("test", schema);
const actualSchema = await table.schema();
expect(actualSchema.fields[0].type.typeId).toBe(Type.Int);
expect((actualSchema.fields[0].type as Int64).bitWidth).toBe(64);
expect(actualSchema.fields[1].type.typeId).toBe(Type.Float);
expect((actualSchema.fields[1].type as Float64).precision).toBe(2);
});
it("can create an empty table from schema that specifies field types by name", async () => {
const schemaLike = {
fields: [
{
name: "id",
type: "int64",
nullable: true,
},
{
name: "vector",
type: "float64",
nullable: true,
},
],
metadata: new Map(),
names: ["id", "vector"],
} satisfies SchemaLike;
const table = await con.createEmptyTable("test", schemaLike);
const actualSchema = await table.schema();
expect(actualSchema.fields[0].type.typeId).toBe(Type.Int);
expect((actualSchema.fields[0].type as Int64).bitWidth).toBe(64);
expect(actualSchema.fields[1].type.typeId).toBe(Type.Float);
expect((actualSchema.fields[1].type as Float64).precision).toBe(2);
});
});

View File

@@ -73,7 +73,7 @@ export type FieldLike =
| {
type: string;
name: string;
nullable?: boolean;
nullable: boolean;
metadata?: Map<string, string>;
};

View File

@@ -43,6 +43,10 @@ export {
DeleteResult,
DropColumnsResult,
UpdateResult,
SplitRandomOptions,
SplitHashOptions,
SplitSequentialOptions,
ShuffleOptions,
} from "./native.js";
export {
@@ -85,6 +89,7 @@ export {
Index,
IndexOptions,
IvfPqOptions,
IvfRqOptions,
IvfFlatOptions,
HnswPqOptions,
HnswSqOptions,
@@ -110,6 +115,7 @@ export {
export { MergeInsertBuilder, WriteExecutionOptions } from "./merge";
export * as embedding from "./embedding";
export { permutationBuilder, PermutationBuilder } from "./permutation";
export * as rerankers from "./rerankers";
export {
SchemaLike,

View File

@@ -112,6 +112,77 @@ export interface IvfPqOptions {
sampleRate?: number;
}
export interface IvfRqOptions {
/**
* The number of IVF partitions to create.
*
* This value should generally scale with the number of rows in the dataset.
* By default the number of partitions is the square root of the number of
* rows.
*
* If this value is too large then the first part of the search (picking the
* right partition) will be slow. If this value is too small then the second
* part of the search (searching within a partition) will be slow.
*/
numPartitions?: number;
/**
* Number of bits per dimension for residual quantization.
*
* This value controls how much each residual component is compressed. The more
* bits, the more accurate the index will be but the slower search. Typical values
* are small integers; the default is 1 bit per dimension.
*/
numBits?: number;
/**
* Distance type to use to build the index.
*
* Default value is "l2".
*
* This is used when training the index to calculate the IVF partitions
* (vectors are grouped in partitions with similar vectors according to this
* distance type) and during quantization.
*
* The distance type used to train an index MUST match the distance type used
* to search the index. Failure to do so will yield inaccurate results.
*
* The following distance types are available:
*
* "l2" - Euclidean distance.
* "cosine" - Cosine distance.
* "dot" - Dot product.
*/
distanceType?: "l2" | "cosine" | "dot";
/**
* Max iterations to train IVF kmeans.
*
* When training an IVF index we use kmeans to calculate the partitions. This parameter
* controls how many iterations of kmeans to run.
*
* The default value is 50.
*/
maxIterations?: number;
/**
* The number of vectors, per partition, to sample when training IVF kmeans.
*
* When an IVF index is trained, we need to calculate partitions. These are groups
* of vectors that are similar to each other. To do this we use an algorithm called kmeans.
*
* Running kmeans on a large dataset can be slow. To speed this up we run kmeans on a
* random sample of the data. This parameter controls the size of the sample. The total
* number of vectors used to train the index is `sample_rate * num_partitions`.
*
* Increasing this value might improve the quality of the index but in most cases the
* default should be sufficient.
*
* The default value is 256.
*/
sampleRate?: number;
}
/**
* Options to create an `HNSW_PQ` index
*/
@@ -523,6 +594,35 @@ export class Index {
options?.distanceType,
options?.numPartitions,
options?.numSubVectors,
options?.numBits,
options?.maxIterations,
options?.sampleRate,
),
);
}
/**
* Create an IvfRq index
*
* IVF-RQ (RabitQ Quantization) compresses vectors using RabitQ quantization
* and organizes them into IVF partitions.
*
* The compression scheme is called RabitQ quantization. Each dimension is quantized into a small number of bits.
* The parameters `num_bits` and `num_partitions` control this process, providing a tradeoff
* between index size (and thus search speed) and index accuracy.
*
* The partitioning process is called IVF and the `num_partitions` parameter controls how
* many groups to create.
*
* Note that training an IVF RQ index on a large dataset is a slow operation and
* currently is also a memory intensive operation.
*/
static ivfRq(options?: Partial<IvfRqOptions>) {
return new Index(
LanceDbIndex.ivfRq(
options?.distanceType,
options?.numPartitions,
options?.numBits,
options?.maxIterations,
options?.sampleRate,
),

View File

@@ -0,0 +1,183 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
import {
PermutationBuilder as NativePermutationBuilder,
Table as NativeTable,
ShuffleOptions,
SplitHashOptions,
SplitRandomOptions,
SplitSequentialOptions,
permutationBuilder as nativePermutationBuilder,
} from "./native.js";
import { LocalTable, Table } from "./table";
/**
* A PermutationBuilder for creating data permutations with splits, shuffling, and filtering.
*
* This class provides a TypeScript wrapper around the native Rust PermutationBuilder,
* offering methods to configure data splits, shuffling, and filtering before executing
* the permutation to create a new table.
*/
export class PermutationBuilder {
private inner: NativePermutationBuilder;
/**
* @hidden
*/
constructor(inner: NativePermutationBuilder) {
this.inner = inner;
}
/**
* Configure random splits for the permutation.
*
* @param options - Configuration for random splitting
* @returns A new PermutationBuilder instance
* @example
* ```ts
* // Split by ratios
* builder.splitRandom({ ratios: [0.7, 0.3], seed: 42 });
*
* // Split by counts
* builder.splitRandom({ counts: [1000, 500], seed: 42 });
*
* // Split with fixed size
* builder.splitRandom({ fixed: 100, seed: 42 });
* ```
*/
splitRandom(options: SplitRandomOptions): PermutationBuilder {
const newInner = this.inner.splitRandom(options);
return new PermutationBuilder(newInner);
}
/**
* Configure hash-based splits for the permutation.
*
* @param options - Configuration for hash-based splitting
* @returns A new PermutationBuilder instance
* @example
* ```ts
* builder.splitHash({
* columns: ["user_id"],
* splitWeights: [70, 30],
* discardWeight: 0
* });
* ```
*/
splitHash(options: SplitHashOptions): PermutationBuilder {
const newInner = this.inner.splitHash(options);
return new PermutationBuilder(newInner);
}
/**
* Configure sequential splits for the permutation.
*
* @param options - Configuration for sequential splitting
* @returns A new PermutationBuilder instance
* @example
* ```ts
* // Split by ratios
* builder.splitSequential({ ratios: [0.8, 0.2] });
*
* // Split by counts
* builder.splitSequential({ counts: [800, 200] });
*
* // Split with fixed size
* builder.splitSequential({ fixed: 1000 });
* ```
*/
splitSequential(options: SplitSequentialOptions): PermutationBuilder {
const newInner = this.inner.splitSequential(options);
return new PermutationBuilder(newInner);
}
/**
* Configure calculated splits for the permutation.
*
* @param calculation - SQL expression for calculating splits
* @returns A new PermutationBuilder instance
* @example
* ```ts
* builder.splitCalculated("user_id % 3");
* ```
*/
splitCalculated(calculation: string): PermutationBuilder {
const newInner = this.inner.splitCalculated(calculation);
return new PermutationBuilder(newInner);
}
/**
* Configure shuffling for the permutation.
*
* @param options - Configuration for shuffling
* @returns A new PermutationBuilder instance
* @example
* ```ts
* // Basic shuffle
* builder.shuffle({ seed: 42 });
*
* // Shuffle with clump size
* builder.shuffle({ seed: 42, clumpSize: 10 });
* ```
*/
shuffle(options: ShuffleOptions): PermutationBuilder {
const newInner = this.inner.shuffle(options);
return new PermutationBuilder(newInner);
}
/**
* Configure filtering for the permutation.
*
* @param filter - SQL filter expression
* @returns A new PermutationBuilder instance
* @example
* ```ts
* builder.filter("age > 18 AND status = 'active'");
* ```
*/
filter(filter: string): PermutationBuilder {
const newInner = this.inner.filter(filter);
return new PermutationBuilder(newInner);
}
/**
* Execute the permutation and create the destination table.
*
* @returns A Promise that resolves to the new Table instance
* @example
* ```ts
* const permutationTable = await builder.execute();
* console.log(`Created table: ${permutationTable.name}`);
* ```
*/
async execute(): Promise<Table> {
const nativeTable: NativeTable = await this.inner.execute();
return new LocalTable(nativeTable);
}
}
/**
* Create a permutation builder for the given table.
*
* @param table - The source table to create a permutation from
* @returns A PermutationBuilder instance
* @example
* ```ts
* const builder = permutationBuilder(sourceTable, "training_data")
* .splitRandom({ ratios: [0.8, 0.2], seed: 42 })
* .shuffle({ seed: 123 });
*
* const trainingTable = await builder.execute();
* ```
*/
export function permutationBuilder(table: Table): PermutationBuilder {
// Extract the inner native table from the TypeScript wrapper
const localTable = table as LocalTable;
// Access inner through type assertion since it's private
const nativeBuilder = nativePermutationBuilder(
// biome-ignore lint/suspicious/noExplicitAny: need access to private variable
(localTable as any).inner,
);
return new PermutationBuilder(nativeBuilder);
}

View File

@@ -326,6 +326,25 @@ export class QueryBase<
return this.inner.analyzePlan();
}
}
/**
* Returns the schema of the output that will be returned by this query.
*
* This can be used to inspect the types and names of the columns that will be
* returned by the query before executing it.
*
* @returns An Arrow Schema describing the output columns.
*/
async outputSchema(): Promise<import("./arrow").Schema> {
let schemaBuffer: Buffer;
if (this.inner instanceof Promise) {
schemaBuffer = await this.inner.then((inner) => inner.outputSchema());
} else {
schemaBuffer = await this.inner.outputSchema();
}
const schema = tableFromIPC(schemaBuffer).schema;
return schema;
}
}
export class StandardQueryBase<

View File

@@ -326,6 +326,9 @@ export function sanitizeDictionary(typeLike: object) {
// biome-ignore lint/suspicious/noExplicitAny: skip
export function sanitizeType(typeLike: unknown): DataType<any> {
if (typeof typeLike === "string") {
return dataTypeFromName(typeLike);
}
if (typeof typeLike !== "object" || typeLike === null) {
throw Error("Expected a Type but object was null/undefined");
}
@@ -447,7 +450,7 @@ export function sanitizeType(typeLike: unknown): DataType<any> {
case Type.DurationSecond:
return new DurationSecond();
default:
throw new Error("Unrecoginized type id in schema: " + typeId);
throw new Error("Unrecognized type id in schema: " + typeId);
}
}
@@ -467,7 +470,15 @@ export function sanitizeField(fieldLike: unknown): Field {
"The field passed in is missing a `type`/`name`/`nullable` property",
);
}
const type = sanitizeType(fieldLike.type);
let type: DataType;
try {
type = sanitizeType(fieldLike.type);
} catch (error: unknown) {
throw Error(
`Unable to sanitize type for field: ${fieldLike.name} due to error: ${error}`,
{ cause: error },
);
}
const name = fieldLike.name;
if (!(typeof name === "string")) {
throw Error("The field passed in had a non-string `name` property");
@@ -581,3 +592,46 @@ function sanitizeData(
},
);
}
const constructorsByTypeName = {
null: () => new Null(),
binary: () => new Binary(),
utf8: () => new Utf8(),
bool: () => new Bool(),
int8: () => new Int8(),
int16: () => new Int16(),
int32: () => new Int32(),
int64: () => new Int64(),
uint8: () => new Uint8(),
uint16: () => new Uint16(),
uint32: () => new Uint32(),
uint64: () => new Uint64(),
float16: () => new Float16(),
float32: () => new Float32(),
float64: () => new Float64(),
datemillisecond: () => new DateMillisecond(),
dateday: () => new DateDay(),
timenanosecond: () => new TimeNanosecond(),
timemicrosecond: () => new TimeMicrosecond(),
timemillisecond: () => new TimeMillisecond(),
timesecond: () => new TimeSecond(),
intervaldaytime: () => new IntervalDayTime(),
intervalyearmonth: () => new IntervalYearMonth(),
durationnanosecond: () => new DurationNanosecond(),
durationmicrosecond: () => new DurationMicrosecond(),
durationmillisecond: () => new DurationMillisecond(),
durationsecond: () => new DurationSecond(),
} as const;
type MappableTypeName = keyof typeof constructorsByTypeName;
export function dataTypeFromName(typeName: string): DataType {
const normalizedTypeName = typeName.toLowerCase() as MappableTypeName;
const _constructor = constructorsByTypeName[normalizedTypeName];
if (!_constructor) {
throw new Error("Unrecognized type name in schema: " + typeName);
}
return _constructor();
}

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-arm64",
"version": "0.22.2-beta.1",
"version": "0.22.3-beta.0",
"os": ["darwin"],
"cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-x64",
"version": "0.22.2-beta.1",
"version": "0.22.3-beta.0",
"os": ["darwin"],
"cpu": ["x64"],
"main": "lancedb.darwin-x64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.22.2-beta.1",
"version": "0.22.3-beta.0",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-musl",
"version": "0.22.2-beta.1",
"version": "0.22.3-beta.0",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.22.2-beta.1",
"version": "0.22.3-beta.0",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-musl",
"version": "0.22.2-beta.1",
"version": "0.22.3-beta.0",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-arm64-msvc",
"version": "0.22.2-beta.1",
"version": "0.22.3-beta.0",
"os": [
"win32"
],

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.22.2-beta.1",
"version": "0.22.3-beta.0",
"os": ["win32"],
"cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node",

View File

@@ -1,12 +1,12 @@
{
"name": "@lancedb/lancedb",
"version": "0.22.2-beta.1",
"version": "0.22.3-beta.0",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "@lancedb/lancedb",
"version": "0.22.2-beta.1",
"version": "0.22.3-beta.0",
"cpu": [
"x64",
"arm64"

View File

@@ -11,7 +11,7 @@
"ann"
],
"private": false,
"version": "0.22.2-beta.1",
"version": "0.22.3-beta.0",
"main": "dist/index.js",
"exports": {
".": "./dist/index.js",

View File

@@ -6,6 +6,7 @@ use std::sync::Mutex;
use lancedb::index::scalar::{BTreeIndexBuilder, FtsIndexBuilder};
use lancedb::index::vector::{
IvfFlatIndexBuilder, IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder,
IvfRqIndexBuilder,
};
use lancedb::index::Index as LanceDbIndex;
use napi_derive::napi;
@@ -65,6 +66,36 @@ impl Index {
})
}
#[napi(factory)]
pub fn ivf_rq(
distance_type: Option<String>,
num_partitions: Option<u32>,
num_bits: Option<u32>,
max_iterations: Option<u32>,
sample_rate: Option<u32>,
) -> napi::Result<Self> {
let mut ivf_rq_builder = IvfRqIndexBuilder::default();
if let Some(distance_type) = distance_type {
let distance_type = parse_distance_type(distance_type)?;
ivf_rq_builder = ivf_rq_builder.distance_type(distance_type);
}
if let Some(num_partitions) = num_partitions {
ivf_rq_builder = ivf_rq_builder.num_partitions(num_partitions);
}
if let Some(num_bits) = num_bits {
ivf_rq_builder = ivf_rq_builder.num_bits(num_bits);
}
if let Some(max_iterations) = max_iterations {
ivf_rq_builder = ivf_rq_builder.max_iterations(max_iterations);
}
if let Some(sample_rate) = sample_rate {
ivf_rq_builder = ivf_rq_builder.sample_rate(sample_rate);
}
Ok(Self {
inner: Mutex::new(Some(LanceDbIndex::IvfRq(ivf_rq_builder))),
})
}
#[napi(factory)]
pub fn ivf_flat(
distance_type: Option<String>,

View File

@@ -12,6 +12,7 @@ mod header;
mod index;
mod iterator;
pub mod merge;
pub mod permutation;
mod query;
pub mod remote;
mod rerankers;

214
nodejs/src/permutation.rs Normal file
View File

@@ -0,0 +1,214 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::{Arc, Mutex};
use crate::{error::NapiErrorExt, table::Table};
use lancedb::dataloader::{
permutation::builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
permutation::split::{SplitSizes, SplitStrategy},
};
use napi_derive::napi;
#[napi(object)]
pub struct SplitRandomOptions {
pub ratios: Option<Vec<f64>>,
pub counts: Option<Vec<i64>>,
pub fixed: Option<i64>,
pub seed: Option<i64>,
}
#[napi(object)]
pub struct SplitHashOptions {
pub columns: Vec<String>,
pub split_weights: Vec<i64>,
pub discard_weight: Option<i64>,
}
#[napi(object)]
pub struct SplitSequentialOptions {
pub ratios: Option<Vec<f64>>,
pub counts: Option<Vec<i64>>,
pub fixed: Option<i64>,
}
#[napi(object)]
pub struct ShuffleOptions {
pub seed: Option<i64>,
pub clump_size: Option<i64>,
}
pub struct PermutationBuilderState {
pub builder: Option<LancePermutationBuilder>,
}
#[napi]
pub struct PermutationBuilder {
state: Arc<Mutex<PermutationBuilderState>>,
}
impl PermutationBuilder {
pub fn new(builder: LancePermutationBuilder) -> Self {
Self {
state: Arc::new(Mutex::new(PermutationBuilderState {
builder: Some(builder),
})),
}
}
}
impl PermutationBuilder {
fn modify(
&self,
func: impl FnOnce(LancePermutationBuilder) -> LancePermutationBuilder,
) -> napi::Result<Self> {
let mut state = self.state.lock().unwrap();
let builder = state
.builder
.take()
.ok_or_else(|| napi::Error::from_reason("Builder already consumed"))?;
state.builder = Some(func(builder));
Ok(Self {
state: self.state.clone(),
})
}
}
#[napi]
impl PermutationBuilder {
/// Configure random splits
#[napi]
pub fn split_random(&self, options: SplitRandomOptions) -> napi::Result<Self> {
// Check that exactly one split type is provided
let split_args_count = [
options.ratios.is_some(),
options.counts.is_some(),
options.fixed.is_some(),
]
.iter()
.filter(|&&x| x)
.count();
if split_args_count != 1 {
return Err(napi::Error::from_reason(
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
));
}
let sizes = if let Some(ratios) = options.ratios {
SplitSizes::Percentages(ratios)
} else if let Some(counts) = options.counts {
SplitSizes::Counts(counts.into_iter().map(|c| c as u64).collect())
} else if let Some(fixed) = options.fixed {
SplitSizes::Fixed(fixed as u64)
} else {
unreachable!("One of the split arguments must be provided");
};
let seed = options.seed.map(|s| s as u64);
self.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes }))
}
/// Configure hash-based splits
#[napi]
pub fn split_hash(&self, options: SplitHashOptions) -> napi::Result<Self> {
let split_weights = options
.split_weights
.into_iter()
.map(|w| w as u64)
.collect();
let discard_weight = options.discard_weight.unwrap_or(0) as u64;
self.modify(|builder| {
builder.with_split_strategy(SplitStrategy::Hash {
columns: options.columns,
split_weights,
discard_weight,
})
})
}
/// Configure sequential splits
#[napi]
pub fn split_sequential(&self, options: SplitSequentialOptions) -> napi::Result<Self> {
// Check that exactly one split type is provided
let split_args_count = [
options.ratios.is_some(),
options.counts.is_some(),
options.fixed.is_some(),
]
.iter()
.filter(|&&x| x)
.count();
if split_args_count != 1 {
return Err(napi::Error::from_reason(
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
));
}
let sizes = if let Some(ratios) = options.ratios {
SplitSizes::Percentages(ratios)
} else if let Some(counts) = options.counts {
SplitSizes::Counts(counts.into_iter().map(|c| c as u64).collect())
} else if let Some(fixed) = options.fixed {
SplitSizes::Fixed(fixed as u64)
} else {
unreachable!("One of the split arguments must be provided");
};
self.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes }))
}
/// Configure calculated splits
#[napi]
pub fn split_calculated(&self, calculation: String) -> napi::Result<Self> {
self.modify(|builder| {
builder.with_split_strategy(SplitStrategy::Calculated { calculation })
})
}
/// Configure shuffling
#[napi]
pub fn shuffle(&self, options: ShuffleOptions) -> napi::Result<Self> {
let seed = options.seed.map(|s| s as u64);
let clump_size = options.clump_size.map(|c| c as u64);
self.modify(|builder| {
builder.with_shuffle_strategy(ShuffleStrategy::Random { seed, clump_size })
})
}
/// Configure filtering
#[napi]
pub fn filter(&self, filter: String) -> napi::Result<Self> {
self.modify(|builder| builder.with_filter(filter))
}
/// Execute the permutation builder and create the table
#[napi]
pub async fn execute(&self) -> napi::Result<Table> {
let builder = {
let mut state = self.state.lock().unwrap();
state
.builder
.take()
.ok_or_else(|| napi::Error::from_reason("Builder already consumed"))?
};
let table = builder.build().await.default_error()?;
Ok(Table::new(table))
}
}
/// Create a permutation builder for the given table
#[napi]
pub fn permutation_builder(table: &crate::table::Table) -> napi::Result<PermutationBuilder> {
use lancedb::dataloader::permutation::builder::PermutationBuilder as LancePermutationBuilder;
let inner_table = table.inner_ref()?.clone();
let inner_builder = LancePermutationBuilder::new(inner_table);
Ok(PermutationBuilder::new(inner_builder))
}

View File

@@ -22,7 +22,7 @@ use crate::error::NapiErrorExt;
use crate::iterator::RecordBatchIterator;
use crate::rerankers::Reranker;
use crate::rerankers::RerankerCallbacks;
use crate::util::parse_distance_type;
use crate::util::{parse_distance_type, schema_to_buffer};
#[napi]
pub struct Query {
@@ -88,6 +88,12 @@ impl Query {
self.inner = self.inner.clone().with_row_id();
}
#[napi(catch_unwind)]
pub async fn output_schema(&self) -> napi::Result<Buffer> {
let schema = self.inner.output_schema().await.default_error()?;
schema_to_buffer(&schema)
}
#[napi(catch_unwind)]
pub async fn execute(
&self,
@@ -273,6 +279,12 @@ impl VectorQuery {
.rerank(Arc::new(Reranker::new(callbacks)));
}
#[napi(catch_unwind)]
pub async fn output_schema(&self) -> napi::Result<Buffer> {
let schema = self.inner.output_schema().await.default_error()?;
schema_to_buffer(&schema)
}
#[napi(catch_unwind)]
pub async fn execute(
&self,
@@ -346,6 +358,12 @@ impl TakeQuery {
self.inner = self.inner.clone().with_row_id();
}
#[napi(catch_unwind)]
pub async fn output_schema(&self) -> napi::Result<Buffer> {
let schema = self.inner.output_schema().await.default_error()?;
schema_to_buffer(&schema)
}
#[napi(catch_unwind)]
pub async fn execute(
&self,

View File

@@ -3,7 +3,6 @@
use std::collections::HashMap;
use arrow_ipc::writer::FileWriter;
use lancedb::ipc::ipc_file_to_batches;
use lancedb::table::{
AddDataMode, ColumnAlteration as LanceColumnAlteration, Duration, NewColumnTransform,
@@ -16,6 +15,7 @@ use crate::error::NapiErrorExt;
use crate::index::Index;
use crate::merge::NativeMergeInsertBuilder;
use crate::query::{Query, TakeQuery, VectorQuery};
use crate::util::schema_to_buffer;
#[napi]
pub struct Table {
@@ -26,7 +26,7 @@ pub struct Table {
}
impl Table {
fn inner_ref(&self) -> napi::Result<&LanceDbTable> {
pub(crate) fn inner_ref(&self) -> napi::Result<&LanceDbTable> {
self.inner
.as_ref()
.ok_or_else(|| napi::Error::from_reason(format!("Table {} is closed", self.name)))
@@ -64,14 +64,7 @@ impl Table {
#[napi(catch_unwind)]
pub async fn schema(&self) -> napi::Result<Buffer> {
let schema = self.inner_ref()?.schema().await.default_error()?;
let mut writer = FileWriter::try_new(vec![], &schema)
.map_err(|e| napi::Error::from_reason(format!("Failed to create IPC file: {}", e)))?;
writer
.finish()
.map_err(|e| napi::Error::from_reason(format!("Failed to finish IPC file: {}", e)))?;
Ok(Buffer::from(writer.into_inner().map_err(|e| {
napi::Error::from_reason(format!("Failed to get IPC file: {}", e))
})?))
schema_to_buffer(&schema)
}
#[napi(catch_unwind)]

View File

@@ -1,7 +1,10 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use arrow_ipc::writer::FileWriter;
use arrow_schema::Schema;
use lancedb::DistanceType;
use napi::bindgen_prelude::Buffer;
pub fn parse_distance_type(distance_type: impl AsRef<str>) -> napi::Result<DistanceType> {
match distance_type.as_ref().to_lowercase().as_str() {
@@ -15,3 +18,15 @@ pub fn parse_distance_type(distance_type: impl AsRef<str>) -> napi::Result<Dista
))),
}
}
/// Convert an Arrow Schema to an Arrow IPC file buffer
pub fn schema_to_buffer(schema: &Schema) -> napi::Result<Buffer> {
let mut writer = FileWriter::try_new(vec![], schema)
.map_err(|e| napi::Error::from_reason(format!("Failed to create IPC file: {}", e)))?;
writer
.finish()
.map_err(|e| napi::Error::from_reason(format!("Failed to finish IPC file: {}", e)))?;
Ok(Buffer::from(writer.into_inner().map_err(|e| {
napi::Error::from_reason(format!("Failed to get IPC file: {}", e))
})?))
}

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.25.2-beta.2"
current_version = "0.25.3-beta.1"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.
@@ -24,6 +24,19 @@ commit = true
message = "Bump version: {current_version} → {new_version}"
commit_args = ""
# Update Cargo.lock after version bump
pre_commit_hooks = [
"""
cd python && cargo update -p lancedb-python
if git diff --quiet ../Cargo.lock; then
echo "Cargo.lock unchanged"
else
git add ../Cargo.lock
echo "Updated and staged Cargo.lock"
fi
""",
]
[tool.bumpversion.parts.pre_l]
values = ["beta", "final"]
optional_value = "final"

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-python"
version = "0.25.2-beta.2"
version = "0.25.3-beta.1"
edition.workspace = true
description = "Python bindings for LanceDB"
license.workspace = true
@@ -14,12 +14,12 @@ name = "_lancedb"
crate-type = ["cdylib"]
[dependencies]
arrow = { version = "55.1", features = ["pyarrow"] }
arrow = { version = "56.2", features = ["pyarrow"] }
async-trait = "0.1"
lancedb = { path = "../rust/lancedb", default-features = false }
env_logger.workspace = true
pyo3 = { version = "0.24", features = ["extension-module", "abi3-py39"] }
pyo3-async-runtimes = { version = "0.24", features = [
pyo3 = { version = "0.25", features = ["extension-module", "abi3-py39"] }
pyo3-async-runtimes = { version = "0.25", features = [
"attributes",
"tokio-runtime",
] }
@@ -28,7 +28,7 @@ futures.workspace = true
tokio = { version = "1.40", features = ["sync"] }
[build-dependencies]
pyo3-build-config = { version = "0.24", features = [
pyo3-build-config = { version = "0.25", features = [
"extension-module",
"abi3-py39",
] }

View File

@@ -5,7 +5,7 @@ dynamic = ["version"]
dependencies = [
"deprecation",
"numpy",
"overrides>=0.7",
"overrides>=0.7; python_version<'3.12'",
"packaging",
"pyarrow>=16",
"pydantic>=1.10",

View File

@@ -123,6 +123,8 @@ class Table:
@property
def tags(self) -> Tags: ...
def query(self) -> Query: ...
def take_offsets(self, offsets: list[int]) -> TakeQuery: ...
def take_row_ids(self, row_ids: list[int]) -> TakeQuery: ...
def vector_search(self) -> VectorQuery: ...
class Tags:
@@ -133,6 +135,7 @@ class Tags:
async def update(self, tag: str, version: int): ...
class IndexConfig:
name: str
index_type: str
columns: List[str]
@@ -164,6 +167,7 @@ class Query:
def postfilter(self): ...
def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ...
def nearest_to_text(self, query: dict) -> FTSQuery: ...
async def output_schema(self) -> pa.Schema: ...
async def execute(
self, max_batch_length: Optional[int], timeout: Optional[timedelta]
) -> RecordBatchStream: ...
@@ -171,6 +175,13 @@ class Query:
async def analyze_plan(self) -> str: ...
def to_query_request(self) -> PyQueryRequest: ...
class TakeQuery:
def select(self, columns: List[str]): ...
def with_row_id(self): ...
async def output_schema(self) -> pa.Schema: ...
async def execute(self) -> RecordBatchStream: ...
def to_query_request(self) -> PyQueryRequest: ...
class FTSQuery:
def where(self, filter: str): ...
def select(self, columns: List[str]): ...
@@ -182,12 +193,14 @@ class FTSQuery:
def get_query(self) -> str: ...
def add_query_vector(self, query_vec: pa.Array) -> None: ...
def nearest_to(self, query_vec: pa.Array) -> HybridQuery: ...
async def output_schema(self) -> pa.Schema: ...
async def execute(
self, max_batch_length: Optional[int], timeout: Optional[timedelta]
) -> RecordBatchStream: ...
def to_query_request(self) -> PyQueryRequest: ...
class VectorQuery:
async def output_schema(self) -> pa.Schema: ...
async def execute(self) -> RecordBatchStream: ...
def where(self, filter: str): ...
def select(self, columns: List[str]): ...
@@ -295,3 +308,34 @@ class AlterColumnsResult:
class DropColumnsResult:
version: int
class AsyncPermutationBuilder:
def select(self, projections: Dict[str, str]) -> "AsyncPermutationBuilder": ...
def split_random(
self,
*,
ratios: Optional[List[float]] = None,
counts: Optional[List[int]] = None,
fixed: Optional[int] = None,
seed: Optional[int] = None,
) -> "AsyncPermutationBuilder": ...
def split_hash(
self, columns: List[str], split_weights: List[int], *, discard_weight: int = 0
) -> "AsyncPermutationBuilder": ...
def split_sequential(
self,
*,
ratios: Optional[List[float]] = None,
counts: Optional[List[int]] = None,
fixed: Optional[int] = None,
) -> "AsyncPermutationBuilder": ...
def split_calculated(self, calculation: str) -> "AsyncPermutationBuilder": ...
def shuffle(
self, seed: Optional[int], clump_size: Optional[int]
) -> "AsyncPermutationBuilder": ...
def filter(self, filter: str) -> "AsyncPermutationBuilder": ...
async def execute(self) -> Table: ...
def async_permutation_builder(
table: Table, dest_table_name: str
) -> AsyncPermutationBuilder: ...

View File

@@ -5,11 +5,20 @@
from __future__ import annotations
from abc import abstractmethod
from datetime import timedelta
from pathlib import Path
import sys
from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Union
if sys.version_info >= (3, 12):
from typing import override
class EnforceOverrides:
pass
else:
from overrides import EnforceOverrides, override # type: ignore
from lancedb.embeddings.registry import EmbeddingFunctionRegistry
from overrides import EnforceOverrides, override # type: ignore
from lancedb.common import data_to_reader, sanitize_uri, validate_schema
from lancedb.background_loop import LOOP
@@ -32,7 +41,6 @@ import deprecation
if TYPE_CHECKING:
import pyarrow as pa
from .pydantic import LanceModel
from datetime import timedelta
from ._lancedb import Connection as LanceDbConnection
from .common import DATA, URI
@@ -444,7 +452,12 @@ class LanceDBConnection(DBConnection):
read_consistency_interval: Optional[timedelta] = None,
storage_options: Optional[Dict[str, str]] = None,
session: Optional[Session] = None,
_inner: Optional[LanceDbConnection] = None,
):
if _inner is not None:
self._conn = _inner
return
if not isinstance(uri, Path):
scheme = get_uri_scheme(uri)
is_local = isinstance(uri, Path) or scheme == "file"
@@ -453,11 +466,6 @@ class LanceDBConnection(DBConnection):
uri = Path(uri)
uri = uri.expanduser().absolute()
Path(uri).mkdir(parents=True, exist_ok=True)
self._uri = str(uri)
self._entered = False
self.read_consistency_interval = read_consistency_interval
self.storage_options = storage_options
self.session = session
if read_consistency_interval is not None:
read_consistency_interval_secs = read_consistency_interval.total_seconds()
@@ -476,10 +484,32 @@ class LanceDBConnection(DBConnection):
session,
)
# TODO: It would be nice if we didn't store self.storage_options but it is
# currently used by the LanceTable.to_lance method. This doesn't _really_
# work because some paths like LanceDBConnection.from_inner will lose the
# storage_options. Also, this class really shouldn't be holding any state
# beyond _conn.
self.storage_options = storage_options
self._conn = AsyncConnection(LOOP.run(do_connect()))
@property
def read_consistency_interval(self) -> Optional[timedelta]:
return LOOP.run(self._conn.get_read_consistency_interval())
@property
def session(self) -> Optional[Session]:
return self._conn.session
@property
def uri(self) -> str:
return self._conn.uri
@classmethod
def from_inner(cls, inner: LanceDbConnection):
return cls(None, _inner=inner)
def __repr__(self) -> str:
val = f"{self.__class__.__name__}(uri={self._uri!r}"
val = f"{self.__class__.__name__}(uri={self._conn.uri!r}"
if self.read_consistency_interval is not None:
val += f", read_consistency_interval={repr(self.read_consistency_interval)}"
val += ")"
@@ -489,6 +519,10 @@ class LanceDBConnection(DBConnection):
conn = AsyncConnection(await lancedb_connect(self.uri))
return await conn.table_names(start_after=start_after, limit=limit)
@property
def _inner(self) -> LanceDbConnection:
return self._conn._inner
@override
def list_namespaces(
self,
@@ -848,6 +882,13 @@ class AsyncConnection(object):
def uri(self) -> str:
return self._inner.uri
async def get_read_consistency_interval(self) -> Optional[timedelta]:
interval_secs = await self._inner.get_read_consistency_interval()
if interval_secs is not None:
return timedelta(seconds=interval_secs)
else:
return None
async def list_namespaces(
self,
namespace: List[str] = [],

View File

@@ -3,9 +3,11 @@
from functools import lru_cache
from typing import List, Union, Optional, Any
from logging import warning
from typing import List, Union, Optional, Any, Callable
import numpy as np
import io
import warnings
from ..util import attempt_import_or_raise
from .base import EmbeddingFunction
@@ -19,35 +21,52 @@ class ColPaliEmbeddings(EmbeddingFunction):
An embedding function that uses the ColPali engine for
multimodal multi-vector embeddings.
This embedding function supports ColQwen2.5 models, producing multivector outputs
for both text and image inputs. The output embeddings are lists of vectors, each
vector being 128-dimensional by default, represented as List[List[float]].
This embedding function supports ColPali models, producing multivector outputs
for both text and image inputs.
Parameters
----------
model_name : str
The name of the model to use (e.g., "Metric-AI/ColQwen2.5-3b-multilingual-v1.0")
Supports models based on these engines:
- ColPali: "vidore/colpali-v1.3" and others
- ColQwen2.5: "Metric-AI/ColQwen2.5-3b-multilingual-v1.0" and others
- ColQwen2: "vidore/colqwen2-v1.0" and others
- ColSmol: "vidore/colSmol-256M" and others
device : str
The device for inference (default "cuda:0").
The device for inference (default "auto").
dtype : str
Data type for model weights (default "bfloat16").
use_token_pooling : bool
Whether to use token pooling to reduce embedding size (default True).
DEPRECATED. Whether to use token pooling. Use `pooling_strategy` instead.
pooling_strategy : str, optional
The token pooling strategy to use, by default "hierarchical".
- "hierarchical": Progressively pools tokens to reduce sequence length.
- "lambda": A simpler pooling that uses a custom `pooling_func`.
pooling_func: typing.Callable, optional
A function to use for pooling when `pooling_strategy` is "lambda".
pool_factor : int
Factor to reduce sequence length if token pooling is enabled (default 2).
quantization_config : Optional[BitsAndBytesConfig]
Quantization configuration for the model. (default None, bitsandbytes needed)
batch_size : int
Batch size for processing inputs (default 2).
offload_folder: str, optional
Folder to offload model weights if using CPU offloading (default None). This is
useful for large models that do not fit in memory.
"""
model_name: str = "Metric-AI/ColQwen2.5-3b-multilingual-v1.0"
device: str = "auto"
dtype: str = "bfloat16"
use_token_pooling: bool = True
pooling_strategy: Optional[str] = "hierarchical"
pooling_func: Optional[Any] = None
pool_factor: int = 2
quantization_config: Optional[Any] = None
batch_size: int = 2
offload_folder: Optional[str] = None
_model = None
_processor = None
@@ -56,15 +75,43 @@ class ColPaliEmbeddings(EmbeddingFunction):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
torch = attempt_import_or_raise("torch", "torch")
if not self.use_token_pooling:
warnings.warn(
"use_token_pooling is deprecated, use pooling_strategy=None instead",
DeprecationWarning,
)
self.pooling_strategy = None
if self.pooling_strategy == "lambda" and self.pooling_func is None:
raise ValueError(
"pooling_func must be provided when pooling_strategy is 'lambda'"
)
device = self.device
if device == "auto":
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
dtype = self.dtype
if device == "mps" and dtype == "bfloat16":
dtype = "float32" # Avoid NaNs on MPS
(
self._model,
self._processor,
self._token_pooler,
) = self._load_model(
self.model_name,
self.dtype,
self.device,
self.use_token_pooling,
dtype,
device,
self.pooling_strategy,
self.pooling_func,
self.quantization_config,
)
@@ -74,16 +121,26 @@ class ColPaliEmbeddings(EmbeddingFunction):
model_name: str,
dtype: str,
device: str,
use_token_pooling: bool,
pooling_strategy: Optional[str],
pooling_func: Optional[Callable],
quantization_config: Optional[Any],
):
"""
Initialize and cache the ColPali model, processor, and token pooler.
"""
if device.startswith("mps"):
# warn some torch ops in late interaction architecture result in nans on mps
warning(
"MPS device detected. Some operations may result in NaNs. "
"If you encounter issues, consider using 'cpu' or 'cuda' devices."
)
torch = attempt_import_or_raise("torch", "torch")
transformers = attempt_import_or_raise("transformers", "transformers")
colpali_engine = attempt_import_or_raise("colpali_engine", "colpali_engine")
from colpali_engine.compression.token_pooling import HierarchicalTokenPooler
from colpali_engine.compression.token_pooling import (
HierarchicalTokenPooler,
LambdaTokenPooler,
)
if quantization_config is not None:
if not isinstance(quantization_config, transformers.BitsAndBytesConfig):
@@ -98,21 +155,45 @@ class ColPaliEmbeddings(EmbeddingFunction):
else:
torch_dtype = torch.float32
model = colpali_engine.models.ColQwen2_5.from_pretrained(
model_class, processor_class = None, None
model_name_lower = model_name.lower()
if "colqwen2.5" in model_name_lower:
model_class = colpali_engine.models.ColQwen2_5
processor_class = colpali_engine.models.ColQwen2_5_Processor
elif "colsmol" in model_name_lower or "colidefics3" in model_name_lower:
model_class = colpali_engine.models.ColIdefics3
processor_class = colpali_engine.models.ColIdefics3Processor
elif "colqwen" in model_name_lower:
model_class = colpali_engine.models.ColQwen2
processor_class = colpali_engine.models.ColQwen2Processor
elif "colpali" in model_name_lower:
model_class = colpali_engine.models.ColPali
processor_class = colpali_engine.models.ColPaliProcessor
if model_class is None:
raise ValueError(f"Unsupported model: {model_name}")
model = model_class.from_pretrained(
model_name,
torch_dtype=torch_dtype,
device_map=device,
quantization_config=quantization_config
if quantization_config is not None
else None,
attn_implementation="flash_attention_2"
if is_flash_attn_2_available()
else None,
low_cpu_mem_usage=True,
).eval()
processor = colpali_engine.models.ColQwen2_5_Processor.from_pretrained(
model_name
)
token_pooler = HierarchicalTokenPooler() if use_token_pooling else None
model = model.to(device)
model = model.to(torch_dtype) # Force cast after moving to device
processor = processor_class.from_pretrained(model_name)
token_pooler = None
if pooling_strategy == "hierarchical":
token_pooler = HierarchicalTokenPooler()
elif pooling_strategy == "lambda":
token_pooler = LambdaTokenPooler(pool_func=pooling_func)
return model, processor, token_pooler
def ndims(self):
@@ -128,7 +209,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
with torch.no_grad():
query_embeddings = self._model(**batch_queries)
if self.use_token_pooling and self._token_pooler is not None:
if self.pooling_strategy and self._token_pooler is not None:
query_embeddings = self._token_pooler.pool_embeddings(
query_embeddings,
pool_factor=self.pool_factor,
@@ -145,13 +226,20 @@ class ColPaliEmbeddings(EmbeddingFunction):
Use token pooling if enabled.
"""
torch = attempt_import_or_raise("torch", "torch")
if self.use_token_pooling and self._token_pooler is not None:
embeddings = self._token_pooler.pool_embeddings(
embeddings,
pool_factor=self.pool_factor,
padding=True,
padding_side=self._processor.tokenizer.padding_side,
)
if self.pooling_strategy and self._token_pooler is not None:
if self.pooling_strategy == "hierarchical":
embeddings = self._token_pooler.pool_embeddings(
embeddings,
pool_factor=self.pool_factor,
padding=True,
padding_side=self._processor.tokenizer.padding_side,
)
elif self.pooling_strategy == "lambda":
embeddings = self._token_pooler.pool_embeddings(
embeddings,
padding=True,
padding_side=self._processor.tokenizer.padding_side,
)
if isinstance(embeddings, torch.Tensor):
tensors = embeddings.detach().cpu()
@@ -179,6 +267,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
)
with torch.no_grad():
query_embeddings = self._model(**batch_queries)
query_embeddings = torch.nan_to_num(query_embeddings)
all_embeddings.extend(self._process_embeddings(query_embeddings))
return all_embeddings
@@ -225,6 +314,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
)
with torch.no_grad():
image_embeddings = self._model(**batch_images)
image_embeddings = torch.nan_to_num(image_embeddings)
all_embeddings.extend(self._process_embeddings(image_embeddings))
return all_embeddings

View File

@@ -605,9 +605,53 @@ class IvfPq:
target_partition_size: Optional[int] = None
@dataclass
class IvfRq:
"""Describes an IVF RQ Index
IVF-RQ (Residual Quantization) stores a compressed copy of each vector using
residual quantization and organizes them into IVF partitions. Parameters
largely mirror IVF-PQ for consistency.
Attributes
----------
distance_type: str, default "l2"
Distance metric used to train the index and for quantization.
The following distance types are available:
"l2" - Euclidean distance.
"cosine" - Cosine distance.
"dot" - Dot product.
num_partitions: int, default sqrt(num_rows)
Number of IVF partitions to create.
num_bits: int, default 1
Number of bits to encode each dimension.
max_iterations: int, default 50
Max iterations to train kmeans when computing IVF partitions.
sample_rate: int, default 256
Controls the number of training vectors: sample_rate * num_partitions.
target_partition_size, default is 8192
Target size of each partition.
"""
distance_type: Literal["l2", "cosine", "dot"] = "l2"
num_partitions: Optional[int] = None
num_bits: int = 1
max_iterations: int = 50
sample_rate: int = 256
target_partition_size: Optional[int] = None
__all__ = [
"BTree",
"IvfPq",
"IvfRq",
"IvfFlat",
"HnswPq",
"HnswSq",

View File

@@ -12,13 +12,18 @@ from __future__ import annotations
from typing import Dict, Iterable, List, Optional, Union
import os
import sys
if sys.version_info >= (3, 12):
from typing import override
else:
from overrides import override
from lancedb.db import DBConnection
from lancedb.table import LanceTable, Table
from lancedb.util import validate_table_name
from lancedb.common import validate_schema
from lancedb.table import sanitize_create_table
from overrides import override
from lance_namespace import LanceNamespace, connect as namespace_connect
from lance_namespace_urllib3_client.models import (

View File

@@ -0,0 +1,72 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from ._lancedb import async_permutation_builder
from .table import LanceTable
from .background_loop import LOOP
from typing import Optional
class PermutationBuilder:
def __init__(self, table: LanceTable):
self._async = async_permutation_builder(table)
def select(self, projections: dict[str, str]) -> "PermutationBuilder":
self._async.select(projections)
return self
def split_random(
self,
*,
ratios: Optional[list[float]] = None,
counts: Optional[list[int]] = None,
fixed: Optional[int] = None,
seed: Optional[int] = None,
) -> "PermutationBuilder":
self._async.split_random(ratios=ratios, counts=counts, fixed=fixed, seed=seed)
return self
def split_hash(
self,
columns: list[str],
split_weights: list[int],
*,
discard_weight: Optional[int] = None,
) -> "PermutationBuilder":
self._async.split_hash(columns, split_weights, discard_weight=discard_weight)
return self
def split_sequential(
self,
*,
ratios: Optional[list[float]] = None,
counts: Optional[list[int]] = None,
fixed: Optional[int] = None,
) -> "PermutationBuilder":
self._async.split_sequential(ratios=ratios, counts=counts, fixed=fixed)
return self
def split_calculated(self, calculation: str) -> "PermutationBuilder":
self._async.split_calculated(calculation)
return self
def shuffle(
self, *, seed: Optional[int] = None, clump_size: Optional[int] = None
) -> "PermutationBuilder":
self._async.shuffle(seed=seed, clump_size=clump_size)
return self
def filter(self, filter: str) -> "PermutationBuilder":
self._async.filter(filter)
return self
def execute(self) -> LanceTable:
async def do_execute():
inner_tbl = await self._async.execute()
return LanceTable.from_inner(inner_tbl)
return LOOP.run(do_execute())
def permutation_builder(table: LanceTable) -> PermutationBuilder:
return PermutationBuilder(table)

View File

@@ -1237,6 +1237,14 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._refine_factor = refine_factor
return self
def output_schema(self) -> pa.Schema:
"""
Return the output schema for the query
This does not execute the query.
"""
return self._table._output_schema(self.to_query_object())
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
"""
Execute the query and return the results as an
@@ -1452,6 +1460,14 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
offset=self._offset,
)
def output_schema(self) -> pa.Schema:
"""
Return the output schema for the query
This does not execute the query.
"""
return self._table._output_schema(self.to_query_object())
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
path, fs, exist = self._table._get_fts_index_path()
if exist:
@@ -1595,6 +1611,10 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
offset=self._offset,
)
def output_schema(self) -> pa.Schema:
query = self.to_query_object()
return self._table._output_schema(query)
def to_batches(
self, /, batch_size: Optional[int] = None, timeout: Optional[timedelta] = None
) -> pa.RecordBatchReader:
@@ -2238,6 +2258,14 @@ class AsyncQueryBase(object):
)
)
async def output_schema(self) -> pa.Schema:
"""
Return the output schema for the query
This does not execute the query.
"""
return await self._inner.output_schema()
async def to_arrow(self, timeout: Optional[timedelta] = None) -> pa.Table:
"""
Execute the query and collect the results into an Apache Arrow Table.
@@ -3193,6 +3221,14 @@ class BaseQueryBuilder(object):
self._inner.with_row_id()
return self
def output_schema(self) -> pa.Schema:
"""
Return the output schema for the query
This does not execute the query.
"""
return LOOP.run(self._inner.output_schema())
def to_batches(
self,
*,

View File

@@ -5,15 +5,20 @@
from datetime import timedelta
import logging
from concurrent.futures import ThreadPoolExecutor
import sys
from typing import Any, Dict, Iterable, List, Optional, Union
from urllib.parse import urlparse
import warnings
if sys.version_info >= (3, 12):
from typing import override
else:
from overrides import override
# Remove this import to fix circular dependency
# from lancedb import connect_async
from lancedb.remote import ClientConfig
import pyarrow as pa
from overrides import override
from ..common import DATA
from ..db import DBConnection, LOOP

View File

@@ -114,7 +114,7 @@ class RemoteTable(Table):
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST", "scalar"] = "scalar",
*,
replace: bool = False,
wait_timeout: timedelta = None,
wait_timeout: Optional[timedelta] = None,
name: Optional[str] = None,
):
"""Creates a scalar index
@@ -153,7 +153,7 @@ class RemoteTable(Table):
column: str,
*,
replace: bool = False,
wait_timeout: timedelta = None,
wait_timeout: Optional[timedelta] = None,
with_position: bool = False,
# tokenizer configs:
base_tokenizer: str = "simple",
@@ -436,6 +436,9 @@ class RemoteTable(Table):
def _analyze_plan(self, query: Query) -> str:
return LOOP.run(self._table._analyze_plan(query))
def _output_schema(self, query: Query) -> pa.Schema:
return LOOP.run(self._table._output_schema(query))
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
"""Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder]
that can be used to create a "merge insert" operation.

View File

@@ -44,7 +44,7 @@ import numpy as np
from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
from .index import BTree, IvfFlat, IvfPq, Bitmap, IvfRq, LabelList, HnswPq, HnswSq, FTS
from .merge import LanceMergeInsertBuilder
from .pydantic import LanceModel, model_to_dict
from .query import (
@@ -74,6 +74,7 @@ from .index import lang_mapping
if TYPE_CHECKING:
from .db import LanceDBConnection
from ._lancedb import (
Table as LanceDBTable,
OptimizeStats,
@@ -88,7 +89,6 @@ if TYPE_CHECKING:
MergeResult,
UpdateResult,
)
from .db import LanceDBConnection
from .index import IndexConfig
import pandas
import PIL
@@ -1248,6 +1248,9 @@ class Table(ABC):
@abstractmethod
def _analyze_plan(self, query: Query) -> str: ...
@abstractmethod
def _output_schema(self, query: Query) -> pa.Schema: ...
@abstractmethod
def _do_merge(
self,
@@ -1707,22 +1710,38 @@ class LanceTable(Table):
namespace: List[str] = [],
storage_options: Optional[Dict[str, str]] = None,
index_cache_size: Optional[int] = None,
_async: AsyncTable = None,
):
self._conn = connection
self._namespace = namespace
self._table = LOOP.run(
connection._conn.open_table(
name,
namespace=namespace,
storage_options=storage_options,
index_cache_size=index_cache_size,
if _async is not None:
self._table = _async
else:
self._table = LOOP.run(
connection._conn.open_table(
name,
namespace=namespace,
storage_options=storage_options,
index_cache_size=index_cache_size,
)
)
)
@property
def name(self) -> str:
return self._table.name
@classmethod
def from_inner(cls, tbl: LanceDBTable):
from .db import LanceDBConnection
async_tbl = AsyncTable(tbl)
conn = LanceDBConnection.from_inner(tbl.database())
return cls(
conn,
async_tbl.name,
_async=async_tbl,
)
@classmethod
def open(cls, db, name, *, namespace: List[str] = [], **kwargs):
tbl = cls(db, name, namespace=namespace, **kwargs)
@@ -1991,7 +2010,7 @@ class LanceTable(Table):
index_cache_size: Optional[int] = None,
num_bits: int = 8,
index_type: Literal[
"IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"
"IVF_FLAT", "IVF_PQ", "IVF_RQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"
] = "IVF_PQ",
max_iterations: int = 50,
sample_rate: int = 256,
@@ -2039,6 +2058,15 @@ class LanceTable(Table):
sample_rate=sample_rate,
target_partition_size=target_partition_size,
)
elif index_type == "IVF_RQ":
config = IvfRq(
distance_type=metric,
num_partitions=num_partitions,
num_bits=num_bits,
max_iterations=max_iterations,
sample_rate=sample_rate,
target_partition_size=target_partition_size,
)
elif index_type == "IVF_HNSW_PQ":
config = HnswPq(
distance_type=metric,
@@ -2736,6 +2764,9 @@ class LanceTable(Table):
def _analyze_plan(self, query: Query) -> str:
return LOOP.run(self._table._analyze_plan(query))
def _output_schema(self, query: Query) -> pa.Schema:
return LOOP.run(self._table._output_schema(query))
def _do_merge(
self,
merge: LanceMergeInsertBuilder,
@@ -2747,6 +2778,10 @@ class LanceTable(Table):
self._table._do_merge(merge, new_data, on_bad_vectors, fill_value)
)
@property
def _inner(self) -> LanceDBTable:
return self._table._inner
@deprecation.deprecated(
deprecated_in="0.21.0",
current_version=__version__,
@@ -3330,7 +3365,7 @@ class AsyncTable:
*,
replace: Optional[bool] = None,
config: Optional[
Union[IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS]
Union[IvfFlat, IvfPq, IvfRq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS]
] = None,
wait_timeout: Optional[timedelta] = None,
name: Optional[str] = None,
@@ -3369,11 +3404,12 @@ class AsyncTable:
"""
if config is not None:
if not isinstance(
config, (IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS)
config,
(IvfFlat, IvfPq, IvfRq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS),
):
raise TypeError(
"config must be an instance of IvfPq, HnswPq, HnswSq, BTree,"
" Bitmap, LabelList, or FTS"
"config must be an instance of IvfPq, IvfRq, HnswPq, HnswSq, BTree,"
" Bitmap, LabelList, or FTS, but got " + str(type(config))
)
try:
await self._inner.create_index(
@@ -3888,6 +3924,10 @@ class AsyncTable:
async_query = self._sync_query_to_async(query)
return await async_query.analyze_plan()
async def _output_schema(self, query: Query) -> pa.Schema:
async_query = self._sync_query_to_async(query)
return await async_query.output_schema()
async def _do_merge(
self,
merge: LanceMergeInsertBuilder,

View File

@@ -18,10 +18,17 @@ AddMode = Literal["append", "overwrite"]
CreateMode = Literal["create", "overwrite"]
# Index type literals
VectorIndexType = Literal["IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"]
VectorIndexType = Literal["IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ", "IVF_RQ"]
ScalarIndexType = Literal["BTREE", "BITMAP", "LABEL_LIST"]
IndexType = Literal[
"IVF_PQ", "IVF_HNSW_PQ", "IVF_HNSW_SQ", "FTS", "BTREE", "BITMAP", "LABEL_LIST"
"IVF_PQ",
"IVF_HNSW_PQ",
"IVF_HNSW_SQ",
"FTS",
"BTREE",
"BITMAP",
"LABEL_LIST",
"IVF_RQ",
]
# Tokenizer literals

View File

@@ -656,6 +656,106 @@ def test_colpali(tmp_path):
)
@pytest.mark.slow
@pytest.mark.skipif(
importlib.util.find_spec("colpali_engine") is None,
reason="colpali_engine not installed",
)
@pytest.mark.parametrize(
"model_name",
[
"vidore/colSmol-256M",
"vidore/colqwen2.5-v0.2",
"vidore/colpali-v1.3",
"vidore/colqwen2-v1.0",
],
)
def test_colpali_models(tmp_path, model_name):
import requests
from lancedb.pydantic import LanceModel
db = lancedb.connect(tmp_path)
registry = get_registry()
func = registry.get("colpali").create(model_name=model_name)
class MediaItems(LanceModel):
text: str
image_uri: str = func.SourceField()
image_bytes: bytes = func.SourceField()
image_vectors: MultiVector(func.ndims()) = func.VectorField()
table = db.create_table(f"media_{model_name.replace('/', '_')}", schema=MediaItems)
texts = [
"a cute cat playing with yarn",
]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
]
image_bytes = [requests.get(uri).content for uri in uris]
table.add(
pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes})
)
image_results = (
table.search("fluffy companion", vector_column_name="image_vectors")
.limit(1)
.to_pydantic(MediaItems)[0]
)
assert "cat" in image_results.text.lower() or "puppy" in image_results.text.lower()
first_row = table.to_arrow().to_pylist()[0]
assert len(first_row["image_vectors"]) > 1, "Should have multiple image vectors"
assert len(first_row["image_vectors"][0]) == func.ndims(), (
"Vector dimension mismatch"
)
@pytest.mark.slow
@pytest.mark.skipif(
importlib.util.find_spec("colpali_engine") is None,
reason="colpali_engine not installed",
)
def test_colpali_pooling(tmp_path):
registry = get_registry()
model_name = "vidore/colSmol-256M"
test_sentence = "a test sentence for pooling"
# 1. Get embeddings with no pooling
func_no_pool = registry.get("colpali").create(
model_name=model_name, pooling_strategy=None
)
unpooled_embeddings = func_no_pool.generate_text_embeddings([test_sentence])[0]
original_length = len(unpooled_embeddings)
assert original_length > 1
# 2. Test hierarchical pooling
func_hierarchical = registry.get("colpali").create(
model_name=model_name, pooling_strategy="hierarchical", pool_factor=2
)
hierarchical_embeddings = func_hierarchical.generate_text_embeddings(
[test_sentence]
)[0]
expected_hierarchical_length = (original_length + 1) // 2
assert len(hierarchical_embeddings) == expected_hierarchical_length
# 3. Test lambda pooling
def simple_pool_func(tensor):
return tensor[::2]
func_lambda = registry.get("colpali").create(
model_name=model_name,
pooling_strategy="lambda",
pooling_func=simple_pool_func,
)
lambda_embeddings = func_lambda.generate_text_embeddings([test_sentence])[0]
expected_lambda_length = (original_length + 1) // 2
assert len(lambda_embeddings) == expected_lambda_length
@pytest.mark.slow
def test_siglip(tmp_path, test_images, query_image_bytes):
from PIL import Image

View File

@@ -8,7 +8,17 @@ import pyarrow as pa
import pytest
import pytest_asyncio
from lancedb import AsyncConnection, AsyncTable, connect_async
from lancedb.index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
from lancedb.index import (
BTree,
IvfFlat,
IvfPq,
IvfRq,
Bitmap,
LabelList,
HnswPq,
HnswSq,
FTS,
)
@pytest_asyncio.fixture
@@ -195,6 +205,16 @@ async def test_create_4bit_ivfpq_index(some_table: AsyncTable):
assert stats.loss >= 0.0
@pytest.mark.asyncio
async def test_create_ivfrq_index(some_table: AsyncTable):
await some_table.create_index("vector", config=IvfRq(num_bits=1))
indices = await some_table.list_indices()
assert len(indices) == 1
assert indices[0].index_type == "IvfRq"
assert indices[0].columns == ["vector"]
assert indices[0].name == "vector_idx"
@pytest.mark.asyncio
async def test_create_hnswpq_index(some_table: AsyncTable):
await some_table.create_index("vector", config=HnswPq(num_partitions=10))

View File

@@ -0,0 +1,462 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import pyarrow as pa
import pytest
from lancedb.permutation import permutation_builder
def test_split_random_ratios(mem_db):
"""Test random splitting with ratios."""
tbl = mem_db.create_table(
"test_table", pa.table({"x": range(100), "y": range(100)})
)
permutation_tbl = permutation_builder(tbl).split_random(ratios=[0.3, 0.7]).execute()
# Check that the table was created and has data
assert permutation_tbl.count_rows() == 100
# Check that split_id column exists and has correct values
data = permutation_tbl.search(None).to_arrow().to_pydict()
split_ids = data["split_id"]
assert set(split_ids) == {0, 1}
# Check approximate split sizes (allowing for rounding)
split_0_count = split_ids.count(0)
split_1_count = split_ids.count(1)
assert 25 <= split_0_count <= 35 # ~30% ± tolerance
assert 65 <= split_1_count <= 75 # ~70% ± tolerance
def test_split_random_counts(mem_db):
"""Test random splitting with absolute counts."""
tbl = mem_db.create_table(
"test_table", pa.table({"x": range(100), "y": range(100)})
)
permutation_tbl = permutation_builder(tbl).split_random(counts=[20, 30]).execute()
# Check that we have exactly the requested counts
assert permutation_tbl.count_rows() == 50
data = permutation_tbl.search(None).to_arrow().to_pydict()
split_ids = data["split_id"]
assert split_ids.count(0) == 20
assert split_ids.count(1) == 30
def test_split_random_fixed(mem_db):
"""Test random splitting with fixed number of splits."""
tbl = mem_db.create_table(
"test_table", pa.table({"x": range(100), "y": range(100)})
)
permutation_tbl = permutation_builder(tbl).split_random(fixed=4).execute()
# Check that we have 4 splits with 25 rows each
assert permutation_tbl.count_rows() == 100
data = permutation_tbl.search(None).to_arrow().to_pydict()
split_ids = data["split_id"]
assert set(split_ids) == {0, 1, 2, 3}
for split_id in range(4):
assert split_ids.count(split_id) == 25
def test_split_random_with_seed(mem_db):
"""Test that seeded random splits are reproducible."""
tbl = mem_db.create_table("test_table", pa.table({"x": range(50), "y": range(50)}))
# Create two identical permutations with same seed
perm1 = permutation_builder(tbl).split_random(ratios=[0.6, 0.4], seed=42).execute()
perm2 = permutation_builder(tbl).split_random(ratios=[0.6, 0.4], seed=42).execute()
# Results should be identical
data1 = perm1.search(None).to_arrow().to_pydict()
data2 = perm2.search(None).to_arrow().to_pydict()
assert data1["row_id"] == data2["row_id"]
assert data1["split_id"] == data2["split_id"]
def test_split_hash(mem_db):
"""Test hash-based splitting."""
tbl = mem_db.create_table(
"test_table",
pa.table(
{
"id": range(100),
"category": (["A", "B", "C"] * 34)[:100], # Repeating pattern
"value": range(100),
}
),
)
permutation_tbl = (
permutation_builder(tbl)
.split_hash(["category"], [1, 1], discard_weight=0)
.execute()
)
# Should have all 100 rows (no discard)
assert permutation_tbl.count_rows() == 100
data = permutation_tbl.search(None).to_arrow().to_pydict()
split_ids = data["split_id"]
assert set(split_ids) == {0, 1}
# Verify that each split has roughly 50 rows (allowing for hash variance)
split_0_count = split_ids.count(0)
split_1_count = split_ids.count(1)
assert 30 <= split_0_count <= 70 # ~50 ± 20 tolerance for hash distribution
assert 30 <= split_1_count <= 70 # ~50 ± 20 tolerance for hash distribution
# Hash splits should be deterministic - same category should go to same split
# Let's verify by creating another permutation and checking consistency
perm2 = (
permutation_builder(tbl)
.split_hash(["category"], [1, 1], discard_weight=0)
.execute()
)
data2 = perm2.search(None).to_arrow().to_pydict()
assert data["split_id"] == data2["split_id"] # Should be identical
def test_split_hash_with_discard(mem_db):
"""Test hash-based splitting with discard weight."""
tbl = mem_db.create_table(
"test_table",
pa.table({"id": range(100), "category": ["A", "B"] * 50, "value": range(100)}),
)
permutation_tbl = (
permutation_builder(tbl)
.split_hash(["category"], [1, 1], discard_weight=2) # Should discard ~50%
.execute()
)
# Should have fewer than 100 rows due to discard
row_count = permutation_tbl.count_rows()
assert row_count < 100
assert row_count > 0 # But not empty
def test_split_sequential(mem_db):
"""Test sequential splitting."""
tbl = mem_db.create_table(
"test_table", pa.table({"x": range(100), "y": range(100)})
)
permutation_tbl = (
permutation_builder(tbl).split_sequential(counts=[30, 40]).execute()
)
assert permutation_tbl.count_rows() == 70
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
split_ids = data["split_id"]
# Sequential should maintain order
assert row_ids == sorted(row_ids)
# First 30 should be split 0, next 40 should be split 1
assert split_ids[:30] == [0] * 30
assert split_ids[30:] == [1] * 40
def test_split_calculated(mem_db):
"""Test calculated splitting."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(100), "value": range(100)})
)
permutation_tbl = (
permutation_builder(tbl)
.split_calculated("id % 3") # Split based on id modulo 3
.execute()
)
assert permutation_tbl.count_rows() == 100
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
split_ids = data["split_id"]
# Verify the calculation: each row's split_id should equal row_id % 3
for i, (row_id, split_id) in enumerate(zip(row_ids, split_ids)):
assert split_id == row_id % 3
def test_split_error_cases(mem_db):
"""Test error handling for invalid split parameters."""
tbl = mem_db.create_table("test_table", pa.table({"x": range(10), "y": range(10)}))
# Test split_random with no parameters
with pytest.raises(Exception):
permutation_builder(tbl).split_random().execute()
# Test split_random with multiple parameters
with pytest.raises(Exception):
permutation_builder(tbl).split_random(
ratios=[0.5, 0.5], counts=[5, 5]
).execute()
# Test split_sequential with no parameters
with pytest.raises(Exception):
permutation_builder(tbl).split_sequential().execute()
# Test split_sequential with multiple parameters
with pytest.raises(Exception):
permutation_builder(tbl).split_sequential(ratios=[0.5, 0.5], fixed=2).execute()
def test_shuffle_no_seed(mem_db):
"""Test shuffling without a seed."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(100), "value": range(100)})
)
# Create a permutation with shuffling (no seed)
permutation_tbl = permutation_builder(tbl).shuffle().execute()
assert permutation_tbl.count_rows() == 100
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
# Row IDs should not be in sequential order due to shuffling
# This is probabilistic but with 100 rows, it's extremely unlikely they'd stay
# in order
assert row_ids != list(range(100))
def test_shuffle_with_seed(mem_db):
"""Test that shuffling with a seed is reproducible."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(50), "value": range(50)})
)
# Create two identical permutations with same shuffle seed
perm1 = permutation_builder(tbl).shuffle(seed=42).execute()
perm2 = permutation_builder(tbl).shuffle(seed=42).execute()
# Results should be identical due to same seed
data1 = perm1.search(None).to_arrow().to_pydict()
data2 = perm2.search(None).to_arrow().to_pydict()
assert data1["row_id"] == data2["row_id"]
assert data1["split_id"] == data2["split_id"]
def test_shuffle_with_clump_size(mem_db):
"""Test shuffling with clump size."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(100), "value": range(100)})
)
# Create a permutation with shuffling using clumps
permutation_tbl = (
permutation_builder(tbl)
.shuffle(clump_size=10) # 10-row clumps
.execute()
)
assert permutation_tbl.count_rows() == 100
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
for i in range(10):
start = row_ids[i * 10]
assert row_ids[i * 10 : (i + 1) * 10] == list(range(start, start + 10))
def test_shuffle_different_seeds(mem_db):
"""Test that different seeds produce different shuffle orders."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(50), "value": range(50)})
)
# Create two permutations with different shuffle seeds
perm1 = permutation_builder(tbl).split_random(fixed=2).shuffle(seed=42).execute()
perm2 = permutation_builder(tbl).split_random(fixed=2).shuffle(seed=123).execute()
# Results should be different due to different seeds
data1 = perm1.search(None).to_arrow().to_pydict()
data2 = perm2.search(None).to_arrow().to_pydict()
# Row order should be different
assert data1["row_id"] != data2["row_id"]
def test_shuffle_combined_with_splits(mem_db):
"""Test shuffling combined with different split strategies."""
tbl = mem_db.create_table(
"test_table",
pa.table(
{
"id": range(100),
"category": (["A", "B", "C"] * 34)[:100],
"value": range(100),
}
),
)
# Test shuffle with random splits
perm_random = (
permutation_builder(tbl)
.split_random(ratios=[0.6, 0.4], seed=42)
.shuffle(seed=123, clump_size=None)
.execute()
)
# Test shuffle with hash splits
perm_hash = (
permutation_builder(tbl)
.split_hash(["category"], [1, 1], discard_weight=0)
.shuffle(seed=456, clump_size=5)
.execute()
)
# Test shuffle with sequential splits
perm_sequential = (
permutation_builder(tbl)
.split_sequential(counts=[40, 35])
.shuffle(seed=789, clump_size=None)
.execute()
)
# Verify all permutations work and have expected properties
assert perm_random.count_rows() == 100
assert perm_hash.count_rows() == 100
assert perm_sequential.count_rows() == 75
# Verify shuffle affected the order
data_random = perm_random.search(None).to_arrow().to_pydict()
data_sequential = perm_sequential.search(None).to_arrow().to_pydict()
assert data_random["row_id"] != list(range(100))
assert data_sequential["row_id"] != list(range(75))
def test_no_shuffle_maintains_order(mem_db):
"""Test that not calling shuffle maintains the original order."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(50), "value": range(50)})
)
# Create permutation without shuffle (should maintain some order)
permutation_tbl = (
permutation_builder(tbl)
.split_sequential(counts=[25, 25]) # Sequential maintains order
.execute()
)
assert permutation_tbl.count_rows() == 50
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
# With sequential splits and no shuffle, should maintain order
assert row_ids == list(range(50))
def test_filter_basic(mem_db):
"""Test basic filtering functionality."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(100), "value": range(100, 200)})
)
# Filter to only include rows where id < 50
permutation_tbl = permutation_builder(tbl).filter("id < 50").execute()
assert permutation_tbl.count_rows() == 50
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
# All row_ids should be less than 50
assert all(row_id < 50 for row_id in row_ids)
def test_filter_with_splits(mem_db):
"""Test filtering combined with split strategies."""
tbl = mem_db.create_table(
"test_table",
pa.table(
{
"id": range(100),
"category": (["A", "B", "C"] * 34)[:100],
"value": range(100),
}
),
)
# Filter to only category A and B, then split
permutation_tbl = (
permutation_builder(tbl)
.filter("category IN ('A', 'B')")
.split_random(ratios=[0.5, 0.5])
.execute()
)
# Should have fewer than 100 rows due to filtering
row_count = permutation_tbl.count_rows()
assert row_count == 67
data = permutation_tbl.search(None).to_arrow().to_pydict()
categories = data["category"]
# All categories should be A or B
assert all(cat in ["A", "B"] for cat in categories)
def test_filter_with_shuffle(mem_db):
"""Test filtering combined with shuffling."""
tbl = mem_db.create_table(
"test_table",
pa.table(
{
"id": range(100),
"category": (["A", "B", "C", "D"] * 25)[:100],
"value": range(100),
}
),
)
# Filter and shuffle
permutation_tbl = (
permutation_builder(tbl)
.filter("category IN ('A', 'C')")
.shuffle(seed=42)
.execute()
)
row_count = permutation_tbl.count_rows()
assert row_count == 50 # Should have 50 rows (A and C categories)
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
assert row_ids != sorted(row_ids)
def test_filter_empty_result(mem_db):
"""Test filtering that results in empty set."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(10), "value": range(10)})
)
# Filter that matches nothing
permutation_tbl = (
permutation_builder(tbl)
.filter("value > 100") # No values > 100 in our data
.execute()
)
assert permutation_tbl.count_rows() == 0

View File

@@ -1298,6 +1298,79 @@ async def test_query_serialization_async(table_async: AsyncTable):
)
def test_query_schema(tmp_path):
db = lancedb.connect(tmp_path)
tbl = db.create_table(
"test",
pa.table(
{
"a": [1, 2, 3],
"text": ["a", "b", "c"],
"vec": pa.array(
[[1, 2], [3, 4], [5, 6]], pa.list_(pa.float32(), list_size=2)
),
}
),
)
assert tbl.search(None).output_schema() == pa.schema(
{
"a": pa.int64(),
"text": pa.string(),
"vec": pa.list_(pa.float32(), list_size=2),
}
)
assert tbl.search(None).select({"bl": "a * 2"}).output_schema() == pa.schema(
{"bl": pa.int64()}
)
assert tbl.search([1, 2]).select(["a"]).output_schema() == pa.schema(
{"a": pa.int64(), "_distance": pa.float32()}
)
assert tbl.search("blah").select(["a"]).output_schema() == pa.schema(
{"a": pa.int64()}
)
assert tbl.take_offsets([0]).select(["text"]).output_schema() == pa.schema(
{"text": pa.string()}
)
@pytest.mark.asyncio
async def test_query_schema_async(tmp_path):
db = await lancedb.connect_async(tmp_path)
tbl = await db.create_table(
"test",
pa.table(
{
"a": [1, 2, 3],
"text": ["a", "b", "c"],
"vec": pa.array(
[[1, 2], [3, 4], [5, 6]], pa.list_(pa.float32(), list_size=2)
),
}
),
)
assert await tbl.query().output_schema() == pa.schema(
{
"a": pa.int64(),
"text": pa.string(),
"vec": pa.list_(pa.float32(), list_size=2),
}
)
assert await tbl.query().select({"bl": "a * 2"}).output_schema() == pa.schema(
{"bl": pa.int64()}
)
assert await tbl.vector_search([1, 2]).select(["a"]).output_schema() == pa.schema(
{"a": pa.int64(), "_distance": pa.float32()}
)
assert await (await tbl.search("blah")).select(["a"]).output_schema() == pa.schema(
{"a": pa.int64()}
)
assert await tbl.take_offsets([0]).select(["text"]).output_schema() == pa.schema(
{"text": pa.string()}
)
def test_query_timeout(tmp_path):
# Use local directory instead of memory:// to add a bit of latency to
# operations so a timeout of zero will trigger exceptions.

View File

@@ -4,7 +4,10 @@
use std::{collections::HashMap, sync::Arc, time::Duration};
use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow};
use lancedb::{connection::Connection as LanceConnection, database::CreateTableMode};
use lancedb::{
connection::Connection as LanceConnection,
database::{CreateTableMode, ReadConsistency},
};
use pyo3::{
exceptions::{PyRuntimeError, PyValueError},
pyclass, pyfunction, pymethods, Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
@@ -23,7 +26,7 @@ impl Connection {
Self { inner: Some(inner) }
}
fn get_inner(&self) -> PyResult<&LanceConnection> {
pub(crate) fn get_inner(&self) -> PyResult<&LanceConnection> {
self.inner
.as_ref()
.ok_or_else(|| PyRuntimeError::new_err("Connection is closed"))
@@ -63,6 +66,18 @@ impl Connection {
self.get_inner().map(|inner| inner.uri().to_string())
}
#[pyo3(signature = ())]
pub fn get_read_consistency_interval(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.get_inner()?.clone();
future_into_py(self_.py(), async move {
Ok(match inner.read_consistency().await.infer_error()? {
ReadConsistency::Manual => None,
ReadConsistency::Eventual(duration) => Some(duration.as_secs_f64()),
ReadConsistency::Strong => Some(0.0_f64),
})
})
}
#[pyo3(signature = (namespace=vec![], start_after=None, limit=None))]
pub fn table_names(
self_: PyRef<'_, Self>,

View File

@@ -1,7 +1,7 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use lancedb::index::vector::IvfFlatIndexBuilder;
use lancedb::index::vector::{IvfFlatIndexBuilder, IvfRqIndexBuilder};
use lancedb::index::{
scalar::{BTreeIndexBuilder, FtsIndexBuilder},
vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
@@ -87,6 +87,22 @@ pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<Lance
}
Ok(LanceDbIndex::IvfPq(ivf_pq_builder))
},
"IvfRq" => {
let params = source.extract::<IvfRqParams>()?;
let distance_type = parse_distance_type(params.distance_type)?;
let mut ivf_rq_builder = IvfRqIndexBuilder::default()
.distance_type(distance_type)
.max_iterations(params.max_iterations)
.sample_rate(params.sample_rate)
.num_bits(params.num_bits);
if let Some(num_partitions) = params.num_partitions {
ivf_rq_builder = ivf_rq_builder.num_partitions(num_partitions);
}
if let Some(target_partition_size) = params.target_partition_size {
ivf_rq_builder = ivf_rq_builder.target_partition_size(target_partition_size);
}
Ok(LanceDbIndex::IvfRq(ivf_rq_builder))
},
"HnswPq" => {
let params = source.extract::<IvfHnswPqParams>()?;
let distance_type = parse_distance_type(params.distance_type)?;
@@ -170,6 +186,16 @@ struct IvfPqParams {
target_partition_size: Option<u32>,
}
#[derive(FromPyObject)]
struct IvfRqParams {
distance_type: String,
num_partitions: Option<u32>,
num_bits: u32,
max_iterations: u32,
sample_rate: u32,
target_partition_size: Option<u32>,
}
#[derive(FromPyObject)]
struct IvfHnswPqParams {
distance_type: String,

View File

@@ -5,6 +5,7 @@ use arrow::RecordBatchStream;
use connection::{connect, Connection};
use env_logger::Env;
use index::IndexConfig;
use permutation::PyAsyncPermutationBuilder;
use pyo3::{
pymodule,
types::{PyModule, PyModuleMethods},
@@ -22,6 +23,7 @@ pub mod connection;
pub mod error;
pub mod header;
pub mod index;
pub mod permutation;
pub mod query;
pub mod session;
pub mod table;
@@ -49,7 +51,9 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<DeleteResult>()?;
m.add_class::<DropColumnsResult>()?;
m.add_class::<UpdateResult>()?;
m.add_class::<PyAsyncPermutationBuilder>()?;
m.add_function(wrap_pyfunction!(connect, m)?)?;
m.add_function(wrap_pyfunction!(permutation::async_permutation_builder, m)?)?;
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
Ok(())

170
python/src/permutation.rs Normal file
View File

@@ -0,0 +1,170 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::{Arc, Mutex};
use crate::{error::PythonErrorExt, table::Table};
use lancedb::dataloader::{
permutation::builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
permutation::split::{SplitSizes, SplitStrategy},
};
use pyo3::{
exceptions::PyRuntimeError, pyclass, pymethods, types::PyAnyMethods, Bound, PyAny, PyRefMut,
PyResult,
};
use pyo3_async_runtimes::tokio::future_into_py;
/// Create a permutation builder for the given table
#[pyo3::pyfunction]
pub fn async_permutation_builder(table: Bound<'_, PyAny>) -> PyResult<PyAsyncPermutationBuilder> {
let table = table.getattr("_inner")?.downcast_into::<Table>()?;
let inner_table = table.borrow().inner_ref()?.clone();
let inner_builder = LancePermutationBuilder::new(inner_table);
Ok(PyAsyncPermutationBuilder {
state: Arc::new(Mutex::new(PyAsyncPermutationBuilderState {
builder: Some(inner_builder),
})),
})
}
struct PyAsyncPermutationBuilderState {
builder: Option<LancePermutationBuilder>,
}
#[pyclass(name = "AsyncPermutationBuilder")]
pub struct PyAsyncPermutationBuilder {
state: Arc<Mutex<PyAsyncPermutationBuilderState>>,
}
impl PyAsyncPermutationBuilder {
fn modify(
&self,
func: impl FnOnce(LancePermutationBuilder) -> LancePermutationBuilder,
) -> PyResult<Self> {
let mut state = self.state.lock().unwrap();
let builder = state
.builder
.take()
.ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?;
state.builder = Some(func(builder));
Ok(Self {
state: self.state.clone(),
})
}
}
#[pymethods]
impl PyAsyncPermutationBuilder {
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None, seed=None))]
pub fn split_random(
slf: PyRefMut<'_, Self>,
ratios: Option<Vec<f64>>,
counts: Option<Vec<u64>>,
fixed: Option<u64>,
seed: Option<u64>,
) -> PyResult<Self> {
// Check that exactly one split type is provided
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
.iter()
.filter(|&&x| x)
.count();
if split_args_count != 1 {
return Err(pyo3::exceptions::PyValueError::new_err(
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
));
}
let sizes = if let Some(ratios) = ratios {
SplitSizes::Percentages(ratios)
} else if let Some(counts) = counts {
SplitSizes::Counts(counts)
} else if let Some(fixed) = fixed {
SplitSizes::Fixed(fixed)
} else {
unreachable!("One of the split arguments must be provided");
};
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes }))
}
#[pyo3(signature = (columns, split_weights, *, discard_weight=0))]
pub fn split_hash(
slf: PyRefMut<'_, Self>,
columns: Vec<String>,
split_weights: Vec<u64>,
discard_weight: u64,
) -> PyResult<Self> {
slf.modify(|builder| {
builder.with_split_strategy(SplitStrategy::Hash {
columns,
split_weights,
discard_weight,
})
})
}
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None))]
pub fn split_sequential(
slf: PyRefMut<'_, Self>,
ratios: Option<Vec<f64>>,
counts: Option<Vec<u64>>,
fixed: Option<u64>,
) -> PyResult<Self> {
// Check that exactly one split type is provided
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
.iter()
.filter(|&&x| x)
.count();
if split_args_count != 1 {
return Err(pyo3::exceptions::PyValueError::new_err(
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
));
}
let sizes = if let Some(ratios) = ratios {
SplitSizes::Percentages(ratios)
} else if let Some(counts) = counts {
SplitSizes::Counts(counts)
} else if let Some(fixed) = fixed {
SplitSizes::Fixed(fixed)
} else {
unreachable!("One of the split arguments must be provided");
};
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes }))
}
pub fn split_calculated(slf: PyRefMut<'_, Self>, calculation: String) -> PyResult<Self> {
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Calculated { calculation }))
}
pub fn shuffle(
slf: PyRefMut<'_, Self>,
seed: Option<u64>,
clump_size: Option<u64>,
) -> PyResult<Self> {
slf.modify(|builder| {
builder.with_shuffle_strategy(ShuffleStrategy::Random { seed, clump_size })
})
}
pub fn filter(slf: PyRefMut<'_, Self>, filter: String) -> PyResult<Self> {
slf.modify(|builder| builder.with_filter(filter))
}
pub fn execute(slf: PyRefMut<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let mut state = slf.state.lock().unwrap();
let builder = state
.builder
.take()
.ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?;
future_into_py(slf.py(), async move {
let table = builder.build().await.infer_error()?;
Ok(Table::new(table))
})
}
}

View File

@@ -9,6 +9,7 @@ use arrow::array::Array;
use arrow::array::ArrayData;
use arrow::pyarrow::FromPyArrow;
use arrow::pyarrow::IntoPyArrow;
use arrow::pyarrow::ToPyArrow;
use lancedb::index::scalar::{
BooleanQuery, BoostQuery, FtsQuery, FullTextSearchQuery, MatchQuery, MultiMatchQuery, Occur,
Operator, PhraseQuery,
@@ -30,6 +31,7 @@ use pyo3::IntoPyObject;
use pyo3::PyAny;
use pyo3::PyRef;
use pyo3::PyResult;
use pyo3::Python;
use pyo3::{exceptions::PyRuntimeError, FromPyObject};
use pyo3::{
exceptions::{PyNotImplementedError, PyValueError},
@@ -445,6 +447,15 @@ impl Query {
})
}
#[pyo3(signature = ())]
pub fn output_schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let schema = inner.output_schema().await.infer_error()?;
Python::with_gil(|py| schema.to_pyarrow(py))
})
}
#[pyo3(signature = (max_batch_length=None, timeout=None))]
pub fn execute(
self_: PyRef<'_, Self>,
@@ -515,6 +526,15 @@ impl TakeQuery {
self.inner = self.inner.clone().with_row_id();
}
#[pyo3(signature = ())]
pub fn output_schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let schema = inner.output_schema().await.infer_error()?;
Python::with_gil(|py| schema.to_pyarrow(py))
})
}
#[pyo3(signature = (max_batch_length=None, timeout=None))]
pub fn execute(
self_: PyRef<'_, Self>,
@@ -601,6 +621,15 @@ impl FTSQuery {
self.inner = self.inner.clone().postfilter();
}
#[pyo3(signature = ())]
pub fn output_schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let schema = inner.output_schema().await.infer_error()?;
Python::with_gil(|py| schema.to_pyarrow(py))
})
}
#[pyo3(signature = (max_batch_length=None, timeout=None))]
pub fn execute(
self_: PyRef<'_, Self>,
@@ -771,6 +800,15 @@ impl VectorQuery {
self.inner = self.inner.clone().bypass_vector_index()
}
#[pyo3(signature = ())]
pub fn output_schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let schema = inner.output_schema().await.infer_error()?;
Python::with_gil(|py| schema.to_pyarrow(py))
})
}
#[pyo3(signature = (max_batch_length=None, timeout=None))]
pub fn execute(
self_: PyRef<'_, Self>,

View File

@@ -3,6 +3,7 @@
use std::{collections::HashMap, sync::Arc};
use crate::{
connection::Connection,
error::PythonErrorExt,
index::{extract_index_params, IndexConfig},
query::{Query, TakeQuery},
@@ -249,7 +250,7 @@ impl Table {
}
impl Table {
fn inner_ref(&self) -> PyResult<&LanceDbTable> {
pub(crate) fn inner_ref(&self) -> PyResult<&LanceDbTable> {
self.inner
.as_ref()
.ok_or_else(|| PyRuntimeError::new_err(format!("Table {} is closed", self.name)))
@@ -272,6 +273,13 @@ impl Table {
self.inner.take();
}
pub fn database(&self) -> PyResult<Connection> {
let inner = self.inner_ref()?.clone();
let inner_connection =
lancedb::Connection::new(inner.database().clone(), inner.embedding_registry().clone());
Ok(Connection::new(inner_connection))
}
pub fn schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {

View File

@@ -1,2 +1,2 @@
[toolchain]
channel = "1.86.0"
channel = "1.90.0"

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb"
version = "0.22.2-beta.1"
version = "0.22.3-beta.0"
edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true
@@ -11,10 +11,12 @@ rust-version.workspace = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
ahash = { workspace = true }
arrow = { workspace = true }
arrow-array = { workspace = true }
arrow-data = { workspace = true }
arrow-schema = { workspace = true }
arrow-select = { workspace = true }
arrow-ord = { workspace = true }
arrow-cast = { workspace = true }
arrow-ipc.workspace = true
@@ -24,12 +26,16 @@ datafusion-common.workspace = true
datafusion-execution.workspace = true
datafusion-expr.workspace = true
datafusion-physical-plan.workspace = true
datafusion.workspace = true
object_store = { workspace = true }
snafu = { workspace = true }
half = { workspace = true }
lazy_static.workspace = true
lance = { workspace = true }
lance-core = { workspace = true }
lance-datafusion.workspace = true
lance-datagen = { workspace = true }
lance-file = { workspace = true }
lance-io = { workspace = true }
lance-index = { workspace = true }
lance-table = { workspace = true }
@@ -37,6 +43,7 @@ lance-linalg = { workspace = true }
lance-testing = { workspace = true }
lance-encoding = { workspace = true }
lance-namespace = { workspace = true }
lance-namespace-impls = { workspace = true, features = ["dir", "rest"] }
moka = { workspace = true }
pin-project = { workspace = true }
tokio = { version = "1.23", features = ["rt-multi-thread"] }
@@ -46,11 +53,13 @@ bytes = "1"
futures.workspace = true
num-traits.workspace = true
url.workspace = true
rand.workspace = true
regex.workspace = true
serde = { version = "^1" }
serde_json = { version = "1" }
async-openai = { version = "0.20.0", optional = true }
serde_with = { version = "3.8.1" }
tempfile = "3.5.0"
aws-sdk-bedrockruntime = { version = "1.27.0", optional = true }
# For remote feature
reqwest = { version = "0.12.0", default-features = false, features = [
@@ -61,9 +70,8 @@ reqwest = { version = "0.12.0", default-features = false, features = [
"macos-system-configuration",
"stream",
], optional = true }
rand = { version = "0.9", features = ["small_rng"], optional = true }
http = { version = "1", optional = true } # Matching what is in reqwest
uuid = { version = "1.7.0", features = ["v4"], optional = true }
uuid = { version = "1.7.0", features = ["v4"] }
polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
polars = { version = ">=0.37,<0.40.0", optional = true }
hf-hub = { version = "0.4.1", optional = true, default-features = false, features = [
@@ -84,7 +92,6 @@ bytemuck_derive.workspace = true
[dev-dependencies]
anyhow = "1"
tempfile = "3.5.0"
rand = { version = "0.9", features = ["small_rng"] }
random_word = { version = "0.4.3", features = ["en"] }
uuid = { version = "1.7.0", features = ["v4"] }
walkdir = "2"
@@ -96,6 +103,7 @@ aws-smithy-runtime = { version = "1.9.1" }
datafusion.workspace = true
http-body = "1" # Matching reqwest
rstest = "0.23.0"
test-log = "0.2"
[features]
@@ -105,7 +113,7 @@ oss = ["lance/oss", "lance-io/oss"]
gcs = ["lance/gcp", "lance-io/gcp"]
azure = ["lance/azure", "lance-io/azure"]
dynamodb = ["lance/dynamodb", "aws"]
remote = ["dep:reqwest", "dep:http", "dep:rand", "dep:uuid"]
remote = ["dep:reqwest", "dep:http"]
fp16kernels = ["lance-linalg/fp16kernels"]
s3-test = []
bedrock = ["dep:aws-sdk-bedrockruntime"]

View File

@@ -7,6 +7,7 @@ pub use arrow_schema;
use datafusion_common::DataFusionError;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use futures::{Stream, StreamExt, TryStreamExt};
use lance_datagen::{BatchCount, BatchGeneratorBuilder, RowCount};
#[cfg(feature = "polars")]
use {crate::polars_arrow_convertors, polars::frame::ArrowChunk, polars::prelude::DataFrame};
@@ -161,6 +162,26 @@ impl IntoArrowStream for datafusion_physical_plan::SendableRecordBatchStream {
}
}
pub trait LanceDbDatagenExt {
fn into_ldb_stream(
self,
batch_size: RowCount,
num_batches: BatchCount,
) -> SendableRecordBatchStream;
}
impl LanceDbDatagenExt for BatchGeneratorBuilder {
fn into_ldb_stream(
self,
batch_size: RowCount,
num_batches: BatchCount,
) -> SendableRecordBatchStream {
let (stream, schema) = self.into_reader_stream(batch_size, num_batches);
let stream = stream.map_err(|err| Error::Arrow { source: err });
Box::pin(SimpleRecordBatchStream::new(stream, schema))
}
}
#[cfg(feature = "polars")]
/// An iterator of record batches formed from a Polars DataFrame.
pub struct PolarsDataFrameRecordBatchReader {

View File

@@ -19,7 +19,7 @@ use crate::database::listing::{
use crate::database::{
CloneTableRequest, CreateNamespaceRequest, CreateTableData, CreateTableMode,
CreateTableRequest, Database, DatabaseOptions, DropNamespaceRequest, ListNamespacesRequest,
OpenTableRequest, TableNamesRequest,
OpenTableRequest, ReadConsistency, TableNamesRequest,
};
use crate::embeddings::{
EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry, MemoryRegistry, WithEmbeddings,
@@ -152,6 +152,7 @@ impl CreateTableBuilder<true> {
let request = self.into_request()?;
Ok(Table::new_with_embedding_registry(
parent.create_table(request).await?,
parent,
embedding_registry,
))
}
@@ -211,9 +212,9 @@ impl CreateTableBuilder<false> {
/// Execute the create table operation
pub async fn execute(self) -> Result<Table> {
Ok(Table::new(
self.parent.clone().create_table(self.request).await?,
))
let parent = self.parent.clone();
let table = parent.create_table(self.request).await?;
Ok(Table::new(table, parent))
}
}
@@ -462,8 +463,10 @@ impl OpenTableBuilder {
/// Open the table
pub async fn execute(self) -> Result<Table> {
let table = self.parent.open_table(self.request).await?;
Ok(Table::new_with_embedding_registry(
self.parent.clone().open_table(self.request).await?,
table,
self.parent,
self.embedding_registry,
))
}
@@ -519,16 +522,15 @@ impl CloneTableBuilder {
/// Execute the clone operation
pub async fn execute(self) -> Result<Table> {
Ok(Table::new(
self.parent.clone().clone_table(self.request).await?,
))
let parent = self.parent.clone();
let table = parent.clone_table(self.request).await?;
Ok(Table::new(table, parent))
}
}
/// A connection to LanceDB
#[derive(Clone)]
pub struct Connection {
uri: String,
internal: Arc<dyn Database>,
embedding_registry: Arc<dyn EmbeddingRegistry>,
}
@@ -540,9 +542,19 @@ impl std::fmt::Display for Connection {
}
impl Connection {
pub fn new(
internal: Arc<dyn Database>,
embedding_registry: Arc<dyn EmbeddingRegistry>,
) -> Self {
Self {
internal,
embedding_registry,
}
}
/// Get the URI of the connection
pub fn uri(&self) -> &str {
self.uri.as_str()
self.internal.uri()
}
/// Get access to the underlying database
@@ -675,6 +687,11 @@ impl Connection {
.await
}
/// Get the read consistency of the connection
pub async fn read_consistency(&self) -> Result<ReadConsistency> {
self.internal.read_consistency().await
}
/// Drop a table in the database.
///
/// # Arguments
@@ -973,7 +990,6 @@ impl ConnectBuilder {
)?);
Ok(Connection {
internal,
uri: self.request.uri,
embedding_registry: self
.embedding_registry
.unwrap_or_else(|| Arc::new(MemoryRegistry::new())),
@@ -996,7 +1012,6 @@ impl ConnectBuilder {
let internal = Arc::new(ListingDatabase::connect_with_options(&self.request).await?);
Ok(Connection {
internal,
uri: self.request.uri,
embedding_registry: self
.embedding_registry
.unwrap_or_else(|| Arc::new(MemoryRegistry::new())),
@@ -1104,7 +1119,6 @@ impl ConnectNamespaceBuilder {
Ok(Connection {
internal,
uri: format!("namespace://{}", self.ns_impl),
embedding_registry: self
.embedding_registry
.unwrap_or_else(|| Arc::new(MemoryRegistry::new())),
@@ -1139,7 +1153,6 @@ mod test_utils {
let internal = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
Self {
internal,
uri: "db://test".to_string(),
embedding_registry: Arc::new(MemoryRegistry::new()),
}
}
@@ -1156,7 +1169,6 @@ mod test_utils {
));
Self {
internal,
uri: "db://test".to_string(),
embedding_registry: Arc::new(MemoryRegistry::new()),
}
}
@@ -1170,7 +1182,7 @@ mod tests {
use crate::database::listing::{ListingDatabaseOptions, NewTableConfig};
use crate::query::QueryBase;
use crate::query::{ExecutableQuery, QueryExecutionOptions};
use crate::test_connection::test_utils::new_test_connection;
use crate::test_utils::connection::new_test_connection;
use arrow::compute::concat_batches;
use arrow_array::RecordBatchReader;
use arrow_schema::{DataType, Field, Schema};
@@ -1187,7 +1199,7 @@ mod tests {
#[tokio::test]
async fn test_connect() {
let tc = new_test_connection().await.unwrap();
assert_eq!(tc.connection.uri, tc.uri);
assert_eq!(tc.connection.uri(), tc.uri);
}
#[cfg(not(windows))]
@@ -1208,7 +1220,7 @@ mod tests {
.await
.unwrap();
assert_eq!(db.uri, relative_uri.to_str().unwrap().to_string());
assert_eq!(db.uri(), relative_uri.to_str().unwrap().to_string());
}
#[tokio::test]

View File

@@ -16,6 +16,7 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use arrow_array::RecordBatchReader;
use async_trait::async_trait;
@@ -213,6 +214,20 @@ impl CloneTableRequest {
}
}
/// How long until a change is reflected from one Table instance to another
///
/// Tables are always internally consistent. If a write method is called on
/// a table instance it will be immediately visible in that same table instance.
pub enum ReadConsistency {
/// Changes will not be automatically propagated until the checkout_latest
/// method is called on the target table
Manual,
/// Changes will be propagated automatically within the given duration
Eventual(Duration),
/// Changes are immediately visible in target tables
Strong,
}
/// The `Database` trait defines the interface for database implementations.
///
/// A database is responsible for managing tables and their metadata.
@@ -220,6 +235,10 @@ impl CloneTableRequest {
pub trait Database:
Send + Sync + std::any::Any + std::fmt::Debug + std::fmt::Display + 'static
{
/// Get the uri of the database
fn uri(&self) -> &str;
/// Get the read consistency of the database
async fn read_consistency(&self) -> Result<ReadConsistency>;
/// List immediate child namespace names in the given namespace
async fn list_namespaces(&self, request: ListNamespacesRequest) -> Result<Vec<String>>;
/// Create a new namespace

View File

@@ -17,6 +17,7 @@ use object_store::local::LocalFileSystem;
use snafu::ResultExt;
use crate::connection::ConnectRequest;
use crate::database::ReadConsistency;
use crate::error::{CreateDirSnafu, Error, Result};
use crate::io::object_store::MirroringObjectStoreWrapper;
use crate::table::NativeTable;
@@ -598,6 +599,22 @@ impl Database for ListingDatabase {
Ok(Vec::new())
}
fn uri(&self) -> &str {
&self.uri
}
async fn read_consistency(&self) -> Result<ReadConsistency> {
if let Some(read_consistency_inverval) = self.read_consistency_interval {
if read_consistency_inverval.is_zero() {
Ok(ReadConsistency::Strong)
} else {
Ok(ReadConsistency::Eventual(read_consistency_inverval))
}
} else {
Ok(ReadConsistency::Manual)
}
}
async fn create_namespace(&self, _request: CreateNamespaceRequest) -> Result<()> {
Err(Error::NotSupported {
message: "Namespace operations are not supported for listing database".into(),
@@ -1249,7 +1266,8 @@ mod tests {
)
.unwrap();
let source_table_obj = Table::new(source_table.clone());
let db = Arc::new(db);
let source_table_obj = Table::new(source_table.clone(), db.clone());
source_table_obj
.add(Box::new(arrow_array::RecordBatchIterator::new(
vec![Ok(batch2)],
@@ -1320,7 +1338,8 @@ mod tests {
.unwrap();
// Create a tag for the current version
let source_table_obj = Table::new(source_table.clone());
let db = Arc::new(db);
let source_table_obj = Table::new(source_table.clone(), db.clone());
let mut tags = source_table_obj.tags().await.unwrap();
tags.create("v1.0", source_table.version().await.unwrap())
.await
@@ -1336,7 +1355,7 @@ mod tests {
)
.unwrap();
let source_table_obj = Table::new(source_table.clone());
let source_table_obj = Table::new(source_table.clone(), db.clone());
source_table_obj
.add(Box::new(arrow_array::RecordBatchIterator::new(
vec![Ok(batch2)],
@@ -1432,7 +1451,8 @@ mod tests {
)
.unwrap();
let cloned_table_obj = Table::new(cloned_table.clone());
let db = Arc::new(db);
let cloned_table_obj = Table::new(cloned_table.clone(), db.clone());
cloned_table_obj
.add(Box::new(arrow_array::RecordBatchIterator::new(
vec![Ok(batch_clone)],
@@ -1452,7 +1472,7 @@ mod tests {
)
.unwrap();
let source_table_obj = Table::new(source_table.clone());
let source_table_obj = Table::new(source_table.clone(), db);
source_table_obj
.add(Box::new(arrow_array::RecordBatchIterator::new(
vec![Ok(batch_source)],
@@ -1495,6 +1515,7 @@ mod tests {
.unwrap();
// Add more data to create new versions
let db = Arc::new(db);
for i in 0..3 {
let batch = RecordBatch::try_new(
schema.clone(),
@@ -1502,7 +1523,7 @@ mod tests {
)
.unwrap();
let source_table_obj = Table::new(source_table.clone());
let source_table_obj = Table::new(source_table.clone(), db.clone());
source_table_obj
.add(Box::new(arrow_array::RecordBatchIterator::new(
vec![Ok(batch)],

View File

@@ -8,17 +8,17 @@ use std::sync::Arc;
use async_trait::async_trait;
use lance_namespace::{
connect as connect_namespace,
models::{
CreateEmptyTableRequest, CreateNamespaceRequest, DescribeTableRequest,
DropNamespaceRequest, DropTableRequest, ListNamespacesRequest, ListTablesRequest,
},
LanceNamespace,
};
use lance_namespace_impls::connect::connect as connect_namespace;
use crate::connection::ConnectRequest;
use crate::database::listing::ListingDatabase;
use crate::error::{Error, Result};
use crate::{connection::ConnectRequest, database::ReadConsistency};
use super::{
BaseTable, CloneTableRequest, CreateNamespaceRequest as DbCreateNamespaceRequest,
@@ -36,6 +36,8 @@ pub struct LanceNamespaceDatabase {
read_consistency_interval: Option<std::time::Duration>,
// Optional session for object stores and caching
session: Option<Arc<lance::session::Session>>,
// database URI
uri: String,
}
impl LanceNamespaceDatabase {
@@ -57,6 +59,7 @@ impl LanceNamespaceDatabase {
storage_options,
read_consistency_interval,
session,
uri: format!("namespace://{}", ns_impl),
})
}
@@ -130,6 +133,22 @@ impl std::fmt::Display for LanceNamespaceDatabase {
#[async_trait]
impl Database for LanceNamespaceDatabase {
fn uri(&self) -> &str {
&self.uri
}
async fn read_consistency(&self) -> Result<ReadConsistency> {
if let Some(read_consistency_inverval) = self.read_consistency_interval {
if read_consistency_inverval.is_zero() {
Ok(ReadConsistency::Strong)
} else {
Ok(ReadConsistency::Eventual(read_consistency_inverval))
}
} else {
Ok(ReadConsistency::Manual)
}
}
async fn list_namespaces(&self, request: DbListNamespacesRequest) -> Result<Vec<String>> {
let ns_request = ListNamespacesRequest {
id: if request.namespace.is_empty() {
@@ -261,7 +280,7 @@ impl Database for LanceNamespaceDatabase {
return listing_db
.open_table(OpenTableRequest {
name: request.name.clone(),
namespace: request.namespace.clone(),
namespace: vec![],
index_cache_size: None,
lance_read_params: None,
})
@@ -305,7 +324,14 @@ impl Database for LanceNamespaceDatabase {
)
.await?;
listing_db.create_table(request).await
let create_request = DbCreateTableRequest {
name: request.name,
namespace: vec![],
data: request.data,
mode: request.mode,
write_options: request.write_options,
};
listing_db.create_table(create_request).await
}
async fn open_table(&self, request: OpenTableRequest) -> Result<Arc<dyn BaseTable>> {
@@ -332,7 +358,13 @@ impl Database for LanceNamespaceDatabase {
.create_listing_database(&request.name, &location, response.storage_options)
.await?;
listing_db.open_table(request).await
let open_request = OpenTableRequest {
name: request.name.clone(),
namespace: vec![],
index_cache_size: request.index_cache_size,
lance_read_params: request.lance_read_params,
};
listing_db.open_table(open_request).await
}
async fn clone_table(&self, _request: CloneTableRequest) -> Result<Arc<dyn BaseTable>> {

View File

@@ -0,0 +1,4 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
pub mod permutation;

View File

@@ -0,0 +1,18 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! Contains the [PermutationBuilder] to create a permutation "view" of an existing table.
//!
//! A permutation view can apply a filter, divide the data into splits, and shuffle the data.
//! The permutation table only stores the split ids and row ids. It is not a materialized copy of
//! the underlying data and can be very lightweight.
//!
//! Building a permutation table should be fairly quick (it is an O(N) operation where N is
//! the number of rows in the base table) and memory efficient, even for billions or trillions
//! of rows.
pub mod builder;
pub mod reader;
pub mod shuffle;
pub mod split;
pub mod util;

View File

@@ -0,0 +1,326 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::Arc;
use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion_execution::{disk_manager::DiskManagerBuilder, runtime_env::RuntimeEnvBuilder};
use datafusion_expr::col;
use futures::TryStreamExt;
use lance_core::ROW_ID;
use lance_datafusion::exec::SessionContextExt;
use crate::{
arrow::{SendableRecordBatchStream, SendableRecordBatchStreamExt, SimpleRecordBatchStream},
connect,
database::{CreateTableData, CreateTableRequest, Database},
dataloader::permutation::{
shuffle::{Shuffler, ShufflerConfig},
split::{SplitStrategy, Splitter, SPLIT_ID_COLUMN},
util::{rename_column, TemporaryDirectory},
},
query::{ExecutableQuery, QueryBase},
Error, Result, Table,
};
pub const SRC_ROW_ID_COL: &str = "row_id";
/// Where to store the permutation table
#[derive(Debug, Clone, Default)]
enum PermutationDestination {
/// The permutation table is a temporary table in memory
#[default]
Temporary,
/// The permutation table is a permanent table in a database
Permanent(Arc<dyn Database>, String),
}
/// Configuration for creating a permutation table
#[derive(Debug, Default)]
pub struct PermutationConfig {
/// Splitting configuration
split_strategy: SplitStrategy,
/// Shuffle strategy
shuffle_strategy: ShuffleStrategy,
/// Optional filter to apply to the base table
filter: Option<String>,
/// Directory to use for temporary files
temp_dir: TemporaryDirectory,
/// Destination
destination: PermutationDestination,
}
/// Strategy for shuffling the data.
#[derive(Debug, Clone)]
pub enum ShuffleStrategy {
/// The data is randomly shuffled
///
/// A seed can be provided to make the shuffle deterministic.
///
/// If a clump size is provided, then data will be shuffled in small blocks of contiguous rows.
/// This decreases the overall randomization but can improve I/O performance when reading from
/// cloud storage.
///
/// For example, a clump size of 16 will means we will shuffle blocks of 16 contiguous rows. This
/// will mean 16x fewer IOPS but these 16 rows will always be close together and this can influence
/// the performance of the model. Note: shuffling within clumps can still be done at read time but
/// this will only provide a local shuffle and not a global shuffle.
Random {
seed: Option<u64>,
clump_size: Option<u64>,
},
/// The data is not shuffled
///
/// This is useful for debugging and testing.
None,
}
impl Default for ShuffleStrategy {
fn default() -> Self {
Self::None
}
}
/// Builder for creating a permutation table.
///
/// A permutation table is a table that stores split assignments and a shuffled order of rows. This
/// can be used to create a permutation reader that reads rows in the order defined by the permutation.
///
/// The permutation table is not a materialized copy of the underlying data and can be very lightweight.
/// It is not a view of the underlying data and is not a copy of the data. It is a separate table that
/// stores just row id and split id.
pub struct PermutationBuilder {
config: PermutationConfig,
base_table: Table,
}
impl PermutationBuilder {
pub fn new(base_table: Table) -> Self {
Self {
config: PermutationConfig::default(),
base_table,
}
}
/// Configures the strategy for assigning rows to splits.
///
/// For example, it is common to create a test/train split of the data. Splits can also be used
/// to limit the number of rows. For example, to only use 10% of the data in a permutation you can
/// create a single split with 10% of the data.
///
/// Splits are _not_ required for parallel processing. A single split can be loaded in parallel across
/// multiple processes and multiple nodes.
///
/// The default is a single split that contains all rows.
pub fn with_split_strategy(mut self, split_strategy: SplitStrategy) -> Self {
self.config.split_strategy = split_strategy;
self
}
/// Configures the strategy for shuffling the data.
///
/// The default is to shuffle the data randomly at row-level granularity (no clump size) and
/// with a random seed.
pub fn with_shuffle_strategy(mut self, shuffle_strategy: ShuffleStrategy) -> Self {
self.config.shuffle_strategy = shuffle_strategy;
self
}
/// Configures a filter to apply to the base table.
///
/// Only rows matching the filter will be included in the permutation.
pub fn with_filter(mut self, filter: String) -> Self {
self.config.filter = Some(filter);
self
}
/// Configures the directory to use for temporary files.
///
/// The default is to use the operating system's default temporary directory.
pub fn with_temp_dir(mut self, temp_dir: TemporaryDirectory) -> Self {
self.config.temp_dir = temp_dir;
self
}
/// Stores the permutation as a table in a database
///
/// By default, the permutation is stored in memory. If this method is called then
/// the permutation will be stored as a table in the given database.
pub fn persist(mut self, database: Arc<dyn Database>, table_name: String) -> Self {
self.config.destination = PermutationDestination::Permanent(database, table_name);
self
}
async fn sort_by_split_id(
&self,
data: SendableRecordBatchStream,
) -> Result<SendableRecordBatchStream> {
let ctx = SessionContext::new_with_config_rt(
SessionConfig::default(),
RuntimeEnvBuilder::new()
.with_memory_limit(100 * 1024 * 1024, 1.0)
.with_disk_manager_builder(
DiskManagerBuilder::default()
.with_mode(self.config.temp_dir.to_disk_manager_mode()),
)
.build_arc()
.unwrap(),
);
let df = ctx
.read_one_shot(data.into_df_stream())
.map_err(|e| Error::Other {
message: format!("Failed to setup sort by split id: {}", e),
source: Some(e.into()),
})?;
let df_stream = df
.sort_by(vec![col(SPLIT_ID_COLUMN)])
.map_err(|e| Error::Other {
message: format!("Failed to plan sort by split id: {}", e),
source: Some(e.into()),
})?
.execute_stream()
.await
.map_err(|e| Error::Other {
message: format!("Failed to sort by split id: {}", e),
source: Some(e.into()),
})?;
let schema = df_stream.schema();
let stream = df_stream.map_err(|e| Error::Other {
message: format!("Failed to execute sort by split id: {}", e),
source: Some(e.into()),
});
Ok(Box::pin(SimpleRecordBatchStream { schema, stream }))
}
/// Builds the permutation table and stores it in the given database.
pub async fn build(self) -> Result<Table> {
// First pass, apply filter and load row ids
let mut rows = self.base_table.query().with_row_id();
if let Some(filter) = &self.config.filter {
rows = rows.only_if(filter);
}
let splitter = Splitter::new(
self.config.temp_dir.clone(),
self.config.split_strategy.clone(),
);
let mut needs_sort = !splitter.orders_by_split_id();
// Might need to load additional columns to calculate splits (e.g. hash columns or calculated
// split id)
rows = splitter.project(rows);
let num_rows = self
.base_table
.count_rows(self.config.filter.clone())
.await? as u64;
// Apply splits
let rows = rows.execute().await?;
let split_data = splitter.apply(rows, num_rows).await?;
// Shuffle data if requested
let shuffled = match self.config.shuffle_strategy {
ShuffleStrategy::None => split_data,
ShuffleStrategy::Random { seed, clump_size } => {
let shuffler = Shuffler::new(ShufflerConfig {
seed,
clump_size,
temp_dir: self.config.temp_dir.clone(),
max_rows_per_file: 10 * 1024 * 1024,
});
shuffler.shuffle(split_data, num_rows).await?
}
};
// We want the final permutation to be sorted by the split id. If we shuffled or if
// the split was not assigned sequentially then we need to sort the data.
needs_sort |= !matches!(self.config.shuffle_strategy, ShuffleStrategy::None);
let sorted = if needs_sort {
self.sort_by_split_id(shuffled).await?
} else {
shuffled
};
// Rename _rowid to row_id
let renamed = rename_column(sorted, ROW_ID, SRC_ROW_ID_COL)?;
let (name, database) = match &self.config.destination {
PermutationDestination::Permanent(database, table_name) => {
(table_name.as_str(), database.clone())
}
PermutationDestination::Temporary => {
let conn = connect("memory:///").execute().await?;
("permutation", conn.database().clone())
}
};
let create_table_request =
CreateTableRequest::new(name.to_string(), CreateTableData::StreamingData(renamed));
let table = database.create_table(create_table_request).await?;
Ok(Table::new(table, database))
}
}
#[cfg(test)]
mod tests {
use arrow::datatypes::Int32Type;
use lance_datagen::{BatchCount, RowCount};
use crate::{arrow::LanceDbDatagenExt, connect, dataloader::permutation::split::SplitSizes};
use super::*;
#[tokio::test]
async fn test_permutation_builder() {
let temp_dir = tempfile::tempdir().unwrap();
let db = connect(temp_dir.path().to_str().unwrap())
.execute()
.await
.unwrap();
let initial_data = lance_datagen::gen_batch()
.col("some_value", lance_datagen::array::step::<Int32Type>())
.into_ldb_stream(RowCount::from(100), BatchCount::from(10));
let data_table = db
.create_table_streaming("mytbl", initial_data)
.execute()
.await
.unwrap();
let permutation_table = PermutationBuilder::new(data_table.clone())
.with_filter("some_value > 57".to_string())
.with_split_strategy(SplitStrategy::Random {
seed: Some(42),
sizes: SplitSizes::Percentages(vec![0.05, 0.30]),
})
.build()
.await
.unwrap();
println!("permutation_table: {:?}", permutation_table);
// Potentially brittle seed-dependent values below
assert_eq!(permutation_table.count_rows(None).await.unwrap(), 330);
assert_eq!(
permutation_table
.count_rows(Some("split_id = 0".to_string()))
.await
.unwrap(),
47
);
assert_eq!(
permutation_table
.count_rows(Some("split_id = 1".to_string()))
.await
.unwrap(),
283
);
}
}

View File

@@ -0,0 +1,384 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! Row ID-based views for LanceDB tables
//!
//! This module provides functionality for creating views that are based on specific row IDs.
//! The `IdView` allows you to create a virtual table that contains only
//! the rows from a source table that correspond to row IDs stored in a separate table.
use crate::arrow::{SendableRecordBatchStream, SimpleRecordBatchStream};
use crate::dataloader::permutation::builder::SRC_ROW_ID_COL;
use crate::dataloader::permutation::split::SPLIT_ID_COLUMN;
use crate::error::Error;
use crate::query::{QueryExecutionOptions, QueryFilter, QueryRequest, Select};
use crate::table::{AnyQuery, BaseTable};
use crate::Result;
use arrow::array::AsArray;
use arrow::datatypes::UInt64Type;
use arrow_array::{RecordBatch, UInt64Array};
use futures::{StreamExt, TryStreamExt};
use lance::arrow::RecordBatchExt;
use lance::dataset::scanner::DatasetRecordBatchStream;
use lance::error::LanceOptionExt;
use lance_core::ROW_ID;
use std::collections::HashMap;
use std::sync::Arc;
/// Reads a permutation of a source table based on row IDs stored in a separate table
pub struct PermutationReader {
base_table: Arc<dyn BaseTable>,
permutation_table: Arc<dyn BaseTable>,
}
impl std::fmt::Debug for PermutationReader {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"PermutationReader(base={}, permutation={})",
self.base_table.name(),
self.permutation_table.name(),
)
}
}
impl PermutationReader {
/// Create a new PermutationReader
pub async fn try_new(
base_table: Arc<dyn BaseTable>,
permutation_table: Arc<dyn BaseTable>,
) -> Result<Self> {
let schema = permutation_table.schema().await?;
if schema.column_with_name(SRC_ROW_ID_COL).is_none() {
return Err(Error::InvalidInput {
message: "Permutation table must contain a column named row_id".to_string(),
});
}
if schema.column_with_name(SPLIT_ID_COLUMN).is_none() {
return Err(Error::InvalidInput {
message: "Permutation table must contain a column named split_id".to_string(),
});
}
Ok(Self {
base_table,
permutation_table,
})
}
fn is_sorted_already<'a, T: Iterator<Item = &'a u64>>(iter: T) -> bool {
for (expected, idx) in iter.enumerate() {
if *idx != expected as u64 {
return false;
}
}
true
}
async fn load_batch(
base_table: &Arc<dyn BaseTable>,
row_ids: RecordBatch,
selection: Select,
has_row_id: bool,
) -> Result<RecordBatch> {
let num_rows = row_ids.num_rows();
let row_ids = row_ids
.column(0)
.as_primitive_opt::<UInt64Type>()
.expect_ok()?
.values();
let filter = format!(
"_rowid in ({})",
row_ids
.iter()
.map(|o| o.to_string())
.collect::<Vec<_>>()
.join(",")
);
let base_query = QueryRequest {
filter: Some(QueryFilter::Sql(filter)),
select: selection,
with_row_id: true,
..Default::default()
};
let mut data = base_table
.query(
&AnyQuery::Query(base_query),
QueryExecutionOptions {
max_batch_length: num_rows as u32,
..Default::default()
},
)
.await?;
let Some(batch) = data.try_next().await? else {
return Err(Error::InvalidInput {
message: "Base table returned no batches".to_string(),
});
};
if data.try_next().await?.is_some() {
return Err(Error::InvalidInput {
message: "Base table returned more than one batch".to_string(),
});
}
if batch.num_rows() != num_rows {
return Err(Error::InvalidInput {
message: "Base table returned different number of rows than the number of row IDs"
.to_string(),
});
}
// There is no guarantee the result order will match the order provided
// so may need to restore order
let actual_row_ids = batch
.column_by_name(ROW_ID)
.expect_ok()?
.as_primitive_opt::<UInt64Type>()
.expect_ok()?
.values();
// Map from row id to order in batch, used to restore original ordering
let ordering = actual_row_ids
.iter()
.copied()
.enumerate()
.map(|(i, o)| (o, i as u64))
.collect::<HashMap<_, _>>();
let desired_idx_order = row_ids
.iter()
.map(|o| ordering.get(o).copied().expect_ok().map_err(Error::from))
.collect::<Result<Vec<_>>>()?;
let ordered_batch = if Self::is_sorted_already(desired_idx_order.iter()) {
// Fast path if already sorted, important as data may be large and
// re-ordering could be expensive
batch
} else {
let desired_idx_order = UInt64Array::from(desired_idx_order);
arrow_select::take::take_record_batch(&batch, &desired_idx_order)?
};
if has_row_id {
Ok(ordered_batch)
} else {
// The user didn't ask for row id, we needed it for ordering the data, but now we drop it
Ok(ordered_batch.drop_column(ROW_ID)?)
}
}
async fn row_ids_to_batches(
base_table: Arc<dyn BaseTable>,
row_ids: DatasetRecordBatchStream,
selection: Select,
) -> Result<SendableRecordBatchStream> {
let has_row_id = Self::has_row_id(&selection)?;
let mut stream = row_ids
.map_err(Error::from)
.try_filter_map(move |batch| {
let selection = selection.clone();
let base_table = base_table.clone();
async move {
Self::load_batch(&base_table, batch, selection, has_row_id)
.await
.map(Some)
}
})
.boxed();
// Need to read out first batch to get schema
let Some(first_batch) = stream.try_next().await? else {
return Err(Error::InvalidInput {
message: "Permutation was empty".to_string(),
});
};
let schema = first_batch.schema();
let stream = futures::stream::once(std::future::ready(Ok(first_batch))).chain(stream);
Ok(Box::pin(SimpleRecordBatchStream::new(stream, schema)))
}
fn has_row_id(selection: &Select) -> Result<bool> {
match selection {
Select::All => {
// _rowid is a system column and is not included in Select::All
Ok(false)
}
Select::Columns(columns) => Ok(columns.contains(&ROW_ID.to_string())),
Select::Dynamic(columns) => {
for column in columns {
if column.0 == ROW_ID {
if column.1 == ROW_ID {
return Ok(true);
} else {
return Err(Error::InvalidInput {
message: format!(
"Dynamic column {} cannot be used to select _rowid",
column.1
),
});
}
}
}
Ok(false)
}
}
}
pub async fn read_split(
&self,
split: u64,
selection: Select,
execution_options: QueryExecutionOptions,
) -> Result<SendableRecordBatchStream> {
let row_ids = self
.permutation_table
.query(
&AnyQuery::Query(QueryRequest {
select: Select::Columns(vec![SRC_ROW_ID_COL.to_string()]),
filter: Some(QueryFilter::Sql(format!("{} = {}", SPLIT_ID_COLUMN, split))),
..Default::default()
}),
execution_options,
)
.await?;
Self::row_ids_to_batches(self.base_table.clone(), row_ids, selection).await
}
}
#[cfg(test)]
mod tests {
use arrow::datatypes::Int32Type;
use arrow_array::{ArrowPrimitiveType, RecordBatch, UInt64Array};
use arrow_schema::{DataType, Field, Schema};
use lance_datagen::{BatchCount, RowCount};
use rand::seq::SliceRandom;
use crate::{
arrow::SendableRecordBatchStream,
query::{ExecutableQuery, QueryBase},
test_utils::datagen::{virtual_table, LanceDbDatagenExt},
Table,
};
use super::*;
async fn collect_from_stream<T: ArrowPrimitiveType>(
mut stream: SendableRecordBatchStream,
column: &str,
) -> Vec<T::Native> {
let mut row_ids = Vec::new();
while let Some(batch) = stream.try_next().await.unwrap() {
let col_idx = batch.schema().index_of(column).unwrap();
row_ids.extend(batch.column(col_idx).as_primitive::<T>().values().to_vec());
}
row_ids
}
async fn collect_column<T: ArrowPrimitiveType>(table: &Table, column: &str) -> Vec<T::Native> {
collect_from_stream::<T>(
table
.query()
.select(Select::Columns(vec![column.to_string()]))
.execute()
.await
.unwrap(),
column,
)
.await
}
#[tokio::test]
async fn test_permutation_reader() {
let base_table = lance_datagen::gen_batch()
.col("idx", lance_datagen::array::step::<Int32Type>())
.col("other_col", lance_datagen::array::step::<UInt64Type>())
.into_mem_table("tbl", RowCount::from(9), BatchCount::from(1))
.await;
let mut row_ids = collect_column::<UInt64Type>(&base_table, "_rowid").await;
row_ids.shuffle(&mut rand::rng());
// Put the last two rows in split 1
let split_ids = UInt64Array::from_iter_values(
std::iter::repeat_n(0, row_ids.len() - 2).chain(std::iter::repeat_n(1, 2)),
);
let permutation_batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("row_id", DataType::UInt64, false),
Field::new(SPLIT_ID_COLUMN, DataType::UInt64, false),
])),
vec![
Arc::new(UInt64Array::from(row_ids.clone())),
Arc::new(split_ids),
],
)
.unwrap();
let row_ids_table = virtual_table("row_ids", &permutation_batch).await;
let reader = PermutationReader::try_new(
base_table.base_table().clone(),
row_ids_table.base_table().clone(),
)
.await
.unwrap();
// Read split 0
let mut stream = reader
.read_split(
0,
Select::All,
QueryExecutionOptions {
max_batch_length: 3,
..Default::default()
},
)
.await
.unwrap();
assert_eq!(stream.schema(), base_table.schema().await.unwrap());
let check_batch = async |stream: &mut SendableRecordBatchStream,
expected_values: &[u64]| {
let batch = stream.try_next().await.unwrap().unwrap();
assert_eq!(batch.num_rows(), expected_values.len());
assert_eq!(
batch.column(0).as_primitive::<Int32Type>().values(),
&expected_values
.iter()
.map(|o| *o as i32)
.collect::<Vec<_>>()
);
assert_eq!(
batch.column(1).as_primitive::<UInt64Type>().values(),
&expected_values
);
};
check_batch(&mut stream, &row_ids[0..3]).await;
check_batch(&mut stream, &row_ids[3..6]).await;
check_batch(&mut stream, &row_ids[6..7]).await;
assert!(stream.try_next().await.unwrap().is_none());
// Read split 1
let mut stream = reader
.read_split(
1,
Select::All,
QueryExecutionOptions {
max_batch_length: 3,
..Default::default()
},
)
.await
.unwrap();
check_batch(&mut stream, &row_ids[7..9]).await;
assert!(stream.try_next().await.unwrap().is_none());
}
}

View File

@@ -0,0 +1,475 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::{Arc, Mutex};
use arrow::compute::concat_batches;
use arrow_array::{RecordBatch, UInt64Array};
use futures::{StreamExt, TryStreamExt};
use lance::io::ObjectStore;
use lance_core::{cache::LanceCache, utils::futures::FinallyStreamExt};
use lance_encoding::decoder::DecoderPlugins;
use lance_file::v2::{
reader::{FileReader, FileReaderOptions},
writer::{FileWriter, FileWriterOptions},
};
use lance_index::scalar::IndexReader;
use lance_io::{
scheduler::{ScanScheduler, SchedulerConfig},
utils::CachedFileSize,
};
use rand::{seq::SliceRandom, Rng, RngCore};
use crate::{
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
dataloader::permutation::util::{non_crypto_rng, TemporaryDirectory},
Error, Result,
};
#[derive(Debug, Clone)]
pub struct ShufflerConfig {
/// An optional seed to make the shuffle deterministic
pub seed: Option<u64>,
/// The maximum number of rows to write to a single file
///
/// The shuffler will need to hold at least this many rows in memory. Setting this value
/// extremely large could cause the shuffler to use a lot of memory (depending on row size).
///
/// However, the shuffler will also need to hold total_num_rows / max_rows_per_file file
/// writers in memory. Each of these will consume some amount of data for column write buffers.
/// So setting this value too small could _also_ cause the shuffler to use a lot of memory and
/// open file handles.
pub max_rows_per_file: u64,
/// The temporary directory to use for writing files
pub temp_dir: TemporaryDirectory,
/// The size of the clumps to shuffle within
///
/// If a clump size is provided, then data will be shuffled in small blocks of contiguous rows.
/// This decreases the overall randomization but can improve I/O performance when reading from
/// cloud storage.
pub clump_size: Option<u64>,
}
impl Default for ShufflerConfig {
fn default() -> Self {
Self {
max_rows_per_file: 1024 * 1024,
seed: Option::default(),
temp_dir: TemporaryDirectory::default(),
clump_size: None,
}
}
}
/// A shuffler that can shuffle a stream of record batches
///
/// To do this the stream is consumed and written to temporary files. A new stream is returned
/// which returns the shuffled data from the temporary files.
///
/// If there are fewer than max_rows_per_file rows in the input stream, then the shuffler will not
/// write any files and will instead perform an in-memory shuffle.
///
/// The number of rows in the input stream must be known in advance.
pub struct Shuffler {
config: ShufflerConfig,
id: String,
}
impl Shuffler {
pub fn new(config: ShufflerConfig) -> Self {
let id = uuid::Uuid::new_v4().to_string();
Self { config, id }
}
/// Shuffles a single batch of data in memory
fn shuffle_batch(
batch: &RecordBatch,
rng: &mut dyn RngCore,
clump_size: u64,
) -> Result<RecordBatch> {
let num_clumps = (batch.num_rows() as u64).div_ceil(clump_size);
let mut indices = (0..num_clumps).collect::<Vec<_>>();
indices.shuffle(rng);
let indices = if clump_size == 1 {
UInt64Array::from(indices)
} else {
UInt64Array::from_iter_values(indices.iter().flat_map(|&clump_index| {
if clump_index == num_clumps - 1 {
clump_index * clump_size..batch.num_rows() as u64
} else {
clump_index * clump_size..(clump_index + 1) * clump_size
}
}))
};
Ok(arrow::compute::take_record_batch(batch, &indices)?)
}
async fn in_memory_shuffle(
&self,
data: SendableRecordBatchStream,
mut rng: Box<dyn RngCore + Send>,
) -> Result<SendableRecordBatchStream> {
let schema = data.schema();
let batches = data.try_collect::<Vec<_>>().await?;
let batch = concat_batches(&schema, &batches)?;
let shuffled = Self::shuffle_batch(&batch, &mut rng, self.config.clump_size.unwrap_or(1))?;
log::debug!("Shuffle job {}: in-memory shuffle complete", self.id);
Ok(Box::pin(SimpleRecordBatchStream::new(
futures::stream::once(async move { Ok(shuffled) }),
schema,
)))
}
async fn do_shuffle(
&self,
mut data: SendableRecordBatchStream,
num_rows: u64,
mut rng: Box<dyn RngCore + Send>,
) -> Result<SendableRecordBatchStream> {
let num_files = num_rows.div_ceil(self.config.max_rows_per_file);
let temp_dir = self.config.temp_dir.create_temp_dir()?;
let tmp_dir = temp_dir.path().to_path_buf();
let clump_size = self.config.clump_size.unwrap_or(1);
if clump_size == 0 {
return Err(Error::InvalidInput {
message: "clump size must be greater than 0".to_string(),
});
}
let object_store = ObjectStore::local();
let arrow_schema = data.schema();
let schema = lance::datatypes::Schema::try_from(arrow_schema.as_ref())?;
// Create file writers
let mut file_writers = Vec::with_capacity(num_files as usize);
for file_index in 0..num_files {
let path = tmp_dir.join(format!("shuffle_{}_{file_index}.lance", self.id));
let path =
object_store::path::Path::from_absolute_path(path).map_err(|err| Error::Other {
message: format!("Failed to create temporary file: {}", err),
source: None,
})?;
let object_writer = object_store.create(&path).await?;
let writer =
FileWriter::try_new(object_writer, schema.clone(), FileWriterOptions::default())?;
file_writers.push(writer);
}
let mut num_rows_seen = 0;
// Randomly distribute clumps to files
while let Some(batch) = data.try_next().await? {
num_rows_seen += batch.num_rows() as u64;
let is_last = num_rows_seen == num_rows;
if num_rows_seen > num_rows {
return Err(Error::Runtime {
message: format!("Expected {} rows but saw {} rows", num_rows, num_rows_seen),
});
}
// This is kind of an annoying limitation but if we allow runt clumps from batches then
// clumps will get unaligned and we will mess up the clumps when we do the in-memory
// shuffle step. If this is a problem we can probably figure out a better way to do this.
if !is_last && batch.num_rows() as u64 % clump_size != 0 {
return Err(Error::Runtime {
message: format!(
"Expected batch size ({}) to be divisible by clump size ({})",
batch.num_rows(),
clump_size
),
});
}
let num_clumps = (batch.num_rows() as u64).div_ceil(clump_size);
let mut batch_offsets_for_files =
vec![Vec::<u64>::with_capacity(batch.num_rows()); num_files as usize];
// Partition the batch randomly and write to the appropriate accumulator
for clump_offset in 0..num_clumps {
let clump_start = clump_offset * clump_size;
let num_rows_in_clump = clump_size.min(batch.num_rows() as u64 - clump_start);
let clump_end = clump_start + num_rows_in_clump;
let file_index = rng.random_range(0..num_files);
batch_offsets_for_files[file_index as usize].extend(clump_start..clump_end);
}
for (file_index, batch_offsets) in batch_offsets_for_files.into_iter().enumerate() {
if batch_offsets.is_empty() {
continue;
}
let indices = UInt64Array::from(batch_offsets);
let partition = arrow::compute::take_record_batch(&batch, &indices)?;
file_writers[file_index].write_batch(&partition).await?;
}
}
// Finish writing files
for (file_idx, mut writer) in file_writers.into_iter().enumerate() {
let num_written = writer.finish().await?;
log::debug!(
"Shuffle job {}: wrote {} rows to file {}",
self.id,
num_written,
file_idx
);
}
let scheduler_config = SchedulerConfig::max_bandwidth(&object_store);
let scan_scheduler = ScanScheduler::new(Arc::new(object_store), scheduler_config);
let job_id = self.id.clone();
let rng = Arc::new(Mutex::new(rng));
// Second pass, read each file as a single batch and shuffle
let stream = futures::stream::iter(0..num_files)
.then(move |file_index| {
let scan_scheduler = scan_scheduler.clone();
let rng = rng.clone();
let tmp_dir = tmp_dir.clone();
let job_id = job_id.clone();
async move {
let path = tmp_dir.join(format!("shuffle_{}_{file_index}.lance", job_id));
let path = object_store::path::Path::from_absolute_path(path).unwrap();
let file_scheduler = scan_scheduler
.open_file(&path, &CachedFileSize::unknown())
.await?;
let reader = FileReader::try_open(
file_scheduler,
None,
Arc::<DecoderPlugins>::default(),
&LanceCache::no_cache(),
FileReaderOptions::default(),
)
.await?;
// Need to read the entire file in a single batch for in-memory shuffling
let batch = reader.read_record_batch(0, reader.num_rows()).await?;
let mut rng = rng.lock().unwrap();
Self::shuffle_batch(&batch, &mut rng, clump_size)
}
})
.finally(move || drop(temp_dir))
.boxed();
Ok(Box::pin(SimpleRecordBatchStream::new(stream, arrow_schema)))
}
pub async fn shuffle(
self,
data: SendableRecordBatchStream,
num_rows: u64,
) -> Result<SendableRecordBatchStream> {
log::debug!(
"Shuffle job {}: shuffling {} rows and {} columns",
self.id,
num_rows,
data.schema().fields.len()
);
let rng = non_crypto_rng(&self.config.seed);
if num_rows < self.config.max_rows_per_file {
return self.in_memory_shuffle(data, rng).await;
}
self.do_shuffle(data, num_rows, rng).await
}
}
#[cfg(test)]
mod tests {
use crate::arrow::LanceDbDatagenExt;
use super::*;
use arrow::{array::AsArray, datatypes::Int32Type};
use datafusion::prelude::SessionContext;
use datafusion_expr::col;
use futures::TryStreamExt;
use lance_datagen::{BatchCount, BatchGeneratorBuilder, ByteCount, RowCount, Seed};
use rand::{rngs::SmallRng, SeedableRng};
fn test_gen() -> BatchGeneratorBuilder {
lance_datagen::gen_batch()
.with_seed(Seed::from(42))
.col("id", lance_datagen::array::step::<Int32Type>())
.col(
"name",
lance_datagen::array::rand_utf8(ByteCount::from(10), false),
)
}
fn create_test_batch(size: RowCount) -> RecordBatch {
test_gen().into_batch_rows(size).unwrap()
}
fn create_test_stream(
num_batches: BatchCount,
batch_size: RowCount,
) -> SendableRecordBatchStream {
test_gen().into_ldb_stream(batch_size, num_batches)
}
#[test]
fn test_shuffle_batch_deterministic() {
let batch = create_test_batch(RowCount::from(10));
let mut rng1 = SmallRng::seed_from_u64(42);
let mut rng2 = SmallRng::seed_from_u64(42);
let shuffled1 = Shuffler::shuffle_batch(&batch, &mut rng1, 1).unwrap();
let shuffled2 = Shuffler::shuffle_batch(&batch, &mut rng2, 1).unwrap();
// Same seed should produce same shuffle
assert_eq!(shuffled1, shuffled2);
}
#[test]
fn test_shuffle_with_clumps() {
let batch = create_test_batch(RowCount::from(10));
let mut rng = SmallRng::seed_from_u64(42);
let shuffled = Shuffler::shuffle_batch(&batch, &mut rng, 3).unwrap();
let values = shuffled.column(0).as_primitive::<Int32Type>();
let mut iter = values.into_iter().map(|o| o.unwrap());
let mut frag_seen = false;
let mut clumps_seen = 0;
while let Some(first) = iter.next() {
// 9 is the last value and not a full clump
if first != 9 {
// Otherwise we should have a full clump
let second = iter.next().unwrap();
let third = iter.next().unwrap();
assert_eq!(first + 1, second);
assert_eq!(first + 2, third);
clumps_seen += 1;
} else {
frag_seen = true;
}
}
assert_eq!(clumps_seen, 3);
assert!(frag_seen);
}
async fn sort_batch(batch: RecordBatch) -> RecordBatch {
let ctx = SessionContext::new();
let df = ctx.read_batch(batch).unwrap();
let sorted = df.sort_by(vec![col("id")]).unwrap();
let batches = sorted.collect().await.unwrap();
let schema = batches[0].schema();
concat_batches(&schema, &batches).unwrap()
}
#[tokio::test]
async fn test_shuffle_batch_preserves_data() {
let batch = create_test_batch(RowCount::from(100));
let mut rng = SmallRng::seed_from_u64(42);
let shuffled = Shuffler::shuffle_batch(&batch, &mut rng, 1).unwrap();
assert_ne!(shuffled, batch);
let sorted = sort_batch(shuffled).await;
assert_eq!(sorted, batch);
}
#[test]
fn test_shuffle_batch_empty() {
let batch = create_test_batch(RowCount::from(0));
let mut rng = SmallRng::seed_from_u64(42);
let shuffled = Shuffler::shuffle_batch(&batch, &mut rng, 1).unwrap();
assert_eq!(shuffled.num_rows(), 0);
}
#[tokio::test]
async fn test_in_memory_shuffle() {
let config = ShufflerConfig {
temp_dir: TemporaryDirectory::None,
..Default::default()
};
let shuffler = Shuffler::new(config);
let stream = create_test_stream(BatchCount::from(5), RowCount::from(20));
let result_stream = shuffler.shuffle(stream, 100).await.unwrap();
let result_batches: Vec<RecordBatch> = result_stream.try_collect().await.unwrap();
assert_eq!(result_batches.len(), 1);
let result_batch = result_batches.into_iter().next().unwrap();
let unshuffled_batches = create_test_stream(BatchCount::from(5), RowCount::from(20))
.try_collect::<Vec<_>>()
.await
.unwrap();
let schema = unshuffled_batches[0].schema();
let unshuffled_batch = concat_batches(&schema, &unshuffled_batches).unwrap();
let sorted = sort_batch(result_batch).await;
assert_eq!(unshuffled_batch, sorted);
}
#[tokio::test]
async fn test_external_shuffle() {
let config = ShufflerConfig {
max_rows_per_file: 100,
..Default::default()
};
let shuffler = Shuffler::new(config);
let stream = create_test_stream(BatchCount::from(5), RowCount::from(1000));
let result_stream = shuffler.shuffle(stream, 5000).await.unwrap();
let result_batches: Vec<RecordBatch> = result_stream.try_collect().await.unwrap();
let unshuffled_batches = create_test_stream(BatchCount::from(5), RowCount::from(1000))
.try_collect::<Vec<_>>()
.await
.unwrap();
let schema = unshuffled_batches[0].schema();
let unshuffled_batch = concat_batches(&schema, &unshuffled_batches).unwrap();
assert_eq!(result_batches.len(), 50);
let result_batch = concat_batches(&schema, &result_batches).unwrap();
let sorted = sort_batch(result_batch).await;
assert_eq!(unshuffled_batch, sorted);
}
#[test_log::test(tokio::test)]
async fn test_external_clump_shuffle() {
let config = ShufflerConfig {
max_rows_per_file: 100,
clump_size: Some(30),
..Default::default()
};
let shuffler = Shuffler::new(config);
// Batch size (900) must be multiple of clump size (30)
let stream = create_test_stream(BatchCount::from(5), RowCount::from(900));
let schema = stream.schema();
// Remove 10 rows from the last batch to simulate ending with partial clump
let mut batches = stream.try_collect::<Vec<_>>().await.unwrap();
let last_index = batches.len() - 1;
let sliced_last = batches[last_index].slice(0, 890);
batches[last_index] = sliced_last;
let stream = Box::pin(SimpleRecordBatchStream::new(
futures::stream::iter(batches).map(Ok).boxed(),
schema.clone(),
));
let result_stream = shuffler.shuffle(stream, 4490).await.unwrap();
let result_batches: Vec<RecordBatch> = result_stream.try_collect().await.unwrap();
let result_batch = concat_batches(&schema, &result_batches).unwrap();
let ids = result_batch.column(0).as_primitive::<Int32Type>();
let mut iter = ids.into_iter().map(|o| o.unwrap());
while let Some(first) = iter.next() {
let rows_left_in_clump = if first == 4470 { 19 } else { 29 };
let mut expected_next = first + 1;
for _ in 0..rows_left_in_clump {
let next = iter.next().unwrap();
assert_eq!(next, expected_next);
expected_next += 1;
}
}
}
}

View File

@@ -0,0 +1,804 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::{
iter,
sync::{
atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
Arc,
},
};
use arrow_array::{Array, BooleanArray, RecordBatch, UInt64Array};
use arrow_schema::{DataType, Field, Schema};
use datafusion_common::hash_utils::create_hashes;
use futures::{StreamExt, TryStreamExt};
use lance::arrow::SchemaExt;
use crate::{
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
dataloader::{
permutation::shuffle::{Shuffler, ShufflerConfig},
permutation::util::TemporaryDirectory,
},
query::{Query, QueryBase, Select},
Error, Result,
};
pub const SPLIT_ID_COLUMN: &str = "split_id";
/// Strategy for assigning rows to splits
#[derive(Debug, Clone)]
pub enum SplitStrategy {
/// All rows will have split id 0
NoSplit,
/// Rows will be randomly assigned to splits
///
/// A seed can be provided to make the assignment deterministic.
Random {
seed: Option<u64>,
sizes: SplitSizes,
},
/// Rows will be assigned to splits based on the values in the specified columns.
///
/// This will ensure rows are always assigned to the same split if the given columns do not change.
///
/// The `split_weights` are used to determine the approximate number of rows in each split. This
/// controls how we divide up the u64 hash space. However, it does not guarantee any particular division
/// of rows. For example, if all rows have identical hash values then all rows will be assigned to the same split
/// regardless of the weights.
///
/// The `discard_weight` controls what percentage of rows should be throw away. For example, if you want your
/// first split to have ~5% of your rows and the second split to have ~10% of your rows then you would set
/// split_weights to [1, 2] and discard weight to 17 (or you could set split_weights to [5, 10] and discard_weight
/// to 85). If you set discard_weight to 0 then all rows will be assigned to a split.
Hash {
columns: Vec<String>,
split_weights: Vec<u64>,
discard_weight: u64,
},
/// Rows will be assigned to splits sequentially.
///
/// The first N1 rows are assigned to split 1, the next N2 rows are assigned to split 2, etc.
///
/// This is mainly useful for debugging and testing.
Sequential { sizes: SplitSizes },
/// Rows will be assigned to splits based on a calculation of one or more columns.
///
/// This is useful when the splits already exist in the base table.
///
/// The provided `calculation` should be an SQL statement that returns an integer value between
/// 0 and the number of splits - 1 (the number of splits is defined by the `splits` configuration).
///
/// If this strategy is used then the counts/percentages in the SplitSizes are ignored.
Calculated { calculation: String },
}
// The default is not to split the data
//
// All data will be assigned to a single split.
impl Default for SplitStrategy {
fn default() -> Self {
Self::NoSplit
}
}
impl SplitStrategy {
pub fn validate(&self, num_rows: u64) -> Result<()> {
match self {
Self::NoSplit => Ok(()),
Self::Random { sizes, .. } => sizes.validate(num_rows),
Self::Hash {
split_weights,
columns,
..
} => {
if columns.is_empty() {
return Err(Error::InvalidInput {
message: "Hash strategy requires at least one column".to_string(),
});
}
if split_weights.is_empty() {
return Err(Error::InvalidInput {
message: "Hash strategy requires at least one split weight".to_string(),
});
}
if split_weights.contains(&0) {
return Err(Error::InvalidInput {
message: "Split weights must be greater than 0".to_string(),
});
}
Ok(())
}
Self::Sequential { sizes } => sizes.validate(num_rows),
Self::Calculated { .. } => Ok(()),
}
}
}
pub struct Splitter {
temp_dir: TemporaryDirectory,
strategy: SplitStrategy,
}
impl Splitter {
pub fn new(temp_dir: TemporaryDirectory, strategy: SplitStrategy) -> Self {
Self { temp_dir, strategy }
}
fn sequential_split_id(
num_rows: u64,
split_sizes: &[u64],
split_index: &AtomicUsize,
counter_in_split: &AtomicU64,
exhausted: &AtomicBool,
) -> UInt64Array {
let mut split_ids = Vec::<u64>::with_capacity(num_rows as usize);
while split_ids.len() < num_rows as usize {
let split_id = split_index.load(Ordering::Relaxed);
let counter = counter_in_split.load(Ordering::Relaxed);
let split_size = split_sizes[split_id];
let remaining_in_split = split_size - counter;
let remaining_in_batch = num_rows - split_ids.len() as u64;
let mut done = false;
let rows_to_add = if remaining_in_batch < remaining_in_split {
counter_in_split.fetch_add(remaining_in_batch, Ordering::Relaxed);
remaining_in_batch
} else {
split_index.fetch_add(1, Ordering::Relaxed);
counter_in_split.store(0, Ordering::Relaxed);
if split_id == split_sizes.len() - 1 {
exhausted.store(true, Ordering::Relaxed);
done = true;
}
remaining_in_split
};
split_ids.extend(iter::repeat(split_id as u64).take(rows_to_add as usize));
if done {
// Quit early if we've run out of splits
break;
}
}
UInt64Array::from(split_ids)
}
async fn apply_sequential(
&self,
source: SendableRecordBatchStream,
num_rows: u64,
sizes: &SplitSizes,
) -> Result<SendableRecordBatchStream> {
let split_sizes = sizes.to_counts(num_rows);
let split_index = AtomicUsize::new(0);
let counter_in_split = AtomicU64::new(0);
let exhausted = AtomicBool::new(false);
let schema = source.schema();
let new_schema = Arc::new(schema.try_with_column(Field::new(
SPLIT_ID_COLUMN,
DataType::UInt64,
false,
))?);
let new_schema_clone = new_schema.clone();
let stream = source.filter_map(move |batch| {
let batch = match batch {
Ok(batch) => batch,
Err(e) => {
return std::future::ready(Some(Err(e)));
}
};
if exhausted.load(Ordering::Relaxed) {
return std::future::ready(None);
}
let split_ids = Self::sequential_split_id(
batch.num_rows() as u64,
&split_sizes,
&split_index,
&counter_in_split,
&exhausted,
);
let mut arrays = batch.columns().to_vec();
// This can happen if we exhaust all splits in the middle of a batch
if split_ids.len() < batch.num_rows() {
arrays = arrays
.iter()
.map(|arr| arr.slice(0, split_ids.len()))
.collect();
}
arrays.push(Arc::new(split_ids));
std::future::ready(Some(Ok(
RecordBatch::try_new(new_schema.clone(), arrays).unwrap()
)))
});
Ok(Box::pin(SimpleRecordBatchStream::new(
stream,
new_schema_clone,
)))
}
fn hash_split_id(batch: &RecordBatch, thresholds: &[u64], total_weight: u64) -> UInt64Array {
let arrays = batch
.columns()
.iter()
// Don't hash the last column which should always be the row id
.take(batch.columns().len() - 1)
.cloned()
.collect::<Vec<_>>();
let mut hashes = vec![0; batch.num_rows()];
let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0);
create_hashes(&arrays, &random_state, &mut hashes).unwrap();
// As an example, let's assume the weights are 1, 2. Our total weight is 3.
//
// Our thresholds are [1, 3]
// Our modulo output will be 0, 1, or 2.
//
// thresholds.binary_search(0) => Err(0) => 0
// thresholds.binary_search(1) => Ok(0) => 1
// thresholds.binary_search(2) => Err(1) => 1
let split_ids = hashes
.iter()
.map(|h| {
let h = h % total_weight;
let split_id = match thresholds.binary_search(&h) {
Ok(i) => (i + 1) as u64,
Err(i) => i as u64,
};
if split_id == thresholds.len() as u64 {
// If we're at the last threshold then we discard the row (indicated by setting
// the split_id to null)
None
} else {
Some(split_id)
}
})
.collect::<Vec<_>>();
UInt64Array::from(split_ids)
}
async fn apply_hash(
&self,
source: SendableRecordBatchStream,
weights: &[u64],
discard_weight: u64,
) -> Result<SendableRecordBatchStream> {
let row_id_index = source.schema().fields.len() - 1;
let new_schema = Arc::new(Schema::new(vec![
source.schema().field(row_id_index).clone(),
Field::new(SPLIT_ID_COLUMN, DataType::UInt64, false),
]));
let total_weight = weights.iter().sum::<u64>() + discard_weight;
// Thresholds are the cumulative sum of the weights
let mut offset = 0;
let thresholds = weights
.iter()
.map(|w| {
let value = offset + w;
offset = value;
value
})
.collect::<Vec<_>>();
let new_schema_clone = new_schema.clone();
let stream = source.map_ok(move |batch| {
let split_ids = Self::hash_split_id(&batch, &thresholds, total_weight);
if split_ids.null_count() > 0 {
let is_valid = split_ids.nulls().unwrap().inner();
let is_valid_mask = BooleanArray::new(is_valid.clone(), None);
let split_ids = arrow::compute::filter(&split_ids, &is_valid_mask).unwrap();
let row_ids = batch.column(row_id_index);
let row_ids = arrow::compute::filter(row_ids.as_ref(), &is_valid_mask).unwrap();
RecordBatch::try_new(new_schema.clone(), vec![row_ids, split_ids]).unwrap()
} else {
RecordBatch::try_new(
new_schema.clone(),
vec![batch.column(row_id_index).clone(), Arc::new(split_ids)],
)
.unwrap()
}
});
Ok(Box::pin(SimpleRecordBatchStream::new(
stream,
new_schema_clone,
)))
}
pub async fn apply(
&self,
source: SendableRecordBatchStream,
num_rows: u64,
) -> Result<SendableRecordBatchStream> {
self.strategy.validate(num_rows)?;
match &self.strategy {
// For consistency, even if no-split, we still give a split id column of all 0s
SplitStrategy::NoSplit => {
self.apply_sequential(source, num_rows, &SplitSizes::Counts(vec![num_rows]))
.await
}
SplitStrategy::Random { seed, sizes } => {
let shuffler = Shuffler::new(ShufflerConfig {
seed: *seed,
// In this case we are only shuffling row ids so we can use a large max_rows_per_file
max_rows_per_file: 10 * 1024 * 1024,
temp_dir: self.temp_dir.clone(),
clump_size: None,
});
let shuffled = shuffler.shuffle(source, num_rows).await?;
self.apply_sequential(shuffled, num_rows, sizes).await
}
SplitStrategy::Sequential { sizes } => {
self.apply_sequential(source, num_rows, sizes).await
}
// Nothing to do, split is calculated in projection
SplitStrategy::Calculated { .. } => Ok(source),
SplitStrategy::Hash {
split_weights,
discard_weight,
..
} => {
self.apply_hash(source, split_weights, *discard_weight)
.await
}
}
}
pub fn project(&self, query: Query) -> Query {
match &self.strategy {
SplitStrategy::Calculated { calculation } => query.select(Select::Dynamic(vec![(
SPLIT_ID_COLUMN.to_string(),
calculation.clone(),
)])),
SplitStrategy::Hash { columns, .. } => query.select(Select::Columns(columns.clone())),
_ => query,
}
}
pub fn orders_by_split_id(&self) -> bool {
match &self.strategy {
SplitStrategy::Hash { .. } | SplitStrategy::Calculated { .. } => true,
SplitStrategy::NoSplit
| SplitStrategy::Sequential { .. }
// It may be strange but for random we shuffle and then assign splits so the result is
// sorted by split id
| SplitStrategy::Random { .. } => false,
}
}
}
/// Split configuration - either percentages or absolute counts
///
/// If the percentages do not sum to 1.0 (or the counts do not sum to the total number of rows)
/// the remaining rows will not be included in the permutation.
///
/// The default implementation assigns all rows to a single split.
#[derive(Debug, Clone)]
pub enum SplitSizes {
/// Percentage splits (must sum to <= 1.0)
///
/// The number of rows in each split is the nearest integer to the percentage multiplied by
/// the total number of rows.
Percentages(Vec<f64>),
/// Absolute row counts per split
///
/// If the dataset doesn't contain enough matching rows to fill all splits then an error
/// will be raised.
Counts(Vec<u64>),
/// Divides data into a fixed number of splits
///
/// Will divide the data evenly.
///
/// If the number of rows is not divisible by the number of splits then the rows per split
/// is rounded down.
Fixed(u64),
}
impl Default for SplitSizes {
fn default() -> Self {
Self::Percentages(vec![1.0])
}
}
impl SplitSizes {
pub fn validate(&self, num_rows: u64) -> Result<()> {
match self {
Self::Percentages(percentages) => {
for percentage in percentages {
if *percentage < 0.0 || *percentage > 1.0 {
return Err(Error::InvalidInput {
message: "Split percentages must be between 0.0 and 1.0".to_string(),
});
}
if percentage * (num_rows as f64) < 1.0 {
return Err(Error::InvalidInput {
message: format!(
"One of the splits has {}% of {} rows which rounds to 0 rows",
percentage * 100.0,
num_rows
),
});
}
}
if percentages.iter().sum::<f64>() > 1.0 {
return Err(Error::InvalidInput {
message: "Split percentages must sum to 1.0 or less".to_string(),
});
}
}
Self::Counts(counts) => {
if counts.iter().sum::<u64>() > num_rows {
return Err(Error::InvalidInput {
message: format!(
"Split counts specified {} rows but only {} are available",
counts.iter().sum::<u64>(),
num_rows
),
});
}
if counts.contains(&0) {
return Err(Error::InvalidInput {
message: "Split counts must be greater than 0".to_string(),
});
}
}
Self::Fixed(num_splits) => {
if *num_splits > num_rows {
return Err(Error::InvalidInput {
message: format!(
"Split fixed config specified {} splits but only {} rows are available. Must have at least 1 row per split.",
*num_splits, num_rows
),
});
}
if (num_rows / num_splits) == 0 {
return Err(Error::InvalidInput {
message: format!(
"Split fixed config specified {} splits but only {} rows are available. Must have at least 1 row per split.",
*num_splits, num_rows
),
});
}
}
}
Ok(())
}
pub fn to_counts(&self, num_rows: u64) -> Vec<u64> {
let sizes = match self {
Self::Percentages(percentages) => {
let mut percentage_sum = 0.0_f64;
let mut counts = percentages
.iter()
.map(|p| {
let count = (p * (num_rows as f64)).round() as u64;
percentage_sum += p;
count
})
.collect::<Vec<_>>();
let sum = counts.iter().sum::<u64>();
let is_basically_one =
(num_rows as f64 - percentage_sum * num_rows as f64).abs() < 0.5;
// If the sum of percentages is close to 1.0 then rounding errors can add up
// to more or less than num_rows
//
// Drop items from buckets until we have the correct number of rows
let mut excess = sum as i64 - num_rows as i64;
let mut drop_idx = 0;
while excess > 0 {
if counts[drop_idx] > 0 {
counts[drop_idx] -= 1;
excess -= 1;
}
drop_idx += 1;
if drop_idx == counts.len() {
drop_idx = 0;
}
}
// On the other hand, if the percentages sum to ~1.0 then the we also shouldn't _lose_
// rows due to rounding errors
let mut add_idx = 0;
while is_basically_one && excess < 0 {
counts[add_idx] += 1;
add_idx += 1;
excess += 1;
if add_idx == counts.len() {
add_idx = 0;
}
}
counts
}
Self::Counts(counts) => counts.clone(),
Self::Fixed(num_splits) => {
let rows_per_split = num_rows / *num_splits;
vec![rows_per_split; *num_splits as usize]
}
};
assert!(sizes.iter().sum::<u64>() <= num_rows);
sizes
}
}
#[cfg(test)]
mod tests {
use crate::arrow::LanceDbDatagenExt;
use super::*;
use arrow::{
array::AsArray,
compute::concat_batches,
datatypes::{Int32Type, UInt64Type},
};
use arrow_array::Int32Array;
use futures::TryStreamExt;
use lance_datagen::{BatchCount, ByteCount, RowCount, Seed};
use std::sync::Arc;
const ID_COLUMN: &str = "id";
#[test]
fn test_split_sizes_percentages_validation() {
// Valid percentages
let sizes = SplitSizes::Percentages(vec![0.7, 0.3]);
assert!(sizes.validate(100).is_ok());
// Sum > 1.0
let sizes = SplitSizes::Percentages(vec![0.7, 0.4]);
assert!(sizes.validate(100).is_err());
// Negative percentage
let sizes = SplitSizes::Percentages(vec![-0.1, 0.5]);
assert!(sizes.validate(100).is_err());
// Percentage > 1.0
let sizes = SplitSizes::Percentages(vec![1.5]);
assert!(sizes.validate(100).is_err());
// Percentage rounds to 0 rows
let sizes = SplitSizes::Percentages(vec![0.001]);
assert!(sizes.validate(100).is_err());
}
#[test]
fn test_split_sizes_counts_validation() {
// Valid counts
let sizes = SplitSizes::Counts(vec![30, 70]);
assert!(sizes.validate(100).is_ok());
// Sum > num_rows
let sizes = SplitSizes::Counts(vec![60, 50]);
assert!(sizes.validate(100).is_err());
// Counts are 0
let sizes = SplitSizes::Counts(vec![0, 100]);
assert!(sizes.validate(100).is_err());
}
#[test]
fn test_split_sizes_fixed_validation() {
// Valid fixed splits
let sizes = SplitSizes::Fixed(5);
assert!(sizes.validate(100).is_ok());
// More splits than rows
let sizes = SplitSizes::Fixed(150);
assert!(sizes.validate(100).is_err());
}
#[test]
fn test_split_sizes_to_sizes_percentages() {
let sizes = SplitSizes::Percentages(vec![0.3, 0.7]);
let result = sizes.to_counts(100);
assert_eq!(result, vec![30, 70]);
// Test rounding
let sizes = SplitSizes::Percentages(vec![0.3, 0.41]);
let result = sizes.to_counts(70);
assert_eq!(result, vec![21, 29]);
}
#[test]
fn test_split_sizes_to_sizes_fixed() {
let sizes = SplitSizes::Fixed(4);
let result = sizes.to_counts(100);
assert_eq!(result, vec![25, 25, 25, 25]);
// Test with remainder
let sizes = SplitSizes::Fixed(3);
let result = sizes.to_counts(10);
assert_eq!(result, vec![3, 3, 3]);
}
fn test_data() -> SendableRecordBatchStream {
lance_datagen::gen_batch()
.with_seed(Seed::from(42))
.col(ID_COLUMN, lance_datagen::array::step::<Int32Type>())
.into_ldb_stream(RowCount::from(10), BatchCount::from(5))
}
async fn verify_splitter(
splitter: Splitter,
data: SendableRecordBatchStream,
num_rows: u64,
expected_split_sizes: &[u64],
row_ids_in_order: bool,
) {
let split_batches = splitter
.apply(data, num_rows)
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let schema = split_batches[0].schema();
let split_batch = concat_batches(&schema, &split_batches).unwrap();
let total_split_sizes = expected_split_sizes.iter().sum::<u64>();
assert_eq!(split_batch.num_rows(), total_split_sizes as usize);
let mut expected = Vec::with_capacity(total_split_sizes as usize);
for (i, size) in expected_split_sizes.iter().enumerate() {
expected.extend(iter::repeat(i as u64).take(*size as usize));
}
let expected = Arc::new(UInt64Array::from(expected)) as Arc<dyn Array>;
assert_eq!(&expected, split_batch.column(1));
let expected_row_ids =
Arc::new(Int32Array::from_iter_values(0..total_split_sizes as i32)) as Arc<dyn Array>;
if row_ids_in_order {
assert_eq!(&expected_row_ids, split_batch.column(0));
} else {
assert_ne!(&expected_row_ids, split_batch.column(0));
}
}
#[tokio::test]
async fn test_fixed_sequential_split() {
let splitter = Splitter::new(
// Sequential splitting doesn't need a temp dir
TemporaryDirectory::None,
SplitStrategy::Sequential {
sizes: SplitSizes::Fixed(3),
},
);
verify_splitter(splitter, test_data(), 50, &[16, 16, 16], true).await;
}
#[tokio::test]
async fn test_fixed_random_split() {
let splitter = Splitter::new(
TemporaryDirectory::None,
SplitStrategy::Random {
seed: Some(42),
sizes: SplitSizes::Fixed(3),
},
);
verify_splitter(splitter, test_data(), 50, &[16, 16, 16], false).await;
}
#[tokio::test]
async fn test_counts_sequential_split() {
let splitter = Splitter::new(
// Sequential splitting doesn't need a temp dir
TemporaryDirectory::None,
SplitStrategy::Sequential {
sizes: SplitSizes::Counts(vec![5, 15, 10]),
},
);
verify_splitter(splitter, test_data(), 50, &[5, 15, 10], true).await;
}
#[tokio::test]
async fn test_counts_random_split() {
let splitter = Splitter::new(
TemporaryDirectory::None,
SplitStrategy::Random {
seed: Some(42),
sizes: SplitSizes::Counts(vec![5, 15, 10]),
},
);
verify_splitter(splitter, test_data(), 50, &[5, 15, 10], false).await;
}
#[tokio::test]
async fn test_percentages_sequential_split() {
let splitter = Splitter::new(
// Sequential splitting doesn't need a temp dir
TemporaryDirectory::None,
SplitStrategy::Sequential {
sizes: SplitSizes::Percentages(vec![0.217, 0.168, 0.17]),
},
);
verify_splitter(splitter, test_data(), 50, &[11, 8, 9], true).await;
}
#[tokio::test]
async fn test_percentages_random_split() {
let splitter = Splitter::new(
TemporaryDirectory::None,
SplitStrategy::Random {
seed: Some(42),
sizes: SplitSizes::Percentages(vec![0.217, 0.168, 0.17]),
},
);
verify_splitter(splitter, test_data(), 50, &[11, 8, 9], false).await;
}
#[tokio::test]
async fn test_hash_split() {
let data = lance_datagen::gen_batch()
.with_seed(Seed::from(42))
.col(
"hash1",
lance_datagen::array::rand_utf8(ByteCount::from(10), false),
)
.col("hash2", lance_datagen::array::step::<Int32Type>())
.col(ID_COLUMN, lance_datagen::array::step::<Int32Type>())
.into_ldb_stream(RowCount::from(10), BatchCount::from(5));
let splitter = Splitter::new(
TemporaryDirectory::None,
SplitStrategy::Hash {
columns: vec!["hash1".to_string(), "hash2".to_string()],
split_weights: vec![1, 2],
discard_weight: 1,
},
);
let split_batches = splitter
.apply(data, 10)
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let schema = split_batches[0].schema();
let split_batch = concat_batches(&schema, &split_batches).unwrap();
// These assertions are all based on fixed seed in data generation but they match
// up roughly to what we expect (25% discarded, 25% in split 0, 50% in split 1)
// 14 rows (28%) are discarded because discard_weight is 1
assert_eq!(split_batch.num_rows(), 36);
assert_eq!(split_batch.num_columns(), 2);
let split_ids = split_batch.column(1).as_primitive::<UInt64Type>().values();
let num_in_split_0 = split_ids.iter().filter(|v| **v == 0).count();
let num_in_split_1 = split_ids.iter().filter(|v| **v == 1).count();
assert_eq!(num_in_split_0, 11); // 22%
assert_eq!(num_in_split_1, 25); // 50%
}
}

View File

@@ -0,0 +1,98 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::{path::PathBuf, sync::Arc};
use arrow_array::RecordBatch;
use arrow_schema::{Fields, Schema};
use datafusion_execution::disk_manager::DiskManagerMode;
use futures::TryStreamExt;
use rand::{rngs::SmallRng, RngCore, SeedableRng};
use tempfile::TempDir;
use crate::{
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
Error, Result,
};
/// Directory to use for temporary files
#[derive(Debug, Clone, Default)]
pub enum TemporaryDirectory {
/// Use the operating system's default temporary directory (e.g. /tmp)
#[default]
OsDefault,
/// Use the specified directory (must be an absolute path)
Specific(PathBuf),
/// If spilling is required, then error out
None,
}
impl TemporaryDirectory {
pub fn create_temp_dir(&self) -> Result<TempDir> {
match self {
Self::OsDefault => tempfile::tempdir(),
Self::Specific(path) => tempfile::Builder::default().tempdir_in(path),
Self::None => {
return Err(Error::Runtime {
message: "No temporary directory was supplied and this operation requires spilling to disk".to_string(),
});
}
}
.map_err(|err| Error::Other {
message: "Failed to create temporary directory".to_string(),
source: Some(err.into()),
})
}
pub fn to_disk_manager_mode(&self) -> DiskManagerMode {
match self {
Self::OsDefault => DiskManagerMode::OsTmpDirectory,
Self::Specific(path) => DiskManagerMode::Directories(vec![path.clone()]),
Self::None => DiskManagerMode::Disabled,
}
}
}
pub fn non_crypto_rng(seed: &Option<u64>) -> Box<dyn RngCore + Send> {
Box::new(
seed.as_ref()
.map(|seed| SmallRng::seed_from_u64(*seed))
.unwrap_or_else(SmallRng::from_os_rng),
)
}
pub fn rename_column(
stream: SendableRecordBatchStream,
old_name: &str,
new_name: &str,
) -> Result<SendableRecordBatchStream> {
let schema = stream.schema();
let field_index = schema.index_of(old_name)?;
let new_fields = schema
.fields
.iter()
.cloned()
.enumerate()
.map(|(idx, f)| {
if idx == field_index {
Arc::new(f.as_ref().clone().with_name(new_name))
} else {
f
}
})
.collect::<Fields>();
let new_schema = Arc::new(Schema::new(new_fields).with_metadata(schema.metadata().clone()));
let new_schema_clone = new_schema.clone();
let renamed_stream = stream.and_then(move |batch| {
let renamed_batch =
RecordBatch::try_new(new_schema.clone(), batch.columns().to_vec()).map_err(Error::from);
std::future::ready(renamed_batch)
});
Ok(Box::pin(SimpleRecordBatchStream::new(
renamed_stream,
new_schema_clone,
)))
}

View File

@@ -8,6 +8,7 @@ use std::sync::Arc;
use std::time::Duration;
use vector::IvfFlatIndexBuilder;
use crate::index::vector::IvfRqIndexBuilder;
use crate::{table::BaseTable, DistanceType, Error, Result};
use self::{
@@ -53,6 +54,9 @@ pub enum Index {
/// IVF index with Product Quantization
IvfPq(IvfPqIndexBuilder),
/// IVF index with RabitQ Quantization
IvfRq(IvfRqIndexBuilder),
/// IVF-HNSW index with Product Quantization
/// It is a variant of the HNSW algorithm that uses product quantization to compress the vectors.
IvfHnswPq(IvfHnswPqIndexBuilder),
@@ -275,6 +279,8 @@ pub enum IndexType {
IvfFlat,
#[serde(alias = "IVF_PQ")]
IvfPq,
#[serde(alias = "IVF_RQ")]
IvfRq,
#[serde(alias = "IVF_HNSW_PQ")]
IvfHnswPq,
#[serde(alias = "IVF_HNSW_SQ")]
@@ -296,6 +302,7 @@ impl std::fmt::Display for IndexType {
match self {
Self::IvfFlat => write!(f, "IVF_FLAT"),
Self::IvfPq => write!(f, "IVF_PQ"),
Self::IvfRq => write!(f, "IVF_RQ"),
Self::IvfHnswPq => write!(f, "IVF_HNSW_PQ"),
Self::IvfHnswSq => write!(f, "IVF_HNSW_SQ"),
Self::BTree => write!(f, "BTREE"),
@@ -317,6 +324,7 @@ impl std::str::FromStr for IndexType {
"FTS" | "INVERTED" => Ok(Self::FTS),
"IVF_FLAT" => Ok(Self::IvfFlat),
"IVF_PQ" => Ok(Self::IvfPq),
"IVF_RQ" => Ok(Self::IvfRq),
"IVF_HNSW_PQ" => Ok(Self::IvfHnswPq),
"IVF_HNSW_SQ" => Ok(Self::IvfHnswSq),
_ => Err(Error::InvalidInput {

View File

@@ -291,6 +291,52 @@ pub(crate) fn suggested_num_sub_vectors(dim: u32) -> u32 {
}
}
/// Builder for an IVF RQ index.
///
/// This index stores a compressed (quantized) copy of every vector. Each dimension
/// is quantized into a small number of bits.
/// The parameters `num_bits` control this process, providing a tradeoff
/// between index size (and thus search speed) and index accuracy.
///
/// The partitioning process is called IVF and the `num_partitions` parameter controls how
/// many groups to create.
///
/// Note that training an IVF RQ index on a large dataset is a slow operation and
/// currently is also a memory intensive operation.
#[derive(Debug, Clone)]
pub struct IvfRqIndexBuilder {
// IVF
pub(crate) distance_type: DistanceType,
pub(crate) num_partitions: Option<u32>,
pub(crate) num_bits: Option<u32>,
pub(crate) sample_rate: u32,
pub(crate) max_iterations: u32,
pub(crate) target_partition_size: Option<u32>,
}
impl Default for IvfRqIndexBuilder {
fn default() -> Self {
Self {
distance_type: DistanceType::L2,
num_partitions: None,
num_bits: None,
sample_rate: 256,
max_iterations: 50,
target_partition_size: None,
}
}
}
impl IvfRqIndexBuilder {
impl_distance_type_setter!();
impl_ivf_params_setter!();
pub fn num_bits(mut self, num_bits: u32) -> Self {
self.num_bits = Some(num_bits);
self
}
}
/// Builder for an IVF HNSW PQ index.
///
/// This index is a combination of IVF and HNSW.

View File

@@ -194,6 +194,7 @@ pub mod arrow;
pub mod connection;
pub mod data;
pub mod database;
pub mod dataloader;
pub mod embeddings;
pub mod error;
pub mod index;
@@ -206,7 +207,8 @@ pub mod query;
pub mod remote;
pub mod rerankers;
pub mod table;
pub mod test_connection;
#[cfg(test)]
pub mod test_utils;
pub mod utils;
use std::fmt::Display;

View File

@@ -6,10 +6,10 @@ use std::{future::Future, time::Duration};
use arrow::compute::concat_batches;
use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array};
use arrow_schema::DataType;
use arrow_schema::{DataType, SchemaRef};
use datafusion_expr::Expr;
use datafusion_physical_plan::ExecutionPlan;
use futures::{stream, try_join, FutureExt, TryStreamExt};
use futures::{stream, try_join, FutureExt, TryFutureExt, TryStreamExt};
use half::f16;
use lance::{
arrow::RecordBatchExt,
@@ -582,16 +582,40 @@ pub trait ExecutableQuery {
options: QueryExecutionOptions,
) -> impl Future<Output = Result<SendableRecordBatchStream>> + Send;
/// Explain the plan for a query
///
/// This will create a string representation of the plan that will be used to
/// execute the query. This will not execute the query.
///
/// This function can be used to get an understanding of what work will be done by the query
/// and is useful for debugging query performance.
fn explain_plan(&self, verbose: bool) -> impl Future<Output = Result<String>> + Send;
/// Execute the query and display the runtime metrics
///
/// This shows the same plan as [`ExecutableQuery::explain_plan`] but includes runtime metrics.
///
/// This function will actually execute the query in order to get the runtime metrics.
fn analyze_plan(&self) -> impl Future<Output = Result<String>> + Send {
self.analyze_plan_with_options(QueryExecutionOptions::default())
}
/// Execute the query and display the runtime metrics
///
/// This is the same as [`ExecutableQuery::analyze_plan`] but allows for specifying the execution options.
fn analyze_plan_with_options(
&self,
options: QueryExecutionOptions,
) -> impl Future<Output = Result<String>> + Send;
/// Return the output schema for data returned by the query without actually executing the query
///
/// This can be useful when the selection for a query is built dynamically as it is not always
/// obvious what the output schema will be.
fn output_schema(&self) -> impl Future<Output = Result<SchemaRef>> + Send {
self.create_plan(QueryExecutionOptions::default())
.and_then(|plan| std::future::ready(Ok(plan.schema())))
}
}
/// A query filter that can be applied to a query
@@ -1505,6 +1529,16 @@ mod tests {
.query()
.limit(10)
.select(Select::dynamic(&[("id2", "id * 2"), ("id", "id")]));
let schema = query.output_schema().await.unwrap();
assert_eq!(
schema,
Arc::new(ArrowSchema::new(vec![
ArrowField::new("id2", DataType::Int32, true),
ArrowField::new("id", DataType::Int32, true),
]))
);
let result = query.execute().await;
let mut batches = result
.expect("should have result")

View File

@@ -16,7 +16,7 @@ use tokio::task::spawn_blocking;
use crate::database::{
CloneTableRequest, CreateNamespaceRequest, CreateTableData, CreateTableMode,
CreateTableRequest, Database, DatabaseOptions, DropNamespaceRequest, ListNamespacesRequest,
OpenTableRequest, TableNamesRequest,
OpenTableRequest, ReadConsistency, TableNamesRequest,
};
use crate::error::Result;
use crate::table::BaseTable;
@@ -189,6 +189,7 @@ struct ListTablesResponse {
pub struct RemoteDatabase<S: HttpSend = Sender> {
client: RestfulLanceDbClient<S>,
table_cache: Cache<String, Arc<RemoteTable<S>>>,
uri: String,
}
impl RemoteDatabase {
@@ -217,6 +218,7 @@ impl RemoteDatabase {
Ok(Self {
client,
table_cache,
uri: uri.to_owned(),
})
}
}
@@ -238,6 +240,7 @@ mod test_utils {
Self {
client,
table_cache: Cache::new(0),
uri: "http://localhost".to_string(),
}
}
@@ -250,6 +253,7 @@ mod test_utils {
Self {
client,
table_cache: Cache::new(0),
uri: "http://localhost".to_string(),
}
}
}
@@ -315,6 +319,17 @@ fn build_cache_key(name: &str, namespace: &[String]) -> String {
#[async_trait]
impl<S: HttpSend> Database for RemoteDatabase<S> {
fn uri(&self) -> &str {
&self.uri
}
async fn read_consistency(&self) -> Result<ReadConsistency> {
Err(Error::NotSupported {
message: "Getting the read consistency of a remote database is not yet supported"
.to_string(),
})
}
async fn table_names(&self, request: TableNamesRequest) -> Result<Vec<String>> {
let mut req = if !request.namespace.is_empty() {
let namespace_id =

View File

@@ -50,6 +50,7 @@ use std::sync::Arc;
use crate::arrow::IntoArrow;
use crate::connection::NoData;
use crate::database::Database;
use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, MemoryRegistry};
use crate::error::{Error, Result};
use crate::index::vector::{suggested_num_partitions_for_hnsw, VectorIndex};
@@ -510,6 +511,9 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
/// Get the namespace of the table.
fn namespace(&self) -> &[String];
/// Get the id of the table
///
/// This is the namespace of the table concatenated with the name
/// separated by a dot (".")
fn id(&self) -> &str;
/// Get the arrow [Schema] of the table.
async fn schema(&self) -> Result<SchemaRef>;
@@ -611,9 +615,10 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
/// A Table is a collection of strong typed Rows.
///
/// The type of the each row is defined in Apache Arrow [Schema].
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct Table {
inner: Arc<dyn BaseTable>,
database: Arc<dyn Database>,
embedding_registry: Arc<dyn EmbeddingRegistry>,
}
@@ -631,11 +636,13 @@ mod test_utils {
{
let inner = Arc::new(crate::remote::table::RemoteTable::new_mock(
name.into(),
handler,
handler.clone(),
None,
));
let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
Self {
inner,
database,
// Registry is unused.
embedding_registry: Arc::new(MemoryRegistry::new()),
}
@@ -651,11 +658,13 @@ mod test_utils {
{
let inner = Arc::new(crate::remote::table::RemoteTable::new_mock(
name.into(),
handler,
handler.clone(),
Some(version),
));
let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
Self {
inner,
database,
// Registry is unused.
embedding_registry: Arc::new(MemoryRegistry::new()),
}
@@ -670,9 +679,10 @@ impl std::fmt::Display for Table {
}
impl Table {
pub fn new(inner: Arc<dyn BaseTable>) -> Self {
pub fn new(inner: Arc<dyn BaseTable>, database: Arc<dyn Database>) -> Self {
Self {
inner,
database,
embedding_registry: Arc::new(MemoryRegistry::new()),
}
}
@@ -681,12 +691,22 @@ impl Table {
&self.inner
}
pub fn database(&self) -> &Arc<dyn Database> {
&self.database
}
pub fn embedding_registry(&self) -> &Arc<dyn EmbeddingRegistry> {
&self.embedding_registry
}
pub(crate) fn new_with_embedding_registry(
inner: Arc<dyn BaseTable>,
database: Arc<dyn Database>,
embedding_registry: Arc<dyn EmbeddingRegistry>,
) -> Self {
Self {
inner,
database,
embedding_registry,
}
}
@@ -1416,12 +1436,6 @@ impl Tags for NativeTags {
}
}
impl From<NativeTable> for Table {
fn from(table: NativeTable) -> Self {
Self::new(Arc::new(table))
}
}
pub trait NativeTableExt {
/// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`].
fn as_native(&self) -> Option<&NativeTable>;
@@ -1843,6 +1857,18 @@ impl NativeTable {
);
Ok(Box::new(lance_idx_params))
}
Index::IvfRq(index) => {
Self::validate_index_type(field, "IVF RQ", supported_vector_data_type)?;
let num_partitions = self
.get_num_partitions(index.num_partitions, false, None)
.await?;
let lance_idx_params = VectorIndexParams::ivf_rq(
num_partitions as usize,
index.num_bits.unwrap_or(1) as u8,
index.distance_type.into(),
);
Ok(Box::new(lance_idx_params))
}
Index::IvfHnswPq(index) => {
Self::validate_index_type(field, "IVF HNSW PQ", supported_vector_data_type)?;
let dim = Self::get_vector_dimension(field)?;
@@ -1912,9 +1938,11 @@ impl NativeTable {
Index::Bitmap(_) => IndexType::Bitmap,
Index::LabelList(_) => IndexType::LabelList,
Index::FTS(_) => IndexType::Inverted,
Index::IvfFlat(_) | Index::IvfPq(_) | Index::IvfHnswPq(_) | Index::IvfHnswSq(_) => {
IndexType::Vector
}
Index::IvfFlat(_)
| Index::IvfPq(_)
| Index::IvfRq(_)
| Index::IvfHnswPq(_)
| Index::IvfHnswSq(_) => IndexType::Vector,
}
}

View File

@@ -125,6 +125,10 @@ impl ExecutionPlan for MetadataEraserExec {
fn partition_statistics(&self, partition: Option<usize>) -> DataFusionResult<Statistics> {
self.input.partition_statistics(partition)
}
fn supports_limit_pushdown(&self) -> bool {
true
}
}
#[derive(Debug)]

View File

@@ -1,126 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! Functions for testing connections.
#[cfg(test)]
pub mod test_utils {
use regex::Regex;
use std::env;
use std::io::{BufRead, BufReader};
use std::process::{Child, ChildStdout, Command, Stdio};
use crate::{connect, Connection};
use anyhow::{bail, Result};
use tempfile::{tempdir, TempDir};
pub struct TestConnection {
pub uri: String,
pub connection: Connection,
_temp_dir: Option<TempDir>,
_process: Option<TestProcess>,
}
struct TestProcess {
child: Child,
}
impl Drop for TestProcess {
#[allow(unused_must_use)]
fn drop(&mut self) {
self.child.kill();
}
}
pub async fn new_test_connection() -> Result<TestConnection> {
match env::var("CREATE_LANCEDB_TEST_CONNECTION_SCRIPT") {
Ok(script_path) => new_remote_connection(&script_path).await,
Err(_e) => new_local_connection().await,
}
}
async fn new_remote_connection(script_path: &str) -> Result<TestConnection> {
let temp_dir = tempdir()?;
let data_path = temp_dir.path().to_str().unwrap().to_string();
let child_result = Command::new(script_path)
.stdin(Stdio::null())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.arg(data_path.clone())
.spawn();
if child_result.is_err() {
bail!(format!(
"Unable to run {}: {:?}",
script_path,
child_result.err()
));
}
let mut process = TestProcess {
child: child_result.unwrap(),
};
let stdout = BufReader::new(process.child.stdout.take().unwrap());
let port = read_process_port(stdout)?;
let uri = "db://test";
let host_override = format!("http://localhost:{}", port);
let connection = create_new_connection(uri, &host_override).await?;
Ok(TestConnection {
uri: uri.to_string(),
connection,
_temp_dir: Some(temp_dir),
_process: Some(process),
})
}
fn read_process_port(mut stdout: BufReader<ChildStdout>) -> Result<String> {
let mut line = String::new();
let re = Regex::new(r"Query node now listening on 0.0.0.0:(.*)").unwrap();
loop {
let result = stdout.read_line(&mut line);
if let Err(err) = result {
bail!(format!(
"read_process_port: error while reading from process output: {}",
err
));
} else if result.unwrap() == 0 {
bail!("read_process_port: hit EOF before reading port from process output.");
}
if re.is_match(&line) {
let caps = re.captures(&line).unwrap();
return Ok(caps[1].to_string());
}
}
}
#[cfg(feature = "remote")]
async fn create_new_connection(
uri: &str,
host_override: &str,
) -> crate::error::Result<Connection> {
connect(uri)
.region("us-east-1")
.api_key("sk_localtest")
.host_override(host_override)
.execute()
.await
}
#[cfg(not(feature = "remote"))]
async fn create_new_connection(
_uri: &str,
_host_override: &str,
) -> crate::error::Result<Connection> {
panic!("remote feature not supported");
}
async fn new_local_connection() -> Result<TestConnection> {
let temp_dir = tempdir()?;
let uri = temp_dir.path().to_str().unwrap();
let connection = connect(uri).execute().await?;
Ok(TestConnection {
uri: uri.to_string(),
connection,
_temp_dir: Some(temp_dir),
_process: None,
})
}
}

View File

@@ -0,0 +1,5 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
pub mod connection;
pub mod datagen;

View File

@@ -0,0 +1,120 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! Functions for testing connections.
use regex::Regex;
use std::env;
use std::io::{BufRead, BufReader};
use std::process::{Child, ChildStdout, Command, Stdio};
use crate::{connect, Connection};
use anyhow::{bail, Result};
use tempfile::{tempdir, TempDir};
pub struct TestConnection {
pub uri: String,
pub connection: Connection,
_temp_dir: Option<TempDir>,
_process: Option<TestProcess>,
}
struct TestProcess {
child: Child,
}
impl Drop for TestProcess {
#[allow(unused_must_use)]
fn drop(&mut self) {
self.child.kill();
}
}
pub async fn new_test_connection() -> Result<TestConnection> {
match env::var("CREATE_LANCEDB_TEST_CONNECTION_SCRIPT") {
Ok(script_path) => new_remote_connection(&script_path).await,
Err(_e) => new_local_connection().await,
}
}
async fn new_remote_connection(script_path: &str) -> Result<TestConnection> {
let temp_dir = tempdir()?;
let data_path = temp_dir.path().to_str().unwrap().to_string();
let child_result = Command::new(script_path)
.stdin(Stdio::null())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.arg(data_path.clone())
.spawn();
if child_result.is_err() {
bail!(format!(
"Unable to run {}: {:?}",
script_path,
child_result.err()
));
}
let mut process = TestProcess {
child: child_result.unwrap(),
};
let stdout = BufReader::new(process.child.stdout.take().unwrap());
let port = read_process_port(stdout)?;
let uri = "db://test";
let host_override = format!("http://localhost:{}", port);
let connection = create_new_connection(uri, &host_override).await?;
Ok(TestConnection {
uri: uri.to_string(),
connection,
_temp_dir: Some(temp_dir),
_process: Some(process),
})
}
fn read_process_port(mut stdout: BufReader<ChildStdout>) -> Result<String> {
let mut line = String::new();
let re = Regex::new(r"Query node now listening on 0.0.0.0:(.*)").unwrap();
loop {
let result = stdout.read_line(&mut line);
if let Err(err) = result {
bail!(format!(
"read_process_port: error while reading from process output: {}",
err
));
} else if result.unwrap() == 0 {
bail!("read_process_port: hit EOF before reading port from process output.");
}
if re.is_match(&line) {
let caps = re.captures(&line).unwrap();
return Ok(caps[1].to_string());
}
}
}
#[cfg(feature = "remote")]
async fn create_new_connection(uri: &str, host_override: &str) -> crate::error::Result<Connection> {
connect(uri)
.region("us-east-1")
.api_key("sk_localtest")
.host_override(host_override)
.execute()
.await
}
#[cfg(not(feature = "remote"))]
async fn create_new_connection(
_uri: &str,
_host_override: &str,
) -> crate::error::Result<Connection> {
panic!("remote feature not supported");
}
async fn new_local_connection() -> Result<TestConnection> {
let temp_dir = tempdir()?;
let uri = temp_dir.path().to_str().unwrap();
let connection = connect(uri).execute().await?;
Ok(TestConnection {
uri: uri.to_string(),
connection,
_temp_dir: Some(temp_dir),
_process: None,
})
}

View File

@@ -0,0 +1,55 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use arrow_array::RecordBatch;
use futures::TryStreamExt;
use lance_datagen::{BatchCount, BatchGeneratorBuilder, RowCount};
use crate::{
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
connect, Error, Table,
};
#[async_trait::async_trait]
pub trait LanceDbDatagenExt {
async fn into_mem_table(
self,
table_name: &str,
rows_per_batch: RowCount,
num_batches: BatchCount,
) -> Table;
}
#[async_trait::async_trait]
impl LanceDbDatagenExt for BatchGeneratorBuilder {
async fn into_mem_table(
self,
table_name: &str,
rows_per_batch: RowCount,
num_batches: BatchCount,
) -> Table {
let (stream, schema) = self.into_reader_stream(rows_per_batch, num_batches);
let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream::new(
stream.map_err(Error::from),
schema,
));
let db = connect("memory:///").execute().await.unwrap();
db.create_table_streaming(table_name, stream)
.execute()
.await
.unwrap()
}
}
pub async fn virtual_table(name: &str, values: &RecordBatch) -> Table {
let schema = values.schema();
let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream::new(
futures::stream::once(std::future::ready(Ok(values.clone()))),
schema,
));
let db = connect("memory:///").execute().await.unwrap();
db.create_table_streaming(name, stream)
.execute()
.await
.unwrap()
}