mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-24 05:49:57 +00:00
Compare commits
40 Commits
python-v0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4f7b24d1a9 | ||
|
|
f9540724b7 | ||
|
|
aeac9c7644 | ||
|
|
6ddd271627 | ||
|
|
f0d7520bdf | ||
|
|
7ef8bafd51 | ||
|
|
aed4a7c98e | ||
|
|
273ba18426 | ||
|
|
8b94308cf2 | ||
|
|
0b7b27481e | ||
|
|
e1f9b011f8 | ||
|
|
d664b8739f | ||
|
|
20bec61ecb | ||
|
|
45255be42c | ||
|
|
93c2cf2f59 | ||
|
|
9d29c83f81 | ||
|
|
2a6143b5bd | ||
|
|
b2242886e0 | ||
|
|
199904ab35 | ||
|
|
1fa888615f | ||
|
|
40967f3baa | ||
|
|
0bfc7de32c | ||
|
|
d43880a585 | ||
|
|
59a886958b | ||
|
|
c36f6746d1 | ||
|
|
25ce6d311f | ||
|
|
92a4e46f9f | ||
|
|
845641c480 | ||
|
|
d96404c635 | ||
|
|
02d31ee412 | ||
|
|
308623577d | ||
|
|
8ee3ae378f | ||
|
|
3372a2aae0 | ||
|
|
4cfcd95320 | ||
|
|
a70ff04bc9 | ||
|
|
a9daa18be9 | ||
|
|
3f2e3986e9 | ||
|
|
bf55feb9b6 | ||
|
|
8f8e06a2da | ||
|
|
03eab0f091 |
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.22.2"
|
||||
current_version = "0.22.3-beta.5"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
107
.github/workflows/codex-update-lance-dependency.yml
vendored
Normal file
107
.github/workflows/codex-update-lance-dependency.yml
vendored
Normal file
@@ -0,0 +1,107 @@
|
||||
name: Codex Update Lance Dependency
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
tag:
|
||||
description: "Tag name from Lance"
|
||||
required: true
|
||||
type: string
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: "Tag name from Lance"
|
||||
required: true
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
actions: read
|
||||
|
||||
jobs:
|
||||
update:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Show inputs
|
||||
run: |
|
||||
echo "tag = ${{ inputs.tag }}"
|
||||
|
||||
- name: Checkout Repo LanceDB
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: true
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 20
|
||||
|
||||
- name: Install Codex CLI
|
||||
run: npm install -g @openai/codex
|
||||
|
||||
- name: Install Rust toolchain
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
toolchain: stable
|
||||
components: clippy, rustfmt
|
||||
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y protobuf-compiler libssl-dev
|
||||
|
||||
- name: Install cargo-info
|
||||
run: cargo install cargo-info
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: python3 -m pip install --upgrade pip packaging
|
||||
|
||||
- name: Configure git user
|
||||
run: |
|
||||
git config user.name "lancedb automation"
|
||||
git config user.email "robot@lancedb.com"
|
||||
|
||||
- name: Configure Codex authentication
|
||||
env:
|
||||
CODEX_TOKEN_B64: ${{ secrets.CODEX_TOKEN }}
|
||||
run: |
|
||||
if [ -z "${CODEX_TOKEN_B64}" ]; then
|
||||
echo "Repository secret CODEX_TOKEN is not defined; skipping Codex execution."
|
||||
exit 1
|
||||
fi
|
||||
mkdir -p ~/.codex
|
||||
echo "${CODEX_TOKEN_B64}" | base64 --decode > ~/.codex/auth.json
|
||||
|
||||
- name: Run Codex to update Lance dependency
|
||||
env:
|
||||
TAG: ${{ inputs.tag }}
|
||||
GITHUB_TOKEN: ${{ secrets.ROBOT_TOKEN }}
|
||||
GH_TOKEN: ${{ secrets.ROBOT_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
VERSION="${TAG#refs/tags/}"
|
||||
VERSION="${VERSION#v}"
|
||||
BRANCH_NAME="codex/update-lance-${VERSION//[^a-zA-Z0-9]/-}"
|
||||
cat <<EOF >/tmp/codex-prompt.txt
|
||||
You are running inside the lancedb repository on a GitHub Actions runner. Update the Lance dependency to version ${VERSION} and prepare a pull request for maintainers to review.
|
||||
|
||||
Follow these steps exactly:
|
||||
1. Use script "ci/set_lance_version.py" to update Lance dependencies. The script already refreshes Cargo metadata, so allow it to finish even if it takes time.
|
||||
2. Run "cargo clippy --workspace --tests --all-features -- -D warnings". If diagnostics appear, fix them yourself and rerun clippy until it exits cleanly. Do not skip any warnings.
|
||||
3. After clippy succeeds, run "cargo fmt --all" to format the workspace.
|
||||
4. Ensure the repository is clean except for intentional changes. Inspect "git status --short" and "git diff" to confirm the dependency update and any required fixes.
|
||||
5. Create and switch to a new branch named "${BRANCH_NAME}" (replace any duplicated hyphens if necessary).
|
||||
6. Stage all relevant files with "git add -A". Commit using the message "chore: update lance dependency to v${VERSION}".
|
||||
7. Push the branch to origin. If the branch already exists, force-push your changes.
|
||||
8. env "GH_TOKEN" is available, use "gh" tools for github related operations like creating pull request.
|
||||
9. Create a pull request targeting "main" with title "chore: update lance dependency to v${VERSION}". In the body, summarize the dependency bump, clippy/fmt verification, and link the triggering tag (${TAG}).
|
||||
10. After creating the PR, display the PR URL, "git status --short", and a concise summary of the commands run and their results.
|
||||
|
||||
Constraints:
|
||||
- Use bash commands; avoid modifying GitHub workflow files other than through the scripted task above.
|
||||
- Do not merge the PR.
|
||||
- If any command fails, diagnose and fix the issue instead of aborting.
|
||||
EOF
|
||||
codex --config shell_environment_policy.ignore_default_excludes=true exec --dangerously-bypass-approvals-and-sandbox "$(cat /tmp/codex-prompt.txt)"
|
||||
101
AGENTS.md
Normal file
101
AGENTS.md
Normal file
@@ -0,0 +1,101 @@
|
||||
LanceDB is a database designed for retrieval, including vector, full-text, and hybrid search.
|
||||
It is a wrapper around Lance. There are two backends: local (in-process like SQLite) and
|
||||
remote (against LanceDB Cloud).
|
||||
|
||||
The core of LanceDB is written in Rust. There are bindings in Python, Typescript, and Java.
|
||||
|
||||
Project layout:
|
||||
|
||||
* `rust/lancedb`: The LanceDB core Rust implementation.
|
||||
* `python`: The Python bindings, using PyO3.
|
||||
* `nodejs`: The Typescript bindings, using napi-rs
|
||||
* `java`: The Java bindings
|
||||
|
||||
Common commands:
|
||||
|
||||
* Check for compiler errors: `cargo check --quiet --features remote --tests --examples`
|
||||
* Run tests: `cargo test --quiet --features remote --tests`
|
||||
* Run specific test: `cargo test --quiet --features remote -p <package_name> --test <test_name>`
|
||||
* Lint: `cargo clippy --quiet --features remote --tests --examples`
|
||||
* Format: `cargo fmt --all`
|
||||
|
||||
Before committing changes, run formatting.
|
||||
|
||||
## Coding tips
|
||||
|
||||
* When writing Rust doctests for things that require a connection or table reference,
|
||||
write them as a function instead of a fully executable test. This allows type checking
|
||||
to run but avoids needing a full test environment. For example:
|
||||
```rust
|
||||
/// ```
|
||||
/// use lance_index::scalar::FullTextSearchQuery;
|
||||
/// use lancedb::query::{QueryBase, ExecutableQuery};
|
||||
///
|
||||
/// # use lancedb::Table;
|
||||
/// # async fn query(table: &Table) -> Result<(), Box<dyn std::error::Error>> {
|
||||
/// let results = table.query()
|
||||
/// .full_text_search(FullTextSearchQuery::new("hello world".into()))
|
||||
/// .execute()
|
||||
/// .await?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
```
|
||||
|
||||
## Example plan: adding a new method on Table
|
||||
|
||||
Adding a new method involves first adding it to the Rust core, then exposing it
|
||||
in the Python and TypeScript bindings. There are both local and remote tables.
|
||||
Remote tables are implemented via a HTTP API and require the `remote` cargo
|
||||
feature flag to be enabled. Python has both sync and async methods.
|
||||
|
||||
Rust core changes:
|
||||
|
||||
1. Add method on `Table` struct in `rust/lancedb/src/table.rs` (calls `BaseTable` trait).
|
||||
2. Add method to `BaseTable` trait in `rust/lancedb/src/table.rs`.
|
||||
3. Implement new trait method on `NativeTable` in `rust/lancedb/src/table.rs`.
|
||||
* Test with unit test in `rust/lancedb/src/table.rs`.
|
||||
4. Implement new trait method on `RemoteTable` in `rust/lancedb/src/remote/table.rs`.
|
||||
* Test with unit test in `rust/lancedb/src/remote/table.rs` against mocked endpoint.
|
||||
|
||||
Python bindings changes:
|
||||
|
||||
1. Add PyO3 method binding in `python/src/table.rs`. Run `make develop` to compile bindings.
|
||||
2. Add types for PyO3 method in `python/python/lancedb/_lancedb.pyi`.
|
||||
3. Add method to `AsyncTable` class in `python/python/lancedb/table.py`.
|
||||
4. Add abstract method to `Table` abstract base class in `python/python/lancedb/table.py`.
|
||||
5. Add concrete sync method to `LanceTable` class in `python/python/lancedb/table.py`.
|
||||
* Should use `LOOP.run()` to call the corresponding `AsyncTable` method.
|
||||
6. Add concrete sync method to `RemoteTable` class in `python/python/lancedb/remote/table.py`.
|
||||
7. Add unit test in `python/tests/test_table.py`.
|
||||
|
||||
TypeScript bindings changes:
|
||||
|
||||
1. Add napi-rs method binding on `Table` in `nodejs/src/table.rs`.
|
||||
2. Run `npm run build` to generate TypeScript definitions.
|
||||
3. Add typescript method on abstract class `Table` in `nodejs/src/table.ts`.
|
||||
4. Add concrete method on `LocalTable` class in `nodejs/src/native_table.ts`.
|
||||
* Note: despite the name, this class is also used for remote tables.
|
||||
5. Add test in `nodejs/__test__/table.test.ts`.
|
||||
6. Run `npm run docs` to generate TypeScript documentation.
|
||||
|
||||
## Review Guidelines
|
||||
|
||||
Please consider the following when reviewing code contributions.
|
||||
|
||||
### Rust API design
|
||||
* Design public APIs so they can be evolved easily in the future without breaking
|
||||
changes. Often this means using builder patterns or options structs instead of
|
||||
long argument lists.
|
||||
* For public APIs, prefer inputs that use `Into<T>` or `AsRef<T>` traits to allow
|
||||
more flexible inputs. For example, use `name: Into<String>` instead of `name: String`,
|
||||
so we don't have to write `func("my_string".to_string())`.
|
||||
|
||||
### Testing
|
||||
* Ensure all new public APIs have documentation and examples.
|
||||
* Ensure that all bugfixes and features have corresponding tests. **We do not merge
|
||||
code without tests.**
|
||||
|
||||
### Documentation
|
||||
* New features must include updates to the rust documentation comments. Link to
|
||||
relevant structs and methods to increase the value of documentation.
|
||||
80
CLAUDE.md
80
CLAUDE.md
@@ -1,80 +0,0 @@
|
||||
LanceDB is a database designed for retrieval, including vector, full-text, and hybrid search.
|
||||
It is a wrapper around Lance. There are two backends: local (in-process like SQLite) and
|
||||
remote (against LanceDB Cloud).
|
||||
|
||||
The core of LanceDB is written in Rust. There are bindings in Python, Typescript, and Java.
|
||||
|
||||
Project layout:
|
||||
|
||||
* `rust/lancedb`: The LanceDB core Rust implementation.
|
||||
* `python`: The Python bindings, using PyO3.
|
||||
* `nodejs`: The Typescript bindings, using napi-rs
|
||||
* `java`: The Java bindings
|
||||
|
||||
Common commands:
|
||||
|
||||
* Check for compiler errors: `cargo check --quiet --features remote --tests --examples`
|
||||
* Run tests: `cargo test --quiet --features remote --tests`
|
||||
* Run specific test: `cargo test --quiet --features remote -p <package_name> --test <test_name>`
|
||||
* Lint: `cargo clippy --quiet --features remote --tests --examples`
|
||||
* Format: `cargo fmt --all`
|
||||
|
||||
Before committing changes, run formatting.
|
||||
|
||||
## Coding tips
|
||||
|
||||
* When writing Rust doctests for things that require a connection or table reference,
|
||||
write them as a function instead of a fully executable test. This allows type checking
|
||||
to run but avoids needing a full test environment. For example:
|
||||
```rust
|
||||
/// ```
|
||||
/// use lance_index::scalar::FullTextSearchQuery;
|
||||
/// use lancedb::query::{QueryBase, ExecutableQuery};
|
||||
///
|
||||
/// # use lancedb::Table;
|
||||
/// # async fn query(table: &Table) -> Result<(), Box<dyn std::error::Error>> {
|
||||
/// let results = table.query()
|
||||
/// .full_text_search(FullTextSearchQuery::new("hello world".into()))
|
||||
/// .execute()
|
||||
/// .await?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
```
|
||||
|
||||
## Example plan: adding a new method on Table
|
||||
|
||||
Adding a new method involves first adding it to the Rust core, then exposing it
|
||||
in the Python and TypeScript bindings. There are both local and remote tables.
|
||||
Remote tables are implemented via a HTTP API and require the `remote` cargo
|
||||
feature flag to be enabled. Python has both sync and async methods.
|
||||
|
||||
Rust core changes:
|
||||
|
||||
1. Add method on `Table` struct in `rust/lancedb/src/table.rs` (calls `BaseTable` trait).
|
||||
2. Add method to `BaseTable` trait in `rust/lancedb/src/table.rs`.
|
||||
3. Implement new trait method on `NativeTable` in `rust/lancedb/src/table.rs`.
|
||||
* Test with unit test in `rust/lancedb/src/table.rs`.
|
||||
4. Implement new trait method on `RemoteTable` in `rust/lancedb/src/remote/table.rs`.
|
||||
* Test with unit test in `rust/lancedb/src/remote/table.rs` against mocked endpoint.
|
||||
|
||||
Python bindings changes:
|
||||
|
||||
1. Add PyO3 method binding in `python/src/table.rs`. Run `make develop` to compile bindings.
|
||||
2. Add types for PyO3 method in `python/python/lancedb/_lancedb.pyi`.
|
||||
3. Add method to `AsyncTable` class in `python/python/lancedb/table.py`.
|
||||
4. Add abstract method to `Table` abstract base class in `python/python/lancedb/table.py`.
|
||||
5. Add concrete sync method to `LanceTable` class in `python/python/lancedb/table.py`.
|
||||
* Should use `LOOP.run()` to call the corresponding `AsyncTable` method.
|
||||
6. Add concrete sync method to `RemoteTable` class in `python/python/lancedb/remote/table.py`.
|
||||
7. Add unit test in `python/tests/test_table.py`.
|
||||
|
||||
TypeScript bindings changes:
|
||||
|
||||
1. Add napi-rs method binding on `Table` in `nodejs/src/table.rs`.
|
||||
2. Run `npm run build` to generate TypeScript definitions.
|
||||
3. Add typescript method on abstract class `Table` in `nodejs/src/table.ts`.
|
||||
4. Add concrete method on `LocalTable` class in `nodejs/src/native_table.ts`.
|
||||
* Note: despite the name, this class is also used for remote tables.
|
||||
5. Add test in `nodejs/__test__/table.test.ts`.
|
||||
6. Run `npm run docs` to generate TypeScript documentation.
|
||||
142
Cargo.lock
generated
142
Cargo.lock
generated
@@ -2933,18 +2933,6 @@ version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f8eb564c5c7423d25c886fb561d1e4ee69f72354d16918afa32c08811f6b6a55"
|
||||
|
||||
[[package]]
|
||||
name = "fastbloom"
|
||||
version = "0.14.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "18c1ddb9231d8554c2d6bdf4cfaabf0c59251658c68b6c95cd52dd0c513a912a"
|
||||
dependencies = [
|
||||
"getrandom 0.3.3",
|
||||
"libm",
|
||||
"rand 0.9.2",
|
||||
"siphasher",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastdivide"
|
||||
version = "0.4.2"
|
||||
@@ -3044,8 +3032,9 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
|
||||
|
||||
[[package]]
|
||||
name = "fsst"
|
||||
version = "0.38.2"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
|
||||
version = "0.39.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d2475ce218217196b161b025598f77e2b405d5e729f7c37bfff145f5df00a41"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"rand 0.9.2",
|
||||
@@ -4229,8 +4218,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance"
|
||||
version = "0.38.2"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
|
||||
version = "0.39.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2f0ca022d0424d991933a62d2898864cf5621873962bd84e65e7d1f023f9c36"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-arith",
|
||||
@@ -4269,6 +4259,7 @@ dependencies = [
|
||||
"lance-index",
|
||||
"lance-io",
|
||||
"lance-linalg",
|
||||
"lance-namespace",
|
||||
"lance-table",
|
||||
"log",
|
||||
"moka",
|
||||
@@ -4279,6 +4270,7 @@ dependencies = [
|
||||
"prost-types",
|
||||
"rand 0.9.2",
|
||||
"roaring",
|
||||
"semver",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"snafu",
|
||||
@@ -4292,8 +4284,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-arrow"
|
||||
version = "0.38.2"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
|
||||
version = "0.39.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7552f8d528775bf0ab21e1f75dcb70bdb2a828eeae58024a803b5a4655fd9a11"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4311,8 +4304,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-bitpacking"
|
||||
version = "0.38.2"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
|
||||
version = "0.39.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2ea14583cc6fa0bb190bcc2d3bc364b0aa545b345702976025f810e4740e8ce"
|
||||
dependencies = [
|
||||
"arrayref",
|
||||
"paste",
|
||||
@@ -4321,8 +4315,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-core"
|
||||
version = "0.38.2"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
|
||||
version = "0.39.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "69c752dedd207384892006c40930f898d6634e05e3d489e89763abfe4b9307e7"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4358,8 +4353,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-datafusion"
|
||||
version = "0.38.2"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
|
||||
version = "0.39.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "21e1e98ca6e5cd337bdda2d9fb66063f295c0c2852d2bc6831366fea833ee608"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -4368,6 +4364,7 @@ dependencies = [
|
||||
"arrow-schema",
|
||||
"arrow-select",
|
||||
"async-trait",
|
||||
"chrono",
|
||||
"datafusion",
|
||||
"datafusion-common",
|
||||
"datafusion-functions",
|
||||
@@ -4387,8 +4384,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-datagen"
|
||||
version = "0.38.2"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
|
||||
version = "0.39.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "483c643fc2806ed1a2766edf4d180511bbd1d549bcc60373e33f4785c6185891"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -4405,8 +4403,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-encoding"
|
||||
version = "0.38.2"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
|
||||
version = "0.39.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a199d1fa3487529c5ffc433fbd1721231330b9350c2ff9b0c7b7dbdb98f0806a"
|
||||
dependencies = [
|
||||
"arrow-arith",
|
||||
"arrow-array",
|
||||
@@ -4443,8 +4442,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-file"
|
||||
version = "0.38.2"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
|
||||
version = "0.39.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b57def2279465232cf5a8cd996300c632442e368745768bbed661c7f0a35334b"
|
||||
dependencies = [
|
||||
"arrow-arith",
|
||||
"arrow-array",
|
||||
@@ -4469,7 +4469,6 @@ dependencies = [
|
||||
"prost",
|
||||
"prost-build",
|
||||
"prost-types",
|
||||
"roaring",
|
||||
"snafu",
|
||||
"tokio",
|
||||
"tracing",
|
||||
@@ -4477,8 +4476,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-index"
|
||||
version = "0.38.2"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
|
||||
version = "0.39.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a75938c61e986aef8c615dc44c92e4c19e393160a59e2b57402ccfe08c5e63af"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-arith",
|
||||
@@ -4500,7 +4500,6 @@ dependencies = [
|
||||
"datafusion-sql",
|
||||
"deepsize",
|
||||
"dirs",
|
||||
"fastbloom",
|
||||
"fst",
|
||||
"futures",
|
||||
"half",
|
||||
@@ -4540,8 +4539,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-io"
|
||||
version = "0.38.2"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
|
||||
version = "0.39.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fa6c3b5b28570d6c951206c5b043f1b35c936928af14fca6f2ac25b0097e4c32"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-arith",
|
||||
@@ -4562,6 +4562,7 @@ dependencies = [
|
||||
"futures",
|
||||
"lance-arrow",
|
||||
"lance-core",
|
||||
"lance-namespace",
|
||||
"log",
|
||||
"object_store",
|
||||
"object_store_opendal",
|
||||
@@ -4580,43 +4581,55 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-linalg"
|
||||
version = "0.38.2"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
|
||||
version = "0.39.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b3cbc7e85a89ff9cb3a4627559dea3fd1c1fb16c0d8bc46ede75eefef51eec06"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
"arrow-ord",
|
||||
"arrow-schema",
|
||||
"bitvec",
|
||||
"cc",
|
||||
"deepsize",
|
||||
"futures",
|
||||
"half",
|
||||
"lance-arrow",
|
||||
"lance-core",
|
||||
"log",
|
||||
"num-traits",
|
||||
"rand 0.9.2",
|
||||
"rayon",
|
||||
"tokio",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lance-namespace"
|
||||
version = "0.0.18"
|
||||
version = "0.39.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7c0629165b5d85ff305f2de8833dcee507e899b36b098864c59f14f3b8b8e62d"
|
||||
checksum = "897dd6726816515bb70a698ce7cda44670dca5761637696d7905b45f405a8cd9"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"lance",
|
||||
"lance-core",
|
||||
"lance-namespace-reqwest-client",
|
||||
"opendal",
|
||||
"snafu",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lance-namespace-impls"
|
||||
version = "0.39.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5e3cfcd3ba369de2719abf6fb6233f69cda639eb5cbcb328487a790e745ab988"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-ipc",
|
||||
"arrow-schema",
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"lance",
|
||||
"lance-core",
|
||||
"lance-io",
|
||||
"lance-namespace",
|
||||
"object_store",
|
||||
"reqwest",
|
||||
"serde_json",
|
||||
"thiserror 1.0.69",
|
||||
"snafu",
|
||||
"url",
|
||||
]
|
||||
|
||||
@@ -4635,8 +4648,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-table"
|
||||
version = "0.38.2"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
|
||||
version = "0.39.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c8facc13760ba034b6c38767b16adba85e44cbcbea8124dc0c63c43865c60630"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -4663,6 +4677,7 @@ dependencies = [
|
||||
"rand 0.9.2",
|
||||
"rangemap",
|
||||
"roaring",
|
||||
"semver",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"snafu",
|
||||
@@ -4674,8 +4689,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-testing"
|
||||
version = "0.38.2"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
|
||||
version = "0.39.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b05052ef86188d6ae6339bdd9f2c5d77190e8ad1158f3dc8a42fa91bde9e5246"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-schema",
|
||||
@@ -4686,7 +4702,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb"
|
||||
version = "0.22.2"
|
||||
version = "0.22.3-beta.5"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"anyhow",
|
||||
@@ -4697,6 +4713,7 @@ dependencies = [
|
||||
"arrow-ipc",
|
||||
"arrow-ord",
|
||||
"arrow-schema",
|
||||
"arrow-select",
|
||||
"async-openai",
|
||||
"async-trait",
|
||||
"aws-config",
|
||||
@@ -4705,13 +4722,11 @@ dependencies = [
|
||||
"aws-sdk-kms",
|
||||
"aws-sdk-s3",
|
||||
"aws-smithy-runtime",
|
||||
"bytemuck_derive",
|
||||
"bytes",
|
||||
"candle-core",
|
||||
"candle-nn",
|
||||
"candle-transformers",
|
||||
"chrono",
|
||||
"crunchy",
|
||||
"datafusion",
|
||||
"datafusion-catalog",
|
||||
"datafusion-common",
|
||||
@@ -4724,6 +4739,7 @@ dependencies = [
|
||||
"http 1.3.1",
|
||||
"http-body 1.0.1",
|
||||
"lance",
|
||||
"lance-arrow",
|
||||
"lance-core",
|
||||
"lance-datafusion",
|
||||
"lance-datagen",
|
||||
@@ -4733,6 +4749,7 @@ dependencies = [
|
||||
"lance-io",
|
||||
"lance-linalg",
|
||||
"lance-namespace",
|
||||
"lance-namespace-impls",
|
||||
"lance-table",
|
||||
"lance-testing",
|
||||
"lazy_static",
|
||||
@@ -4780,7 +4797,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb-nodejs"
|
||||
version = "0.22.2"
|
||||
version = "0.22.3-beta.5"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-ipc",
|
||||
@@ -4800,7 +4817,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb-python"
|
||||
version = "0.25.2"
|
||||
version = "0.25.3-beta.5"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"async-trait",
|
||||
@@ -5160,12 +5177,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "mock_instant"
|
||||
version = "0.3.2"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9366861eb2a2c436c20b12c8dbec5f798cea6b47ad99216be0282942e2c81ea0"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
]
|
||||
checksum = "dce6dd36094cac388f119d2e9dc82dc730ef91c32a6222170d630e5414b956e6"
|
||||
|
||||
[[package]]
|
||||
name = "moka"
|
||||
|
||||
42
Cargo.toml
42
Cargo.toml
@@ -15,18 +15,20 @@ categories = ["database-implementations"]
|
||||
rust-version = "1.78.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.38.2", default-features = false, "features" = ["dynamodb"], "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-core = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-datagen = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-file = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-io = { "version" = "=0.38.2", default-features = false, "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-index = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-linalg = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-table = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-testing = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-datafusion = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-encoding = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-namespace = "0.0.18"
|
||||
lance = { "version" = "=0.39.0", default-features = false }
|
||||
lance-core = "=0.39.0"
|
||||
lance-datagen = "=0.39.0"
|
||||
lance-file = "=0.39.0"
|
||||
lance-io = { "version" = "=0.39.0", default-features = false }
|
||||
lance-index = "=0.39.0"
|
||||
lance-linalg = "=0.39.0"
|
||||
lance-namespace = "=0.39.0"
|
||||
lance-namespace-impls = { "version" = "=0.39.0", "features" = ["dir-aws", "dir-gcp", "dir-azure", "dir-oss", "rest"] }
|
||||
lance-table = "=0.39.0"
|
||||
lance-testing = "=0.39.0"
|
||||
lance-datafusion = "=0.39.0"
|
||||
lance-encoding = "=0.39.0"
|
||||
lance-arrow = "=0.39.0"
|
||||
ahash = "0.8"
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "56.2", optional = false }
|
||||
@@ -35,6 +37,7 @@ arrow-data = "56.2"
|
||||
arrow-ipc = "56.2"
|
||||
arrow-ord = "56.2"
|
||||
arrow-schema = "56.2"
|
||||
arrow-select = "56.2"
|
||||
arrow-cast = "56.2"
|
||||
async-trait = "0"
|
||||
datafusion = { version = "50.1", default-features = false }
|
||||
@@ -59,19 +62,4 @@ num-traits = "0.2"
|
||||
regex = "1.10"
|
||||
lazy_static = "1"
|
||||
semver = "1.0.25"
|
||||
crunchy = "0.2.4"
|
||||
chrono = "0.4"
|
||||
# Workaround for: https://github.com/Lokathor/bytemuck/issues/306
|
||||
bytemuck_derive = ">=1.8.1, <1.9.0"
|
||||
|
||||
# This is only needed when we reference preview releases of lance
|
||||
# Force to use the same lance version as the rest of the project to avoid duplicate dependencies
|
||||
[patch.crates-io]
|
||||
lance = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-io = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-index = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-linalg = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-table = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-testing = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-datafusion = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-encoding = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
|
||||
@@ -55,7 +55,7 @@ def extract_features(line: str) -> list:
|
||||
match = re.search(r'"features"\s*=\s*\[\s*(.*?)\s*\]', line, re.DOTALL)
|
||||
if match:
|
||||
features_str = match.group(1)
|
||||
return [f.strip('"') for f in features_str.split(",") if len(f) > 0]
|
||||
return [f.strip().strip('"') for f in features_str.split(",") if f.strip()]
|
||||
return []
|
||||
|
||||
|
||||
@@ -117,7 +117,7 @@ def update_cargo_toml(line_updater):
|
||||
lance_line = ""
|
||||
is_parsing_lance_line = False
|
||||
for line in lines:
|
||||
if line.startswith("lance") and not line.startswith("lance-namespace"):
|
||||
if line.startswith("lance"):
|
||||
# Check if this is a single-line or multi-line entry
|
||||
# Single-line entries either:
|
||||
# 1. End with } (complete inline table)
|
||||
@@ -183,10 +183,8 @@ def set_preview_version(version: str):
|
||||
|
||||
def line_updater(line: str) -> str:
|
||||
package_name = line.split("=", maxsplit=1)[0].strip()
|
||||
base_version = version.split("-")[0] # Get the base version without beta suffix
|
||||
|
||||
# Build config in desired order: version, default-features, features, tag, git
|
||||
config = {"version": f"={base_version}"}
|
||||
config = {"version": f"={version}"}
|
||||
|
||||
if extract_default_features(line):
|
||||
config["default-features"] = False
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
# VoyageAI Embeddings : Multimodal
|
||||
|
||||
VoyageAI embeddings can also be used to embed both text and image data, only some of the models support image data and you can check the list
|
||||
under [https://docs.voyageai.com/docs/multimodal-embeddings](https://docs.voyageai.com/docs/multimodal-embeddings)
|
||||
|
||||
Supported parameters (to be passed in `create` method) are:
|
||||
|
||||
| Parameter | Type | Default Value | Description |
|
||||
|---|---|-------------------------|-------------------------------------------|
|
||||
| `name` | `str` | `"voyage-multimodal-3"` | The model ID of the VoyageAI model to use |
|
||||
|
||||
Usage Example:
|
||||
|
||||
```python
|
||||
import base64
|
||||
import os
|
||||
from io import BytesIO
|
||||
|
||||
import requests
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.embeddings import get_registry
|
||||
import pandas as pd
|
||||
|
||||
os.environ['VOYAGE_API_KEY'] = 'YOUR_VOYAGE_API_KEY'
|
||||
|
||||
db = lancedb.connect(".lancedb")
|
||||
func = get_registry().get("voyageai").create(name="voyage-multimodal-3")
|
||||
|
||||
|
||||
def image_to_base64(image_bytes: bytes):
|
||||
buffered = BytesIO(image_bytes)
|
||||
img_str = base64.b64encode(buffered.getvalue())
|
||||
return img_str.decode("utf-8")
|
||||
|
||||
|
||||
class Images(LanceModel):
|
||||
label: str
|
||||
image_uri: str = func.SourceField() # image uri as the source
|
||||
image_bytes: str = func.SourceField() # image bytes base64 encoded as the source
|
||||
vector: Vector(func.ndims()) = func.VectorField() # vector column
|
||||
vec_from_bytes: Vector(func.ndims()) = func.VectorField() # Another vector column
|
||||
|
||||
|
||||
if "images" in db.table_names():
|
||||
db.drop_table("images")
|
||||
table = db.create_table("images", schema=Images)
|
||||
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
|
||||
uris = [
|
||||
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
||||
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
||||
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
||||
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
||||
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
||||
]
|
||||
# get each uri as bytes
|
||||
images_bytes = [image_to_base64(requests.get(uri).content) for uri in uris]
|
||||
table.add(
|
||||
pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": images_bytes})
|
||||
)
|
||||
```
|
||||
Now we can search using text from both the default vector column and the custom vector column
|
||||
```python
|
||||
|
||||
# text search
|
||||
actual = table.search("man's best friend", "vec_from_bytes").limit(1).to_pydantic(Images)[0]
|
||||
print(actual.label) # prints "dog"
|
||||
|
||||
frombytes = (
|
||||
table.search("man's best friend", vector_column_name="vec_from_bytes")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
print(frombytes.label)
|
||||
|
||||
```
|
||||
|
||||
Because we're using a multi-modal embedding function, we can also search using images
|
||||
|
||||
```python
|
||||
# image search
|
||||
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
|
||||
image_bytes = requests.get(query_image_uri).content
|
||||
query_image = Image.open(BytesIO(image_bytes))
|
||||
actual = table.search(query_image, "vec_from_bytes").limit(1).to_pydantic(Images)[0]
|
||||
print(actual.label == "dog")
|
||||
|
||||
# image search using a custom vector column
|
||||
other = (
|
||||
table.search(query_image, vector_column_name="vec_from_bytes")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
print(actual.label)
|
||||
|
||||
```
|
||||
@@ -397,117 +397,6 @@ For **read-only access**, LanceDB will need a policy such as:
|
||||
}
|
||||
```
|
||||
|
||||
#### DynamoDB Commit Store for concurrent writes
|
||||
|
||||
By default, S3 does not support concurrent writes. Having two or more processes
|
||||
writing to the same table at the same time can lead to data corruption. This is
|
||||
because S3, unlike other object stores, does not have any atomic put or copy
|
||||
operation.
|
||||
|
||||
To enable concurrent writes, you can configure LanceDB to use a DynamoDB table
|
||||
as a commit store. This table will be used to coordinate writes between
|
||||
different processes. To enable this feature, you must modify your connection
|
||||
URI to use the `s3+ddb` scheme and add a query parameter `ddbTableName` with the
|
||||
name of the table to use.
|
||||
|
||||
=== "Python"
|
||||
|
||||
=== "Sync API"
|
||||
|
||||
```python
|
||||
import lancedb
|
||||
db = lancedb.connect(
|
||||
"s3+ddb://bucket/path?ddbTableName=my-dynamodb-table",
|
||||
)
|
||||
```
|
||||
=== "Async API"
|
||||
|
||||
```python
|
||||
import lancedb
|
||||
async_db = await lancedb.connect_async(
|
||||
"s3+ddb://bucket/path?ddbTableName=my-dynamodb-table",
|
||||
)
|
||||
```
|
||||
|
||||
=== "JavaScript"
|
||||
|
||||
```javascript
|
||||
const lancedb = require("lancedb");
|
||||
|
||||
const db = await lancedb.connect(
|
||||
"s3+ddb://bucket/path?ddbTableName=my-dynamodb-table",
|
||||
);
|
||||
```
|
||||
|
||||
The DynamoDB table must be created with the following schema:
|
||||
|
||||
- Hash key: `base_uri` (string)
|
||||
- Range key: `version` (number)
|
||||
|
||||
You can create this programmatically with:
|
||||
|
||||
=== "Python"
|
||||
|
||||
<!-- skip-test -->
|
||||
```python
|
||||
import boto3
|
||||
|
||||
dynamodb = boto3.client("dynamodb")
|
||||
table = dynamodb.create_table(
|
||||
TableName=table_name,
|
||||
KeySchema=[
|
||||
{"AttributeName": "base_uri", "KeyType": "HASH"},
|
||||
{"AttributeName": "version", "KeyType": "RANGE"},
|
||||
],
|
||||
AttributeDefinitions=[
|
||||
{"AttributeName": "base_uri", "AttributeType": "S"},
|
||||
{"AttributeName": "version", "AttributeType": "N"},
|
||||
],
|
||||
ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1},
|
||||
)
|
||||
```
|
||||
|
||||
=== "JavaScript"
|
||||
|
||||
<!-- skip-test -->
|
||||
```javascript
|
||||
import {
|
||||
CreateTableCommand,
|
||||
DynamoDBClient,
|
||||
} from "@aws-sdk/client-dynamodb";
|
||||
|
||||
const dynamodb = new DynamoDBClient({
|
||||
region: CONFIG.awsRegion,
|
||||
credentials: {
|
||||
accessKeyId: CONFIG.awsAccessKeyId,
|
||||
secretAccessKey: CONFIG.awsSecretAccessKey,
|
||||
},
|
||||
endpoint: CONFIG.awsEndpoint,
|
||||
});
|
||||
const command = new CreateTableCommand({
|
||||
TableName: table_name,
|
||||
AttributeDefinitions: [
|
||||
{
|
||||
AttributeName: "base_uri",
|
||||
AttributeType: "S",
|
||||
},
|
||||
{
|
||||
AttributeName: "version",
|
||||
AttributeType: "N",
|
||||
},
|
||||
],
|
||||
KeySchema: [
|
||||
{ AttributeName: "base_uri", KeyType: "HASH" },
|
||||
{ AttributeName: "version", KeyType: "RANGE" },
|
||||
],
|
||||
ProvisionedThroughput: {
|
||||
ReadCapacityUnits: 1,
|
||||
WriteCapacityUnits: 1,
|
||||
},
|
||||
});
|
||||
await client.send(command);
|
||||
```
|
||||
|
||||
|
||||
#### S3-compatible stores
|
||||
|
||||
|
||||
@@ -64,6 +64,36 @@ builder.filter("age > 18 AND status = 'active'");
|
||||
|
||||
***
|
||||
|
||||
### persist()
|
||||
|
||||
```ts
|
||||
persist(connection, tableName): PermutationBuilder
|
||||
```
|
||||
|
||||
Configure the permutation to be persisted.
|
||||
|
||||
#### Parameters
|
||||
|
||||
* **connection**: [`Connection`](Connection.md)
|
||||
The connection to persist the permutation to
|
||||
|
||||
* **tableName**: `string`
|
||||
The name of the table to create
|
||||
|
||||
#### Returns
|
||||
|
||||
[`PermutationBuilder`](PermutationBuilder.md)
|
||||
|
||||
A new PermutationBuilder instance
|
||||
|
||||
#### Example
|
||||
|
||||
```ts
|
||||
builder.persist(connection, "permutation_table");
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
### shuffle()
|
||||
|
||||
```ts
|
||||
@@ -98,15 +128,15 @@ builder.shuffle({ seed: 42, clumpSize: 10 });
|
||||
### splitCalculated()
|
||||
|
||||
```ts
|
||||
splitCalculated(calculation): PermutationBuilder
|
||||
splitCalculated(options): PermutationBuilder
|
||||
```
|
||||
|
||||
Configure calculated splits for the permutation.
|
||||
|
||||
#### Parameters
|
||||
|
||||
* **calculation**: `string`
|
||||
SQL expression for calculating splits
|
||||
* **options**: [`SplitCalculatedOptions`](../interfaces/SplitCalculatedOptions.md)
|
||||
Configuration for calculated splitting
|
||||
|
||||
#### Returns
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ AnalyzeExec verbose=true, metrics=[]
|
||||
### execute()
|
||||
|
||||
```ts
|
||||
protected execute(options?): RecordBatchIterator
|
||||
protected execute(options?): AsyncGenerator<RecordBatch<any>, void, unknown>
|
||||
```
|
||||
|
||||
Execute the query and return the results as an
|
||||
@@ -91,7 +91,7 @@ Execute the query and return the results as an
|
||||
|
||||
#### Returns
|
||||
|
||||
[`RecordBatchIterator`](RecordBatchIterator.md)
|
||||
`AsyncGenerator`<`RecordBatch`<`any`>, `void`, `unknown`>
|
||||
|
||||
#### See
|
||||
|
||||
@@ -343,6 +343,29 @@ This is useful for pagination.
|
||||
|
||||
***
|
||||
|
||||
### outputSchema()
|
||||
|
||||
```ts
|
||||
outputSchema(): Promise<Schema<any>>
|
||||
```
|
||||
|
||||
Returns the schema of the output that will be returned by this query.
|
||||
|
||||
This can be used to inspect the types and names of the columns that will be
|
||||
returned by the query before executing it.
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<`Schema`<`any`>>
|
||||
|
||||
An Arrow Schema describing the output columns.
|
||||
|
||||
#### Inherited from
|
||||
|
||||
`StandardQueryBase.outputSchema`
|
||||
|
||||
***
|
||||
|
||||
### select()
|
||||
|
||||
```ts
|
||||
|
||||
@@ -81,7 +81,7 @@ AnalyzeExec verbose=true, metrics=[]
|
||||
### execute()
|
||||
|
||||
```ts
|
||||
protected execute(options?): RecordBatchIterator
|
||||
protected execute(options?): AsyncGenerator<RecordBatch<any>, void, unknown>
|
||||
```
|
||||
|
||||
Execute the query and return the results as an
|
||||
@@ -92,7 +92,7 @@ Execute the query and return the results as an
|
||||
|
||||
#### Returns
|
||||
|
||||
[`RecordBatchIterator`](RecordBatchIterator.md)
|
||||
`AsyncGenerator`<`RecordBatch`<`any`>, `void`, `unknown`>
|
||||
|
||||
#### See
|
||||
|
||||
@@ -140,6 +140,25 @@ const plan = await table.query().nearestTo([0.5, 0.2]).explainPlan();
|
||||
|
||||
***
|
||||
|
||||
### outputSchema()
|
||||
|
||||
```ts
|
||||
outputSchema(): Promise<Schema<any>>
|
||||
```
|
||||
|
||||
Returns the schema of the output that will be returned by this query.
|
||||
|
||||
This can be used to inspect the types and names of the columns that will be
|
||||
returned by the query before executing it.
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<`Schema`<`any`>>
|
||||
|
||||
An Arrow Schema describing the output columns.
|
||||
|
||||
***
|
||||
|
||||
### select()
|
||||
|
||||
```ts
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / RecordBatchIterator
|
||||
|
||||
# Class: RecordBatchIterator
|
||||
|
||||
## Implements
|
||||
|
||||
- `AsyncIterator`<`RecordBatch`>
|
||||
|
||||
## Constructors
|
||||
|
||||
### new RecordBatchIterator()
|
||||
|
||||
```ts
|
||||
new RecordBatchIterator(promise?): RecordBatchIterator
|
||||
```
|
||||
|
||||
#### Parameters
|
||||
|
||||
* **promise?**: `Promise`<`RecordBatchIterator`>
|
||||
|
||||
#### Returns
|
||||
|
||||
[`RecordBatchIterator`](RecordBatchIterator.md)
|
||||
|
||||
## Methods
|
||||
|
||||
### next()
|
||||
|
||||
```ts
|
||||
next(): Promise<IteratorResult<RecordBatch<any>, any>>
|
||||
```
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<`IteratorResult`<`RecordBatch`<`any`>, `any`>>
|
||||
|
||||
#### Implementation of
|
||||
|
||||
`AsyncIterator.next`
|
||||
@@ -76,7 +76,7 @@ AnalyzeExec verbose=true, metrics=[]
|
||||
### execute()
|
||||
|
||||
```ts
|
||||
protected execute(options?): RecordBatchIterator
|
||||
protected execute(options?): AsyncGenerator<RecordBatch<any>, void, unknown>
|
||||
```
|
||||
|
||||
Execute the query and return the results as an
|
||||
@@ -87,7 +87,7 @@ Execute the query and return the results as an
|
||||
|
||||
#### Returns
|
||||
|
||||
[`RecordBatchIterator`](RecordBatchIterator.md)
|
||||
`AsyncGenerator`<`RecordBatch`<`any`>, `void`, `unknown`>
|
||||
|
||||
#### See
|
||||
|
||||
@@ -143,6 +143,29 @@ const plan = await table.query().nearestTo([0.5, 0.2]).explainPlan();
|
||||
|
||||
***
|
||||
|
||||
### outputSchema()
|
||||
|
||||
```ts
|
||||
outputSchema(): Promise<Schema<any>>
|
||||
```
|
||||
|
||||
Returns the schema of the output that will be returned by this query.
|
||||
|
||||
This can be used to inspect the types and names of the columns that will be
|
||||
returned by the query before executing it.
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<`Schema`<`any`>>
|
||||
|
||||
An Arrow Schema describing the output columns.
|
||||
|
||||
#### Inherited from
|
||||
|
||||
[`QueryBase`](QueryBase.md).[`outputSchema`](QueryBase.md#outputschema)
|
||||
|
||||
***
|
||||
|
||||
### select()
|
||||
|
||||
```ts
|
||||
|
||||
@@ -221,7 +221,7 @@ also increase the latency of your query. The default value is 1.5*limit.
|
||||
### execute()
|
||||
|
||||
```ts
|
||||
protected execute(options?): RecordBatchIterator
|
||||
protected execute(options?): AsyncGenerator<RecordBatch<any>, void, unknown>
|
||||
```
|
||||
|
||||
Execute the query and return the results as an
|
||||
@@ -232,7 +232,7 @@ Execute the query and return the results as an
|
||||
|
||||
#### Returns
|
||||
|
||||
[`RecordBatchIterator`](RecordBatchIterator.md)
|
||||
`AsyncGenerator`<`RecordBatch`<`any`>, `void`, `unknown`>
|
||||
|
||||
#### See
|
||||
|
||||
@@ -498,6 +498,29 @@ This is useful for pagination.
|
||||
|
||||
***
|
||||
|
||||
### outputSchema()
|
||||
|
||||
```ts
|
||||
outputSchema(): Promise<Schema<any>>
|
||||
```
|
||||
|
||||
Returns the schema of the output that will be returned by this query.
|
||||
|
||||
This can be used to inspect the types and names of the columns that will be
|
||||
returned by the query before executing it.
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<`Schema`<`any`>>
|
||||
|
||||
An Arrow Schema describing the output columns.
|
||||
|
||||
#### Inherited from
|
||||
|
||||
`StandardQueryBase.outputSchema`
|
||||
|
||||
***
|
||||
|
||||
### postfilter()
|
||||
|
||||
```ts
|
||||
|
||||
19
docs/src/js/functions/RecordBatchIterator.md
Normal file
19
docs/src/js/functions/RecordBatchIterator.md
Normal file
@@ -0,0 +1,19 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / RecordBatchIterator
|
||||
|
||||
# Function: RecordBatchIterator()
|
||||
|
||||
```ts
|
||||
function RecordBatchIterator(promisedInner): AsyncGenerator<RecordBatch<any>, void, unknown>
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
* **promisedInner**: `Promise`<`RecordBatchIterator`>
|
||||
|
||||
## Returns
|
||||
|
||||
`AsyncGenerator`<`RecordBatch`<`any`>, `void`, `unknown`>
|
||||
@@ -7,7 +7,7 @@
|
||||
# Function: permutationBuilder()
|
||||
|
||||
```ts
|
||||
function permutationBuilder(table, destTableName): PermutationBuilder
|
||||
function permutationBuilder(table): PermutationBuilder
|
||||
```
|
||||
|
||||
Create a permutation builder for the given table.
|
||||
@@ -17,9 +17,6 @@ Create a permutation builder for the given table.
|
||||
* **table**: [`Table`](../classes/Table.md)
|
||||
The source table to create a permutation from
|
||||
|
||||
* **destTableName**: `string`
|
||||
The name for the destination permutation table
|
||||
|
||||
## Returns
|
||||
|
||||
[`PermutationBuilder`](../classes/PermutationBuilder.md)
|
||||
|
||||
@@ -32,7 +32,6 @@
|
||||
- [PhraseQuery](classes/PhraseQuery.md)
|
||||
- [Query](classes/Query.md)
|
||||
- [QueryBase](classes/QueryBase.md)
|
||||
- [RecordBatchIterator](classes/RecordBatchIterator.md)
|
||||
- [Session](classes/Session.md)
|
||||
- [StaticHeaderProvider](classes/StaticHeaderProvider.md)
|
||||
- [Table](classes/Table.md)
|
||||
@@ -78,6 +77,7 @@
|
||||
- [RemovalStats](interfaces/RemovalStats.md)
|
||||
- [RetryConfig](interfaces/RetryConfig.md)
|
||||
- [ShuffleOptions](interfaces/ShuffleOptions.md)
|
||||
- [SplitCalculatedOptions](interfaces/SplitCalculatedOptions.md)
|
||||
- [SplitHashOptions](interfaces/SplitHashOptions.md)
|
||||
- [SplitRandomOptions](interfaces/SplitRandomOptions.md)
|
||||
- [SplitSequentialOptions](interfaces/SplitSequentialOptions.md)
|
||||
@@ -105,6 +105,7 @@
|
||||
|
||||
## Functions
|
||||
|
||||
- [RecordBatchIterator](functions/RecordBatchIterator.md)
|
||||
- [connect](functions/connect.md)
|
||||
- [makeArrowTable](functions/makeArrowTable.md)
|
||||
- [packBits](functions/packBits.md)
|
||||
|
||||
101
docs/src/js/interfaces/IvfRqOptions.md
Normal file
101
docs/src/js/interfaces/IvfRqOptions.md
Normal file
@@ -0,0 +1,101 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / IvfRqOptions
|
||||
|
||||
# Interface: IvfRqOptions
|
||||
|
||||
## Properties
|
||||
|
||||
### distanceType?
|
||||
|
||||
```ts
|
||||
optional distanceType: "l2" | "cosine" | "dot";
|
||||
```
|
||||
|
||||
Distance type to use to build the index.
|
||||
|
||||
Default value is "l2".
|
||||
|
||||
This is used when training the index to calculate the IVF partitions
|
||||
(vectors are grouped in partitions with similar vectors according to this
|
||||
distance type) and during quantization.
|
||||
|
||||
The distance type used to train an index MUST match the distance type used
|
||||
to search the index. Failure to do so will yield inaccurate results.
|
||||
|
||||
The following distance types are available:
|
||||
|
||||
"l2" - Euclidean distance.
|
||||
"cosine" - Cosine distance.
|
||||
"dot" - Dot product.
|
||||
|
||||
***
|
||||
|
||||
### maxIterations?
|
||||
|
||||
```ts
|
||||
optional maxIterations: number;
|
||||
```
|
||||
|
||||
Max iterations to train IVF kmeans.
|
||||
|
||||
When training an IVF index we use kmeans to calculate the partitions. This parameter
|
||||
controls how many iterations of kmeans to run.
|
||||
|
||||
The default value is 50.
|
||||
|
||||
***
|
||||
|
||||
### numBits?
|
||||
|
||||
```ts
|
||||
optional numBits: number;
|
||||
```
|
||||
|
||||
Number of bits per dimension for residual quantization.
|
||||
|
||||
This value controls how much each residual component is compressed. The more
|
||||
bits, the more accurate the index will be but the slower search. Typical values
|
||||
are small integers; the default is 1 bit per dimension.
|
||||
|
||||
***
|
||||
|
||||
### numPartitions?
|
||||
|
||||
```ts
|
||||
optional numPartitions: number;
|
||||
```
|
||||
|
||||
The number of IVF partitions to create.
|
||||
|
||||
This value should generally scale with the number of rows in the dataset.
|
||||
By default the number of partitions is the square root of the number of
|
||||
rows.
|
||||
|
||||
If this value is too large then the first part of the search (picking the
|
||||
right partition) will be slow. If this value is too small then the second
|
||||
part of the search (searching within a partition) will be slow.
|
||||
|
||||
***
|
||||
|
||||
### sampleRate?
|
||||
|
||||
```ts
|
||||
optional sampleRate: number;
|
||||
```
|
||||
|
||||
The number of vectors, per partition, to sample when training IVF kmeans.
|
||||
|
||||
When an IVF index is trained, we need to calculate partitions. These are groups
|
||||
of vectors that are similar to each other. To do this we use an algorithm called kmeans.
|
||||
|
||||
Running kmeans on a large dataset can be slow. To speed this up we run kmeans on a
|
||||
random sample of the data. This parameter controls the size of the sample. The total
|
||||
number of vectors used to train the index is `sample_rate * num_partitions`.
|
||||
|
||||
Increasing this value might improve the quality of the index but in most cases the
|
||||
default should be sufficient.
|
||||
|
||||
The default value is 256.
|
||||
23
docs/src/js/interfaces/SplitCalculatedOptions.md
Normal file
23
docs/src/js/interfaces/SplitCalculatedOptions.md
Normal file
@@ -0,0 +1,23 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / SplitCalculatedOptions
|
||||
|
||||
# Interface: SplitCalculatedOptions
|
||||
|
||||
## Properties
|
||||
|
||||
### calculation
|
||||
|
||||
```ts
|
||||
calculation: string;
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
### splitNames?
|
||||
|
||||
```ts
|
||||
optional splitNames: string[];
|
||||
```
|
||||
@@ -24,6 +24,14 @@ optional discardWeight: number;
|
||||
|
||||
***
|
||||
|
||||
### splitNames?
|
||||
|
||||
```ts
|
||||
optional splitNames: string[];
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
### splitWeights
|
||||
|
||||
```ts
|
||||
|
||||
@@ -37,3 +37,11 @@ optional ratios: number[];
|
||||
```ts
|
||||
optional seed: number;
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
### splitNames?
|
||||
|
||||
```ts
|
||||
optional splitNames: string[];
|
||||
```
|
||||
|
||||
@@ -29,3 +29,11 @@ optional fixed: number;
|
||||
```ts
|
||||
optional ratios: number[];
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
### splitNames?
|
||||
|
||||
```ts
|
||||
optional splitNames: string[];
|
||||
```
|
||||
|
||||
@@ -51,8 +51,11 @@ pub enum Error {
|
||||
DatasetAlreadyExists { uri: String, location: Location },
|
||||
#[snafu(display("Table '{name}' already exists"))]
|
||||
TableAlreadyExists { name: String },
|
||||
#[snafu(display("Table '{name}' was not found"))]
|
||||
TableNotFound { name: String },
|
||||
#[snafu(display("Table '{name}' was not found: {source}"))]
|
||||
TableNotFound {
|
||||
name: String,
|
||||
source: Box<dyn std::error::Error + Send + Sync>,
|
||||
},
|
||||
#[snafu(display("Invalid table name '{name}': {reason}"))]
|
||||
InvalidTableName { name: String, reason: String },
|
||||
#[snafu(display("Embedding function '{name}' was not found: {reason}, {location}"))]
|
||||
@@ -191,7 +194,7 @@ impl From<lancedb::Error> for Error {
|
||||
message,
|
||||
location: std::panic::Location::caller().to_snafu_location(),
|
||||
},
|
||||
lancedb::Error::TableNotFound { name } => Self::TableNotFound { name },
|
||||
lancedb::Error::TableNotFound { name, source } => Self::TableNotFound { name, source },
|
||||
lancedb::Error::TableAlreadyExists { name } => Self::TableAlreadyExists { name },
|
||||
lancedb::Error::EmbeddingFunctionNotFound { name, reason } => {
|
||||
Self::EmbeddingFunctionNotFound {
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.22.2-final.0</version>
|
||||
<version>0.22.3-beta.5</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.22.2-final.0</version>
|
||||
<version>0.22.3-beta.5</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.22.2-final.0</version>
|
||||
<version>0.22.3-beta.5</version>
|
||||
<packaging>pom</packaging>
|
||||
<name>${project.artifactId}</name>
|
||||
<description>LanceDB Java SDK Parent POM</description>
|
||||
|
||||
13
nodejs/AGENTS.md
Normal file
13
nodejs/AGENTS.md
Normal file
@@ -0,0 +1,13 @@
|
||||
These are the typescript bindings of LanceDB.
|
||||
The core Rust library is in the `../rust/lancedb` directory, the rust binding
|
||||
code is in the `src/` directory and the typescript bindings are in
|
||||
the `lancedb/` directory.
|
||||
|
||||
Whenever you change the Rust code, you will need to recompile: `npm run build`.
|
||||
|
||||
Common commands:
|
||||
* Build: `npm run build`
|
||||
* Lint: `npm run lint`
|
||||
* Fix lints: `npm run lint-fix`
|
||||
* Test: `npm test`
|
||||
* Run single test file: `npm test __test__/arrow.test.ts`
|
||||
@@ -1,13 +0,0 @@
|
||||
These are the typescript bindings of LanceDB.
|
||||
The core Rust library is in the `../rust/lancedb` directory, the rust binding
|
||||
code is in the `src/` directory and the typescript bindings are in
|
||||
the `lancedb/` directory.
|
||||
|
||||
Whenever you change the Rust code, you will need to recompile: `npm run build`.
|
||||
|
||||
Common commands:
|
||||
* Build: `npm run build`
|
||||
* Lint: `npm run lint`
|
||||
* Fix lints: `npm run lint-fix`
|
||||
* Test: `npm test`
|
||||
* Run single test file: `npm test __test__/arrow.test.ts`
|
||||
1
nodejs/CLAUDE.md
Symbolic link
1
nodejs/CLAUDE.md
Symbolic link
@@ -0,0 +1 @@
|
||||
AGENTS.md
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "lancedb-nodejs"
|
||||
edition.workspace = true
|
||||
version = "0.22.2"
|
||||
version = "0.22.3-beta.5"
|
||||
license.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
|
||||
@@ -38,23 +38,22 @@ describe("PermutationBuilder", () => {
|
||||
});
|
||||
|
||||
test("should create permutation builder", () => {
|
||||
const builder = permutationBuilder(table, "permutation_table");
|
||||
const builder = permutationBuilder(table);
|
||||
expect(builder).toBeDefined();
|
||||
});
|
||||
|
||||
test("should execute basic permutation", async () => {
|
||||
const builder = permutationBuilder(table, "permutation_table");
|
||||
const builder = permutationBuilder(table);
|
||||
const permutationTable = await builder.execute();
|
||||
|
||||
expect(permutationTable).toBeDefined();
|
||||
expect(permutationTable.name).toBe("permutation_table");
|
||||
|
||||
const rowCount = await permutationTable.countRows();
|
||||
expect(rowCount).toBe(10);
|
||||
});
|
||||
|
||||
test("should create permutation with random splits", async () => {
|
||||
const builder = permutationBuilder(table, "permutation_table").splitRandom({
|
||||
const builder = permutationBuilder(table).splitRandom({
|
||||
ratios: [1.0],
|
||||
seed: 42,
|
||||
});
|
||||
@@ -65,7 +64,7 @@ describe("PermutationBuilder", () => {
|
||||
});
|
||||
|
||||
test("should create permutation with percentage splits", async () => {
|
||||
const builder = permutationBuilder(table, "permutation_table").splitRandom({
|
||||
const builder = permutationBuilder(table).splitRandom({
|
||||
ratios: [0.3, 0.7],
|
||||
seed: 42,
|
||||
});
|
||||
@@ -84,7 +83,7 @@ describe("PermutationBuilder", () => {
|
||||
});
|
||||
|
||||
test("should create permutation with count splits", async () => {
|
||||
const builder = permutationBuilder(table, "permutation_table").splitRandom({
|
||||
const builder = permutationBuilder(table).splitRandom({
|
||||
counts: [3, 7],
|
||||
seed: 42,
|
||||
});
|
||||
@@ -102,7 +101,7 @@ describe("PermutationBuilder", () => {
|
||||
});
|
||||
|
||||
test("should create permutation with hash splits", async () => {
|
||||
const builder = permutationBuilder(table, "permutation_table").splitHash({
|
||||
const builder = permutationBuilder(table).splitHash({
|
||||
columns: ["id"],
|
||||
splitWeights: [50, 50],
|
||||
discardWeight: 0,
|
||||
@@ -122,10 +121,9 @@ describe("PermutationBuilder", () => {
|
||||
});
|
||||
|
||||
test("should create permutation with sequential splits", async () => {
|
||||
const builder = permutationBuilder(
|
||||
table,
|
||||
"permutation_table",
|
||||
).splitSequential({ ratios: [0.5, 0.5] });
|
||||
const builder = permutationBuilder(table).splitSequential({
|
||||
ratios: [0.5, 0.5],
|
||||
});
|
||||
|
||||
const permutationTable = await builder.execute();
|
||||
const rowCount = await permutationTable.countRows();
|
||||
@@ -140,10 +138,9 @@ describe("PermutationBuilder", () => {
|
||||
});
|
||||
|
||||
test("should create permutation with calculated splits", async () => {
|
||||
const builder = permutationBuilder(
|
||||
table,
|
||||
"permutation_table",
|
||||
).splitCalculated("id % 2");
|
||||
const builder = permutationBuilder(table).splitCalculated({
|
||||
calculation: "id % 2",
|
||||
});
|
||||
|
||||
const permutationTable = await builder.execute();
|
||||
const rowCount = await permutationTable.countRows();
|
||||
@@ -159,7 +156,7 @@ describe("PermutationBuilder", () => {
|
||||
});
|
||||
|
||||
test("should create permutation with shuffle", async () => {
|
||||
const builder = permutationBuilder(table, "permutation_table").shuffle({
|
||||
const builder = permutationBuilder(table).shuffle({
|
||||
seed: 42,
|
||||
});
|
||||
|
||||
@@ -169,7 +166,7 @@ describe("PermutationBuilder", () => {
|
||||
});
|
||||
|
||||
test("should create permutation with shuffle and clump size", async () => {
|
||||
const builder = permutationBuilder(table, "permutation_table").shuffle({
|
||||
const builder = permutationBuilder(table).shuffle({
|
||||
seed: 42,
|
||||
clumpSize: 2,
|
||||
});
|
||||
@@ -180,9 +177,7 @@ describe("PermutationBuilder", () => {
|
||||
});
|
||||
|
||||
test("should create permutation with filter", async () => {
|
||||
const builder = permutationBuilder(table, "permutation_table").filter(
|
||||
"value > 50",
|
||||
);
|
||||
const builder = permutationBuilder(table).filter("value > 50");
|
||||
|
||||
const permutationTable = await builder.execute();
|
||||
const rowCount = await permutationTable.countRows();
|
||||
@@ -190,7 +185,7 @@ describe("PermutationBuilder", () => {
|
||||
});
|
||||
|
||||
test("should chain multiple operations", async () => {
|
||||
const builder = permutationBuilder(table, "permutation_table")
|
||||
const builder = permutationBuilder(table)
|
||||
.filter("value <= 80")
|
||||
.splitRandom({ ratios: [0.5, 0.5], seed: 42 })
|
||||
.shuffle({ seed: 123 });
|
||||
@@ -209,7 +204,7 @@ describe("PermutationBuilder", () => {
|
||||
});
|
||||
|
||||
test("should throw error for invalid split arguments", () => {
|
||||
const builder = permutationBuilder(table, "permutation_table");
|
||||
const builder = permutationBuilder(table);
|
||||
|
||||
// Test no arguments provided
|
||||
expect(() => builder.splitRandom({})).toThrow(
|
||||
@@ -223,7 +218,7 @@ describe("PermutationBuilder", () => {
|
||||
});
|
||||
|
||||
test("should throw error when builder is consumed", async () => {
|
||||
const builder = permutationBuilder(table, "permutation_table");
|
||||
const builder = permutationBuilder(table);
|
||||
|
||||
// Execute once
|
||||
await builder.execute();
|
||||
@@ -231,4 +226,146 @@ describe("PermutationBuilder", () => {
|
||||
// Should throw error on second execution
|
||||
await expect(builder.execute()).rejects.toThrow("Builder already consumed");
|
||||
});
|
||||
|
||||
test("should accept custom split names with random splits", async () => {
|
||||
const builder = permutationBuilder(table).splitRandom({
|
||||
ratios: [0.3, 0.7],
|
||||
seed: 42,
|
||||
splitNames: ["train", "test"],
|
||||
});
|
||||
|
||||
const permutationTable = await builder.execute();
|
||||
const rowCount = await permutationTable.countRows();
|
||||
expect(rowCount).toBe(10);
|
||||
|
||||
// Split names are provided but split_id is still numeric (0, 1, etc.)
|
||||
// The names are metadata that can be used by higher-level APIs
|
||||
const split0Count = await permutationTable.countRows("split_id = 0");
|
||||
const split1Count = await permutationTable.countRows("split_id = 1");
|
||||
|
||||
expect(split0Count).toBeGreaterThan(0);
|
||||
expect(split1Count).toBeGreaterThan(0);
|
||||
expect(split0Count + split1Count).toBe(10);
|
||||
});
|
||||
|
||||
test("should accept custom split names with hash splits", async () => {
|
||||
const builder = permutationBuilder(table).splitHash({
|
||||
columns: ["id"],
|
||||
splitWeights: [50, 50],
|
||||
discardWeight: 0,
|
||||
splitNames: ["set_a", "set_b"],
|
||||
});
|
||||
|
||||
const permutationTable = await builder.execute();
|
||||
const rowCount = await permutationTable.countRows();
|
||||
expect(rowCount).toBe(10);
|
||||
|
||||
// Split names are provided but split_id is still numeric
|
||||
const split0Count = await permutationTable.countRows("split_id = 0");
|
||||
const split1Count = await permutationTable.countRows("split_id = 1");
|
||||
|
||||
expect(split0Count).toBeGreaterThan(0);
|
||||
expect(split1Count).toBeGreaterThan(0);
|
||||
expect(split0Count + split1Count).toBe(10);
|
||||
});
|
||||
|
||||
test("should accept custom split names with sequential splits", async () => {
|
||||
const builder = permutationBuilder(table).splitSequential({
|
||||
ratios: [0.5, 0.5],
|
||||
splitNames: ["first", "second"],
|
||||
});
|
||||
|
||||
const permutationTable = await builder.execute();
|
||||
const rowCount = await permutationTable.countRows();
|
||||
expect(rowCount).toBe(10);
|
||||
|
||||
// Split names are provided but split_id is still numeric
|
||||
const split0Count = await permutationTable.countRows("split_id = 0");
|
||||
const split1Count = await permutationTable.countRows("split_id = 1");
|
||||
|
||||
expect(split0Count).toBe(5);
|
||||
expect(split1Count).toBe(5);
|
||||
});
|
||||
|
||||
test("should accept custom split names with calculated splits", async () => {
|
||||
const builder = permutationBuilder(table).splitCalculated({
|
||||
calculation: "id % 2",
|
||||
splitNames: ["even", "odd"],
|
||||
});
|
||||
|
||||
const permutationTable = await builder.execute();
|
||||
const rowCount = await permutationTable.countRows();
|
||||
expect(rowCount).toBe(10);
|
||||
|
||||
// Split names are provided but split_id is still numeric
|
||||
const split0Count = await permutationTable.countRows("split_id = 0");
|
||||
const split1Count = await permutationTable.countRows("split_id = 1");
|
||||
|
||||
expect(split0Count).toBeGreaterThan(0);
|
||||
expect(split1Count).toBeGreaterThan(0);
|
||||
expect(split0Count + split1Count).toBe(10);
|
||||
});
|
||||
|
||||
test("should persist permutation to a new table", async () => {
|
||||
const db = await connect(tmpDir.name);
|
||||
const builder = permutationBuilder(table)
|
||||
.splitRandom({
|
||||
ratios: [0.7, 0.3],
|
||||
seed: 42,
|
||||
splitNames: ["train", "validation"],
|
||||
})
|
||||
.persist(db, "my_permutation");
|
||||
|
||||
// Execute the builder which will persist the table
|
||||
const permutationTable = await builder.execute();
|
||||
|
||||
// Verify the persisted table exists and can be opened
|
||||
const persistedTable = await db.openTable("my_permutation");
|
||||
expect(persistedTable).toBeDefined();
|
||||
|
||||
// Verify the persisted table has the correct number of rows
|
||||
const rowCount = await persistedTable.countRows();
|
||||
expect(rowCount).toBe(10);
|
||||
|
||||
// Verify splits exist (numeric split_id values)
|
||||
const split0Count = await persistedTable.countRows("split_id = 0");
|
||||
const split1Count = await persistedTable.countRows("split_id = 1");
|
||||
|
||||
expect(split0Count).toBeGreaterThan(0);
|
||||
expect(split1Count).toBeGreaterThan(0);
|
||||
expect(split0Count + split1Count).toBe(10);
|
||||
|
||||
// Verify the table returned by execute is the same as the persisted one
|
||||
const executedRowCount = await permutationTable.countRows();
|
||||
expect(executedRowCount).toBe(10);
|
||||
});
|
||||
|
||||
test("should persist permutation with multiple operations", async () => {
|
||||
const db = await connect(tmpDir.name);
|
||||
const builder = permutationBuilder(table)
|
||||
.filter("value > 30")
|
||||
.splitRandom({ ratios: [0.5, 0.5], seed: 123, splitNames: ["a", "b"] })
|
||||
.shuffle({ seed: 456 })
|
||||
.persist(db, "filtered_permutation");
|
||||
|
||||
// Execute the builder
|
||||
const permutationTable = await builder.execute();
|
||||
|
||||
// Verify the persisted table
|
||||
const persistedTable = await db.openTable("filtered_permutation");
|
||||
const rowCount = await persistedTable.countRows();
|
||||
expect(rowCount).toBe(7); // Values 40, 50, 60, 70, 80, 90, 100
|
||||
|
||||
// Verify splits exist (numeric split_id values)
|
||||
const split0Count = await persistedTable.countRows("split_id = 0");
|
||||
const split1Count = await persistedTable.countRows("split_id = 1");
|
||||
|
||||
expect(split0Count).toBeGreaterThan(0);
|
||||
expect(split1Count).toBeGreaterThan(0);
|
||||
expect(split0Count + split1Count).toBe(7);
|
||||
|
||||
// Verify the executed table matches
|
||||
const executedRowCount = await permutationTable.countRows();
|
||||
expect(executedRowCount).toBe(7);
|
||||
});
|
||||
});
|
||||
|
||||
111
nodejs/__test__/query.test.ts
Normal file
111
nodejs/__test__/query.test.ts
Normal file
@@ -0,0 +1,111 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import * as tmp from "tmp";
|
||||
|
||||
import { type Table, connect } from "../lancedb";
|
||||
import {
|
||||
Field,
|
||||
FixedSizeList,
|
||||
Float32,
|
||||
Int64,
|
||||
Schema,
|
||||
Utf8,
|
||||
makeArrowTable,
|
||||
} from "../lancedb/arrow";
|
||||
import { Index } from "../lancedb/indices";
|
||||
|
||||
describe("Query outputSchema", () => {
|
||||
let tmpDir: tmp.DirResult;
|
||||
let table: Table;
|
||||
|
||||
beforeEach(async () => {
|
||||
tmpDir = tmp.dirSync({ unsafeCleanup: true });
|
||||
const db = await connect(tmpDir.name);
|
||||
|
||||
// Create table with explicit schema to ensure proper types
|
||||
const schema = new Schema([
|
||||
new Field("a", new Int64(), true),
|
||||
new Field("text", new Utf8(), true),
|
||||
new Field(
|
||||
"vec",
|
||||
new FixedSizeList(2, new Field("item", new Float32())),
|
||||
true,
|
||||
),
|
||||
]);
|
||||
|
||||
const data = makeArrowTable(
|
||||
[
|
||||
{ a: 1n, text: "foo", vec: [1, 2] },
|
||||
{ a: 2n, text: "bar", vec: [3, 4] },
|
||||
{ a: 3n, text: "baz", vec: [5, 6] },
|
||||
],
|
||||
{ schema },
|
||||
);
|
||||
table = await db.createTable("test", data);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
tmpDir.removeCallback();
|
||||
});
|
||||
|
||||
it("should return schema for plain query", async () => {
|
||||
const schema = await table.query().outputSchema();
|
||||
|
||||
expect(schema.fields.length).toBe(3);
|
||||
expect(schema.fields.map((f) => f.name)).toEqual(["a", "text", "vec"]);
|
||||
expect(schema.fields[0].type.toString()).toBe("Int64");
|
||||
expect(schema.fields[1].type.toString()).toBe("Utf8");
|
||||
});
|
||||
|
||||
it("should return schema with dynamic projection", async () => {
|
||||
const schema = await table.query().select({ bl: "a * 2" }).outputSchema();
|
||||
|
||||
expect(schema.fields.length).toBe(1);
|
||||
expect(schema.fields[0].name).toBe("bl");
|
||||
expect(schema.fields[0].type.toString()).toBe("Int64");
|
||||
});
|
||||
|
||||
it("should return schema for vector search with _distance column", async () => {
|
||||
const schema = await table
|
||||
.vectorSearch([1, 2])
|
||||
.select(["a"])
|
||||
.outputSchema();
|
||||
|
||||
expect(schema.fields.length).toBe(2);
|
||||
expect(schema.fields.map((f) => f.name)).toEqual(["a", "_distance"]);
|
||||
expect(schema.fields[0].type.toString()).toBe("Int64");
|
||||
expect(schema.fields[1].type.toString()).toBe("Float32");
|
||||
});
|
||||
|
||||
it("should return schema for FTS search", async () => {
|
||||
await table.createIndex("text", { config: Index.fts() });
|
||||
|
||||
const schema = await table
|
||||
.search("foo", "fts")
|
||||
.select(["a"])
|
||||
.outputSchema();
|
||||
|
||||
// FTS search includes _score column in addition to selected columns
|
||||
expect(schema.fields.length).toBe(2);
|
||||
expect(schema.fields.map((f) => f.name)).toContain("a");
|
||||
expect(schema.fields.map((f) => f.name)).toContain("_score");
|
||||
const aField = schema.fields.find((f) => f.name === "a");
|
||||
expect(aField?.type.toString()).toBe("Int64");
|
||||
});
|
||||
|
||||
it("should return schema for take query", async () => {
|
||||
const schema = await table.takeOffsets([0]).select(["text"]).outputSchema();
|
||||
|
||||
expect(schema.fields.length).toBe(1);
|
||||
expect(schema.fields[0].name).toBe("text");
|
||||
expect(schema.fields[0].type.toString()).toBe("Utf8");
|
||||
});
|
||||
|
||||
it("should return full schema when no select is specified", async () => {
|
||||
const schema = await table.query().outputSchema();
|
||||
|
||||
// Should return all columns
|
||||
expect(schema.fields.length).toBe(3);
|
||||
});
|
||||
});
|
||||
@@ -43,6 +43,7 @@ export {
|
||||
DeleteResult,
|
||||
DropColumnsResult,
|
||||
UpdateResult,
|
||||
SplitCalculatedOptions,
|
||||
SplitRandomOptions,
|
||||
SplitHashOptions,
|
||||
SplitSequentialOptions,
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import { Connection, LocalConnection } from "./connection.js";
|
||||
import {
|
||||
PermutationBuilder as NativePermutationBuilder,
|
||||
Table as NativeTable,
|
||||
ShuffleOptions,
|
||||
SplitCalculatedOptions,
|
||||
SplitHashOptions,
|
||||
SplitRandomOptions,
|
||||
SplitSequentialOptions,
|
||||
@@ -29,6 +31,23 @@ export class PermutationBuilder {
|
||||
this.inner = inner;
|
||||
}
|
||||
|
||||
/**
|
||||
* Configure the permutation to be persisted.
|
||||
*
|
||||
* @param connection - The connection to persist the permutation to
|
||||
* @param tableName - The name of the table to create
|
||||
* @returns A new PermutationBuilder instance
|
||||
* @example
|
||||
* ```ts
|
||||
* builder.persist(connection, "permutation_table");
|
||||
* ```
|
||||
*/
|
||||
persist(connection: Connection, tableName: string): PermutationBuilder {
|
||||
const localConnection = connection as LocalConnection;
|
||||
const newInner = this.inner.persist(localConnection.inner, tableName);
|
||||
return new PermutationBuilder(newInner);
|
||||
}
|
||||
|
||||
/**
|
||||
* Configure random splits for the permutation.
|
||||
*
|
||||
@@ -95,15 +114,15 @@ export class PermutationBuilder {
|
||||
/**
|
||||
* Configure calculated splits for the permutation.
|
||||
*
|
||||
* @param calculation - SQL expression for calculating splits
|
||||
* @param options - Configuration for calculated splitting
|
||||
* @returns A new PermutationBuilder instance
|
||||
* @example
|
||||
* ```ts
|
||||
* builder.splitCalculated("user_id % 3");
|
||||
* ```
|
||||
*/
|
||||
splitCalculated(calculation: string): PermutationBuilder {
|
||||
const newInner = this.inner.splitCalculated(calculation);
|
||||
splitCalculated(options: SplitCalculatedOptions): PermutationBuilder {
|
||||
const newInner = this.inner.splitCalculated(options);
|
||||
return new PermutationBuilder(newInner);
|
||||
}
|
||||
|
||||
@@ -161,7 +180,6 @@ export class PermutationBuilder {
|
||||
* Create a permutation builder for the given table.
|
||||
*
|
||||
* @param table - The source table to create a permutation from
|
||||
* @param destTableName - The name for the destination permutation table
|
||||
* @returns A PermutationBuilder instance
|
||||
* @example
|
||||
* ```ts
|
||||
@@ -172,17 +190,13 @@ export class PermutationBuilder {
|
||||
* const trainingTable = await builder.execute();
|
||||
* ```
|
||||
*/
|
||||
export function permutationBuilder(
|
||||
table: Table,
|
||||
destTableName: string,
|
||||
): PermutationBuilder {
|
||||
export function permutationBuilder(table: Table): PermutationBuilder {
|
||||
// Extract the inner native table from the TypeScript wrapper
|
||||
const localTable = table as LocalTable;
|
||||
// Access inner through type assertion since it's private
|
||||
const nativeBuilder = nativePermutationBuilder(
|
||||
// biome-ignore lint/suspicious/noExplicitAny: need access to private variable
|
||||
(localTable as any).inner,
|
||||
destTableName,
|
||||
);
|
||||
return new PermutationBuilder(nativeBuilder);
|
||||
}
|
||||
|
||||
@@ -20,35 +20,25 @@ import {
|
||||
} from "./native";
|
||||
import { Reranker } from "./rerankers";
|
||||
|
||||
export class RecordBatchIterator implements AsyncIterator<RecordBatch> {
|
||||
private promisedInner?: Promise<NativeBatchIterator>;
|
||||
private inner?: NativeBatchIterator;
|
||||
export async function* RecordBatchIterator(
|
||||
promisedInner: Promise<NativeBatchIterator>,
|
||||
) {
|
||||
const inner = await promisedInner;
|
||||
|
||||
constructor(promise?: Promise<NativeBatchIterator>) {
|
||||
// TODO: check promise reliably so we dont need to pass two arguments.
|
||||
this.promisedInner = promise;
|
||||
if (inner === undefined) {
|
||||
throw new Error("Invalid iterator state");
|
||||
}
|
||||
|
||||
// biome-ignore lint/suspicious/noExplicitAny: skip
|
||||
async next(): Promise<IteratorResult<RecordBatch<any>>> {
|
||||
if (this.inner === undefined) {
|
||||
this.inner = await this.promisedInner;
|
||||
}
|
||||
if (this.inner === undefined) {
|
||||
throw new Error("Invalid iterator state state");
|
||||
}
|
||||
const n = await this.inner.next();
|
||||
if (n == null) {
|
||||
return Promise.resolve({ done: true, value: null });
|
||||
}
|
||||
const tbl = tableFromIPC(n);
|
||||
if (tbl.batches.length != 1) {
|
||||
for (let buffer = await inner.next(); buffer; buffer = await inner.next()) {
|
||||
const { batches } = tableFromIPC(buffer);
|
||||
|
||||
if (batches.length !== 1) {
|
||||
throw new Error("Expected only one batch");
|
||||
}
|
||||
return Promise.resolve({ done: false, value: tbl.batches[0] });
|
||||
|
||||
yield batches[0];
|
||||
}
|
||||
}
|
||||
/* eslint-enable */
|
||||
|
||||
class RecordBatchIterable<
|
||||
NativeQueryType extends NativeQuery | NativeVectorQuery | NativeTakeQuery,
|
||||
@@ -64,7 +54,7 @@ class RecordBatchIterable<
|
||||
|
||||
// biome-ignore lint/suspicious/noExplicitAny: skip
|
||||
[Symbol.asyncIterator](): AsyncIterator<RecordBatch<any>, any, undefined> {
|
||||
return new RecordBatchIterator(
|
||||
return RecordBatchIterator(
|
||||
this.inner.execute(this.options?.maxBatchLength, this.options?.timeoutMs),
|
||||
);
|
||||
}
|
||||
@@ -231,10 +221,8 @@ export class QueryBase<
|
||||
* single query)
|
||||
*
|
||||
*/
|
||||
protected execute(
|
||||
options?: Partial<QueryExecutionOptions>,
|
||||
): RecordBatchIterator {
|
||||
return new RecordBatchIterator(this.nativeExecute(options));
|
||||
protected execute(options?: Partial<QueryExecutionOptions>) {
|
||||
return RecordBatchIterator(this.nativeExecute(options));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -242,8 +230,7 @@ export class QueryBase<
|
||||
*/
|
||||
// biome-ignore lint/suspicious/noExplicitAny: skip
|
||||
[Symbol.asyncIterator](): AsyncIterator<RecordBatch<any>> {
|
||||
const promise = this.nativeExecute();
|
||||
return new RecordBatchIterator(promise);
|
||||
return RecordBatchIterator(this.nativeExecute());
|
||||
}
|
||||
|
||||
/** Collect the results as an Arrow @see {@link ArrowTable}. */
|
||||
@@ -326,6 +313,25 @@ export class QueryBase<
|
||||
return this.inner.analyzePlan();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the schema of the output that will be returned by this query.
|
||||
*
|
||||
* This can be used to inspect the types and names of the columns that will be
|
||||
* returned by the query before executing it.
|
||||
*
|
||||
* @returns An Arrow Schema describing the output columns.
|
||||
*/
|
||||
async outputSchema(): Promise<import("./arrow").Schema> {
|
||||
let schemaBuffer: Buffer;
|
||||
if (this.inner instanceof Promise) {
|
||||
schemaBuffer = await this.inner.then((inner) => inner.outputSchema());
|
||||
} else {
|
||||
schemaBuffer = await this.inner.outputSchema();
|
||||
}
|
||||
const schema = tableFromIPC(schemaBuffer).schema;
|
||||
return schema;
|
||||
}
|
||||
}
|
||||
|
||||
export class StandardQueryBase<
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-arm64",
|
||||
"version": "0.22.2",
|
||||
"version": "0.22.3-beta.5",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.darwin-arm64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-x64",
|
||||
"version": "0.22.2",
|
||||
"version": "0.22.3-beta.5",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.darwin-x64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||
"version": "0.22.2",
|
||||
"version": "0.22.3-beta.5",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||
"version": "0.22.2",
|
||||
"version": "0.22.3-beta.5",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||
"version": "0.22.2",
|
||||
"version": "0.22.3-beta.5",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||
"version": "0.22.2",
|
||||
"version": "0.22.3-beta.5",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||
"version": "0.22.2",
|
||||
"version": "0.22.3-beta.5",
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||
"version": "0.22.2",
|
||||
"version": "0.22.3-beta.5",
|
||||
"os": ["win32"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.win32-x64-msvc.node",
|
||||
|
||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.22.2",
|
||||
"version": "0.22.3-beta.5",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.22.2",
|
||||
"version": "0.22.3-beta.5",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
"ann"
|
||||
],
|
||||
"private": false,
|
||||
"version": "0.22.2",
|
||||
"version": "0.22.3-beta.5",
|
||||
"main": "dist/index.js",
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use lancedb::database::CreateTableMode;
|
||||
use lancedb::database::{CreateTableMode, Database};
|
||||
use napi::bindgen_prelude::*;
|
||||
use napi_derive::*;
|
||||
|
||||
@@ -41,6 +41,10 @@ impl Connection {
|
||||
_ => Err(napi::Error::from_reason(format!("Invalid mode {}", mode))),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn database(&self) -> napi::Result<Arc<dyn Database>> {
|
||||
Ok(self.get_inner()?.database().clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[napi]
|
||||
|
||||
@@ -5,8 +5,8 @@ use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::{error::NapiErrorExt, table::Table};
|
||||
use lancedb::dataloader::{
|
||||
permutation::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
|
||||
split::{SplitSizes, SplitStrategy},
|
||||
permutation::builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
|
||||
permutation::split::{SplitSizes, SplitStrategy},
|
||||
};
|
||||
use napi_derive::napi;
|
||||
|
||||
@@ -16,6 +16,7 @@ pub struct SplitRandomOptions {
|
||||
pub counts: Option<Vec<i64>>,
|
||||
pub fixed: Option<i64>,
|
||||
pub seed: Option<i64>,
|
||||
pub split_names: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
@@ -23,6 +24,7 @@ pub struct SplitHashOptions {
|
||||
pub columns: Vec<String>,
|
||||
pub split_weights: Vec<i64>,
|
||||
pub discard_weight: Option<i64>,
|
||||
pub split_names: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
@@ -30,6 +32,13 @@ pub struct SplitSequentialOptions {
|
||||
pub ratios: Option<Vec<f64>>,
|
||||
pub counts: Option<Vec<i64>>,
|
||||
pub fixed: Option<i64>,
|
||||
pub split_names: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
pub struct SplitCalculatedOptions {
|
||||
pub calculation: String,
|
||||
pub split_names: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
@@ -40,7 +49,6 @@ pub struct ShuffleOptions {
|
||||
|
||||
pub struct PermutationBuilderState {
|
||||
pub builder: Option<LancePermutationBuilder>,
|
||||
pub dest_table_name: String,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
@@ -49,11 +57,10 @@ pub struct PermutationBuilder {
|
||||
}
|
||||
|
||||
impl PermutationBuilder {
|
||||
pub fn new(builder: LancePermutationBuilder, dest_table_name: String) -> Self {
|
||||
pub fn new(builder: LancePermutationBuilder) -> Self {
|
||||
Self {
|
||||
state: Arc::new(Mutex::new(PermutationBuilderState {
|
||||
builder: Some(builder),
|
||||
dest_table_name,
|
||||
})),
|
||||
}
|
||||
}
|
||||
@@ -78,6 +85,16 @@ impl PermutationBuilder {
|
||||
|
||||
#[napi]
|
||||
impl PermutationBuilder {
|
||||
#[napi]
|
||||
pub fn persist(
|
||||
&self,
|
||||
connection: &crate::connection::Connection,
|
||||
table_name: String,
|
||||
) -> napi::Result<Self> {
|
||||
let database = connection.database()?;
|
||||
self.modify(|builder| builder.persist(database, table_name))
|
||||
}
|
||||
|
||||
/// Configure random splits
|
||||
#[napi]
|
||||
pub fn split_random(&self, options: SplitRandomOptions) -> napi::Result<Self> {
|
||||
@@ -109,7 +126,12 @@ impl PermutationBuilder {
|
||||
|
||||
let seed = options.seed.map(|s| s as u64);
|
||||
|
||||
self.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes }))
|
||||
self.modify(|builder| {
|
||||
builder.with_split_strategy(
|
||||
SplitStrategy::Random { seed, sizes },
|
||||
options.split_names.clone(),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// Configure hash-based splits
|
||||
@@ -122,12 +144,15 @@ impl PermutationBuilder {
|
||||
.collect();
|
||||
let discard_weight = options.discard_weight.unwrap_or(0) as u64;
|
||||
|
||||
self.modify(|builder| {
|
||||
builder.with_split_strategy(SplitStrategy::Hash {
|
||||
columns: options.columns,
|
||||
split_weights,
|
||||
discard_weight,
|
||||
})
|
||||
self.modify(move |builder| {
|
||||
builder.with_split_strategy(
|
||||
SplitStrategy::Hash {
|
||||
columns: options.columns,
|
||||
split_weights,
|
||||
discard_weight,
|
||||
},
|
||||
options.split_names,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -160,14 +185,21 @@ impl PermutationBuilder {
|
||||
unreachable!("One of the split arguments must be provided");
|
||||
};
|
||||
|
||||
self.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes }))
|
||||
self.modify(move |builder| {
|
||||
builder.with_split_strategy(SplitStrategy::Sequential { sizes }, options.split_names)
|
||||
})
|
||||
}
|
||||
|
||||
/// Configure calculated splits
|
||||
#[napi]
|
||||
pub fn split_calculated(&self, calculation: String) -> napi::Result<Self> {
|
||||
self.modify(|builder| {
|
||||
builder.with_split_strategy(SplitStrategy::Calculated { calculation })
|
||||
pub fn split_calculated(&self, options: SplitCalculatedOptions) -> napi::Result<Self> {
|
||||
self.modify(move |builder| {
|
||||
builder.with_split_strategy(
|
||||
SplitStrategy::Calculated {
|
||||
calculation: options.calculation,
|
||||
},
|
||||
options.split_names,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -191,32 +223,26 @@ impl PermutationBuilder {
|
||||
/// Execute the permutation builder and create the table
|
||||
#[napi]
|
||||
pub async fn execute(&self) -> napi::Result<Table> {
|
||||
let (builder, dest_table_name) = {
|
||||
let builder = {
|
||||
let mut state = self.state.lock().unwrap();
|
||||
let builder = state
|
||||
state
|
||||
.builder
|
||||
.take()
|
||||
.ok_or_else(|| napi::Error::from_reason("Builder already consumed"))?;
|
||||
|
||||
let dest_table_name = std::mem::take(&mut state.dest_table_name);
|
||||
(builder, dest_table_name)
|
||||
.ok_or_else(|| napi::Error::from_reason("Builder already consumed"))?
|
||||
};
|
||||
|
||||
let table = builder.build(&dest_table_name).await.default_error()?;
|
||||
let table = builder.build().await.default_error()?;
|
||||
Ok(Table::new(table))
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a permutation builder for the given table
|
||||
#[napi]
|
||||
pub fn permutation_builder(
|
||||
table: &crate::table::Table,
|
||||
dest_table_name: String,
|
||||
) -> napi::Result<PermutationBuilder> {
|
||||
use lancedb::dataloader::permutation::PermutationBuilder as LancePermutationBuilder;
|
||||
pub fn permutation_builder(table: &crate::table::Table) -> napi::Result<PermutationBuilder> {
|
||||
use lancedb::dataloader::permutation::builder::PermutationBuilder as LancePermutationBuilder;
|
||||
|
||||
let inner_table = table.inner_ref()?.clone();
|
||||
let inner_builder = LancePermutationBuilder::new(inner_table);
|
||||
|
||||
Ok(PermutationBuilder::new(inner_builder, dest_table_name))
|
||||
Ok(PermutationBuilder::new(inner_builder))
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ use crate::error::NapiErrorExt;
|
||||
use crate::iterator::RecordBatchIterator;
|
||||
use crate::rerankers::Reranker;
|
||||
use crate::rerankers::RerankerCallbacks;
|
||||
use crate::util::parse_distance_type;
|
||||
use crate::util::{parse_distance_type, schema_to_buffer};
|
||||
|
||||
#[napi]
|
||||
pub struct Query {
|
||||
@@ -88,6 +88,12 @@ impl Query {
|
||||
self.inner = self.inner.clone().with_row_id();
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn output_schema(&self) -> napi::Result<Buffer> {
|
||||
let schema = self.inner.output_schema().await.default_error()?;
|
||||
schema_to_buffer(&schema)
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn execute(
|
||||
&self,
|
||||
@@ -273,6 +279,12 @@ impl VectorQuery {
|
||||
.rerank(Arc::new(Reranker::new(callbacks)));
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn output_schema(&self) -> napi::Result<Buffer> {
|
||||
let schema = self.inner.output_schema().await.default_error()?;
|
||||
schema_to_buffer(&schema)
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn execute(
|
||||
&self,
|
||||
@@ -346,6 +358,12 @@ impl TakeQuery {
|
||||
self.inner = self.inner.clone().with_row_id();
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn output_schema(&self) -> napi::Result<Buffer> {
|
||||
let schema = self.inner.output_schema().await.default_error()?;
|
||||
schema_to_buffer(&schema)
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn execute(
|
||||
&self,
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use arrow_ipc::writer::FileWriter;
|
||||
use lancedb::ipc::ipc_file_to_batches;
|
||||
use lancedb::table::{
|
||||
AddDataMode, ColumnAlteration as LanceColumnAlteration, Duration, NewColumnTransform,
|
||||
@@ -16,6 +15,7 @@ use crate::error::NapiErrorExt;
|
||||
use crate::index::Index;
|
||||
use crate::merge::NativeMergeInsertBuilder;
|
||||
use crate::query::{Query, TakeQuery, VectorQuery};
|
||||
use crate::util::schema_to_buffer;
|
||||
|
||||
#[napi]
|
||||
pub struct Table {
|
||||
@@ -64,14 +64,7 @@ impl Table {
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn schema(&self) -> napi::Result<Buffer> {
|
||||
let schema = self.inner_ref()?.schema().await.default_error()?;
|
||||
let mut writer = FileWriter::try_new(vec![], &schema)
|
||||
.map_err(|e| napi::Error::from_reason(format!("Failed to create IPC file: {}", e)))?;
|
||||
writer
|
||||
.finish()
|
||||
.map_err(|e| napi::Error::from_reason(format!("Failed to finish IPC file: {}", e)))?;
|
||||
Ok(Buffer::from(writer.into_inner().map_err(|e| {
|
||||
napi::Error::from_reason(format!("Failed to get IPC file: {}", e))
|
||||
})?))
|
||||
schema_to_buffer(&schema)
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use arrow_ipc::writer::FileWriter;
|
||||
use arrow_schema::Schema;
|
||||
use lancedb::DistanceType;
|
||||
use napi::bindgen_prelude::Buffer;
|
||||
|
||||
pub fn parse_distance_type(distance_type: impl AsRef<str>) -> napi::Result<DistanceType> {
|
||||
match distance_type.as_ref().to_lowercase().as_str() {
|
||||
@@ -15,3 +18,15 @@ pub fn parse_distance_type(distance_type: impl AsRef<str>) -> napi::Result<Dista
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert an Arrow Schema to an Arrow IPC file buffer
|
||||
pub fn schema_to_buffer(schema: &Schema) -> napi::Result<Buffer> {
|
||||
let mut writer = FileWriter::try_new(vec![], schema)
|
||||
.map_err(|e| napi::Error::from_reason(format!("Failed to create IPC file: {}", e)))?;
|
||||
writer
|
||||
.finish()
|
||||
.map_err(|e| napi::Error::from_reason(format!("Failed to finish IPC file: {}", e)))?;
|
||||
Ok(Buffer::from(writer.into_inner().map_err(|e| {
|
||||
napi::Error::from_reason(format!("Failed to get IPC file: {}", e))
|
||||
})?))
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.25.3-beta.0"
|
||||
current_version = "0.25.3"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
19
python/AGENTS.md
Normal file
19
python/AGENTS.md
Normal file
@@ -0,0 +1,19 @@
|
||||
These are the Python bindings of LanceDB.
|
||||
The core Rust library is in the `../rust/lancedb` directory, the rust binding
|
||||
code is in the `src/` directory and the Python bindings are in the `lancedb/` directory.
|
||||
|
||||
Common commands:
|
||||
|
||||
* Build: `make develop`
|
||||
* Format: `make format`
|
||||
* Lint: `make check`
|
||||
* Fix lints: `make fix`
|
||||
* Test: `make test`
|
||||
* Doc test: `make doctest`
|
||||
|
||||
Before committing changes, run lints and then formatting.
|
||||
|
||||
When you change the Rust code, you will need to recompile the Python bindings: `make develop`.
|
||||
|
||||
When you export new types from Rust to Python, you must manually update `python/lancedb/_lancedb.pyi`
|
||||
with the corresponding type hints. You can run `pyright` to check for type errors in the Python code.
|
||||
@@ -1,19 +0,0 @@
|
||||
These are the Python bindings of LanceDB.
|
||||
The core Rust library is in the `../rust/lancedb` directory, the rust binding
|
||||
code is in the `src/` directory and the Python bindings are in the `lancedb/` directory.
|
||||
|
||||
Common commands:
|
||||
|
||||
* Build: `make develop`
|
||||
* Format: `make format`
|
||||
* Lint: `make check`
|
||||
* Fix lints: `make fix`
|
||||
* Test: `make test`
|
||||
* Doc test: `make doctest`
|
||||
|
||||
Before committing changes, run lints and then formatting.
|
||||
|
||||
When you change the Rust code, you will need to recompile the Python bindings: `make develop`.
|
||||
|
||||
When you export new types from Rust to Python, you must manually update `python/lancedb/_lancedb.pyi`
|
||||
with the corresponding type hints. You can run `pyright` to check for type errors in the Python code.
|
||||
1
python/CLAUDE.md
Symbolic link
1
python/CLAUDE.md
Symbolic link
@@ -0,0 +1 @@
|
||||
AGENTS.md
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.25.3-beta.0"
|
||||
version = "0.25.3"
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
license.workspace = true
|
||||
|
||||
@@ -17,7 +17,7 @@ from .db import AsyncConnection, DBConnection, LanceDBConnection
|
||||
from .remote import ClientConfig
|
||||
from .remote.db import RemoteDBConnection
|
||||
from .schema import vector
|
||||
from .table import AsyncTable
|
||||
from .table import AsyncTable, Table
|
||||
from ._lancedb import Session
|
||||
from .namespace import connect_namespace, LanceNamespaceDBConnection
|
||||
|
||||
@@ -233,6 +233,7 @@ __all__ = [
|
||||
"LanceNamespaceDBConnection",
|
||||
"RemoteDBConnection",
|
||||
"Session",
|
||||
"Table",
|
||||
"__version__",
|
||||
]
|
||||
|
||||
|
||||
@@ -123,6 +123,8 @@ class Table:
|
||||
@property
|
||||
def tags(self) -> Tags: ...
|
||||
def query(self) -> Query: ...
|
||||
def take_offsets(self, offsets: list[int]) -> TakeQuery: ...
|
||||
def take_row_ids(self, row_ids: list[int]) -> TakeQuery: ...
|
||||
def vector_search(self) -> VectorQuery: ...
|
||||
|
||||
class Tags:
|
||||
@@ -165,6 +167,7 @@ class Query:
|
||||
def postfilter(self): ...
|
||||
def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ...
|
||||
def nearest_to_text(self, query: dict) -> FTSQuery: ...
|
||||
async def output_schema(self) -> pa.Schema: ...
|
||||
async def execute(
|
||||
self, max_batch_length: Optional[int], timeout: Optional[timedelta]
|
||||
) -> RecordBatchStream: ...
|
||||
@@ -172,6 +175,13 @@ class Query:
|
||||
async def analyze_plan(self) -> str: ...
|
||||
def to_query_request(self) -> PyQueryRequest: ...
|
||||
|
||||
class TakeQuery:
|
||||
def select(self, columns: List[str]): ...
|
||||
def with_row_id(self): ...
|
||||
async def output_schema(self) -> pa.Schema: ...
|
||||
async def execute(self) -> RecordBatchStream: ...
|
||||
def to_query_request(self) -> PyQueryRequest: ...
|
||||
|
||||
class FTSQuery:
|
||||
def where(self, filter: str): ...
|
||||
def select(self, columns: List[str]): ...
|
||||
@@ -183,12 +193,14 @@ class FTSQuery:
|
||||
def get_query(self) -> str: ...
|
||||
def add_query_vector(self, query_vec: pa.Array) -> None: ...
|
||||
def nearest_to(self, query_vec: pa.Array) -> HybridQuery: ...
|
||||
async def output_schema(self) -> pa.Schema: ...
|
||||
async def execute(
|
||||
self, max_batch_length: Optional[int], timeout: Optional[timedelta]
|
||||
) -> RecordBatchStream: ...
|
||||
def to_query_request(self) -> PyQueryRequest: ...
|
||||
|
||||
class VectorQuery:
|
||||
async def output_schema(self) -> pa.Schema: ...
|
||||
async def execute(self) -> RecordBatchStream: ...
|
||||
def where(self, filter: str): ...
|
||||
def select(self, columns: List[str]): ...
|
||||
@@ -327,3 +339,7 @@ class AsyncPermutationBuilder:
|
||||
def async_permutation_builder(
|
||||
table: Table, dest_table_name: str
|
||||
) -> AsyncPermutationBuilder: ...
|
||||
def fts_query_to_json(query: Any) -> str: ...
|
||||
|
||||
class PermutationReader:
|
||||
def __init__(self, base_table: Table, permutation_table: Table): ...
|
||||
|
||||
@@ -3,9 +3,11 @@
|
||||
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import List, Union, Optional, Any
|
||||
from logging import warning
|
||||
from typing import List, Union, Optional, Any, Callable
|
||||
import numpy as np
|
||||
import io
|
||||
import warnings
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import EmbeddingFunction
|
||||
@@ -19,35 +21,52 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
An embedding function that uses the ColPali engine for
|
||||
multimodal multi-vector embeddings.
|
||||
|
||||
This embedding function supports ColQwen2.5 models, producing multivector outputs
|
||||
for both text and image inputs. The output embeddings are lists of vectors, each
|
||||
vector being 128-dimensional by default, represented as List[List[float]].
|
||||
This embedding function supports ColPali models, producing multivector outputs
|
||||
for both text and image inputs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name : str
|
||||
The name of the model to use (e.g., "Metric-AI/ColQwen2.5-3b-multilingual-v1.0")
|
||||
Supports models based on these engines:
|
||||
- ColPali: "vidore/colpali-v1.3" and others
|
||||
- ColQwen2.5: "Metric-AI/ColQwen2.5-3b-multilingual-v1.0" and others
|
||||
- ColQwen2: "vidore/colqwen2-v1.0" and others
|
||||
- ColSmol: "vidore/colSmol-256M" and others
|
||||
|
||||
device : str
|
||||
The device for inference (default "cuda:0").
|
||||
The device for inference (default "auto").
|
||||
dtype : str
|
||||
Data type for model weights (default "bfloat16").
|
||||
use_token_pooling : bool
|
||||
Whether to use token pooling to reduce embedding size (default True).
|
||||
DEPRECATED. Whether to use token pooling. Use `pooling_strategy` instead.
|
||||
pooling_strategy : str, optional
|
||||
The token pooling strategy to use, by default "hierarchical".
|
||||
- "hierarchical": Progressively pools tokens to reduce sequence length.
|
||||
- "lambda": A simpler pooling that uses a custom `pooling_func`.
|
||||
pooling_func: typing.Callable, optional
|
||||
A function to use for pooling when `pooling_strategy` is "lambda".
|
||||
pool_factor : int
|
||||
Factor to reduce sequence length if token pooling is enabled (default 2).
|
||||
quantization_config : Optional[BitsAndBytesConfig]
|
||||
Quantization configuration for the model. (default None, bitsandbytes needed)
|
||||
batch_size : int
|
||||
Batch size for processing inputs (default 2).
|
||||
offload_folder: str, optional
|
||||
Folder to offload model weights if using CPU offloading (default None). This is
|
||||
useful for large models that do not fit in memory.
|
||||
"""
|
||||
|
||||
model_name: str = "Metric-AI/ColQwen2.5-3b-multilingual-v1.0"
|
||||
device: str = "auto"
|
||||
dtype: str = "bfloat16"
|
||||
use_token_pooling: bool = True
|
||||
pooling_strategy: Optional[str] = "hierarchical"
|
||||
pooling_func: Optional[Any] = None
|
||||
pool_factor: int = 2
|
||||
quantization_config: Optional[Any] = None
|
||||
batch_size: int = 2
|
||||
offload_folder: Optional[str] = None
|
||||
|
||||
_model = None
|
||||
_processor = None
|
||||
@@ -56,15 +75,43 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
|
||||
if not self.use_token_pooling:
|
||||
warnings.warn(
|
||||
"use_token_pooling is deprecated, use pooling_strategy=None instead",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.pooling_strategy = None
|
||||
|
||||
if self.pooling_strategy == "lambda" and self.pooling_func is None:
|
||||
raise ValueError(
|
||||
"pooling_func must be provided when pooling_strategy is 'lambda'"
|
||||
)
|
||||
|
||||
device = self.device
|
||||
if device == "auto":
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
dtype = self.dtype
|
||||
if device == "mps" and dtype == "bfloat16":
|
||||
dtype = "float32" # Avoid NaNs on MPS
|
||||
|
||||
(
|
||||
self._model,
|
||||
self._processor,
|
||||
self._token_pooler,
|
||||
) = self._load_model(
|
||||
self.model_name,
|
||||
self.dtype,
|
||||
self.device,
|
||||
self.use_token_pooling,
|
||||
dtype,
|
||||
device,
|
||||
self.pooling_strategy,
|
||||
self.pooling_func,
|
||||
self.quantization_config,
|
||||
)
|
||||
|
||||
@@ -74,16 +121,26 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
model_name: str,
|
||||
dtype: str,
|
||||
device: str,
|
||||
use_token_pooling: bool,
|
||||
pooling_strategy: Optional[str],
|
||||
pooling_func: Optional[Callable],
|
||||
quantization_config: Optional[Any],
|
||||
):
|
||||
"""
|
||||
Initialize and cache the ColPali model, processor, and token pooler.
|
||||
"""
|
||||
if device.startswith("mps"):
|
||||
# warn some torch ops in late interaction architecture result in nans on mps
|
||||
warning(
|
||||
"MPS device detected. Some operations may result in NaNs. "
|
||||
"If you encounter issues, consider using 'cpu' or 'cuda' devices."
|
||||
)
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
transformers = attempt_import_or_raise("transformers", "transformers")
|
||||
colpali_engine = attempt_import_or_raise("colpali_engine", "colpali_engine")
|
||||
from colpali_engine.compression.token_pooling import HierarchicalTokenPooler
|
||||
from colpali_engine.compression.token_pooling import (
|
||||
HierarchicalTokenPooler,
|
||||
LambdaTokenPooler,
|
||||
)
|
||||
|
||||
if quantization_config is not None:
|
||||
if not isinstance(quantization_config, transformers.BitsAndBytesConfig):
|
||||
@@ -98,21 +155,45 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
else:
|
||||
torch_dtype = torch.float32
|
||||
|
||||
model = colpali_engine.models.ColQwen2_5.from_pretrained(
|
||||
model_class, processor_class = None, None
|
||||
model_name_lower = model_name.lower()
|
||||
if "colqwen2.5" in model_name_lower:
|
||||
model_class = colpali_engine.models.ColQwen2_5
|
||||
processor_class = colpali_engine.models.ColQwen2_5_Processor
|
||||
elif "colsmol" in model_name_lower or "colidefics3" in model_name_lower:
|
||||
model_class = colpali_engine.models.ColIdefics3
|
||||
processor_class = colpali_engine.models.ColIdefics3Processor
|
||||
elif "colqwen" in model_name_lower:
|
||||
model_class = colpali_engine.models.ColQwen2
|
||||
processor_class = colpali_engine.models.ColQwen2Processor
|
||||
elif "colpali" in model_name_lower:
|
||||
model_class = colpali_engine.models.ColPali
|
||||
processor_class = colpali_engine.models.ColPaliProcessor
|
||||
|
||||
if model_class is None:
|
||||
raise ValueError(f"Unsupported model: {model_name}")
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=device,
|
||||
quantization_config=quantization_config
|
||||
if quantization_config is not None
|
||||
else None,
|
||||
attn_implementation="flash_attention_2"
|
||||
if is_flash_attn_2_available()
|
||||
else None,
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
processor = colpali_engine.models.ColQwen2_5_Processor.from_pretrained(
|
||||
model_name
|
||||
)
|
||||
token_pooler = HierarchicalTokenPooler() if use_token_pooling else None
|
||||
model = model.to(device)
|
||||
model = model.to(torch_dtype) # Force cast after moving to device
|
||||
processor = processor_class.from_pretrained(model_name)
|
||||
|
||||
token_pooler = None
|
||||
if pooling_strategy == "hierarchical":
|
||||
token_pooler = HierarchicalTokenPooler()
|
||||
elif pooling_strategy == "lambda":
|
||||
token_pooler = LambdaTokenPooler(pool_func=pooling_func)
|
||||
|
||||
return model, processor, token_pooler
|
||||
|
||||
def ndims(self):
|
||||
@@ -128,7 +209,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
with torch.no_grad():
|
||||
query_embeddings = self._model(**batch_queries)
|
||||
|
||||
if self.use_token_pooling and self._token_pooler is not None:
|
||||
if self.pooling_strategy and self._token_pooler is not None:
|
||||
query_embeddings = self._token_pooler.pool_embeddings(
|
||||
query_embeddings,
|
||||
pool_factor=self.pool_factor,
|
||||
@@ -145,13 +226,20 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
Use token pooling if enabled.
|
||||
"""
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
if self.use_token_pooling and self._token_pooler is not None:
|
||||
embeddings = self._token_pooler.pool_embeddings(
|
||||
embeddings,
|
||||
pool_factor=self.pool_factor,
|
||||
padding=True,
|
||||
padding_side=self._processor.tokenizer.padding_side,
|
||||
)
|
||||
if self.pooling_strategy and self._token_pooler is not None:
|
||||
if self.pooling_strategy == "hierarchical":
|
||||
embeddings = self._token_pooler.pool_embeddings(
|
||||
embeddings,
|
||||
pool_factor=self.pool_factor,
|
||||
padding=True,
|
||||
padding_side=self._processor.tokenizer.padding_side,
|
||||
)
|
||||
elif self.pooling_strategy == "lambda":
|
||||
embeddings = self._token_pooler.pool_embeddings(
|
||||
embeddings,
|
||||
padding=True,
|
||||
padding_side=self._processor.tokenizer.padding_side,
|
||||
)
|
||||
|
||||
if isinstance(embeddings, torch.Tensor):
|
||||
tensors = embeddings.detach().cpu()
|
||||
@@ -179,6 +267,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
)
|
||||
with torch.no_grad():
|
||||
query_embeddings = self._model(**batch_queries)
|
||||
query_embeddings = torch.nan_to_num(query_embeddings)
|
||||
all_embeddings.extend(self._process_embeddings(query_embeddings))
|
||||
return all_embeddings
|
||||
|
||||
@@ -225,6 +314,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
)
|
||||
with torch.no_grad():
|
||||
image_embeddings = self._model(**batch_images)
|
||||
image_embeddings = torch.nan_to_num(image_embeddings)
|
||||
all_embeddings.extend(self._process_embeddings(image_embeddings))
|
||||
return all_embeddings
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
import base64
|
||||
import os
|
||||
from typing import ClassVar, TYPE_CHECKING, List, Union, Any
|
||||
from typing import ClassVar, TYPE_CHECKING, List, Union, Any, Generator
|
||||
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
@@ -19,6 +19,23 @@ from .utils import api_key_not_found_help, IMAGES, TEXT
|
||||
if TYPE_CHECKING:
|
||||
import PIL
|
||||
|
||||
# Token limits for different VoyageAI models
|
||||
VOYAGE_TOTAL_TOKEN_LIMITS = {
|
||||
"voyage-context-3": 32_000,
|
||||
"voyage-3.5-lite": 1_000_000,
|
||||
"voyage-3.5": 320_000,
|
||||
"voyage-3-lite": 120_000,
|
||||
"voyage-3": 120_000,
|
||||
"voyage-multimodal-3": 120_000,
|
||||
"voyage-finance-2": 120_000,
|
||||
"voyage-multilingual-2": 120_000,
|
||||
"voyage-law-2": 120_000,
|
||||
"voyage-code-2": 120_000,
|
||||
}
|
||||
|
||||
# Batch size for embedding requests (max number of items per batch)
|
||||
BATCH_SIZE = 1000
|
||||
|
||||
|
||||
def is_valid_url(text):
|
||||
try:
|
||||
@@ -120,6 +137,9 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
name: str
|
||||
The name of the model to use. List of acceptable models:
|
||||
|
||||
* voyage-context-3
|
||||
* voyage-3.5
|
||||
* voyage-3.5-lite
|
||||
* voyage-3
|
||||
* voyage-3-lite
|
||||
* voyage-multimodal-3
|
||||
@@ -157,25 +177,35 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
name: str
|
||||
client: ClassVar = None
|
||||
text_embedding_models: list = [
|
||||
"voyage-3.5",
|
||||
"voyage-3.5-lite",
|
||||
"voyage-3",
|
||||
"voyage-3-lite",
|
||||
"voyage-finance-2",
|
||||
"voyage-multilingual-2",
|
||||
"voyage-law-2",
|
||||
"voyage-code-2",
|
||||
]
|
||||
multimodal_embedding_models: list = ["voyage-multimodal-3"]
|
||||
contextual_embedding_models: list = ["voyage-context-3"]
|
||||
|
||||
def _is_multimodal_model(self, model_name: str):
|
||||
return (
|
||||
model_name in self.multimodal_embedding_models or "multimodal" in model_name
|
||||
)
|
||||
|
||||
def _is_contextual_model(self, model_name: str):
|
||||
return model_name in self.contextual_embedding_models or "context" in model_name
|
||||
|
||||
def ndims(self):
|
||||
if self.name == "voyage-3-lite":
|
||||
return 512
|
||||
elif self.name == "voyage-code-2":
|
||||
return 1536
|
||||
elif self.name in [
|
||||
"voyage-context-3",
|
||||
"voyage-3.5",
|
||||
"voyage-3.5-lite",
|
||||
"voyage-3",
|
||||
"voyage-multimodal-3",
|
||||
"voyage-finance-2",
|
||||
@@ -207,6 +237,11 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
result = client.multimodal_embed(
|
||||
inputs=[[query]], model=self.name, input_type="query", **kwargs
|
||||
)
|
||||
elif self._is_contextual_model(self.name):
|
||||
result = client.contextualized_embed(
|
||||
inputs=[[query]], model=self.name, input_type="query", **kwargs
|
||||
)
|
||||
result = result.results[0]
|
||||
else:
|
||||
result = client.embed(
|
||||
texts=[query], model=self.name, input_type="query", **kwargs
|
||||
@@ -231,18 +266,164 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
List[np.array]: the list of embeddings
|
||||
"""
|
||||
client = VoyageAIEmbeddingFunction._get_client()
|
||||
|
||||
# For multimodal models, check if inputs contain images
|
||||
if self._is_multimodal_model(self.name):
|
||||
inputs = sanitize_multimodal_input(inputs)
|
||||
result = client.multimodal_embed(
|
||||
inputs=inputs, model=self.name, input_type="document", **kwargs
|
||||
sanitized = sanitize_multimodal_input(inputs)
|
||||
has_images = any(
|
||||
inp["content"][0].get("type") != "text" for inp in sanitized
|
||||
)
|
||||
if has_images:
|
||||
# Use non-batched API for images
|
||||
result = client.multimodal_embed(
|
||||
inputs=sanitized, model=self.name, input_type="document", **kwargs
|
||||
)
|
||||
return result.embeddings
|
||||
# Extract texts for batching
|
||||
inputs = [inp["content"][0]["text"] for inp in sanitized]
|
||||
else:
|
||||
inputs = sanitize_text_input(inputs)
|
||||
result = client.embed(
|
||||
texts=inputs, model=self.name, input_type="document", **kwargs
|
||||
)
|
||||
|
||||
return result.embeddings
|
||||
# Use batching for all text inputs
|
||||
return self._embed_with_batching(
|
||||
client, inputs, input_type="document", **kwargs
|
||||
)
|
||||
|
||||
def _build_batches(
|
||||
self, client, texts: List[str]
|
||||
) -> Generator[List[str], None, None]:
|
||||
"""
|
||||
Generate batches of texts based on token limits using a generator.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
client : voyageai.Client
|
||||
The VoyageAI client instance.
|
||||
texts : List[str]
|
||||
List of texts to batch.
|
||||
|
||||
Yields
|
||||
------
|
||||
List[str]: Batches of texts.
|
||||
"""
|
||||
if not texts:
|
||||
return
|
||||
|
||||
max_tokens_per_batch = VOYAGE_TOTAL_TOKEN_LIMITS.get(self.name, 120_000)
|
||||
current_batch: List[str] = []
|
||||
current_batch_tokens = 0
|
||||
|
||||
# Tokenize all texts in one API call
|
||||
token_lists = client.tokenize(texts, model=self.name)
|
||||
token_counts = [len(token_list) for token_list in token_lists]
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
n_tokens = token_counts[i]
|
||||
|
||||
# Check if adding this text would exceed limits
|
||||
if current_batch and (
|
||||
len(current_batch) >= BATCH_SIZE
|
||||
or (current_batch_tokens + n_tokens > max_tokens_per_batch)
|
||||
):
|
||||
# Yield the current batch and start a new one
|
||||
yield current_batch
|
||||
current_batch = []
|
||||
current_batch_tokens = 0
|
||||
|
||||
current_batch.append(text)
|
||||
current_batch_tokens += n_tokens
|
||||
|
||||
# Yield the last batch (always has at least one text)
|
||||
if current_batch:
|
||||
yield current_batch
|
||||
|
||||
def _get_embed_function(
|
||||
self, client, input_type: str = "document", **kwargs
|
||||
) -> callable:
|
||||
"""
|
||||
Get the appropriate embedding function based on model type.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
client : voyageai.Client
|
||||
The VoyageAI client instance.
|
||||
input_type : str
|
||||
Either "query" or "document"
|
||||
**kwargs
|
||||
Additional arguments to pass to the embedding API
|
||||
|
||||
Returns
|
||||
-------
|
||||
callable: A function that takes a batch of texts and returns embeddings.
|
||||
"""
|
||||
if self._is_multimodal_model(self.name):
|
||||
|
||||
def embed_batch(batch: List[str]) -> List[np.array]:
|
||||
batch_inputs = sanitize_multimodal_input(batch)
|
||||
result = client.multimodal_embed(
|
||||
inputs=batch_inputs,
|
||||
model=self.name,
|
||||
input_type=input_type,
|
||||
**kwargs,
|
||||
)
|
||||
return result.embeddings
|
||||
|
||||
return embed_batch
|
||||
|
||||
elif self._is_contextual_model(self.name):
|
||||
|
||||
def embed_batch(batch: List[str]) -> List[np.array]:
|
||||
result = client.contextualized_embed(
|
||||
inputs=[batch], model=self.name, input_type=input_type, **kwargs
|
||||
)
|
||||
return result.results[0].embeddings
|
||||
|
||||
return embed_batch
|
||||
|
||||
else:
|
||||
|
||||
def embed_batch(batch: List[str]) -> List[np.array]:
|
||||
result = client.embed(
|
||||
texts=batch, model=self.name, input_type=input_type, **kwargs
|
||||
)
|
||||
return result.embeddings
|
||||
|
||||
return embed_batch
|
||||
|
||||
def _embed_with_batching(
|
||||
self, client, texts: List[str], input_type: str = "document", **kwargs
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Embed texts with automatic batching based on token limits.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
client : voyageai.Client
|
||||
The VoyageAI client instance.
|
||||
texts : List[str]
|
||||
List of texts to embed.
|
||||
input_type : str
|
||||
Either "query" or "document"
|
||||
**kwargs
|
||||
Additional arguments to pass to the embedding API
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[np.array]: List of embeddings.
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# Get the appropriate embedding function for this model type
|
||||
embed_fn = self._get_embed_function(client, input_type=input_type, **kwargs)
|
||||
|
||||
# Process each batch
|
||||
all_embeddings = []
|
||||
for batch in self._build_batches(client, texts):
|
||||
batch_embeddings = embed_fn(batch)
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
return all_embeddings
|
||||
|
||||
@staticmethod
|
||||
def _get_client():
|
||||
|
||||
@@ -1,18 +1,63 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from ._lancedb import async_permutation_builder
|
||||
from deprecation import deprecated
|
||||
from lancedb import AsyncConnection, DBConnection
|
||||
import pyarrow as pa
|
||||
import json
|
||||
|
||||
from ._lancedb import async_permutation_builder, PermutationReader
|
||||
from .table import LanceTable
|
||||
from .background_loop import LOOP
|
||||
from typing import Optional
|
||||
from .util import batch_to_tensor
|
||||
from typing import Any, Callable, Iterator, Literal, Optional, TYPE_CHECKING, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lancedb.dependencies import pandas as pd, numpy as np, polars as pl
|
||||
|
||||
|
||||
class PermutationBuilder:
|
||||
def __init__(self, table: LanceTable, dest_table_name: str):
|
||||
self._async = async_permutation_builder(table, dest_table_name)
|
||||
"""
|
||||
A utility for creating a "permutation table" which is a table that defines an
|
||||
ordering on a base table.
|
||||
|
||||
def select(self, projections: dict[str, str]) -> "PermutationBuilder":
|
||||
self._async.select(projections)
|
||||
The permutation table does not store the actual data. It only stores row
|
||||
ids and split ids to define the ordering. The [Permutation] class can be used to
|
||||
read the data from the base table in the order defined by the permutation table.
|
||||
|
||||
Permutations can split, shuffle, and filter the data in the base table.
|
||||
|
||||
A filter limits the rows that are included in the permutation.
|
||||
Splits divide the data into subsets (for example, a test/train split, or K
|
||||
different splits for cross-validation).
|
||||
Shuffling randomizes the order of the rows in the permutation.
|
||||
|
||||
Splits can optionally be named. If names are provided it will enable them to
|
||||
be referenced by name in the future. If names are not provided then they can only
|
||||
be referenced by their ordinal index. There is no requirement to name every split.
|
||||
|
||||
By default, the permutation will be stored in memory and will be lost when the
|
||||
program exits. To persist the permutation (for very large datasets or to share
|
||||
the permutation across multiple workers) use the [persist](#persist) method to
|
||||
create a permanent table.
|
||||
"""
|
||||
|
||||
def __init__(self, table: LanceTable):
|
||||
"""
|
||||
Creates a new permutation builder for the given table.
|
||||
|
||||
By default, the permutation builder will create a single split that contains all
|
||||
rows in the same order as the base table.
|
||||
"""
|
||||
self._async = async_permutation_builder(table)
|
||||
|
||||
def persist(
|
||||
self, database: Union[DBConnection, AsyncConnection], table_name: str
|
||||
) -> "PermutationBuilder":
|
||||
"""
|
||||
Persist the permutation to the given database.
|
||||
"""
|
||||
self._async.persist(database, table_name)
|
||||
return self
|
||||
|
||||
def split_random(
|
||||
@@ -22,8 +67,38 @@ class PermutationBuilder:
|
||||
counts: Optional[list[int]] = None,
|
||||
fixed: Optional[int] = None,
|
||||
seed: Optional[int] = None,
|
||||
split_names: Optional[list[str]] = None,
|
||||
) -> "PermutationBuilder":
|
||||
self._async.split_random(ratios=ratios, counts=counts, fixed=fixed, seed=seed)
|
||||
"""
|
||||
Configure random splits for the permutation.
|
||||
|
||||
One of ratios, counts, or fixed must be provided.
|
||||
|
||||
If ratios are provided, they will be used to determine the relative size of each
|
||||
split. For example, if ratios are [0.3, 0.7] then the first split will contain
|
||||
30% of the rows and the second split will contain 70% of the rows.
|
||||
|
||||
If counts are provided, they will be used to determine the absolute number of
|
||||
rows in each split. For example, if counts are [100, 200] then the first split
|
||||
will contain 100 rows and the second split will contain 200 rows.
|
||||
|
||||
If fixed is provided, it will be used to determine the number of splits.
|
||||
For example, if fixed is 3 then the permutation will be split evenly into 3
|
||||
splits.
|
||||
|
||||
Rows will be randomly assigned to splits. The optional seed can be provided to
|
||||
make the assignment deterministic.
|
||||
|
||||
The optional split_names can be provided to name the splits. If not provided,
|
||||
the splits can only be referenced by their index.
|
||||
"""
|
||||
self._async.split_random(
|
||||
ratios=ratios,
|
||||
counts=counts,
|
||||
fixed=fixed,
|
||||
seed=seed,
|
||||
split_names=split_names,
|
||||
)
|
||||
return self
|
||||
|
||||
def split_hash(
|
||||
@@ -32,8 +107,33 @@ class PermutationBuilder:
|
||||
split_weights: list[int],
|
||||
*,
|
||||
discard_weight: Optional[int] = None,
|
||||
split_names: Optional[list[str]] = None,
|
||||
) -> "PermutationBuilder":
|
||||
self._async.split_hash(columns, split_weights, discard_weight=discard_weight)
|
||||
"""
|
||||
Configure hash-based splits for the permutation.
|
||||
|
||||
First, a hash will be calculated over the specified columns. The splits weights
|
||||
are then used to determine how many rows to assign to each split. For example,
|
||||
if split weights are [1, 2] then the first split will contain 1/3 of the rows
|
||||
and the second split will contain 2/3 of the rows.
|
||||
|
||||
The optional discard weight can be provided to determine what percentage of rows
|
||||
should be discarded. For example, if split weights are [1, 2] and discard
|
||||
weight is 1 then 25% of the rows will be discarded.
|
||||
|
||||
Hash-based splits are useful if you want the split to be more or less random but
|
||||
you don't want the split assignments to change if rows are added or removed
|
||||
from the table.
|
||||
|
||||
The optional split_names can be provided to name the splits. If not provided,
|
||||
the splits can only be referenced by their index.
|
||||
"""
|
||||
self._async.split_hash(
|
||||
columns,
|
||||
split_weights,
|
||||
discard_weight=discard_weight,
|
||||
split_names=split_names,
|
||||
)
|
||||
return self
|
||||
|
||||
def split_sequential(
|
||||
@@ -42,25 +142,85 @@ class PermutationBuilder:
|
||||
ratios: Optional[list[float]] = None,
|
||||
counts: Optional[list[int]] = None,
|
||||
fixed: Optional[int] = None,
|
||||
split_names: Optional[list[str]] = None,
|
||||
) -> "PermutationBuilder":
|
||||
self._async.split_sequential(ratios=ratios, counts=counts, fixed=fixed)
|
||||
"""
|
||||
Configure sequential splits for the permutation.
|
||||
|
||||
One of ratios, counts, or fixed must be provided.
|
||||
|
||||
If ratios are provided, they will be used to determine the relative size of each
|
||||
split. For example, if ratios are [0.3, 0.7] then the first split will contain
|
||||
30% of the rows and the second split will contain 70% of the rows.
|
||||
|
||||
If counts are provided, they will be used to determine the absolute number of
|
||||
rows in each split. For example, if counts are [100, 200] then the first split
|
||||
will contain 100 rows and the second split will contain 200 rows.
|
||||
|
||||
If fixed is provided, it will be used to determine the number of splits.
|
||||
For example, if fixed is 3 then the permutation will be split evenly into 3
|
||||
splits.
|
||||
|
||||
Rows will be assigned to splits sequentially. The first N1 rows are assigned to
|
||||
split 1, the next N2 rows are assigned to split 2, etc.
|
||||
|
||||
The optional split_names can be provided to name the splits. If not provided,
|
||||
the splits can only be referenced by their index.
|
||||
"""
|
||||
self._async.split_sequential(
|
||||
ratios=ratios, counts=counts, fixed=fixed, split_names=split_names
|
||||
)
|
||||
return self
|
||||
|
||||
def split_calculated(self, calculation: str) -> "PermutationBuilder":
|
||||
self._async.split_calculated(calculation)
|
||||
def split_calculated(
|
||||
self, calculation: str, split_names: Optional[list[str]] = None
|
||||
) -> "PermutationBuilder":
|
||||
"""
|
||||
Use pre-calculated splits for the permutation.
|
||||
|
||||
The calculation should be an SQL statement that returns an integer value between
|
||||
0 and the number of splits - 1. For example, if you have 3 splits then the
|
||||
calculation should return 0 for the first split, 1 for the second split, and 2
|
||||
for the third split.
|
||||
|
||||
This can be used to implement any kind of user-defined split strategy.
|
||||
|
||||
The optional split_names can be provided to name the splits. If not provided,
|
||||
the splits can only be referenced by their index.
|
||||
"""
|
||||
self._async.split_calculated(calculation, split_names=split_names)
|
||||
return self
|
||||
|
||||
def shuffle(
|
||||
self, *, seed: Optional[int] = None, clump_size: Optional[int] = None
|
||||
) -> "PermutationBuilder":
|
||||
"""
|
||||
Randomly shuffle the rows in the permutation.
|
||||
|
||||
An optional seed can be provided to make the shuffle deterministic.
|
||||
|
||||
If a clump size is provided, then data will be shuffled as small "clumps"
|
||||
of contiguous rows. This allows for a balance between randomization and
|
||||
I/O performance. It can be useful when reading from cloud storage.
|
||||
"""
|
||||
self._async.shuffle(seed=seed, clump_size=clump_size)
|
||||
return self
|
||||
|
||||
def filter(self, filter: str) -> "PermutationBuilder":
|
||||
"""
|
||||
Configure a filter for the permutation.
|
||||
|
||||
The filter should be an SQL statement that returns a boolean value for each row.
|
||||
Only rows where the filter is true will be included in the permutation.
|
||||
"""
|
||||
self._async.filter(filter)
|
||||
return self
|
||||
|
||||
def execute(self) -> LanceTable:
|
||||
"""
|
||||
Execute the configuration and create the permutation table.
|
||||
"""
|
||||
|
||||
async def do_execute():
|
||||
inner_tbl = await self._async.execute()
|
||||
return LanceTable.from_inner(inner_tbl)
|
||||
@@ -68,5 +228,594 @@ class PermutationBuilder:
|
||||
return LOOP.run(do_execute())
|
||||
|
||||
|
||||
def permutation_builder(table: LanceTable, dest_table_name: str) -> PermutationBuilder:
|
||||
return PermutationBuilder(table, dest_table_name)
|
||||
def permutation_builder(table: LanceTable) -> PermutationBuilder:
|
||||
return PermutationBuilder(table)
|
||||
|
||||
|
||||
class Permutations:
|
||||
"""
|
||||
A collection of permutations indexed by name or ordinal index.
|
||||
|
||||
Splits are defined when the permutation is created. Splits can always be referenced
|
||||
by their ordinal index. If names were provided when the permutation was created
|
||||
then they can also be referenced by name.
|
||||
|
||||
Each permutation or "split" is a view of a portion of the base table. For more
|
||||
details see [Permutation].
|
||||
|
||||
Attributes
|
||||
----------
|
||||
base_table: LanceTable
|
||||
The base table that the permutations are based on.
|
||||
permutation_table: LanceTable
|
||||
The permutation table that defines the splits.
|
||||
split_names: list[str]
|
||||
The names of the splits.
|
||||
split_dict: dict[str, int]
|
||||
A dictionary mapping split names to their ordinal index.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> # Initial data
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("memory:///")
|
||||
>>> tbl = db.create_table("tbl", data=[{"x": x} for x in range(1000)])
|
||||
>>> # Create a permutation
|
||||
>>> perm_tbl = (
|
||||
... permutation_builder(tbl)
|
||||
... .split_random(ratios=[0.95, 0.05], split_names=["train", "test"])
|
||||
... .shuffle()
|
||||
... .execute()
|
||||
... )
|
||||
>>> # Read the permutations
|
||||
>>> permutations = Permutations(tbl, perm_tbl)
|
||||
>>> permutations["train"]
|
||||
<lancedb.permutation.Permutation ...>
|
||||
>>> permutations[0]
|
||||
<lancedb.permutation.Permutation ...>
|
||||
>>> permutations.split_names
|
||||
['train', 'test']
|
||||
>>> permutations.split_dict
|
||||
{'train': 0, 'test': 1}
|
||||
"""
|
||||
|
||||
def __init__(self, base_table: LanceTable, permutation_table: LanceTable):
|
||||
self.base_table = base_table
|
||||
self.permutation_table = permutation_table
|
||||
|
||||
if permutation_table.schema.metadata is not None:
|
||||
split_names = permutation_table.schema.metadata.get(
|
||||
b"split_names", None
|
||||
).decode("utf-8")
|
||||
if split_names is not None:
|
||||
self.split_names = json.loads(split_names)
|
||||
self.split_dict = {
|
||||
name: idx for idx, name in enumerate(self.split_names)
|
||||
}
|
||||
else:
|
||||
# No split names are defined in the permutation table
|
||||
self.split_names = []
|
||||
self.split_dict = {}
|
||||
else:
|
||||
# No metadata is defined in the permutation table
|
||||
self.split_names = []
|
||||
self.split_dict = {}
|
||||
|
||||
def get_by_name(self, name: str) -> "Permutation":
|
||||
"""
|
||||
Get a permutation by name.
|
||||
|
||||
If no split named `name` is found then an error will be raised.
|
||||
"""
|
||||
idx = self.split_dict.get(name, None)
|
||||
if idx is None:
|
||||
raise ValueError(f"No split named `{name}` found")
|
||||
return self.get_by_index(idx)
|
||||
|
||||
def get_by_index(self, index: int) -> "Permutation":
|
||||
"""
|
||||
Get a permutation by index.
|
||||
"""
|
||||
return Permutation.from_tables(self.base_table, self.permutation_table, index)
|
||||
|
||||
def __getitem__(self, name: Union[str, int]) -> "Permutation":
|
||||
if isinstance(name, str):
|
||||
return self.get_by_name(name)
|
||||
elif isinstance(name, int):
|
||||
return self.get_by_index(name)
|
||||
else:
|
||||
raise TypeError(f"Invalid split name or index: {name}")
|
||||
|
||||
|
||||
class Transforms:
|
||||
"""
|
||||
Namespace for common transformation functions
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def arrow2python(batch: pa.RecordBatch) -> dict[str, list[Any]]:
|
||||
return batch.to_pydict()
|
||||
|
||||
@staticmethod
|
||||
def arrow2arrow(batch: pa.RecordBatch) -> pa.RecordBatch:
|
||||
return batch
|
||||
|
||||
@staticmethod
|
||||
def arrow2numpy(batch: pa.RecordBatch) -> "np.ndarray":
|
||||
return batch.to_pandas().to_numpy()
|
||||
|
||||
@staticmethod
|
||||
def arrow2pandas(batch: pa.RecordBatch) -> "pd.DataFrame":
|
||||
return batch.to_pandas()
|
||||
|
||||
@staticmethod
|
||||
def arrow2polars() -> "pl.DataFrame":
|
||||
import polars as pl
|
||||
|
||||
def impl(batch: pa.RecordBatch) -> pl.DataFrame:
|
||||
return pl.from_arrow(batch)
|
||||
|
||||
return impl
|
||||
|
||||
|
||||
# HuggingFace uses 10 which is pretty small
|
||||
DEFAULT_BATCH_SIZE = 100
|
||||
|
||||
|
||||
class Permutation:
|
||||
"""
|
||||
A Permutation is a view of a dataset that can be used as input to model training
|
||||
and evaluation.
|
||||
|
||||
A Permutation fulfills the pytorch Dataset contract and is loosely modeled after the
|
||||
huggingface Dataset so it should be easy to use with existing code.
|
||||
|
||||
A permutation is not a "materialized view" or copy of the underlying data. It is
|
||||
calculated on the fly from the base table. As a result, it is truly "lazy" and does
|
||||
not require materializing the entire dataset in memory.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reader: PermutationReader,
|
||||
selection: dict[str, str],
|
||||
batch_size: int,
|
||||
transform_fn: Callable[pa.RecordBatch, Any],
|
||||
):
|
||||
"""
|
||||
Internal constructor. Use [from_tables](#from_tables) instead.
|
||||
"""
|
||||
assert reader is not None, "reader is required"
|
||||
assert selection is not None, "selection is required"
|
||||
self.reader = reader
|
||||
self.selection = selection
|
||||
self.transform_fn = transform_fn
|
||||
self.batch_size = batch_size
|
||||
|
||||
def _with_selection(self, selection: dict[str, str]) -> "Permutation":
|
||||
"""
|
||||
Creates a new permutation with the given selection
|
||||
|
||||
Does not validation of the selection and it replaces it entirely. This is not
|
||||
intended for public use.
|
||||
"""
|
||||
return Permutation(self.reader, selection, self.batch_size, self.transform_fn)
|
||||
|
||||
def _with_reader(self, reader: PermutationReader) -> "Permutation":
|
||||
"""
|
||||
Creates a new permutation with the given reader
|
||||
|
||||
This is an internal method and should not be used directly.
|
||||
"""
|
||||
return Permutation(reader, self.selection, self.batch_size, self.transform_fn)
|
||||
|
||||
def with_batch_size(self, batch_size: int) -> "Permutation":
|
||||
"""
|
||||
Creates a new permutation with the given batch size
|
||||
"""
|
||||
return Permutation(self.reader, self.selection, batch_size, self.transform_fn)
|
||||
|
||||
@classmethod
|
||||
def identity(cls, table: LanceTable) -> "Permutation":
|
||||
"""
|
||||
Creates an identity permutation for the given table.
|
||||
"""
|
||||
return Permutation.from_tables(table, None, None)
|
||||
|
||||
@classmethod
|
||||
def from_tables(
|
||||
cls,
|
||||
base_table: LanceTable,
|
||||
permutation_table: Optional[LanceTable] = None,
|
||||
split: Optional[Union[str, int]] = None,
|
||||
) -> "Permutation":
|
||||
"""
|
||||
Creates a permutation from the given base table and permutation table.
|
||||
|
||||
A permutation table identifies which rows, and in what order, the data should
|
||||
be read from the base table. For more details see the [PermutationBuilder]
|
||||
class.
|
||||
|
||||
If no permutation table is provided, then the identity permutation will be
|
||||
created. An identity permutation is a permutation that reads all rows in the
|
||||
base table in the order they are stored.
|
||||
|
||||
The split parameter identifies which split to use. If no split is provided
|
||||
then the first split will be used.
|
||||
"""
|
||||
assert base_table is not None, "base_table is required"
|
||||
if split is not None:
|
||||
if permutation_table is None:
|
||||
raise ValueError(
|
||||
"Cannot create a permutation on split `{split}`"
|
||||
" because no permutation table is provided"
|
||||
)
|
||||
if isinstance(split, str):
|
||||
if permutation_table.schema.metadata is None:
|
||||
raise ValueError(
|
||||
f"Cannot create a permutation on split `{split}`"
|
||||
" because no split names are defined in the permutation table"
|
||||
)
|
||||
split_names = permutation_table.schema.metadata.get(
|
||||
b"split_names", None
|
||||
).decode("utf-8")
|
||||
if split_names is None:
|
||||
raise ValueError(
|
||||
f"Cannot create a permutation on split `{split}`"
|
||||
" because no split names are defined in the permutation table"
|
||||
)
|
||||
split_names = json.loads(split_names)
|
||||
try:
|
||||
split = split_names.index(split)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Cannot create a permutation on split `{split}`"
|
||||
f" because split `{split}` is not defined in the "
|
||||
"permutation table"
|
||||
)
|
||||
elif isinstance(split, int):
|
||||
split = split
|
||||
else:
|
||||
raise TypeError(f"Invalid split: {split}")
|
||||
else:
|
||||
split = 0
|
||||
|
||||
async def do_from_tables():
|
||||
reader = await PermutationReader.from_tables(
|
||||
base_table, permutation_table, split
|
||||
)
|
||||
schema = await reader.output_schema(None)
|
||||
initial_selection = {name: name for name in schema.names}
|
||||
return cls(
|
||||
reader, initial_selection, DEFAULT_BATCH_SIZE, Transforms.arrow2python
|
||||
)
|
||||
|
||||
return LOOP.run(do_from_tables())
|
||||
|
||||
@property
|
||||
def schema(self) -> pa.Schema:
|
||||
async def do_output_schema():
|
||||
return await self.reader.output_schema(self.selection)
|
||||
|
||||
return LOOP.run(do_output_schema())
|
||||
|
||||
@property
|
||||
def num_columns(self) -> int:
|
||||
"""
|
||||
The number of columns in the permutation
|
||||
"""
|
||||
return len(self.schema)
|
||||
|
||||
@property
|
||||
def num_rows(self) -> int:
|
||||
"""
|
||||
The number of rows in the permutation
|
||||
"""
|
||||
return self.reader.count_rows()
|
||||
|
||||
@property
|
||||
def column_names(self) -> list[str]:
|
||||
"""
|
||||
The names of the columns in the permutation
|
||||
"""
|
||||
return self.schema.names
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[int, int]:
|
||||
"""
|
||||
The shape of the permutation
|
||||
|
||||
This will return self.num_rows, self.num_columns
|
||||
"""
|
||||
return self.num_rows, self.num_columns
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
The number of rows in the permutation
|
||||
|
||||
This is an alias for [num_rows][lancedb.permutation.Permutation.num_rows]
|
||||
"""
|
||||
return self.num_rows
|
||||
|
||||
def unique(self, _column: str) -> list[Any]:
|
||||
"""
|
||||
Get the unique values in the given column
|
||||
"""
|
||||
raise Exception("unique is not yet implemented")
|
||||
|
||||
def flatten(self) -> "Permutation":
|
||||
"""
|
||||
Flatten the permutation
|
||||
|
||||
Each column with a struct type will be flattened into multiple columns.
|
||||
|
||||
This flattening operation happens at read time as a post-processing step
|
||||
so this call is cheap and no data is copied or modified in the underlying
|
||||
dataset.
|
||||
"""
|
||||
raise Exception("flatten is not yet implemented")
|
||||
|
||||
def remove_columns(self, columns: list[str]) -> "Permutation":
|
||||
"""
|
||||
Remove the given columns from the permutation
|
||||
|
||||
Note: this does not actually modify the underlying dataset. It only changes
|
||||
which columns are visible from this permutation. Also, this does not introduce
|
||||
a post-processing step. Instead, we simply do not read those columns in the
|
||||
first place.
|
||||
|
||||
If any of the provided columns does not exist in the current permutation then it
|
||||
will be ignored (no error is raised for missing columns)
|
||||
|
||||
Returns a new permutation with the given columns removed. This does not modify
|
||||
self.
|
||||
"""
|
||||
assert columns is not None, "columns is required"
|
||||
|
||||
new_selection = {
|
||||
name: value for name, value in self.selection.items() if name not in columns
|
||||
}
|
||||
|
||||
if len(new_selection) == 0:
|
||||
raise ValueError("Cannot remove all columns")
|
||||
|
||||
return self._with_selection(new_selection)
|
||||
|
||||
def rename_column(self, old_name: str, new_name: str) -> "Permutation":
|
||||
"""
|
||||
Rename a column in the permutation
|
||||
|
||||
If there is no column named old_name then an error will be raised
|
||||
If there is already a column named new_name then an error will be raised
|
||||
|
||||
Note: this does not actually modify the underlying dataset. It only changes
|
||||
the name of the column that is visible from this permutation. This is a
|
||||
post-processing step but done at the batch level and so it is very cheap.
|
||||
No data will be copied.
|
||||
"""
|
||||
assert old_name is not None, "old_name is required"
|
||||
assert new_name is not None, "new_name is required"
|
||||
if old_name not in self.selection:
|
||||
raise ValueError(
|
||||
f"Cannot rename column `{old_name}` because it does not exist"
|
||||
)
|
||||
if new_name in self.selection:
|
||||
raise ValueError(
|
||||
f"Cannot rename column `{old_name}` to `{new_name}` because a column "
|
||||
"with that name already exists"
|
||||
)
|
||||
new_selection = self.selection.copy()
|
||||
new_selection[new_name] = new_selection[old_name]
|
||||
del new_selection[old_name]
|
||||
return self._with_selection(new_selection)
|
||||
|
||||
def rename_columns(self, column_map: dict[str, str]) -> "Permutation":
|
||||
"""
|
||||
Rename the given columns in the permutation
|
||||
|
||||
If any of the columns do not exist then an error will be raised
|
||||
If any of the new names already exist then an error will be raised
|
||||
|
||||
Note: this does not actually modify the underlying dataset. It only changes
|
||||
the name of the column that is visible from this permutation. This is a
|
||||
post-processing step but done at the batch level and so it is very cheap.
|
||||
No data will be copied.
|
||||
"""
|
||||
assert column_map is not None, "column_map is required"
|
||||
|
||||
new_permutation = self
|
||||
for old_name, new_name in column_map.items():
|
||||
new_permutation = new_permutation.rename_column(old_name, new_name)
|
||||
return new_permutation
|
||||
|
||||
def select_columns(self, columns: list[str]) -> "Permutation":
|
||||
"""
|
||||
Select the given columns from the permutation
|
||||
|
||||
This method refines the current selection, potentially removing columns. It
|
||||
will not add back columns that were previously removed.
|
||||
|
||||
If any of the columns do not exist then an error will be raised
|
||||
|
||||
This does not introduce a post-processing step. It simply reduces the amount
|
||||
of data we read.
|
||||
"""
|
||||
assert columns is not None, "columns is required"
|
||||
if len(columns) == 0:
|
||||
raise ValueError("Must select at least one column")
|
||||
|
||||
new_selection = {}
|
||||
for name in columns:
|
||||
value = self.selection.get(name, None)
|
||||
if value is None:
|
||||
raise ValueError(
|
||||
f"Cannot select column `{name}` because it does not exist"
|
||||
)
|
||||
new_selection[name] = value
|
||||
return self._with_selection(new_selection)
|
||||
|
||||
def __iter__(self) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
Iterate over the permutation
|
||||
"""
|
||||
return self.iter(self.batch_size, skip_last_batch=True)
|
||||
|
||||
def iter(
|
||||
self, batch_size: int, skip_last_batch: bool = False
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
Iterate over the permutation in batches
|
||||
|
||||
If skip_last_batch is True, the last batch will be skipped if it is not a
|
||||
multiple of batch_size.
|
||||
"""
|
||||
|
||||
async def get_iter():
|
||||
return await self.reader.read(self.selection, batch_size=batch_size)
|
||||
|
||||
async_iter = LOOP.run(get_iter())
|
||||
|
||||
async def get_next():
|
||||
return await async_iter.__anext__()
|
||||
|
||||
try:
|
||||
while True:
|
||||
batch = LOOP.run(get_next())
|
||||
if batch.num_rows == batch_size or not skip_last_batch:
|
||||
yield self.transform_fn(batch)
|
||||
except StopAsyncIteration:
|
||||
return
|
||||
|
||||
def with_format(
|
||||
self, format: Literal["numpy", "python", "pandas", "arrow", "torch", "polars"]
|
||||
) -> "Permutation":
|
||||
"""
|
||||
Set the format for batches
|
||||
|
||||
If this method is not called, the "python" format will be used.
|
||||
|
||||
The format can be one of:
|
||||
- "numpy" - the batch will be a dict of numpy arrays (one per column)
|
||||
- "python" - the batch will be a dict of lists (one per column)
|
||||
- "pandas" - the batch will be a pandas DataFrame
|
||||
- "arrow" - the batch will be a pyarrow RecordBatch
|
||||
- "torch" - the batch will be a two dimensional torch tensor
|
||||
- "polars" - the batch will be a polars DataFrame
|
||||
|
||||
Conversion may or may not involve a data copy. Lance uses Arrow internally
|
||||
and so it is able to zero-copy to the arrow and polars.
|
||||
|
||||
Conversion to torch will be zero-copy but will only support a subset of data
|
||||
types (numeric types).
|
||||
|
||||
Conversion to numpy and/or pandas will typically be zero-copy for numeric
|
||||
types. Conversion of strings, lists, and structs will require creating python
|
||||
objects and this is not zero-copy.
|
||||
|
||||
For custom formatting, use [with_transform](#with_transform) which overrides
|
||||
this method.
|
||||
"""
|
||||
assert format is not None, "format is required"
|
||||
if format == "python":
|
||||
return self.with_transform(Transforms.arrow2python)
|
||||
elif format == "numpy":
|
||||
return self.with_transform(Transforms.arrow2numpy)
|
||||
elif format == "pandas":
|
||||
return self.with_transform(Transforms.arrow2pandas)
|
||||
elif format == "arrow":
|
||||
return self.with_transform(Transforms.arrow2arrow)
|
||||
elif format == "torch":
|
||||
return self.with_transform(batch_to_tensor)
|
||||
elif format == "polars":
|
||||
return self.with_transform(Transforms.arrow2polars())
|
||||
else:
|
||||
raise ValueError(f"Invalid format: {format}")
|
||||
|
||||
def with_transform(self, transform: Callable[pa.RecordBatch, Any]) -> "Permutation":
|
||||
"""
|
||||
Set a custom transform for the permutation
|
||||
|
||||
The transform is a callable that will be invoked with each record batch. The
|
||||
return value will be used as the batch for iteration.
|
||||
|
||||
Note: transforms are not invoked in parallel. This method is not a good place
|
||||
for expensive operations such as image decoding.
|
||||
"""
|
||||
assert transform is not None, "transform is required"
|
||||
return Permutation(self.reader, self.selection, self.batch_size, transform)
|
||||
|
||||
def __getitem__(self, index: int) -> Any:
|
||||
"""
|
||||
Return a single row from the permutation
|
||||
|
||||
The output will always be a python dictionary regardless of the format.
|
||||
|
||||
This method is mostly useful for debugging and exploration. For actual
|
||||
processing use [iter](#iter) or a torch data loader to perform batched
|
||||
processing.
|
||||
"""
|
||||
pass
|
||||
|
||||
@deprecated(details="Use with_skip instead")
|
||||
def skip(self, skip: int) -> "Permutation":
|
||||
"""
|
||||
Skip the first `skip` rows of the permutation
|
||||
|
||||
Note: this method returns a new permutation and does not modify `self`
|
||||
It is provided for compatibility with the huggingface Dataset API.
|
||||
|
||||
Use [with_skip](#with_skip) instead to avoid confusion.
|
||||
"""
|
||||
return self.with_skip(skip)
|
||||
|
||||
def with_skip(self, skip: int) -> "Permutation":
|
||||
"""
|
||||
Skip the first `skip` rows of the permutation
|
||||
"""
|
||||
|
||||
async def do_with_skip():
|
||||
reader = await self.reader.with_offset(skip)
|
||||
return self._with_reader(reader)
|
||||
|
||||
return LOOP.run(do_with_skip())
|
||||
|
||||
@deprecated(details="Use with_take instead")
|
||||
def take(self, limit: int) -> "Permutation":
|
||||
"""
|
||||
Limit the permutation to `limit` rows (following any `skip`)
|
||||
|
||||
Note: this method returns a new permutation and does not modify `self`
|
||||
It is provided for compatibility with the huggingface Dataset API.
|
||||
|
||||
Use [with_take](#with_take) instead to avoid confusion.
|
||||
"""
|
||||
return self.with_take(limit)
|
||||
|
||||
def with_take(self, limit: int) -> "Permutation":
|
||||
"""
|
||||
Limit the permutation to `limit` rows (following any `skip`)
|
||||
"""
|
||||
|
||||
async def do_with_take():
|
||||
reader = await self.reader.with_limit(limit)
|
||||
return self._with_reader(reader)
|
||||
|
||||
return LOOP.run(do_with_take())
|
||||
|
||||
@deprecated(details="Use with_repeat instead")
|
||||
def repeat(self, times: int) -> "Permutation":
|
||||
"""
|
||||
Repeat the permutation `times` times
|
||||
|
||||
Note: this method returns a new permutation and does not modify `self`
|
||||
It is provided for compatibility with the huggingface Dataset API.
|
||||
|
||||
Use [with_repeat](#with_repeat) instead to avoid confusion.
|
||||
"""
|
||||
return self.with_repeat(times)
|
||||
|
||||
def with_repeat(self, times: int) -> "Permutation":
|
||||
"""
|
||||
Repeat the permutation `times` times
|
||||
"""
|
||||
raise Exception("with_repeat is not yet implemented")
|
||||
|
||||
@@ -37,7 +37,7 @@ from .rerankers.base import Reranker
|
||||
from .rerankers.rrf import RRFReranker
|
||||
from .rerankers.util import check_reranker_result
|
||||
from .util import flatten_columns
|
||||
|
||||
from lancedb._lancedb import fts_query_to_json
|
||||
from typing_extensions import Annotated
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -124,6 +124,24 @@ class FullTextQuery(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""
|
||||
Convert the query to a JSON string.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
A JSON string representation of the query.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from lancedb.query import MatchQuery
|
||||
>>> query = MatchQuery("puppy", "text", fuzziness=2)
|
||||
>>> query.to_json()
|
||||
'{"match":{"column":"text","terms":"puppy","boost":1.0,"fuzziness":2,"max_expansions":50,"operator":"Or","prefix_length":0}}'
|
||||
"""
|
||||
return fts_query_to_json(self)
|
||||
|
||||
def __and__(self, other: "FullTextQuery") -> "FullTextQuery":
|
||||
"""
|
||||
Combine two queries with a logical AND operation.
|
||||
@@ -288,6 +306,8 @@ class BooleanQuery(FullTextQuery):
|
||||
----------
|
||||
queries : list[tuple(Occur, FullTextQuery)]
|
||||
The list of queries with their occurrence requirements.
|
||||
Each tuple contains an Occur value (MUST, SHOULD, or MUST_NOT)
|
||||
and a FullTextQuery to apply.
|
||||
"""
|
||||
|
||||
queries: list[tuple[Occur, FullTextQuery]]
|
||||
@@ -1237,6 +1257,14 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
self._refine_factor = refine_factor
|
||||
return self
|
||||
|
||||
def output_schema(self) -> pa.Schema:
|
||||
"""
|
||||
Return the output schema for the query
|
||||
|
||||
This does not execute the query.
|
||||
"""
|
||||
return self._table._output_schema(self.to_query_object())
|
||||
|
||||
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
|
||||
"""
|
||||
Execute the query and return the results as an
|
||||
@@ -1452,6 +1480,14 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
offset=self._offset,
|
||||
)
|
||||
|
||||
def output_schema(self) -> pa.Schema:
|
||||
"""
|
||||
Return the output schema for the query
|
||||
|
||||
This does not execute the query.
|
||||
"""
|
||||
return self._table._output_schema(self.to_query_object())
|
||||
|
||||
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
|
||||
path, fs, exist = self._table._get_fts_index_path()
|
||||
if exist:
|
||||
@@ -1595,6 +1631,10 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
||||
offset=self._offset,
|
||||
)
|
||||
|
||||
def output_schema(self) -> pa.Schema:
|
||||
query = self.to_query_object()
|
||||
return self._table._output_schema(query)
|
||||
|
||||
def to_batches(
|
||||
self, /, batch_size: Optional[int] = None, timeout: Optional[timedelta] = None
|
||||
) -> pa.RecordBatchReader:
|
||||
@@ -2238,6 +2278,14 @@ class AsyncQueryBase(object):
|
||||
)
|
||||
)
|
||||
|
||||
async def output_schema(self) -> pa.Schema:
|
||||
"""
|
||||
Return the output schema for the query
|
||||
|
||||
This does not execute the query.
|
||||
"""
|
||||
return await self._inner.output_schema()
|
||||
|
||||
async def to_arrow(self, timeout: Optional[timedelta] = None) -> pa.Table:
|
||||
"""
|
||||
Execute the query and collect the results into an Apache Arrow Table.
|
||||
@@ -3193,6 +3241,14 @@ class BaseQueryBuilder(object):
|
||||
self._inner.with_row_id()
|
||||
return self
|
||||
|
||||
def output_schema(self) -> pa.Schema:
|
||||
"""
|
||||
Return the output schema for the query
|
||||
|
||||
This does not execute the query.
|
||||
"""
|
||||
return LOOP.run(self._inner.output_schema())
|
||||
|
||||
def to_batches(
|
||||
self,
|
||||
*,
|
||||
|
||||
@@ -436,6 +436,9 @@ class RemoteTable(Table):
|
||||
def _analyze_plan(self, query: Query) -> str:
|
||||
return LOOP.run(self._table._analyze_plan(query))
|
||||
|
||||
def _output_schema(self, query: Query) -> pa.Schema:
|
||||
return LOOP.run(self._table._output_schema(query))
|
||||
|
||||
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
||||
"""Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder]
|
||||
that can be used to create a "merge insert" operation.
|
||||
|
||||
@@ -21,6 +21,8 @@ class VoyageAIReranker(Reranker):
|
||||
----------
|
||||
model_name : str, default "rerank-english-v2.0"
|
||||
The name of the cross encoder model to use. Available voyageai models are:
|
||||
- rerank-2.5
|
||||
- rerank-2.5-lite
|
||||
- rerank-2
|
||||
- rerank-2-lite
|
||||
column : str, default "text"
|
||||
|
||||
@@ -1248,6 +1248,9 @@ class Table(ABC):
|
||||
@abstractmethod
|
||||
def _analyze_plan(self, query: Query) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def _output_schema(self, query: Query) -> pa.Schema: ...
|
||||
|
||||
@abstractmethod
|
||||
def _do_merge(
|
||||
self,
|
||||
@@ -2761,6 +2764,9 @@ class LanceTable(Table):
|
||||
def _analyze_plan(self, query: Query) -> str:
|
||||
return LOOP.run(self._table._analyze_plan(query))
|
||||
|
||||
def _output_schema(self, query: Query) -> pa.Schema:
|
||||
return LOOP.run(self._table._output_schema(query))
|
||||
|
||||
def _do_merge(
|
||||
self,
|
||||
merge: LanceMergeInsertBuilder,
|
||||
@@ -3918,6 +3924,10 @@ class AsyncTable:
|
||||
async_query = self._sync_query_to_async(query)
|
||||
return await async_query.analyze_plan()
|
||||
|
||||
async def _output_schema(self, query: Query) -> pa.Schema:
|
||||
async_query = self._sync_query_to_async(query)
|
||||
return await async_query.output_schema()
|
||||
|
||||
async def _do_merge(
|
||||
self,
|
||||
merge: LanceMergeInsertBuilder,
|
||||
|
||||
@@ -366,3 +366,56 @@ def add_note(base_exception: BaseException, note: str):
|
||||
)
|
||||
else:
|
||||
raise ValueError("Cannot add note to exception")
|
||||
|
||||
|
||||
def tbl_to_tensor(tbl: pa.Table):
|
||||
"""
|
||||
Convert a PyArrow Table to a PyTorch Tensor.
|
||||
|
||||
Each column is converted to a tensor (using zero-copy via DLPack)
|
||||
and the columns are then stacked into a single tensor.
|
||||
|
||||
Fails if torch is not installed.
|
||||
Fails if any column is more than one chunk.
|
||||
Fails if a column's data type is not supported by PyTorch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tbl : pa.Table or pa.RecordBatch
|
||||
The table or record batch to convert to a tensor.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor: The tensor containing the columns of the table.
|
||||
"""
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
|
||||
def to_tensor(col: pa.ChunkedArray):
|
||||
if col.num_chunks > 1:
|
||||
raise Exception("Single batch was too large to fit into a one-chunk table")
|
||||
return torch.from_dlpack(col.chunk(0))
|
||||
|
||||
return torch.stack([to_tensor(tbl.column(i)) for i in range(tbl.num_columns)])
|
||||
|
||||
|
||||
def batch_to_tensor(batch: pa.RecordBatch):
|
||||
"""
|
||||
Convert a PyArrow RecordBatch to a PyTorch Tensor.
|
||||
|
||||
Each column is converted to a tensor (using zero-copy via DLPack)
|
||||
and the columns are then stacked into a single tensor.
|
||||
|
||||
Fails if torch is not installed.
|
||||
Fails if a column's data type is not supported by PyTorch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
batch : pa.RecordBatch
|
||||
The record batch to convert to a tensor.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor: The tensor containing the columns of the record batch.
|
||||
"""
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
return torch.stack([torch.from_dlpack(col) for col in batch.columns])
|
||||
|
||||
@@ -532,6 +532,27 @@ def test_voyageai_embedding_function():
|
||||
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||
)
|
||||
def test_voyageai_embedding_function_contextual_model():
|
||||
voyageai = (
|
||||
get_registry().get("voyageai").create(name="voyage-context-3", max_retries=0)
|
||||
)
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = voyageai.SourceField()
|
||||
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
||||
|
||||
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||
db = lancedb.connect("~/lancedb")
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||
@@ -656,6 +677,106 @@ def test_colpali(tmp_path):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
importlib.util.find_spec("colpali_engine") is None,
|
||||
reason="colpali_engine not installed",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[
|
||||
"vidore/colSmol-256M",
|
||||
"vidore/colqwen2.5-v0.2",
|
||||
"vidore/colpali-v1.3",
|
||||
"vidore/colqwen2-v1.0",
|
||||
],
|
||||
)
|
||||
def test_colpali_models(tmp_path, model_name):
|
||||
import requests
|
||||
from lancedb.pydantic import LanceModel
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = get_registry()
|
||||
func = registry.get("colpali").create(model_name=model_name)
|
||||
|
||||
class MediaItems(LanceModel):
|
||||
text: str
|
||||
image_uri: str = func.SourceField()
|
||||
image_bytes: bytes = func.SourceField()
|
||||
image_vectors: MultiVector(func.ndims()) = func.VectorField()
|
||||
|
||||
table = db.create_table(f"media_{model_name.replace('/', '_')}", schema=MediaItems)
|
||||
|
||||
texts = [
|
||||
"a cute cat playing with yarn",
|
||||
]
|
||||
|
||||
uris = [
|
||||
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||
]
|
||||
|
||||
image_bytes = [requests.get(uri).content for uri in uris]
|
||||
|
||||
table.add(
|
||||
pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes})
|
||||
)
|
||||
|
||||
image_results = (
|
||||
table.search("fluffy companion", vector_column_name="image_vectors")
|
||||
.limit(1)
|
||||
.to_pydantic(MediaItems)[0]
|
||||
)
|
||||
assert "cat" in image_results.text.lower() or "puppy" in image_results.text.lower()
|
||||
|
||||
first_row = table.to_arrow().to_pylist()[0]
|
||||
assert len(first_row["image_vectors"]) > 1, "Should have multiple image vectors"
|
||||
assert len(first_row["image_vectors"][0]) == func.ndims(), (
|
||||
"Vector dimension mismatch"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
importlib.util.find_spec("colpali_engine") is None,
|
||||
reason="colpali_engine not installed",
|
||||
)
|
||||
def test_colpali_pooling(tmp_path):
|
||||
registry = get_registry()
|
||||
model_name = "vidore/colSmol-256M"
|
||||
test_sentence = "a test sentence for pooling"
|
||||
|
||||
# 1. Get embeddings with no pooling
|
||||
func_no_pool = registry.get("colpali").create(
|
||||
model_name=model_name, pooling_strategy=None
|
||||
)
|
||||
unpooled_embeddings = func_no_pool.generate_text_embeddings([test_sentence])[0]
|
||||
original_length = len(unpooled_embeddings)
|
||||
assert original_length > 1
|
||||
|
||||
# 2. Test hierarchical pooling
|
||||
func_hierarchical = registry.get("colpali").create(
|
||||
model_name=model_name, pooling_strategy="hierarchical", pool_factor=2
|
||||
)
|
||||
hierarchical_embeddings = func_hierarchical.generate_text_embeddings(
|
||||
[test_sentence]
|
||||
)[0]
|
||||
expected_hierarchical_length = (original_length + 1) // 2
|
||||
assert len(hierarchical_embeddings) == expected_hierarchical_length
|
||||
|
||||
# 3. Test lambda pooling
|
||||
def simple_pool_func(tensor):
|
||||
return tensor[::2]
|
||||
|
||||
func_lambda = registry.get("colpali").create(
|
||||
model_name=model_name,
|
||||
pooling_strategy="lambda",
|
||||
pooling_func=simple_pool_func,
|
||||
)
|
||||
lambda_embeddings = func_lambda.generate_text_embeddings([test_sentence])[0]
|
||||
expected_lambda_length = (original_length + 1) // 2
|
||||
assert len(lambda_embeddings) == expected_lambda_length
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_siglip(tmp_path, test_images, query_image_bytes):
|
||||
from PIL import Image
|
||||
|
||||
@@ -20,7 +20,14 @@ from unittest import mock
|
||||
import lancedb as ldb
|
||||
from lancedb.db import DBConnection
|
||||
from lancedb.index import FTS
|
||||
from lancedb.query import BoostQuery, MatchQuery, MultiMatchQuery, PhraseQuery
|
||||
from lancedb.query import (
|
||||
BoostQuery,
|
||||
MatchQuery,
|
||||
MultiMatchQuery,
|
||||
PhraseQuery,
|
||||
BooleanQuery,
|
||||
Occur,
|
||||
)
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pandas as pd
|
||||
@@ -727,3 +734,146 @@ def test_fts_ngram(mem_db: DBConnection):
|
||||
results = table.search("la", query_type="fts").limit(10).to_list()
|
||||
assert len(results) == 2
|
||||
assert set(r["text"] for r in results) == {"lance database", "lance is cool"}
|
||||
|
||||
|
||||
def test_fts_query_to_json():
|
||||
"""Test that FTS query to_json() produces valid JSON strings with exact format."""
|
||||
|
||||
# Test MatchQuery - basic
|
||||
match_query = MatchQuery("hello world", "text")
|
||||
json_str = match_query.to_json()
|
||||
expected = (
|
||||
'{"match":{"column":"text","terms":"hello world","boost":1.0,'
|
||||
'"fuzziness":0,"max_expansions":50,"operator":"Or","prefix_length":0}}'
|
||||
)
|
||||
assert json_str == expected
|
||||
|
||||
# Test MatchQuery with options
|
||||
match_query = MatchQuery("puppy", "text", fuzziness=2, boost=1.5, prefix_length=3)
|
||||
json_str = match_query.to_json()
|
||||
expected = (
|
||||
'{"match":{"column":"text","terms":"puppy","boost":1.5,"fuzziness":2,'
|
||||
'"max_expansions":50,"operator":"Or","prefix_length":3}}'
|
||||
)
|
||||
assert json_str == expected
|
||||
|
||||
# Test PhraseQuery
|
||||
phrase_query = PhraseQuery("quick brown fox", "title")
|
||||
json_str = phrase_query.to_json()
|
||||
expected = '{"phrase":{"column":"title","terms":"quick brown fox","slop":0}}'
|
||||
assert json_str == expected
|
||||
|
||||
# Test PhraseQuery with slop
|
||||
phrase_query = PhraseQuery("quick brown", "title", slop=2)
|
||||
json_str = phrase_query.to_json()
|
||||
expected = '{"phrase":{"column":"title","terms":"quick brown","slop":2}}'
|
||||
assert json_str == expected
|
||||
|
||||
# Test BooleanQuery with MUST
|
||||
must_query = BooleanQuery(
|
||||
[
|
||||
(Occur.MUST, MatchQuery("puppy", "text")),
|
||||
(Occur.MUST, MatchQuery("runs", "text")),
|
||||
]
|
||||
)
|
||||
json_str = must_query.to_json()
|
||||
expected = (
|
||||
'{"boolean":{"should":[],"must":[{"match":{"column":"text","terms":"puppy",'
|
||||
'"boost":1.0,"fuzziness":0,"max_expansions":50,"operator":"Or",'
|
||||
'"prefix_length":0}},{"match":{"column":"text","terms":"runs","boost":1.0,'
|
||||
'"fuzziness":0,"max_expansions":50,"operator":"Or","prefix_length":0}}],'
|
||||
'"must_not":[]}}'
|
||||
)
|
||||
assert json_str == expected
|
||||
|
||||
# Test BooleanQuery with SHOULD
|
||||
should_query = BooleanQuery(
|
||||
[
|
||||
(Occur.SHOULD, MatchQuery("cat", "text")),
|
||||
(Occur.SHOULD, MatchQuery("dog", "text")),
|
||||
]
|
||||
)
|
||||
json_str = should_query.to_json()
|
||||
expected = (
|
||||
'{"boolean":{"should":[{"match":{"column":"text","terms":"cat","boost":1.0,'
|
||||
'"fuzziness":0,"max_expansions":50,"operator":"Or","prefix_length":0}},'
|
||||
'{"match":{"column":"text","terms":"dog","boost":1.0,"fuzziness":0,'
|
||||
'"max_expansions":50,"operator":"Or","prefix_length":0}}],"must":[],'
|
||||
'"must_not":[]}}'
|
||||
)
|
||||
assert json_str == expected
|
||||
|
||||
# Test BooleanQuery with MUST_NOT
|
||||
must_not_query = BooleanQuery(
|
||||
[
|
||||
(Occur.MUST, MatchQuery("puppy", "text")),
|
||||
(Occur.MUST_NOT, MatchQuery("training", "text")),
|
||||
]
|
||||
)
|
||||
json_str = must_not_query.to_json()
|
||||
expected = (
|
||||
'{"boolean":{"should":[],"must":[{"match":{"column":"text","terms":"puppy",'
|
||||
'"boost":1.0,"fuzziness":0,"max_expansions":50,"operator":"Or",'
|
||||
'"prefix_length":0}}],"must_not":[{"match":{"column":"text",'
|
||||
'"terms":"training","boost":1.0,"fuzziness":0,"max_expansions":50,'
|
||||
'"operator":"Or","prefix_length":0}}]}}'
|
||||
)
|
||||
assert json_str == expected
|
||||
|
||||
# Test BoostQuery
|
||||
positive = MatchQuery("puppy", "text")
|
||||
negative = MatchQuery("training", "text")
|
||||
boost_query = BoostQuery(positive, negative, negative_boost=0.3)
|
||||
json_str = boost_query.to_json()
|
||||
expected = (
|
||||
'{"boost":{"positive":{"match":{"column":"text","terms":"puppy",'
|
||||
'"boost":1.0,"fuzziness":0,"max_expansions":50,"operator":"Or",'
|
||||
'"prefix_length":0}},"negative":{"match":{"column":"text",'
|
||||
'"terms":"training","boost":1.0,"fuzziness":0,"max_expansions":50,'
|
||||
'"operator":"Or","prefix_length":0}},"negative_boost":0.3}}'
|
||||
)
|
||||
assert json_str == expected
|
||||
|
||||
# Test MultiMatchQuery
|
||||
multi_match = MultiMatchQuery("python", ["tags", "title"])
|
||||
json_str = multi_match.to_json()
|
||||
expected = (
|
||||
'{"multi_match":{"query":"python","columns":["tags","title"],'
|
||||
'"boost":[1.0,1.0]}}'
|
||||
)
|
||||
assert json_str == expected
|
||||
|
||||
# Test complex nested BooleanQuery
|
||||
inner1 = BooleanQuery(
|
||||
[
|
||||
(Occur.MUST, MatchQuery("python", "tags")),
|
||||
(Occur.MUST, MatchQuery("tutorial", "title")),
|
||||
]
|
||||
)
|
||||
inner2 = BooleanQuery(
|
||||
[
|
||||
(Occur.MUST, MatchQuery("rust", "tags")),
|
||||
(Occur.MUST, MatchQuery("guide", "title")),
|
||||
]
|
||||
)
|
||||
complex_query = BooleanQuery(
|
||||
[
|
||||
(Occur.SHOULD, inner1),
|
||||
(Occur.SHOULD, inner2),
|
||||
]
|
||||
)
|
||||
json_str = complex_query.to_json()
|
||||
expected = (
|
||||
'{"boolean":{"should":[{"boolean":{"should":[],"must":[{"match":'
|
||||
'{"column":"tags","terms":"python","boost":1.0,"fuzziness":0,'
|
||||
'"max_expansions":50,"operator":"Or","prefix_length":0}},{"match":'
|
||||
'{"column":"title","terms":"tutorial","boost":1.0,"fuzziness":0,'
|
||||
'"max_expansions":50,"operator":"Or","prefix_length":0}}],"must_not":[]}}'
|
||||
',{"boolean":{"should":[],"must":[{"match":{"column":"tags",'
|
||||
'"terms":"rust","boost":1.0,"fuzziness":0,"max_expansions":50,'
|
||||
'"operator":"Or","prefix_length":0}},{"match":{"column":"title",'
|
||||
'"terms":"guide","boost":1.0,"fuzziness":0,"max_expansions":50,'
|
||||
'"operator":"Or","prefix_length":0}}],"must_not":[]}}],"must":[],'
|
||||
'"must_not":[]}}'
|
||||
)
|
||||
assert json_str == expected
|
||||
|
||||
@@ -59,6 +59,14 @@ class TempNamespace(LanceNamespace):
|
||||
root
|
||||
] # Reference to shared namespaces
|
||||
|
||||
def namespace_id(self) -> str:
|
||||
"""Return a human-readable unique identifier for this namespace instance.
|
||||
|
||||
Returns:
|
||||
A unique identifier string based on the root directory
|
||||
"""
|
||||
return f"TempNamespace {{ root: '{self.config.root}' }}"
|
||||
|
||||
def list_tables(self, request: ListTablesRequest) -> ListTablesResponse:
|
||||
"""List all tables in the namespace."""
|
||||
if not request.id:
|
||||
|
||||
@@ -2,9 +2,26 @@
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import pyarrow as pa
|
||||
import math
|
||||
import pytest
|
||||
|
||||
from lancedb.permutation import permutation_builder
|
||||
from lancedb import DBConnection, Table, connect
|
||||
from lancedb.permutation import Permutation, Permutations, permutation_builder
|
||||
|
||||
|
||||
def test_permutation_persistence(tmp_path):
|
||||
db = connect(tmp_path)
|
||||
tbl = db.create_table("test_table", pa.table({"x": range(100), "y": range(100)}))
|
||||
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl).shuffle().persist(db, "test_permutation").execute()
|
||||
)
|
||||
assert permutation_tbl.count_rows() == 100
|
||||
|
||||
re_open = db.open_table("test_permutation")
|
||||
assert re_open.count_rows() == 100
|
||||
|
||||
assert permutation_tbl.to_arrow() == re_open.to_arrow()
|
||||
|
||||
|
||||
def test_split_random_ratios(mem_db):
|
||||
@@ -12,11 +29,7 @@ def test_split_random_ratios(mem_db):
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"x": range(100), "y": range(100)})
|
||||
)
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
.split_random(ratios=[0.3, 0.7])
|
||||
.execute()
|
||||
)
|
||||
permutation_tbl = permutation_builder(tbl).split_random(ratios=[0.3, 0.7]).execute()
|
||||
|
||||
# Check that the table was created and has data
|
||||
assert permutation_tbl.count_rows() == 100
|
||||
@@ -38,11 +51,7 @@ def test_split_random_counts(mem_db):
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"x": range(100), "y": range(100)})
|
||||
)
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
.split_random(counts=[20, 30])
|
||||
.execute()
|
||||
)
|
||||
permutation_tbl = permutation_builder(tbl).split_random(counts=[20, 30]).execute()
|
||||
|
||||
# Check that we have exactly the requested counts
|
||||
assert permutation_tbl.count_rows() == 50
|
||||
@@ -58,9 +67,7 @@ def test_split_random_fixed(mem_db):
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"x": range(100), "y": range(100)})
|
||||
)
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation").split_random(fixed=4).execute()
|
||||
)
|
||||
permutation_tbl = permutation_builder(tbl).split_random(fixed=4).execute()
|
||||
|
||||
# Check that we have 4 splits with 25 rows each
|
||||
assert permutation_tbl.count_rows() == 100
|
||||
@@ -78,17 +85,9 @@ def test_split_random_with_seed(mem_db):
|
||||
tbl = mem_db.create_table("test_table", pa.table({"x": range(50), "y": range(50)}))
|
||||
|
||||
# Create two identical permutations with same seed
|
||||
perm1 = (
|
||||
permutation_builder(tbl, "perm1")
|
||||
.split_random(ratios=[0.6, 0.4], seed=42)
|
||||
.execute()
|
||||
)
|
||||
perm1 = permutation_builder(tbl).split_random(ratios=[0.6, 0.4], seed=42).execute()
|
||||
|
||||
perm2 = (
|
||||
permutation_builder(tbl, "perm2")
|
||||
.split_random(ratios=[0.6, 0.4], seed=42)
|
||||
.execute()
|
||||
)
|
||||
perm2 = permutation_builder(tbl).split_random(ratios=[0.6, 0.4], seed=42).execute()
|
||||
|
||||
# Results should be identical
|
||||
data1 = perm1.search(None).to_arrow().to_pydict()
|
||||
@@ -112,7 +111,7 @@ def test_split_hash(mem_db):
|
||||
)
|
||||
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
permutation_builder(tbl)
|
||||
.split_hash(["category"], [1, 1], discard_weight=0)
|
||||
.execute()
|
||||
)
|
||||
@@ -133,7 +132,7 @@ def test_split_hash(mem_db):
|
||||
# Hash splits should be deterministic - same category should go to same split
|
||||
# Let's verify by creating another permutation and checking consistency
|
||||
perm2 = (
|
||||
permutation_builder(tbl, "test_permutation2")
|
||||
permutation_builder(tbl)
|
||||
.split_hash(["category"], [1, 1], discard_weight=0)
|
||||
.execute()
|
||||
)
|
||||
@@ -150,7 +149,7 @@ def test_split_hash_with_discard(mem_db):
|
||||
)
|
||||
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
permutation_builder(tbl)
|
||||
.split_hash(["category"], [1, 1], discard_weight=2) # Should discard ~50%
|
||||
.execute()
|
||||
)
|
||||
@@ -168,9 +167,7 @@ def test_split_sequential(mem_db):
|
||||
)
|
||||
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
.split_sequential(counts=[30, 40])
|
||||
.execute()
|
||||
permutation_builder(tbl).split_sequential(counts=[30, 40]).execute()
|
||||
)
|
||||
|
||||
assert permutation_tbl.count_rows() == 70
|
||||
@@ -194,7 +191,7 @@ def test_split_calculated(mem_db):
|
||||
)
|
||||
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
permutation_builder(tbl)
|
||||
.split_calculated("id % 3") # Split based on id modulo 3
|
||||
.execute()
|
||||
)
|
||||
@@ -215,24 +212,34 @@ def test_split_error_cases(mem_db):
|
||||
tbl = mem_db.create_table("test_table", pa.table({"x": range(10), "y": range(10)}))
|
||||
|
||||
# Test split_random with no parameters
|
||||
with pytest.raises(Exception):
|
||||
permutation_builder(tbl, "error1").split_random().execute()
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
||||
):
|
||||
permutation_builder(tbl).split_random().execute()
|
||||
|
||||
# Test split_random with multiple parameters
|
||||
with pytest.raises(Exception):
|
||||
permutation_builder(tbl, "error2").split_random(
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
||||
):
|
||||
permutation_builder(tbl).split_random(
|
||||
ratios=[0.5, 0.5], counts=[5, 5]
|
||||
).execute()
|
||||
|
||||
# Test split_sequential with no parameters
|
||||
with pytest.raises(Exception):
|
||||
permutation_builder(tbl, "error3").split_sequential().execute()
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
||||
):
|
||||
permutation_builder(tbl).split_sequential().execute()
|
||||
|
||||
# Test split_sequential with multiple parameters
|
||||
with pytest.raises(Exception):
|
||||
permutation_builder(tbl, "error4").split_sequential(
|
||||
ratios=[0.5, 0.5], fixed=2
|
||||
).execute()
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
||||
):
|
||||
permutation_builder(tbl).split_sequential(ratios=[0.5, 0.5], fixed=2).execute()
|
||||
|
||||
|
||||
def test_shuffle_no_seed(mem_db):
|
||||
@@ -242,7 +249,7 @@ def test_shuffle_no_seed(mem_db):
|
||||
)
|
||||
|
||||
# Create a permutation with shuffling (no seed)
|
||||
permutation_tbl = permutation_builder(tbl, "test_permutation").shuffle().execute()
|
||||
permutation_tbl = permutation_builder(tbl).shuffle().execute()
|
||||
|
||||
assert permutation_tbl.count_rows() == 100
|
||||
|
||||
@@ -262,9 +269,9 @@ def test_shuffle_with_seed(mem_db):
|
||||
)
|
||||
|
||||
# Create two identical permutations with same shuffle seed
|
||||
perm1 = permutation_builder(tbl, "perm1").shuffle(seed=42).execute()
|
||||
perm1 = permutation_builder(tbl).shuffle(seed=42).execute()
|
||||
|
||||
perm2 = permutation_builder(tbl, "perm2").shuffle(seed=42).execute()
|
||||
perm2 = permutation_builder(tbl).shuffle(seed=42).execute()
|
||||
|
||||
# Results should be identical due to same seed
|
||||
data1 = perm1.search(None).to_arrow().to_pydict()
|
||||
@@ -282,7 +289,7 @@ def test_shuffle_with_clump_size(mem_db):
|
||||
|
||||
# Create a permutation with shuffling using clumps
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
permutation_builder(tbl)
|
||||
.shuffle(clump_size=10) # 10-row clumps
|
||||
.execute()
|
||||
)
|
||||
@@ -304,19 +311,9 @@ def test_shuffle_different_seeds(mem_db):
|
||||
)
|
||||
|
||||
# Create two permutations with different shuffle seeds
|
||||
perm1 = (
|
||||
permutation_builder(tbl, "perm1")
|
||||
.split_random(fixed=2)
|
||||
.shuffle(seed=42)
|
||||
.execute()
|
||||
)
|
||||
perm1 = permutation_builder(tbl).split_random(fixed=2).shuffle(seed=42).execute()
|
||||
|
||||
perm2 = (
|
||||
permutation_builder(tbl, "perm2")
|
||||
.split_random(fixed=2)
|
||||
.shuffle(seed=123)
|
||||
.execute()
|
||||
)
|
||||
perm2 = permutation_builder(tbl).split_random(fixed=2).shuffle(seed=123).execute()
|
||||
|
||||
# Results should be different due to different seeds
|
||||
data1 = perm1.search(None).to_arrow().to_pydict()
|
||||
@@ -341,7 +338,7 @@ def test_shuffle_combined_with_splits(mem_db):
|
||||
|
||||
# Test shuffle with random splits
|
||||
perm_random = (
|
||||
permutation_builder(tbl, "perm_random")
|
||||
permutation_builder(tbl)
|
||||
.split_random(ratios=[0.6, 0.4], seed=42)
|
||||
.shuffle(seed=123, clump_size=None)
|
||||
.execute()
|
||||
@@ -349,7 +346,7 @@ def test_shuffle_combined_with_splits(mem_db):
|
||||
|
||||
# Test shuffle with hash splits
|
||||
perm_hash = (
|
||||
permutation_builder(tbl, "perm_hash")
|
||||
permutation_builder(tbl)
|
||||
.split_hash(["category"], [1, 1], discard_weight=0)
|
||||
.shuffle(seed=456, clump_size=5)
|
||||
.execute()
|
||||
@@ -357,7 +354,7 @@ def test_shuffle_combined_with_splits(mem_db):
|
||||
|
||||
# Test shuffle with sequential splits
|
||||
perm_sequential = (
|
||||
permutation_builder(tbl, "perm_sequential")
|
||||
permutation_builder(tbl)
|
||||
.split_sequential(counts=[40, 35])
|
||||
.shuffle(seed=789, clump_size=None)
|
||||
.execute()
|
||||
@@ -384,7 +381,7 @@ def test_no_shuffle_maintains_order(mem_db):
|
||||
|
||||
# Create permutation without shuffle (should maintain some order)
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
permutation_builder(tbl)
|
||||
.split_sequential(counts=[25, 25]) # Sequential maintains order
|
||||
.execute()
|
||||
)
|
||||
@@ -405,9 +402,7 @@ def test_filter_basic(mem_db):
|
||||
)
|
||||
|
||||
# Filter to only include rows where id < 50
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation").filter("id < 50").execute()
|
||||
)
|
||||
permutation_tbl = permutation_builder(tbl).filter("id < 50").execute()
|
||||
|
||||
assert permutation_tbl.count_rows() == 50
|
||||
|
||||
@@ -433,7 +428,7 @@ def test_filter_with_splits(mem_db):
|
||||
|
||||
# Filter to only category A and B, then split
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
permutation_builder(tbl)
|
||||
.filter("category IN ('A', 'B')")
|
||||
.split_random(ratios=[0.5, 0.5])
|
||||
.execute()
|
||||
@@ -465,7 +460,7 @@ def test_filter_with_shuffle(mem_db):
|
||||
|
||||
# Filter and shuffle
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
permutation_builder(tbl)
|
||||
.filter("category IN ('A', 'C')")
|
||||
.shuffle(seed=42)
|
||||
.execute()
|
||||
@@ -488,9 +483,461 @@ def test_filter_empty_result(mem_db):
|
||||
|
||||
# Filter that matches nothing
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
permutation_builder(tbl)
|
||||
.filter("value > 100") # No values > 100 in our data
|
||||
.execute()
|
||||
)
|
||||
|
||||
assert permutation_tbl.count_rows() == 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mem_db() -> DBConnection:
|
||||
return connect("memory:///")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def some_table(mem_db: DBConnection) -> Table:
|
||||
data = pa.table(
|
||||
{
|
||||
"id": range(1000),
|
||||
"value": range(1000),
|
||||
}
|
||||
)
|
||||
return mem_db.create_table("some_table", data)
|
||||
|
||||
|
||||
def test_no_split_names(some_table: Table):
|
||||
perm_tbl = (
|
||||
permutation_builder(some_table).split_sequential(counts=[500, 500]).execute()
|
||||
)
|
||||
permutations = Permutations(some_table, perm_tbl)
|
||||
assert permutations.split_names == []
|
||||
assert permutations.split_dict == {}
|
||||
assert permutations[0].num_rows == 500
|
||||
assert permutations[1].num_rows == 500
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def some_perm_table(some_table: Table) -> Table:
|
||||
return (
|
||||
permutation_builder(some_table)
|
||||
.split_random(ratios=[0.95, 0.05], seed=42, split_names=["train", "test"])
|
||||
.shuffle(seed=42)
|
||||
.execute()
|
||||
)
|
||||
|
||||
|
||||
def test_nonexistent_split(some_table: Table, some_perm_table: Table):
|
||||
# Reference by name and name does not exist
|
||||
with pytest.raises(ValueError, match="split `nonexistent` is not defined"):
|
||||
Permutation.from_tables(some_table, some_perm_table, "nonexistent")
|
||||
|
||||
# Reference by ordinal and there are no rows
|
||||
with pytest.raises(ValueError, match="No rows found"):
|
||||
Permutation.from_tables(some_table, some_perm_table, 5)
|
||||
|
||||
|
||||
def test_permutations(some_table: Table, some_perm_table: Table):
|
||||
permutations = Permutations(some_table, some_perm_table)
|
||||
assert permutations.split_names == ["train", "test"]
|
||||
assert permutations.split_dict == {"train": 0, "test": 1}
|
||||
assert permutations["train"].num_rows == 950
|
||||
assert permutations[0].num_rows == 950
|
||||
assert permutations["test"].num_rows == 50
|
||||
assert permutations[1].num_rows == 50
|
||||
|
||||
with pytest.raises(ValueError, match="No split named `nonexistent` found"):
|
||||
permutations["nonexistent"]
|
||||
with pytest.raises(ValueError, match="No rows found"):
|
||||
permutations[5]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def some_permutation(some_table: Table, some_perm_table: Table) -> Permutation:
|
||||
return Permutation.from_tables(some_table, some_perm_table)
|
||||
|
||||
|
||||
def test_num_rows(some_permutation: Permutation):
|
||||
assert some_permutation.num_rows == 950
|
||||
|
||||
|
||||
def test_num_columns(some_permutation: Permutation):
|
||||
assert some_permutation.num_columns == 2
|
||||
|
||||
|
||||
def test_column_names(some_permutation: Permutation):
|
||||
assert some_permutation.column_names == ["id", "value"]
|
||||
|
||||
|
||||
def test_shape(some_permutation: Permutation):
|
||||
assert some_permutation.shape == (950, 2)
|
||||
|
||||
|
||||
def test_schema(some_permutation: Permutation):
|
||||
assert some_permutation.schema == pa.schema(
|
||||
[("id", pa.int64()), ("value", pa.int64())]
|
||||
)
|
||||
|
||||
|
||||
def test_limit_offset(some_permutation: Permutation):
|
||||
assert some_permutation.with_take(100).num_rows == 100
|
||||
assert some_permutation.with_skip(100).num_rows == 850
|
||||
assert some_permutation.with_take(100).with_skip(100).num_rows == 100
|
||||
|
||||
with pytest.raises(Exception):
|
||||
some_permutation.with_take(1000000).num_rows
|
||||
with pytest.raises(Exception):
|
||||
some_permutation.with_skip(1000000).num_rows
|
||||
with pytest.raises(Exception):
|
||||
some_permutation.with_take(500).with_skip(500).num_rows
|
||||
with pytest.raises(Exception):
|
||||
some_permutation.with_skip(500).with_take(500).num_rows
|
||||
|
||||
|
||||
def test_remove_columns(some_permutation: Permutation):
|
||||
assert some_permutation.remove_columns(["value"]).schema == pa.schema(
|
||||
[("id", pa.int64())]
|
||||
)
|
||||
# Should not modify the original permutation
|
||||
assert some_permutation.schema.names == ["id", "value"]
|
||||
# Cannot remove all columns
|
||||
with pytest.raises(ValueError, match="Cannot remove all columns"):
|
||||
some_permutation.remove_columns(["id", "value"])
|
||||
|
||||
|
||||
def test_rename_column(some_permutation: Permutation):
|
||||
assert some_permutation.rename_column("value", "new_value").schema == pa.schema(
|
||||
[("id", pa.int64()), ("new_value", pa.int64())]
|
||||
)
|
||||
# Should not modify the original permutation
|
||||
assert some_permutation.schema.names == ["id", "value"]
|
||||
# Cannot rename to an existing column
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="a column with that name already exists",
|
||||
):
|
||||
some_permutation.rename_column("value", "id")
|
||||
# Cannot rename a non-existent column
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="does not exist",
|
||||
):
|
||||
some_permutation.rename_column("non_existent", "new_value")
|
||||
|
||||
|
||||
def test_rename_columns(some_permutation: Permutation):
|
||||
assert some_permutation.rename_columns({"value": "new_value"}).schema == pa.schema(
|
||||
[("id", pa.int64()), ("new_value", pa.int64())]
|
||||
)
|
||||
# Should not modify the original permutation
|
||||
assert some_permutation.schema.names == ["id", "value"]
|
||||
# Cannot rename to an existing column
|
||||
with pytest.raises(ValueError, match="a column with that name already exists"):
|
||||
some_permutation.rename_columns({"value": "id"})
|
||||
|
||||
|
||||
def test_select_columns(some_permutation: Permutation):
|
||||
assert some_permutation.select_columns(["id"]).schema == pa.schema(
|
||||
[("id", pa.int64())]
|
||||
)
|
||||
# Should not modify the original permutation
|
||||
assert some_permutation.schema.names == ["id", "value"]
|
||||
# Cannot select a non-existent column
|
||||
with pytest.raises(ValueError, match="does not exist"):
|
||||
some_permutation.select_columns(["non_existent"])
|
||||
# Empty selection is not allowed
|
||||
with pytest.raises(ValueError, match="select at least one column"):
|
||||
some_permutation.select_columns([])
|
||||
|
||||
|
||||
def test_iter_basic(some_permutation: Permutation):
|
||||
"""Test basic iteration with custom batch size."""
|
||||
batch_size = 100
|
||||
batches = list(some_permutation.iter(batch_size, skip_last_batch=False))
|
||||
|
||||
# Check that we got the expected number of batches
|
||||
expected_batches = (950 + batch_size - 1) // batch_size # ceiling division
|
||||
assert len(batches) == expected_batches
|
||||
|
||||
# Check that all batches are dicts (default python format)
|
||||
assert all(isinstance(batch, dict) for batch in batches)
|
||||
|
||||
# Check that batches have the correct structure
|
||||
for batch in batches:
|
||||
assert "id" in batch
|
||||
assert "value" in batch
|
||||
assert isinstance(batch["id"], list)
|
||||
assert isinstance(batch["value"], list)
|
||||
|
||||
# Check that all batches except the last have the correct size
|
||||
for batch in batches[:-1]:
|
||||
assert len(batch["id"]) == batch_size
|
||||
assert len(batch["value"]) == batch_size
|
||||
|
||||
# Last batch might be smaller
|
||||
assert len(batches[-1]["id"]) <= batch_size
|
||||
|
||||
|
||||
def test_iter_skip_last_batch(some_permutation: Permutation):
|
||||
"""Test iteration with skip_last_batch=True."""
|
||||
batch_size = 300
|
||||
batches_with_skip = list(some_permutation.iter(batch_size, skip_last_batch=True))
|
||||
batches_without_skip = list(
|
||||
some_permutation.iter(batch_size, skip_last_batch=False)
|
||||
)
|
||||
|
||||
# With skip_last_batch=True, we should have fewer batches if the last one is partial
|
||||
num_full_batches = 950 // batch_size
|
||||
assert len(batches_with_skip) == num_full_batches
|
||||
|
||||
# Without skip_last_batch, we should have one more batch if there's a remainder
|
||||
if 950 % batch_size != 0:
|
||||
assert len(batches_without_skip) == num_full_batches + 1
|
||||
# Last batch should be smaller
|
||||
assert len(batches_without_skip[-1]["id"]) == 950 % batch_size
|
||||
|
||||
# All batches with skip_last_batch should be full size
|
||||
for batch in batches_with_skip:
|
||||
assert len(batch["id"]) == batch_size
|
||||
|
||||
|
||||
def test_iter_different_batch_sizes(some_permutation: Permutation):
|
||||
"""Test iteration with different batch sizes."""
|
||||
|
||||
# Test with small batch size
|
||||
small_batches = list(some_permutation.iter(100, skip_last_batch=False))
|
||||
assert len(small_batches) == 10 # ceiling(950 / 100)
|
||||
|
||||
# Test with large batch size
|
||||
large_batches = list(some_permutation.iter(400, skip_last_batch=False))
|
||||
assert len(large_batches) == 3 # ceiling(950 / 400)
|
||||
|
||||
# Test with batch size equal to total rows
|
||||
single_batch = list(some_permutation.iter(950, skip_last_batch=False))
|
||||
assert len(single_batch) == 1
|
||||
assert len(single_batch[0]["id"]) == 950
|
||||
|
||||
# Test with batch size larger than total rows
|
||||
oversized_batch = list(some_permutation.iter(10000, skip_last_batch=False))
|
||||
assert len(oversized_batch) == 1
|
||||
assert len(oversized_batch[0]["id"]) == 950
|
||||
|
||||
|
||||
def test_dunder_iter(some_permutation: Permutation):
|
||||
"""Test the __iter__ method."""
|
||||
# __iter__ should use DEFAULT_BATCH_SIZE (100) and skip_last_batch=True
|
||||
batches = list(some_permutation)
|
||||
|
||||
# With DEFAULT_BATCH_SIZE=100 and skip_last_batch=True, we should get 9 batches
|
||||
assert len(batches) == 9 # ceiling(950 / 100)
|
||||
|
||||
# All batches should be full size
|
||||
for batch in batches:
|
||||
assert len(batch["id"]) == 100
|
||||
assert len(batch["value"]) == 100
|
||||
|
||||
some_permutation = some_permutation.with_batch_size(400)
|
||||
batches = list(some_permutation)
|
||||
assert len(batches) == 2 # floor(950 / 400) since skip_last_batch=True
|
||||
for batch in batches:
|
||||
assert len(batch["id"]) == 400
|
||||
assert len(batch["value"]) == 400
|
||||
|
||||
|
||||
def test_iter_with_different_formats(some_permutation: Permutation):
|
||||
"""Test iteration with different output formats."""
|
||||
batch_size = 100
|
||||
|
||||
# Test with arrow format
|
||||
arrow_perm = some_permutation.with_format("arrow")
|
||||
arrow_batches = list(arrow_perm.iter(batch_size, skip_last_batch=False))
|
||||
assert all(isinstance(batch, pa.RecordBatch) for batch in arrow_batches)
|
||||
|
||||
# Test with python format (default)
|
||||
python_perm = some_permutation.with_format("python")
|
||||
python_batches = list(python_perm.iter(batch_size, skip_last_batch=False))
|
||||
assert all(isinstance(batch, dict) for batch in python_batches)
|
||||
|
||||
# Test with pandas format
|
||||
pandas_perm = some_permutation.with_format("pandas")
|
||||
pandas_batches = list(pandas_perm.iter(batch_size, skip_last_batch=False))
|
||||
# Import pandas to check the type
|
||||
import pandas as pd
|
||||
|
||||
assert all(isinstance(batch, pd.DataFrame) for batch in pandas_batches)
|
||||
|
||||
|
||||
def test_iter_with_column_selection(some_permutation: Permutation):
|
||||
"""Test iteration after column selection."""
|
||||
# Select only the id column
|
||||
id_only = some_permutation.select_columns(["id"])
|
||||
batches = list(id_only.iter(100, skip_last_batch=False))
|
||||
|
||||
# Check that batches only contain the id column
|
||||
for batch in batches:
|
||||
assert "id" in batch
|
||||
assert "value" not in batch
|
||||
|
||||
|
||||
def test_iter_with_column_rename(some_permutation: Permutation):
|
||||
"""Test iteration after renaming columns."""
|
||||
renamed = some_permutation.rename_column("value", "data")
|
||||
batches = list(renamed.iter(100, skip_last_batch=False))
|
||||
|
||||
# Check that batches have the renamed column
|
||||
for batch in batches:
|
||||
assert "id" in batch
|
||||
assert "data" in batch
|
||||
assert "value" not in batch
|
||||
|
||||
|
||||
def test_iter_with_limit_offset(some_permutation: Permutation):
|
||||
"""Test iteration with limit and offset."""
|
||||
# Test with offset
|
||||
offset_perm = some_permutation.with_skip(100)
|
||||
offset_batches = list(offset_perm.iter(100, skip_last_batch=False))
|
||||
# Should have 850 rows (950 - 100)
|
||||
expected_batches = math.ceil(850 / 100)
|
||||
assert len(offset_batches) == expected_batches
|
||||
|
||||
# Test with limit
|
||||
limit_perm = some_permutation.with_take(500)
|
||||
limit_batches = list(limit_perm.iter(100, skip_last_batch=False))
|
||||
# Should have 5 batches (500 / 100)
|
||||
assert len(limit_batches) == 5
|
||||
|
||||
no_skip = some_permutation.iter(101, skip_last_batch=False)
|
||||
row_100 = next(no_skip)["id"][100]
|
||||
|
||||
# Test with both limit and offset
|
||||
limited_perm = some_permutation.with_skip(100).with_take(300)
|
||||
limited_batches = list(limited_perm.iter(100, skip_last_batch=False))
|
||||
# Should have 3 batches (300 / 100)
|
||||
assert len(limited_batches) == 3
|
||||
assert limited_batches[0]["id"][0] == row_100
|
||||
|
||||
|
||||
def test_iter_empty_permutation(mem_db):
|
||||
"""Test iteration over an empty permutation."""
|
||||
# Create a table and filter it to be empty
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"id": range(10), "value": range(10)})
|
||||
)
|
||||
permutation_tbl = permutation_builder(tbl).filter("value > 100").execute()
|
||||
with pytest.raises(ValueError, match="No rows found"):
|
||||
Permutation.from_tables(tbl, permutation_tbl)
|
||||
|
||||
|
||||
def test_iter_single_row(mem_db):
|
||||
"""Test iteration over a permutation with a single row."""
|
||||
tbl = mem_db.create_table("test_table", pa.table({"id": [42], "value": [100]}))
|
||||
permutation_tbl = permutation_builder(tbl).execute()
|
||||
perm = Permutation.from_tables(tbl, permutation_tbl)
|
||||
|
||||
# With skip_last_batch=False, should get one batch
|
||||
batches = list(perm.iter(10, skip_last_batch=False))
|
||||
assert len(batches) == 1
|
||||
assert len(batches[0]["id"]) == 1
|
||||
|
||||
# With skip_last_batch=True, should skip the single row (since it's < batch_size)
|
||||
batches_skip = list(perm.iter(10, skip_last_batch=True))
|
||||
assert len(batches_skip) == 0
|
||||
|
||||
|
||||
def test_identity_permutation(mem_db):
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"id": range(10), "value": range(10)})
|
||||
)
|
||||
permutation = Permutation.identity(tbl)
|
||||
|
||||
assert permutation.num_rows == 10
|
||||
assert permutation.num_columns == 2
|
||||
|
||||
batches = list(permutation.iter(10, skip_last_batch=False))
|
||||
assert len(batches) == 1
|
||||
assert len(batches[0]["id"]) == 10
|
||||
assert len(batches[0]["value"]) == 10
|
||||
|
||||
permutation = permutation.remove_columns(["value"])
|
||||
assert permutation.num_columns == 1
|
||||
assert permutation.schema == pa.schema([("id", pa.int64())])
|
||||
assert permutation.column_names == ["id"]
|
||||
assert permutation.shape == (10, 1)
|
||||
|
||||
|
||||
def test_transform_fn(mem_db):
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"id": range(10), "value": range(10)})
|
||||
)
|
||||
permutation = Permutation.identity(tbl)
|
||||
|
||||
np_result = list(permutation.with_format("numpy").iter(10, skip_last_batch=False))[
|
||||
0
|
||||
]
|
||||
assert np_result.shape == (10, 2)
|
||||
assert np_result.dtype == np.int64
|
||||
assert isinstance(np_result, np.ndarray)
|
||||
|
||||
pd_result = list(permutation.with_format("pandas").iter(10, skip_last_batch=False))[
|
||||
0
|
||||
]
|
||||
assert pd_result.shape == (10, 2)
|
||||
assert pd_result.dtypes.tolist() == [np.int64, np.int64]
|
||||
assert isinstance(pd_result, pd.DataFrame)
|
||||
|
||||
pl_result = list(permutation.with_format("polars").iter(10, skip_last_batch=False))[
|
||||
0
|
||||
]
|
||||
assert pl_result.shape == (10, 2)
|
||||
assert pl_result.dtypes == [pl.Int64, pl.Int64]
|
||||
assert isinstance(pl_result, pl.DataFrame)
|
||||
|
||||
py_result = list(permutation.with_format("python").iter(10, skip_last_batch=False))[
|
||||
0
|
||||
]
|
||||
assert len(py_result) == 2
|
||||
assert len(py_result["id"]) == 10
|
||||
assert len(py_result["value"]) == 10
|
||||
assert isinstance(py_result, dict)
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
torch_result = list(
|
||||
permutation.with_format("torch").iter(10, skip_last_batch=False)
|
||||
)[0]
|
||||
assert torch_result.shape == (2, 10)
|
||||
assert torch_result.dtype == torch.int64
|
||||
assert isinstance(torch_result, torch.Tensor)
|
||||
except ImportError:
|
||||
# Skip check if torch is not installed
|
||||
pass
|
||||
|
||||
arrow_result = list(
|
||||
permutation.with_format("arrow").iter(10, skip_last_batch=False)
|
||||
)[0]
|
||||
assert arrow_result.shape == (10, 2)
|
||||
assert arrow_result.schema == pa.schema([("id", pa.int64()), ("value", pa.int64())])
|
||||
assert isinstance(arrow_result, pa.RecordBatch)
|
||||
|
||||
|
||||
def test_custom_transform(mem_db):
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"id": range(10), "value": range(10)})
|
||||
)
|
||||
permutation = Permutation.identity(tbl)
|
||||
|
||||
def transform(batch: pa.RecordBatch) -> pa.RecordBatch:
|
||||
return batch.select(["id"])
|
||||
|
||||
transformed = permutation.with_transform(transform)
|
||||
batches = list(transformed.iter(10, skip_last_batch=False))
|
||||
assert len(batches) == 1
|
||||
batch = batches[0]
|
||||
|
||||
assert batch == pa.record_batch([range(10)], ["id"])
|
||||
|
||||
@@ -1298,6 +1298,79 @@ async def test_query_serialization_async(table_async: AsyncTable):
|
||||
)
|
||||
|
||||
|
||||
def test_query_schema(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
tbl = db.create_table(
|
||||
"test",
|
||||
pa.table(
|
||||
{
|
||||
"a": [1, 2, 3],
|
||||
"text": ["a", "b", "c"],
|
||||
"vec": pa.array(
|
||||
[[1, 2], [3, 4], [5, 6]], pa.list_(pa.float32(), list_size=2)
|
||||
),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
assert tbl.search(None).output_schema() == pa.schema(
|
||||
{
|
||||
"a": pa.int64(),
|
||||
"text": pa.string(),
|
||||
"vec": pa.list_(pa.float32(), list_size=2),
|
||||
}
|
||||
)
|
||||
assert tbl.search(None).select({"bl": "a * 2"}).output_schema() == pa.schema(
|
||||
{"bl": pa.int64()}
|
||||
)
|
||||
assert tbl.search([1, 2]).select(["a"]).output_schema() == pa.schema(
|
||||
{"a": pa.int64(), "_distance": pa.float32()}
|
||||
)
|
||||
assert tbl.search("blah").select(["a"]).output_schema() == pa.schema(
|
||||
{"a": pa.int64()}
|
||||
)
|
||||
assert tbl.take_offsets([0]).select(["text"]).output_schema() == pa.schema(
|
||||
{"text": pa.string()}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_schema_async(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
tbl = await db.create_table(
|
||||
"test",
|
||||
pa.table(
|
||||
{
|
||||
"a": [1, 2, 3],
|
||||
"text": ["a", "b", "c"],
|
||||
"vec": pa.array(
|
||||
[[1, 2], [3, 4], [5, 6]], pa.list_(pa.float32(), list_size=2)
|
||||
),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
assert await tbl.query().output_schema() == pa.schema(
|
||||
{
|
||||
"a": pa.int64(),
|
||||
"text": pa.string(),
|
||||
"vec": pa.list_(pa.float32(), list_size=2),
|
||||
}
|
||||
)
|
||||
assert await tbl.query().select({"bl": "a * 2"}).output_schema() == pa.schema(
|
||||
{"bl": pa.int64()}
|
||||
)
|
||||
assert await tbl.vector_search([1, 2]).select(["a"]).output_schema() == pa.schema(
|
||||
{"a": pa.int64(), "_distance": pa.float32()}
|
||||
)
|
||||
assert await (await tbl.search("blah")).select(["a"]).output_schema() == pa.schema(
|
||||
{"a": pa.int64()}
|
||||
)
|
||||
assert await tbl.take_offsets([0]).select(["text"]).output_schema() == pa.schema(
|
||||
{"text": pa.string()}
|
||||
)
|
||||
|
||||
|
||||
def test_query_timeout(tmp_path):
|
||||
# Use local directory instead of memory:// to add a bit of latency to
|
||||
# operations so a timeout of zero will trigger exceptions.
|
||||
|
||||
@@ -484,7 +484,7 @@ def test_jina_reranker(tmp_path, use_tantivy):
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_voyageai_reranker(tmp_path, use_tantivy):
|
||||
pytest.importorskip("voyageai")
|
||||
reranker = VoyageAIReranker(model_name="rerank-2")
|
||||
reranker = VoyageAIReranker(model_name="rerank-2.5")
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||
|
||||
|
||||
@@ -3,19 +3,11 @@
|
||||
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from lancedb.util import tbl_to_tensor
|
||||
|
||||
torch = pytest.importorskip("torch")
|
||||
|
||||
|
||||
def tbl_to_tensor(tbl):
|
||||
def to_tensor(col: pa.ChunkedArray):
|
||||
if col.num_chunks > 1:
|
||||
raise Exception("Single batch was too large to fit into a one-chunk table")
|
||||
return torch.from_dlpack(col.chunk(0))
|
||||
|
||||
return torch.stack([to_tensor(tbl.column(i)) for i in range(tbl.num_columns)])
|
||||
|
||||
|
||||
def test_table_dataloader(mem_db):
|
||||
table = mem_db.create_table("test_table", pa.table({"a": range(1000)}))
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
|
||||
@@ -6,7 +6,7 @@ use std::{collections::HashMap, sync::Arc, time::Duration};
|
||||
use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow};
|
||||
use lancedb::{
|
||||
connection::Connection as LanceConnection,
|
||||
database::{CreateTableMode, ReadConsistency},
|
||||
database::{CreateTableMode, Database, ReadConsistency},
|
||||
};
|
||||
use pyo3::{
|
||||
exceptions::{PyRuntimeError, PyValueError},
|
||||
@@ -42,6 +42,10 @@ impl Connection {
|
||||
_ => Err(PyValueError::new_err(format!("Invalid mode {}", mode))),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn database(&self) -> PyResult<Arc<dyn Database>> {
|
||||
Ok(self.get_inner()?.database().clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
|
||||
@@ -5,7 +5,7 @@ use arrow::RecordBatchStream;
|
||||
use connection::{connect, Connection};
|
||||
use env_logger::Env;
|
||||
use index::IndexConfig;
|
||||
use permutation::PyAsyncPermutationBuilder;
|
||||
use permutation::{PyAsyncPermutationBuilder, PyPermutationReader};
|
||||
use pyo3::{
|
||||
pymodule,
|
||||
types::{PyModule, PyModuleMethods},
|
||||
@@ -52,9 +52,11 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<DropColumnsResult>()?;
|
||||
m.add_class::<UpdateResult>()?;
|
||||
m.add_class::<PyAsyncPermutationBuilder>()?;
|
||||
m.add_class::<PyPermutationReader>()?;
|
||||
m.add_function(wrap_pyfunction!(connect, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(permutation::async_permutation_builder, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(query::fts_query_to_json, m)?)?;
|
||||
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -3,23 +3,29 @@
|
||||
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::{error::PythonErrorExt, table::Table};
|
||||
use lancedb::dataloader::{
|
||||
permutation::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
|
||||
split::{SplitSizes, SplitStrategy},
|
||||
use crate::{
|
||||
arrow::RecordBatchStream, connection::Connection, error::PythonErrorExt, table::Table,
|
||||
};
|
||||
use arrow::pyarrow::ToPyArrow;
|
||||
use lancedb::{
|
||||
dataloader::permutation::{
|
||||
builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
|
||||
reader::PermutationReader,
|
||||
split::{SplitSizes, SplitStrategy},
|
||||
},
|
||||
query::Select,
|
||||
};
|
||||
use pyo3::{
|
||||
exceptions::PyRuntimeError, pyclass, pymethods, types::PyAnyMethods, Bound, PyAny, PyRefMut,
|
||||
PyResult,
|
||||
exceptions::PyRuntimeError,
|
||||
pyclass, pymethods,
|
||||
types::{PyAnyMethods, PyDict, PyDictMethods, PyType},
|
||||
Bound, PyAny, PyRef, PyRefMut, PyResult, Python,
|
||||
};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
/// Create a permutation builder for the given table
|
||||
#[pyo3::pyfunction]
|
||||
pub fn async_permutation_builder(
|
||||
table: Bound<'_, PyAny>,
|
||||
dest_table_name: String,
|
||||
) -> PyResult<PyAsyncPermutationBuilder> {
|
||||
pub fn async_permutation_builder(table: Bound<'_, PyAny>) -> PyResult<PyAsyncPermutationBuilder> {
|
||||
let table = table.getattr("_inner")?.downcast_into::<Table>()?;
|
||||
let inner_table = table.borrow().inner_ref()?.clone();
|
||||
let inner_builder = LancePermutationBuilder::new(inner_table);
|
||||
@@ -27,14 +33,12 @@ pub fn async_permutation_builder(
|
||||
Ok(PyAsyncPermutationBuilder {
|
||||
state: Arc::new(Mutex::new(PyAsyncPermutationBuilderState {
|
||||
builder: Some(inner_builder),
|
||||
dest_table_name,
|
||||
})),
|
||||
})
|
||||
}
|
||||
|
||||
struct PyAsyncPermutationBuilderState {
|
||||
builder: Option<LancePermutationBuilder>,
|
||||
dest_table_name: String,
|
||||
}
|
||||
|
||||
#[pyclass(name = "AsyncPermutationBuilder")]
|
||||
@@ -61,13 +65,32 @@ impl PyAsyncPermutationBuilder {
|
||||
|
||||
#[pymethods]
|
||||
impl PyAsyncPermutationBuilder {
|
||||
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None, seed=None))]
|
||||
#[pyo3(signature = (database, table_name))]
|
||||
pub fn persist(
|
||||
slf: PyRefMut<'_, Self>,
|
||||
database: Bound<'_, PyAny>,
|
||||
table_name: String,
|
||||
) -> PyResult<Self> {
|
||||
let conn = if database.hasattr("_conn")? {
|
||||
database
|
||||
.getattr("_conn")?
|
||||
.getattr("_inner")?
|
||||
.downcast_into::<Connection>()?
|
||||
} else {
|
||||
database.getattr("_inner")?.downcast_into::<Connection>()?
|
||||
};
|
||||
let database = conn.borrow().database()?;
|
||||
slf.modify(|builder| builder.persist(database, table_name))
|
||||
}
|
||||
|
||||
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None, seed=None, split_names=None))]
|
||||
pub fn split_random(
|
||||
slf: PyRefMut<'_, Self>,
|
||||
ratios: Option<Vec<f64>>,
|
||||
counts: Option<Vec<u64>>,
|
||||
fixed: Option<u64>,
|
||||
seed: Option<u64>,
|
||||
split_names: Option<Vec<String>>,
|
||||
) -> PyResult<Self> {
|
||||
// Check that exactly one split type is provided
|
||||
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
|
||||
@@ -91,31 +114,38 @@ impl PyAsyncPermutationBuilder {
|
||||
unreachable!("One of the split arguments must be provided");
|
||||
};
|
||||
|
||||
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes }))
|
||||
slf.modify(|builder| {
|
||||
builder.with_split_strategy(SplitStrategy::Random { seed, sizes }, split_names)
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (columns, split_weights, *, discard_weight=0))]
|
||||
#[pyo3(signature = (columns, split_weights, *, discard_weight=0, split_names=None))]
|
||||
pub fn split_hash(
|
||||
slf: PyRefMut<'_, Self>,
|
||||
columns: Vec<String>,
|
||||
split_weights: Vec<u64>,
|
||||
discard_weight: u64,
|
||||
split_names: Option<Vec<String>>,
|
||||
) -> PyResult<Self> {
|
||||
slf.modify(|builder| {
|
||||
builder.with_split_strategy(SplitStrategy::Hash {
|
||||
columns,
|
||||
split_weights,
|
||||
discard_weight,
|
||||
})
|
||||
builder.with_split_strategy(
|
||||
SplitStrategy::Hash {
|
||||
columns,
|
||||
split_weights,
|
||||
discard_weight,
|
||||
},
|
||||
split_names,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None))]
|
||||
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None, split_names=None))]
|
||||
pub fn split_sequential(
|
||||
slf: PyRefMut<'_, Self>,
|
||||
ratios: Option<Vec<f64>>,
|
||||
counts: Option<Vec<u64>>,
|
||||
fixed: Option<u64>,
|
||||
split_names: Option<Vec<String>>,
|
||||
) -> PyResult<Self> {
|
||||
// Check that exactly one split type is provided
|
||||
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
|
||||
@@ -139,11 +169,19 @@ impl PyAsyncPermutationBuilder {
|
||||
unreachable!("One of the split arguments must be provided");
|
||||
};
|
||||
|
||||
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes }))
|
||||
slf.modify(|builder| {
|
||||
builder.with_split_strategy(SplitStrategy::Sequential { sizes }, split_names)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn split_calculated(slf: PyRefMut<'_, Self>, calculation: String) -> PyResult<Self> {
|
||||
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Calculated { calculation }))
|
||||
pub fn split_calculated(
|
||||
slf: PyRefMut<'_, Self>,
|
||||
calculation: String,
|
||||
split_names: Option<Vec<String>>,
|
||||
) -> PyResult<Self> {
|
||||
slf.modify(|builder| {
|
||||
builder.with_split_strategy(SplitStrategy::Calculated { calculation }, split_names)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn shuffle(
|
||||
@@ -167,11 +205,127 @@ impl PyAsyncPermutationBuilder {
|
||||
.take()
|
||||
.ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?;
|
||||
|
||||
let dest_table_name = std::mem::take(&mut state.dest_table_name);
|
||||
|
||||
future_into_py(slf.py(), async move {
|
||||
let table = builder.build(&dest_table_name).await.infer_error()?;
|
||||
let table = builder.build().await.infer_error()?;
|
||||
Ok(Table::new(table))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(name = "PermutationReader")]
|
||||
pub struct PyPermutationReader {
|
||||
reader: Arc<PermutationReader>,
|
||||
}
|
||||
|
||||
impl PyPermutationReader {
|
||||
fn from_reader(reader: PermutationReader) -> Self {
|
||||
Self {
|
||||
reader: Arc::new(reader),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_selection(selection: Option<Bound<'_, PyAny>>) -> PyResult<Select> {
|
||||
let Some(selection) = selection else {
|
||||
return Ok(Select::All);
|
||||
};
|
||||
let selection = selection.downcast_into::<PyDict>()?;
|
||||
let selection = selection
|
||||
.iter()
|
||||
.map(|(key, value)| {
|
||||
let key = key.extract::<String>()?;
|
||||
let value = value.extract::<String>()?;
|
||||
Ok((key, value))
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()?;
|
||||
Ok(Select::dynamic(&selection))
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyPermutationReader {
|
||||
#[classmethod]
|
||||
pub fn from_tables<'py>(
|
||||
cls: &Bound<'py, PyType>,
|
||||
base_table: Bound<'py, PyAny>,
|
||||
permutation_table: Option<Bound<'py, PyAny>>,
|
||||
split: u64,
|
||||
) -> PyResult<Bound<'py, PyAny>> {
|
||||
let base_table = base_table.getattr("_inner")?.downcast_into::<Table>()?;
|
||||
let permutation_table = permutation_table
|
||||
.map(|p| PyResult::Ok(p.getattr("_inner")?.downcast_into::<Table>()?))
|
||||
.transpose()?;
|
||||
|
||||
let base_table = base_table.borrow().inner_ref()?.base_table().clone();
|
||||
let permutation_table = permutation_table
|
||||
.map(|p| PyResult::Ok(p.borrow().inner_ref()?.base_table().clone()))
|
||||
.transpose()?;
|
||||
|
||||
future_into_py(cls.py(), async move {
|
||||
let reader = if let Some(permutation_table) = permutation_table {
|
||||
PermutationReader::try_from_tables(base_table, permutation_table, split)
|
||||
.await
|
||||
.infer_error()?
|
||||
} else {
|
||||
PermutationReader::identity(base_table).await
|
||||
};
|
||||
Ok(Self::from_reader(reader))
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (selection=None))]
|
||||
pub fn output_schema<'py>(
|
||||
slf: PyRef<'py, Self>,
|
||||
selection: Option<Bound<'py, PyAny>>,
|
||||
) -> PyResult<Bound<'py, PyAny>> {
|
||||
let selection = Self::parse_selection(selection)?;
|
||||
let reader = slf.reader.clone();
|
||||
future_into_py(slf.py(), async move {
|
||||
let schema = reader.output_schema(selection).await.infer_error()?;
|
||||
Python::with_gil(|py| schema.to_pyarrow(py))
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = ())]
|
||||
pub fn count_rows<'py>(slf: PyRef<'py, Self>) -> u64 {
|
||||
slf.reader.count_rows()
|
||||
}
|
||||
|
||||
#[pyo3(signature = (offset))]
|
||||
pub fn with_offset<'py>(slf: PyRef<'py, Self>, offset: u64) -> PyResult<Bound<'py, PyAny>> {
|
||||
let reader = slf.reader.as_ref().clone();
|
||||
future_into_py(slf.py(), async move {
|
||||
let reader = reader.with_offset(offset).await.infer_error()?;
|
||||
Ok(Self::from_reader(reader))
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (limit))]
|
||||
pub fn with_limit<'py>(slf: PyRef<'py, Self>, limit: u64) -> PyResult<Bound<'py, PyAny>> {
|
||||
let reader = slf.reader.as_ref().clone();
|
||||
future_into_py(slf.py(), async move {
|
||||
let reader = reader.with_limit(limit).await.infer_error()?;
|
||||
Ok(Self::from_reader(reader))
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (selection=None, *, batch_size=None))]
|
||||
pub fn read<'py>(
|
||||
slf: PyRef<'py, Self>,
|
||||
selection: Option<Bound<'py, PyAny>>,
|
||||
batch_size: Option<u32>,
|
||||
) -> PyResult<Bound<'py, PyAny>> {
|
||||
let selection = Self::parse_selection(selection)?;
|
||||
let reader = slf.reader.clone();
|
||||
let batch_size = batch_size.unwrap_or(1024);
|
||||
future_into_py(slf.py(), async move {
|
||||
use lancedb::query::QueryExecutionOptions;
|
||||
let mut execution_options = QueryExecutionOptions::default();
|
||||
execution_options.max_batch_length = batch_size;
|
||||
let stream = reader
|
||||
.read(selection, execution_options)
|
||||
.await
|
||||
.infer_error()?;
|
||||
Ok(RecordBatchStream::new(stream))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ use arrow::array::Array;
|
||||
use arrow::array::ArrayData;
|
||||
use arrow::pyarrow::FromPyArrow;
|
||||
use arrow::pyarrow::IntoPyArrow;
|
||||
use arrow::pyarrow::ToPyArrow;
|
||||
use lancedb::index::scalar::{
|
||||
BooleanQuery, BoostQuery, FtsQuery, FullTextSearchQuery, MatchQuery, MultiMatchQuery, Occur,
|
||||
Operator, PhraseQuery,
|
||||
@@ -22,6 +23,7 @@ use lancedb::query::{
|
||||
};
|
||||
use lancedb::table::AnyQuery;
|
||||
use pyo3::prelude::{PyAnyMethods, PyDictMethods};
|
||||
use pyo3::pyfunction;
|
||||
use pyo3::pymethods;
|
||||
use pyo3::types::PyList;
|
||||
use pyo3::types::{PyDict, PyString};
|
||||
@@ -30,6 +32,7 @@ use pyo3::IntoPyObject;
|
||||
use pyo3::PyAny;
|
||||
use pyo3::PyRef;
|
||||
use pyo3::PyResult;
|
||||
use pyo3::Python;
|
||||
use pyo3::{exceptions::PyRuntimeError, FromPyObject};
|
||||
use pyo3::{
|
||||
exceptions::{PyNotImplementedError, PyValueError},
|
||||
@@ -445,6 +448,15 @@ impl Query {
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = ())]
|
||||
pub fn output_schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let schema = inner.output_schema().await.infer_error()?;
|
||||
Python::with_gil(|py| schema.to_pyarrow(py))
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (max_batch_length=None, timeout=None))]
|
||||
pub fn execute(
|
||||
self_: PyRef<'_, Self>,
|
||||
@@ -515,6 +527,15 @@ impl TakeQuery {
|
||||
self.inner = self.inner.clone().with_row_id();
|
||||
}
|
||||
|
||||
#[pyo3(signature = ())]
|
||||
pub fn output_schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let schema = inner.output_schema().await.infer_error()?;
|
||||
Python::with_gil(|py| schema.to_pyarrow(py))
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (max_batch_length=None, timeout=None))]
|
||||
pub fn execute(
|
||||
self_: PyRef<'_, Self>,
|
||||
@@ -601,6 +622,15 @@ impl FTSQuery {
|
||||
self.inner = self.inner.clone().postfilter();
|
||||
}
|
||||
|
||||
#[pyo3(signature = ())]
|
||||
pub fn output_schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let schema = inner.output_schema().await.infer_error()?;
|
||||
Python::with_gil(|py| schema.to_pyarrow(py))
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (max_batch_length=None, timeout=None))]
|
||||
pub fn execute(
|
||||
self_: PyRef<'_, Self>,
|
||||
@@ -771,6 +801,15 @@ impl VectorQuery {
|
||||
self.inner = self.inner.clone().bypass_vector_index()
|
||||
}
|
||||
|
||||
#[pyo3(signature = ())]
|
||||
pub fn output_schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let schema = inner.output_schema().await.infer_error()?;
|
||||
Python::with_gil(|py| schema.to_pyarrow(py))
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (max_batch_length=None, timeout=None))]
|
||||
pub fn execute(
|
||||
self_: PyRef<'_, Self>,
|
||||
@@ -944,3 +983,15 @@ impl HybridQuery {
|
||||
req
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a Python FTS query to JSON string
|
||||
#[pyfunction]
|
||||
pub fn fts_query_to_json(query_obj: &Bound<'_, PyAny>) -> PyResult<String> {
|
||||
let wrapped: PyLanceDB<FtsQuery> = query_obj.extract()?;
|
||||
lancedb::table::datafusion::udtf::fts::to_json(&wrapped.0).map_err(|e| {
|
||||
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
|
||||
"Failed to serialize FTS query to JSON: {}",
|
||||
e
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb"
|
||||
version = "0.22.2"
|
||||
version = "0.22.3-beta.5"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
@@ -16,6 +16,7 @@ arrow = { workspace = true }
|
||||
arrow-array = { workspace = true }
|
||||
arrow-data = { workspace = true }
|
||||
arrow-schema = { workspace = true }
|
||||
arrow-select = { workspace = true }
|
||||
arrow-ord = { workspace = true }
|
||||
arrow-cast = { workspace = true }
|
||||
arrow-ipc.workspace = true
|
||||
@@ -41,7 +42,9 @@ lance-table = { workspace = true }
|
||||
lance-linalg = { workspace = true }
|
||||
lance-testing = { workspace = true }
|
||||
lance-encoding = { workspace = true }
|
||||
lance-arrow = { workspace = true }
|
||||
lance-namespace = { workspace = true }
|
||||
lance-namespace-impls = { workspace = true }
|
||||
moka = { workspace = true }
|
||||
pin-project = { workspace = true }
|
||||
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
||||
@@ -83,10 +86,6 @@ candle-nn = { version = "0.9.1", optional = true }
|
||||
tokenizers = { version = "0.19.1", optional = true }
|
||||
semver = { workspace = true }
|
||||
|
||||
# For a workaround, see workspace Cargo.toml
|
||||
crunchy.workspace = true
|
||||
bytemuck_derive.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = "1"
|
||||
tempfile = "3.5.0"
|
||||
|
||||
@@ -1182,13 +1182,13 @@ mod tests {
|
||||
use crate::database::listing::{ListingDatabaseOptions, NewTableConfig};
|
||||
use crate::query::QueryBase;
|
||||
use crate::query::{ExecutableQuery, QueryExecutionOptions};
|
||||
use crate::test_connection::test_utils::new_test_connection;
|
||||
use crate::test_utils::connection::new_test_connection;
|
||||
use arrow::compute::concat_batches;
|
||||
use arrow_array::RecordBatchReader;
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
|
||||
use futures::{stream, TryStreamExt};
|
||||
use lance::error::{ArrowResult, DataFusionResult};
|
||||
use lance_core::error::{ArrowResult, DataFusionResult};
|
||||
use lance_testing::datagen::{BatchGenerator, IncrementingInt32};
|
||||
use tempfile::tempdir;
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ use arrow_array::{
|
||||
use arrow_cast::{can_cast_types, cast};
|
||||
use arrow_schema::{ArrowError, DataType, Field, Schema};
|
||||
use half::f16;
|
||||
use lance::arrow::{DataTypeExt, FixedSizeListArrayExt};
|
||||
use lance_arrow::{DataTypeExt, FixedSizeListArrayExt};
|
||||
use log::warn;
|
||||
use num_traits::cast::AsPrimitive;
|
||||
|
||||
@@ -189,7 +189,7 @@ mod tests {
|
||||
};
|
||||
use arrow_schema::Field;
|
||||
use half::f16;
|
||||
use lance::arrow::FixedSizeListArrayExt;
|
||||
use lance_arrow::FixedSizeListArrayExt;
|
||||
|
||||
#[test]
|
||||
fn test_coerce_list_to_fixed_size_list() {
|
||||
|
||||
@@ -455,6 +455,7 @@ impl ListingDatabase {
|
||||
// `remove_dir_all` may be used to remove something not be a dataset
|
||||
lance::Error::NotFound { .. } => Error::TableNotFound {
|
||||
name: name.to_owned(),
|
||||
source: Box::new(err),
|
||||
},
|
||||
_ => Error::from(err),
|
||||
})?;
|
||||
|
||||
@@ -8,13 +8,13 @@ use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use lance_namespace::{
|
||||
connect as connect_namespace,
|
||||
models::{
|
||||
CreateEmptyTableRequest, CreateNamespaceRequest, DescribeTableRequest,
|
||||
DropNamespaceRequest, DropTableRequest, ListNamespacesRequest, ListTablesRequest,
|
||||
},
|
||||
LanceNamespace,
|
||||
};
|
||||
use lance_namespace_impls::ConnectBuilder;
|
||||
|
||||
use crate::database::listing::ListingDatabase;
|
||||
use crate::error::{Error, Result};
|
||||
@@ -48,11 +48,16 @@ impl LanceNamespaceDatabase {
|
||||
read_consistency_interval: Option<std::time::Duration>,
|
||||
session: Option<Arc<lance::session::Session>>,
|
||||
) -> Result<Self> {
|
||||
let namespace = connect_namespace(ns_impl, ns_properties.clone())
|
||||
.await
|
||||
.map_err(|e| Error::InvalidInput {
|
||||
message: format!("Failed to connect to namespace: {:?}", e),
|
||||
})?;
|
||||
let mut builder = ConnectBuilder::new(ns_impl);
|
||||
for (key, value) in ns_properties.clone() {
|
||||
builder = builder.property(key, value);
|
||||
}
|
||||
if let Some(ref sess) = session {
|
||||
builder = builder.session(sess.clone());
|
||||
}
|
||||
let namespace = builder.connect().await.map_err(|e| Error::InvalidInput {
|
||||
message: format!("Failed to connect to namespace: {:?}", e),
|
||||
})?;
|
||||
|
||||
Ok(Self {
|
||||
namespace,
|
||||
|
||||
@@ -2,6 +2,3 @@
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
pub mod permutation;
|
||||
pub mod shuffle;
|
||||
pub mod split;
|
||||
pub mod util;
|
||||
|
||||
@@ -7,288 +7,12 @@
|
||||
//! The permutation table only stores the split ids and row ids. It is not a materialized copy of
|
||||
//! the underlying data and can be very lightweight.
|
||||
//!
|
||||
//! Building a permutation table should be fairly quick and memory efficient, even for billions or
|
||||
//! trillions of rows.
|
||||
//! Building a permutation table should be fairly quick (it is an O(N) operation where N is
|
||||
//! the number of rows in the base table) and memory efficient, even for billions or trillions
|
||||
//! of rows.
|
||||
|
||||
use datafusion::prelude::{SessionConfig, SessionContext};
|
||||
use datafusion_execution::{disk_manager::DiskManagerBuilder, runtime_env::RuntimeEnvBuilder};
|
||||
use datafusion_expr::col;
|
||||
use futures::TryStreamExt;
|
||||
use lance_datafusion::exec::SessionContextExt;
|
||||
|
||||
use crate::{
|
||||
arrow::{SendableRecordBatchStream, SendableRecordBatchStreamExt, SimpleRecordBatchStream},
|
||||
dataloader::{
|
||||
shuffle::{Shuffler, ShufflerConfig},
|
||||
split::{SplitStrategy, Splitter, SPLIT_ID_COLUMN},
|
||||
util::{rename_column, TemporaryDirectory},
|
||||
},
|
||||
query::{ExecutableQuery, QueryBase},
|
||||
Connection, Error, Result, Table,
|
||||
};
|
||||
|
||||
/// Configuration for creating a permutation table
|
||||
#[derive(Debug, Default)]
|
||||
pub struct PermutationConfig {
|
||||
/// Splitting configuration
|
||||
pub split_strategy: SplitStrategy,
|
||||
/// Shuffle strategy
|
||||
pub shuffle_strategy: ShuffleStrategy,
|
||||
/// Optional filter to apply to the base table
|
||||
pub filter: Option<String>,
|
||||
/// Directory to use for temporary files
|
||||
pub temp_dir: TemporaryDirectory,
|
||||
}
|
||||
|
||||
/// Strategy for shuffling the data.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ShuffleStrategy {
|
||||
/// The data is randomly shuffled
|
||||
///
|
||||
/// A seed can be provided to make the shuffle deterministic.
|
||||
///
|
||||
/// If a clump size is provided, then data will be shuffled in small blocks of contiguous rows.
|
||||
/// This decreases the overall randomization but can improve I/O performance when reading from
|
||||
/// cloud storage.
|
||||
///
|
||||
/// For example, a clump size of 16 will means we will shuffle blocks of 16 contiguous rows. This
|
||||
/// will mean 16x fewer IOPS but these 16 rows will always be close together and this can influence
|
||||
/// the performance of the model. Note: shuffling within clumps can still be done at read time but
|
||||
/// this will only provide a local shuffle and not a global shuffle.
|
||||
Random {
|
||||
seed: Option<u64>,
|
||||
clump_size: Option<u64>,
|
||||
},
|
||||
/// The data is not shuffled
|
||||
///
|
||||
/// This is useful for debugging and testing.
|
||||
None,
|
||||
}
|
||||
|
||||
impl Default for ShuffleStrategy {
|
||||
fn default() -> Self {
|
||||
Self::None
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for creating a permutation table.
|
||||
///
|
||||
/// A permutation table is a table that stores split assignments and a shuffled order of rows. This
|
||||
/// can be used to create a
|
||||
pub struct PermutationBuilder {
|
||||
config: PermutationConfig,
|
||||
base_table: Table,
|
||||
}
|
||||
|
||||
impl PermutationBuilder {
|
||||
pub fn new(base_table: Table) -> Self {
|
||||
Self {
|
||||
config: PermutationConfig::default(),
|
||||
base_table,
|
||||
}
|
||||
}
|
||||
|
||||
/// Configures the strategy for assigning rows to splits.
|
||||
///
|
||||
/// For example, it is common to create a test/train split of the data. Splits can also be used
|
||||
/// to limit the number of rows. For example, to only use 10% of the data in a permutation you can
|
||||
/// create a single split with 10% of the data.
|
||||
///
|
||||
/// Splits are _not_ required for parallel processing. A single split can be loaded in parallel across
|
||||
/// multiple processes and multiple nodes.
|
||||
///
|
||||
/// The default is a single split that contains all rows.
|
||||
pub fn with_split_strategy(mut self, split_strategy: SplitStrategy) -> Self {
|
||||
self.config.split_strategy = split_strategy;
|
||||
self
|
||||
}
|
||||
|
||||
/// Configures the strategy for shuffling the data.
|
||||
///
|
||||
/// The default is to shuffle the data randomly at row-level granularity (no shard size) and
|
||||
/// with a random seed.
|
||||
pub fn with_shuffle_strategy(mut self, shuffle_strategy: ShuffleStrategy) -> Self {
|
||||
self.config.shuffle_strategy = shuffle_strategy;
|
||||
self
|
||||
}
|
||||
|
||||
/// Configures a filter to apply to the base table.
|
||||
///
|
||||
/// Only rows matching the filter will be included in the permutation.
|
||||
pub fn with_filter(mut self, filter: String) -> Self {
|
||||
self.config.filter = Some(filter);
|
||||
self
|
||||
}
|
||||
|
||||
/// Configures the directory to use for temporary files.
|
||||
///
|
||||
/// The default is to use the operating system's default temporary directory.
|
||||
pub fn with_temp_dir(mut self, temp_dir: TemporaryDirectory) -> Self {
|
||||
self.config.temp_dir = temp_dir;
|
||||
self
|
||||
}
|
||||
|
||||
async fn sort_by_split_id(
|
||||
&self,
|
||||
data: SendableRecordBatchStream,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let ctx = SessionContext::new_with_config_rt(
|
||||
SessionConfig::default(),
|
||||
RuntimeEnvBuilder::new()
|
||||
.with_memory_limit(100 * 1024 * 1024, 1.0)
|
||||
.with_disk_manager_builder(
|
||||
DiskManagerBuilder::default()
|
||||
.with_mode(self.config.temp_dir.to_disk_manager_mode()),
|
||||
)
|
||||
.build_arc()
|
||||
.unwrap(),
|
||||
);
|
||||
let df = ctx
|
||||
.read_one_shot(data.into_df_stream())
|
||||
.map_err(|e| Error::Other {
|
||||
message: format!("Failed to setup sort by split id: {}", e),
|
||||
source: Some(e.into()),
|
||||
})?;
|
||||
let df_stream = df
|
||||
.sort_by(vec![col(SPLIT_ID_COLUMN)])
|
||||
.map_err(|e| Error::Other {
|
||||
message: format!("Failed to plan sort by split id: {}", e),
|
||||
source: Some(e.into()),
|
||||
})?
|
||||
.execute_stream()
|
||||
.await
|
||||
.map_err(|e| Error::Other {
|
||||
message: format!("Failed to sort by split id: {}", e),
|
||||
source: Some(e.into()),
|
||||
})?;
|
||||
|
||||
let schema = df_stream.schema();
|
||||
let stream = df_stream.map_err(|e| Error::Other {
|
||||
message: format!("Failed to execute sort by split id: {}", e),
|
||||
source: Some(e.into()),
|
||||
});
|
||||
Ok(Box::pin(SimpleRecordBatchStream { schema, stream }))
|
||||
}
|
||||
|
||||
/// Builds the permutation table and stores it in the given database.
|
||||
pub async fn build(self, dest_table_name: &str) -> Result<Table> {
|
||||
// First pass, apply filter and load row ids
|
||||
let mut rows = self.base_table.query().with_row_id();
|
||||
|
||||
if let Some(filter) = &self.config.filter {
|
||||
rows = rows.only_if(filter);
|
||||
}
|
||||
|
||||
let splitter = Splitter::new(
|
||||
self.config.temp_dir.clone(),
|
||||
self.config.split_strategy.clone(),
|
||||
);
|
||||
|
||||
let mut needs_sort = !splitter.orders_by_split_id();
|
||||
|
||||
// Might need to load additional columns to calculate splits (e.g. hash columns or calculated
|
||||
// split id)
|
||||
rows = splitter.project(rows);
|
||||
|
||||
let num_rows = self
|
||||
.base_table
|
||||
.count_rows(self.config.filter.clone())
|
||||
.await? as u64;
|
||||
|
||||
// Apply splits
|
||||
let rows = rows.execute().await?;
|
||||
let split_data = splitter.apply(rows, num_rows).await?;
|
||||
|
||||
// Shuffle data if requested
|
||||
let shuffled = match self.config.shuffle_strategy {
|
||||
ShuffleStrategy::None => split_data,
|
||||
ShuffleStrategy::Random { seed, clump_size } => {
|
||||
let shuffler = Shuffler::new(ShufflerConfig {
|
||||
seed,
|
||||
clump_size,
|
||||
temp_dir: self.config.temp_dir.clone(),
|
||||
max_rows_per_file: 10 * 1024 * 1024,
|
||||
});
|
||||
shuffler.shuffle(split_data, num_rows).await?
|
||||
}
|
||||
};
|
||||
|
||||
// We want the final permutation to be sorted by the split id. If we shuffled or if
|
||||
// the split was not assigned sequentially then we need to sort the data.
|
||||
needs_sort |= !matches!(self.config.shuffle_strategy, ShuffleStrategy::None);
|
||||
|
||||
let sorted = if needs_sort {
|
||||
self.sort_by_split_id(shuffled).await?
|
||||
} else {
|
||||
shuffled
|
||||
};
|
||||
|
||||
// Rename _rowid to row_id
|
||||
let renamed = rename_column(sorted, "_rowid", "row_id")?;
|
||||
|
||||
// Create permutation table
|
||||
let conn = Connection::new(
|
||||
self.base_table.database().clone(),
|
||||
self.base_table.embedding_registry().clone(),
|
||||
);
|
||||
conn.create_table_streaming(dest_table_name, renamed)
|
||||
.execute()
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use arrow::datatypes::Int32Type;
|
||||
use lance_datagen::{BatchCount, RowCount};
|
||||
|
||||
use crate::{arrow::LanceDbDatagenExt, connect, dataloader::split::SplitSizes};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_permutation_builder() {
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
|
||||
let db = connect(temp_dir.path().to_str().unwrap())
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let initial_data = lance_datagen::gen_batch()
|
||||
.col("some_value", lance_datagen::array::step::<Int32Type>())
|
||||
.into_ldb_stream(RowCount::from(100), BatchCount::from(10));
|
||||
let data_table = db
|
||||
.create_table_streaming("mytbl", initial_data)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let permutation_table = PermutationBuilder::new(data_table)
|
||||
.with_filter("some_value > 57".to_string())
|
||||
.with_split_strategy(SplitStrategy::Random {
|
||||
seed: Some(42),
|
||||
sizes: SplitSizes::Percentages(vec![0.05, 0.30]),
|
||||
})
|
||||
.build("permutation")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Potentially brittle seed-dependent values below
|
||||
assert_eq!(permutation_table.count_rows(None).await.unwrap(), 330);
|
||||
assert_eq!(
|
||||
permutation_table
|
||||
.count_rows(Some("split_id = 0".to_string()))
|
||||
.await
|
||||
.unwrap(),
|
||||
47
|
||||
);
|
||||
assert_eq!(
|
||||
permutation_table
|
||||
.count_rows(Some("split_id = 1".to_string()))
|
||||
.await
|
||||
.unwrap(),
|
||||
283
|
||||
);
|
||||
}
|
||||
}
|
||||
pub mod builder;
|
||||
pub mod reader;
|
||||
pub mod shuffle;
|
||||
pub mod split;
|
||||
pub mod util;
|
||||
|
||||
374
rust/lancedb/src/dataloader/permutation/builder.rs
Normal file
374
rust/lancedb/src/dataloader/permutation/builder.rs
Normal file
@@ -0,0 +1,374 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use datafusion::prelude::{SessionConfig, SessionContext};
|
||||
use datafusion_execution::{disk_manager::DiskManagerBuilder, runtime_env::RuntimeEnvBuilder};
|
||||
use datafusion_expr::col;
|
||||
use futures::TryStreamExt;
|
||||
use lance_core::ROW_ID;
|
||||
use lance_datafusion::exec::SessionContextExt;
|
||||
|
||||
use crate::{
|
||||
arrow::{SendableRecordBatchStream, SendableRecordBatchStreamExt, SimpleRecordBatchStream},
|
||||
connect,
|
||||
database::{CreateTableData, CreateTableRequest, Database},
|
||||
dataloader::permutation::{
|
||||
shuffle::{Shuffler, ShufflerConfig},
|
||||
split::{SplitStrategy, Splitter, SPLIT_ID_COLUMN},
|
||||
util::{rename_column, TemporaryDirectory},
|
||||
},
|
||||
query::{ExecutableQuery, QueryBase},
|
||||
Error, Result, Table,
|
||||
};
|
||||
|
||||
pub const SRC_ROW_ID_COL: &str = "row_id";
|
||||
|
||||
pub const SPLIT_NAMES_CONFIG_KEY: &str = "split_names";
|
||||
|
||||
/// Where to store the permutation table
|
||||
#[derive(Debug, Clone, Default)]
|
||||
enum PermutationDestination {
|
||||
/// The permutation table is a temporary table in memory
|
||||
#[default]
|
||||
Temporary,
|
||||
/// The permutation table is a permanent table in a database
|
||||
Permanent(Arc<dyn Database>, String),
|
||||
}
|
||||
|
||||
/// Configuration for creating a permutation table
|
||||
#[derive(Debug, Default)]
|
||||
pub struct PermutationConfig {
|
||||
/// Splitting configuration
|
||||
split_strategy: SplitStrategy,
|
||||
/// Optional names for the splits
|
||||
split_names: Option<Vec<String>>,
|
||||
/// Shuffle strategy
|
||||
shuffle_strategy: ShuffleStrategy,
|
||||
/// Optional filter to apply to the base table
|
||||
filter: Option<String>,
|
||||
/// Directory to use for temporary files
|
||||
temp_dir: TemporaryDirectory,
|
||||
/// Destination
|
||||
destination: PermutationDestination,
|
||||
}
|
||||
|
||||
/// Strategy for shuffling the data.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ShuffleStrategy {
|
||||
/// The data is randomly shuffled
|
||||
///
|
||||
/// A seed can be provided to make the shuffle deterministic.
|
||||
///
|
||||
/// If a clump size is provided, then data will be shuffled in small blocks of contiguous rows.
|
||||
/// This decreases the overall randomization but can improve I/O performance when reading from
|
||||
/// cloud storage.
|
||||
///
|
||||
/// For example, a clump size of 16 will means we will shuffle blocks of 16 contiguous rows. This
|
||||
/// will mean 16x fewer IOPS but these 16 rows will always be close together and this can influence
|
||||
/// the performance of the model. Note: shuffling within clumps can still be done at read time but
|
||||
/// this will only provide a local shuffle and not a global shuffle.
|
||||
Random {
|
||||
seed: Option<u64>,
|
||||
clump_size: Option<u64>,
|
||||
},
|
||||
/// The data is not shuffled
|
||||
///
|
||||
/// This is useful for debugging and testing.
|
||||
None,
|
||||
}
|
||||
|
||||
impl Default for ShuffleStrategy {
|
||||
fn default() -> Self {
|
||||
Self::None
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for creating a permutation table.
|
||||
///
|
||||
/// A permutation table is a table that stores split assignments and a shuffled order of rows. This
|
||||
/// can be used to create a permutation reader that reads rows in the order defined by the permutation.
|
||||
///
|
||||
/// The permutation table is not a materialized copy of the underlying data and can be very lightweight.
|
||||
/// It is not a view of the underlying data and is not a copy of the data. It is a separate table that
|
||||
/// stores just row id and split id.
|
||||
pub struct PermutationBuilder {
|
||||
config: PermutationConfig,
|
||||
base_table: Table,
|
||||
}
|
||||
|
||||
impl PermutationBuilder {
|
||||
pub fn new(base_table: Table) -> Self {
|
||||
Self {
|
||||
config: PermutationConfig::default(),
|
||||
base_table,
|
||||
}
|
||||
}
|
||||
|
||||
/// Configures the strategy for assigning rows to splits.
|
||||
///
|
||||
/// For example, it is common to create a test/train split of the data. Splits can also be used
|
||||
/// to limit the number of rows. For example, to only use 10% of the data in a permutation you can
|
||||
/// create a single split with 10% of the data.
|
||||
///
|
||||
/// Splits are _not_ required for parallel processing. A single split can be loaded in parallel across
|
||||
/// multiple processes and multiple nodes.
|
||||
///
|
||||
/// The default is a single split that contains all rows.
|
||||
///
|
||||
/// An optional list of names can be provided for the splits. This is for convenience and the names
|
||||
/// will be stored in the permutation table's config metadata.
|
||||
pub fn with_split_strategy(
|
||||
mut self,
|
||||
split_strategy: SplitStrategy,
|
||||
split_names: Option<Vec<String>>,
|
||||
) -> Self {
|
||||
self.config.split_strategy = split_strategy;
|
||||
self.config.split_names = split_names;
|
||||
self
|
||||
}
|
||||
|
||||
/// Configures the strategy for shuffling the data.
|
||||
///
|
||||
/// The default is to shuffle the data randomly at row-level granularity (no clump size) and
|
||||
/// with a random seed.
|
||||
pub fn with_shuffle_strategy(mut self, shuffle_strategy: ShuffleStrategy) -> Self {
|
||||
self.config.shuffle_strategy = shuffle_strategy;
|
||||
self
|
||||
}
|
||||
|
||||
/// Configures a filter to apply to the base table.
|
||||
///
|
||||
/// Only rows matching the filter will be included in the permutation.
|
||||
pub fn with_filter(mut self, filter: String) -> Self {
|
||||
self.config.filter = Some(filter);
|
||||
self
|
||||
}
|
||||
|
||||
/// Configures the directory to use for temporary files.
|
||||
///
|
||||
/// The default is to use the operating system's default temporary directory.
|
||||
pub fn with_temp_dir(mut self, temp_dir: TemporaryDirectory) -> Self {
|
||||
self.config.temp_dir = temp_dir;
|
||||
self
|
||||
}
|
||||
|
||||
/// Stores the permutation as a table in a database
|
||||
///
|
||||
/// By default, the permutation is stored in memory. If this method is called then
|
||||
/// the permutation will be stored as a table in the given database.
|
||||
pub fn persist(mut self, database: Arc<dyn Database>, table_name: String) -> Self {
|
||||
self.config.destination = PermutationDestination::Permanent(database, table_name);
|
||||
self
|
||||
}
|
||||
|
||||
async fn sort_by_split_id(
|
||||
&self,
|
||||
data: SendableRecordBatchStream,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let ctx = SessionContext::new_with_config_rt(
|
||||
SessionConfig::default(),
|
||||
RuntimeEnvBuilder::new()
|
||||
.with_memory_limit(100 * 1024 * 1024, 1.0)
|
||||
.with_disk_manager_builder(
|
||||
DiskManagerBuilder::default()
|
||||
.with_mode(self.config.temp_dir.to_disk_manager_mode()),
|
||||
)
|
||||
.build_arc()
|
||||
.unwrap(),
|
||||
);
|
||||
let df = ctx
|
||||
.read_one_shot(data.into_df_stream())
|
||||
.map_err(|e| Error::Other {
|
||||
message: format!("Failed to setup sort by split id: {}", e),
|
||||
source: Some(e.into()),
|
||||
})?;
|
||||
let df_stream = df
|
||||
.sort_by(vec![col(SPLIT_ID_COLUMN)])
|
||||
.map_err(|e| Error::Other {
|
||||
message: format!("Failed to plan sort by split id: {}", e),
|
||||
source: Some(e.into()),
|
||||
})?
|
||||
.execute_stream()
|
||||
.await
|
||||
.map_err(|e| Error::Other {
|
||||
message: format!("Failed to sort by split id: {}", e),
|
||||
source: Some(e.into()),
|
||||
})?;
|
||||
|
||||
let schema = df_stream.schema();
|
||||
let stream = df_stream.map_err(|e| Error::Other {
|
||||
message: format!("Failed to execute sort by split id: {}", e),
|
||||
source: Some(e.into()),
|
||||
});
|
||||
Ok(Box::pin(SimpleRecordBatchStream { schema, stream }))
|
||||
}
|
||||
|
||||
fn add_split_names(
|
||||
data: SendableRecordBatchStream,
|
||||
split_names: &[String],
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let schema = data
|
||||
.schema()
|
||||
.as_ref()
|
||||
.clone()
|
||||
.with_metadata(HashMap::from([(
|
||||
SPLIT_NAMES_CONFIG_KEY.to_string(),
|
||||
serde_json::to_string(split_names).map_err(|e| Error::Other {
|
||||
message: format!("Failed to serialize split names: {}", e),
|
||||
source: Some(e.into()),
|
||||
})?,
|
||||
)]));
|
||||
let schema = Arc::new(schema);
|
||||
let schema_clone = schema.clone();
|
||||
let stream = data.map_ok(move |batch| batch.with_schema(schema.clone()).unwrap());
|
||||
Ok(Box::pin(SimpleRecordBatchStream {
|
||||
schema: schema_clone,
|
||||
stream,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Builds the permutation table and stores it in the given database.
|
||||
pub async fn build(self) -> Result<Table> {
|
||||
// First pass, apply filter and load row ids
|
||||
let mut rows = self.base_table.query().with_row_id();
|
||||
|
||||
if let Some(filter) = &self.config.filter {
|
||||
rows = rows.only_if(filter);
|
||||
}
|
||||
|
||||
let splitter = Splitter::new(
|
||||
self.config.temp_dir.clone(),
|
||||
self.config.split_strategy.clone(),
|
||||
);
|
||||
|
||||
let mut needs_sort = !splitter.orders_by_split_id();
|
||||
|
||||
// Might need to load additional columns to calculate splits (e.g. hash columns or calculated
|
||||
// split id)
|
||||
rows = splitter.project(rows);
|
||||
|
||||
let num_rows = self
|
||||
.base_table
|
||||
.count_rows(self.config.filter.clone())
|
||||
.await? as u64;
|
||||
|
||||
// Apply splits
|
||||
let rows = rows.execute().await?;
|
||||
let split_data = splitter.apply(rows, num_rows).await?;
|
||||
|
||||
// Shuffle data if requested
|
||||
let shuffled = match self.config.shuffle_strategy {
|
||||
ShuffleStrategy::None => split_data,
|
||||
ShuffleStrategy::Random { seed, clump_size } => {
|
||||
let shuffler = Shuffler::new(ShufflerConfig {
|
||||
seed,
|
||||
clump_size,
|
||||
temp_dir: self.config.temp_dir.clone(),
|
||||
max_rows_per_file: 10 * 1024 * 1024,
|
||||
});
|
||||
shuffler.shuffle(split_data, num_rows).await?
|
||||
}
|
||||
};
|
||||
|
||||
// We want the final permutation to be sorted by the split id. If we shuffled or if
|
||||
// the split was not assigned sequentially then we need to sort the data.
|
||||
needs_sort |= !matches!(self.config.shuffle_strategy, ShuffleStrategy::None);
|
||||
|
||||
let sorted = if needs_sort {
|
||||
self.sort_by_split_id(shuffled).await?
|
||||
} else {
|
||||
shuffled
|
||||
};
|
||||
|
||||
// Rename _rowid to row_id
|
||||
let renamed = rename_column(sorted, ROW_ID, SRC_ROW_ID_COL)?;
|
||||
|
||||
let streaming_data = if let Some(split_names) = &self.config.split_names {
|
||||
Self::add_split_names(renamed, split_names)?
|
||||
} else {
|
||||
renamed
|
||||
};
|
||||
|
||||
let (name, database) = match &self.config.destination {
|
||||
PermutationDestination::Permanent(database, table_name) => {
|
||||
(table_name.as_str(), database.clone())
|
||||
}
|
||||
PermutationDestination::Temporary => {
|
||||
let conn = connect("memory:///").execute().await?;
|
||||
("permutation", conn.database().clone())
|
||||
}
|
||||
};
|
||||
|
||||
let create_table_request = CreateTableRequest::new(
|
||||
name.to_string(),
|
||||
CreateTableData::StreamingData(streaming_data),
|
||||
);
|
||||
|
||||
let table = database.create_table(create_table_request).await?;
|
||||
|
||||
Ok(Table::new(table, database))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use arrow::datatypes::Int32Type;
|
||||
use lance_datagen::{BatchCount, RowCount};
|
||||
|
||||
use crate::{arrow::LanceDbDatagenExt, connect, dataloader::permutation::split::SplitSizes};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_permutation_builder() {
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
|
||||
let db = connect(temp_dir.path().to_str().unwrap())
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let initial_data = lance_datagen::gen_batch()
|
||||
.col("some_value", lance_datagen::array::step::<Int32Type>())
|
||||
.into_ldb_stream(RowCount::from(100), BatchCount::from(10));
|
||||
let data_table = db
|
||||
.create_table_streaming("mytbl", initial_data)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let permutation_table = PermutationBuilder::new(data_table.clone())
|
||||
.with_filter("some_value > 57".to_string())
|
||||
.with_split_strategy(
|
||||
SplitStrategy::Random {
|
||||
seed: Some(42),
|
||||
sizes: SplitSizes::Percentages(vec![0.05, 0.30]),
|
||||
},
|
||||
None,
|
||||
)
|
||||
.build()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
println!("permutation_table: {:?}", permutation_table);
|
||||
|
||||
// Potentially brittle seed-dependent values below
|
||||
assert_eq!(permutation_table.count_rows(None).await.unwrap(), 330);
|
||||
assert_eq!(
|
||||
permutation_table
|
||||
.count_rows(Some("split_id = 0".to_string()))
|
||||
.await
|
||||
.unwrap(),
|
||||
47
|
||||
);
|
||||
assert_eq!(
|
||||
permutation_table
|
||||
.count_rows(Some("split_id = 1".to_string()))
|
||||
.await
|
||||
.unwrap(),
|
||||
283
|
||||
);
|
||||
}
|
||||
}
|
||||
546
rust/lancedb/src/dataloader/permutation/reader.rs
Normal file
546
rust/lancedb/src/dataloader/permutation/reader.rs
Normal file
@@ -0,0 +1,546 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! Row ID-based views for LanceDB tables
|
||||
//!
|
||||
//! This module provides functionality for creating views that are based on specific row IDs.
|
||||
//! The `IdView` allows you to create a virtual table that contains only
|
||||
//! the rows from a source table that correspond to row IDs stored in a separate table.
|
||||
|
||||
use crate::arrow::{SendableRecordBatchStream, SimpleRecordBatchStream};
|
||||
use crate::dataloader::permutation::builder::SRC_ROW_ID_COL;
|
||||
use crate::dataloader::permutation::split::SPLIT_ID_COLUMN;
|
||||
use crate::error::Error;
|
||||
use crate::query::{
|
||||
ExecutableQuery, QueryBase, QueryExecutionOptions, QueryFilter, QueryRequest, Select,
|
||||
};
|
||||
use crate::table::{AnyQuery, BaseTable, Filter};
|
||||
use crate::{Result, Table};
|
||||
use arrow::array::AsArray;
|
||||
use arrow::compute::concat_batches;
|
||||
use arrow::datatypes::UInt64Type;
|
||||
use arrow_array::{RecordBatch, UInt64Array};
|
||||
use arrow_schema::SchemaRef;
|
||||
use futures::{StreamExt, TryStreamExt};
|
||||
use lance::dataset::scanner::DatasetRecordBatchStream;
|
||||
use lance::io::RecordBatchStream;
|
||||
use lance_arrow::RecordBatchExt;
|
||||
use lance_core::error::LanceOptionExt;
|
||||
use lance_core::ROW_ID;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Reads a permutation of a source table based on row IDs stored in a separate table
|
||||
#[derive(Clone)]
|
||||
pub struct PermutationReader {
|
||||
base_table: Arc<dyn BaseTable>,
|
||||
permutation_table: Option<Arc<dyn BaseTable>>,
|
||||
offset: Option<u64>,
|
||||
limit: Option<u64>,
|
||||
available_rows: u64,
|
||||
split: u64,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for PermutationReader {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"PermutationReader(base={}, permutation={}, split={}, offset={:?}, limit={:?})",
|
||||
self.base_table.name(),
|
||||
self.permutation_table
|
||||
.as_ref()
|
||||
.map(|t| t.name())
|
||||
.unwrap_or("--"),
|
||||
self.split,
|
||||
self.offset,
|
||||
self.limit,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl PermutationReader {
|
||||
/// Create a new PermutationReader
|
||||
pub async fn inner_new(
|
||||
base_table: Arc<dyn BaseTable>,
|
||||
permutation_table: Option<Arc<dyn BaseTable>>,
|
||||
split: u64,
|
||||
) -> Result<Self> {
|
||||
let mut slf = Self {
|
||||
base_table,
|
||||
permutation_table,
|
||||
offset: None,
|
||||
limit: None,
|
||||
available_rows: 0,
|
||||
split,
|
||||
};
|
||||
slf.validate().await?;
|
||||
// Calculate the number of available rows
|
||||
slf.available_rows = slf.verify_limit_offset(None, None).await?;
|
||||
if slf.available_rows == 0 {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "No rows found in the permutation table for the given split".to_string(),
|
||||
});
|
||||
}
|
||||
Ok(slf)
|
||||
}
|
||||
|
||||
pub async fn try_from_tables(
|
||||
base_table: Arc<dyn BaseTable>,
|
||||
permutation_table: Arc<dyn BaseTable>,
|
||||
split: u64,
|
||||
) -> Result<Self> {
|
||||
Self::inner_new(base_table, Some(permutation_table), split).await
|
||||
}
|
||||
|
||||
pub async fn identity(base_table: Arc<dyn BaseTable>) -> Self {
|
||||
Self::inner_new(base_table, None, 0).await.unwrap()
|
||||
}
|
||||
|
||||
/// Validates the limit and offset and returns the number of rows that will be read
|
||||
fn validate_limit_offset(
|
||||
limit: Option<u64>,
|
||||
offset: Option<u64>,
|
||||
available_rows: u64,
|
||||
) -> Result<u64> {
|
||||
match (limit, offset) {
|
||||
(Some(limit), Some(offset)) => {
|
||||
if offset + limit > available_rows {
|
||||
Err(Error::InvalidInput {
|
||||
message: "Offset + limit is greater than the number of rows in the permutation table"
|
||||
.to_string(),
|
||||
})
|
||||
} else {
|
||||
Ok(limit)
|
||||
}
|
||||
}
|
||||
(None, Some(offset)) => {
|
||||
if offset > available_rows {
|
||||
Err(Error::InvalidInput {
|
||||
message:
|
||||
"Offset is greater than the number of rows in the permutation table"
|
||||
.to_string(),
|
||||
})
|
||||
} else {
|
||||
Ok(available_rows - offset)
|
||||
}
|
||||
}
|
||||
(Some(limit), None) => {
|
||||
if limit > available_rows {
|
||||
Err(Error::InvalidInput {
|
||||
message:
|
||||
"Limit is greater than the number of rows in the permutation table"
|
||||
.to_string(),
|
||||
})
|
||||
} else {
|
||||
Ok(limit)
|
||||
}
|
||||
}
|
||||
(None, None) => Ok(available_rows),
|
||||
}
|
||||
}
|
||||
|
||||
async fn verify_limit_offset(&self, limit: Option<u64>, offset: Option<u64>) -> Result<u64> {
|
||||
let available_rows = if let Some(permutation_table) = &self.permutation_table {
|
||||
permutation_table
|
||||
.count_rows(Some(Filter::Sql(format!(
|
||||
"{} = {}",
|
||||
SPLIT_ID_COLUMN, self.split
|
||||
))))
|
||||
.await? as u64
|
||||
} else {
|
||||
self.base_table.count_rows(None).await? as u64
|
||||
};
|
||||
Self::validate_limit_offset(limit, offset, available_rows)
|
||||
}
|
||||
|
||||
pub async fn with_offset(mut self, offset: u64) -> Result<Self> {
|
||||
let available_rows = self.verify_limit_offset(self.limit, Some(offset)).await?;
|
||||
self.offset = Some(offset);
|
||||
self.available_rows = available_rows;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub async fn with_limit(mut self, limit: u64) -> Result<Self> {
|
||||
let available_rows = self.verify_limit_offset(Some(limit), self.offset).await?;
|
||||
self.available_rows = available_rows;
|
||||
self.limit = Some(limit);
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
fn is_sorted_already<'a, T: Iterator<Item = &'a u64>>(iter: T) -> bool {
|
||||
for (expected, idx) in iter.enumerate() {
|
||||
if *idx != expected as u64 {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
async fn load_batch(
|
||||
base_table: &Arc<dyn BaseTable>,
|
||||
row_ids: RecordBatch,
|
||||
selection: Select,
|
||||
has_row_id: bool,
|
||||
) -> Result<RecordBatch> {
|
||||
let num_rows = row_ids.num_rows();
|
||||
let row_ids = row_ids
|
||||
.column(0)
|
||||
.as_primitive_opt::<UInt64Type>()
|
||||
.expect_ok()?
|
||||
.values();
|
||||
|
||||
let filter = format!(
|
||||
"_rowid in ({})",
|
||||
row_ids
|
||||
.iter()
|
||||
.map(|o| o.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
);
|
||||
|
||||
let base_query = QueryRequest {
|
||||
filter: Some(QueryFilter::Sql(filter)),
|
||||
select: selection,
|
||||
with_row_id: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let data = base_table
|
||||
.query(
|
||||
&AnyQuery::Query(base_query),
|
||||
QueryExecutionOptions {
|
||||
max_batch_length: num_rows as u32,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let schema = data.schema();
|
||||
|
||||
let batches = data.try_collect::<Vec<_>>().await?;
|
||||
|
||||
if batches.is_empty() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Base table returned no batches".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if batches.iter().map(|b| b.num_rows()).sum::<usize>() != num_rows {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Base table returned different number of rows than the number of row IDs"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let batch = if batches.len() == 1 {
|
||||
batches.into_iter().next().unwrap()
|
||||
} else {
|
||||
concat_batches(&schema, &batches)?
|
||||
};
|
||||
|
||||
// There is no guarantee the result order will match the order provided
|
||||
// so may need to restore order
|
||||
let actual_row_ids = batch
|
||||
.column_by_name(ROW_ID)
|
||||
.expect_ok()?
|
||||
.as_primitive_opt::<UInt64Type>()
|
||||
.expect_ok()?
|
||||
.values();
|
||||
|
||||
// Map from row id to order in batch, used to restore original ordering
|
||||
let ordering = actual_row_ids
|
||||
.iter()
|
||||
.copied()
|
||||
.enumerate()
|
||||
.map(|(i, o)| (o, i as u64))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let desired_idx_order = row_ids
|
||||
.iter()
|
||||
.map(|o| ordering.get(o).copied().expect_ok().map_err(Error::from))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let ordered_batch = if Self::is_sorted_already(desired_idx_order.iter()) {
|
||||
// Fast path if already sorted, important as data may be large and
|
||||
// re-ordering could be expensive
|
||||
batch
|
||||
} else {
|
||||
let desired_idx_order = UInt64Array::from(desired_idx_order);
|
||||
|
||||
arrow_select::take::take_record_batch(&batch, &desired_idx_order)?
|
||||
};
|
||||
|
||||
if has_row_id {
|
||||
Ok(ordered_batch)
|
||||
} else {
|
||||
// The user didn't ask for row id, we needed it for ordering the data, but now we drop it
|
||||
Ok(ordered_batch.drop_column(ROW_ID)?)
|
||||
}
|
||||
}
|
||||
|
||||
async fn row_ids_to_batches(
|
||||
base_table: Arc<dyn BaseTable>,
|
||||
row_ids: DatasetRecordBatchStream,
|
||||
selection: Select,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let has_row_id = Self::has_row_id(&selection)?;
|
||||
let mut stream = row_ids
|
||||
.map_err(Error::from)
|
||||
.try_filter_map(move |batch| {
|
||||
let selection = selection.clone();
|
||||
let base_table = base_table.clone();
|
||||
async move {
|
||||
Self::load_batch(&base_table, batch, selection, has_row_id)
|
||||
.await
|
||||
.map(Some)
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
|
||||
// Need to read out first batch to get schema
|
||||
let Some(first_batch) = stream.try_next().await? else {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Permutation was empty".to_string(),
|
||||
});
|
||||
};
|
||||
let schema = first_batch.schema();
|
||||
|
||||
let stream = futures::stream::once(std::future::ready(Ok(first_batch))).chain(stream);
|
||||
|
||||
Ok(Box::pin(SimpleRecordBatchStream::new(stream, schema)))
|
||||
}
|
||||
|
||||
fn has_row_id(selection: &Select) -> Result<bool> {
|
||||
match selection {
|
||||
Select::All => {
|
||||
// _rowid is a system column and is not included in Select::All
|
||||
Ok(false)
|
||||
}
|
||||
Select::Columns(columns) => Ok(columns.contains(&ROW_ID.to_string())),
|
||||
Select::Dynamic(columns) => {
|
||||
for column in columns {
|
||||
if column.0 == ROW_ID {
|
||||
if column.1 == ROW_ID {
|
||||
return Ok(true);
|
||||
} else {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"Dynamic column {} cannot be used to select _rowid",
|
||||
column.1
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn validate(&self) -> Result<()> {
|
||||
if let Some(permutation_table) = &self.permutation_table {
|
||||
let schema = permutation_table.schema().await?;
|
||||
if schema.column_with_name(SRC_ROW_ID_COL).is_none() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Permutation table must contain a column named row_id".to_string(),
|
||||
});
|
||||
}
|
||||
if schema.column_with_name(SPLIT_ID_COLUMN).is_none() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Permutation table must contain a column named split_id".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
let avail_rows = if let Some(permutation_table) = &self.permutation_table {
|
||||
permutation_table.count_rows(None).await? as u64
|
||||
} else {
|
||||
self.base_table.count_rows(None).await? as u64
|
||||
};
|
||||
Self::validate_limit_offset(self.limit, self.offset, avail_rows)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn read(
|
||||
&self,
|
||||
selection: Select,
|
||||
execution_options: QueryExecutionOptions,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
// Note: this relies on the row ids query here being returned in consistent order
|
||||
let row_ids = if let Some(permutation_table) = &self.permutation_table {
|
||||
permutation_table
|
||||
.query(
|
||||
&AnyQuery::Query(QueryRequest {
|
||||
select: Select::Columns(vec![SRC_ROW_ID_COL.to_string()]),
|
||||
filter: Some(QueryFilter::Sql(format!(
|
||||
"{} = {}",
|
||||
SPLIT_ID_COLUMN, self.split
|
||||
))),
|
||||
offset: self.offset.map(|o| o as usize),
|
||||
limit: self.limit.map(|l| l as usize),
|
||||
..Default::default()
|
||||
}),
|
||||
execution_options,
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
self.base_table
|
||||
.query(
|
||||
&AnyQuery::Query(QueryRequest {
|
||||
select: Select::Columns(vec![ROW_ID.to_string()]),
|
||||
offset: self.offset.map(|o| o as usize),
|
||||
limit: self.limit.map(|l| l as usize),
|
||||
..Default::default()
|
||||
}),
|
||||
execution_options,
|
||||
)
|
||||
.await?
|
||||
};
|
||||
Self::row_ids_to_batches(self.base_table.clone(), row_ids, selection).await
|
||||
}
|
||||
|
||||
pub async fn output_schema(&self, selection: Select) -> Result<SchemaRef> {
|
||||
let table = Table::from(self.base_table.clone());
|
||||
table.query().select(selection).output_schema().await
|
||||
}
|
||||
|
||||
pub fn count_rows(&self) -> u64 {
|
||||
self.available_rows
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use arrow::datatypes::Int32Type;
|
||||
use arrow_array::{ArrowPrimitiveType, RecordBatch, UInt64Array};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use lance_datagen::{BatchCount, RowCount};
|
||||
use rand::seq::SliceRandom;
|
||||
|
||||
use crate::{
|
||||
arrow::SendableRecordBatchStream,
|
||||
query::{ExecutableQuery, QueryBase},
|
||||
test_utils::datagen::{virtual_table, LanceDbDatagenExt},
|
||||
Table,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
async fn collect_from_stream<T: ArrowPrimitiveType>(
|
||||
mut stream: SendableRecordBatchStream,
|
||||
column: &str,
|
||||
) -> Vec<T::Native> {
|
||||
let mut row_ids = Vec::new();
|
||||
while let Some(batch) = stream.try_next().await.unwrap() {
|
||||
let col_idx = batch.schema().index_of(column).unwrap();
|
||||
row_ids.extend(batch.column(col_idx).as_primitive::<T>().values().to_vec());
|
||||
}
|
||||
row_ids
|
||||
}
|
||||
|
||||
async fn collect_column<T: ArrowPrimitiveType>(table: &Table, column: &str) -> Vec<T::Native> {
|
||||
collect_from_stream::<T>(
|
||||
table
|
||||
.query()
|
||||
.select(Select::Columns(vec![column.to_string()]))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap(),
|
||||
column,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_permutation_reader() {
|
||||
let base_table = lance_datagen::gen_batch()
|
||||
.col("idx", lance_datagen::array::step::<Int32Type>())
|
||||
.col("other_col", lance_datagen::array::step::<UInt64Type>())
|
||||
.into_mem_table("tbl", RowCount::from(9), BatchCount::from(1))
|
||||
.await;
|
||||
|
||||
let mut row_ids = collect_column::<UInt64Type>(&base_table, "_rowid").await;
|
||||
row_ids.shuffle(&mut rand::rng());
|
||||
// Put the last two rows in split 1
|
||||
let split_ids = UInt64Array::from_iter_values(
|
||||
std::iter::repeat_n(0, row_ids.len() - 2).chain(std::iter::repeat_n(1, 2)),
|
||||
);
|
||||
let permutation_batch = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![
|
||||
Field::new("row_id", DataType::UInt64, false),
|
||||
Field::new(SPLIT_ID_COLUMN, DataType::UInt64, false),
|
||||
])),
|
||||
vec![
|
||||
Arc::new(UInt64Array::from(row_ids.clone())),
|
||||
Arc::new(split_ids),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let row_ids_table = virtual_table("row_ids", &permutation_batch).await;
|
||||
|
||||
let reader = PermutationReader::try_from_tables(
|
||||
base_table.base_table().clone(),
|
||||
row_ids_table.base_table().clone(),
|
||||
0,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Read split 0
|
||||
let mut stream = reader
|
||||
.read(
|
||||
Select::All,
|
||||
QueryExecutionOptions {
|
||||
max_batch_length: 3,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(stream.schema(), base_table.schema().await.unwrap());
|
||||
|
||||
let check_batch = async |stream: &mut SendableRecordBatchStream,
|
||||
expected_values: &[u64]| {
|
||||
let batch = stream.try_next().await.unwrap().unwrap();
|
||||
assert_eq!(batch.num_rows(), expected_values.len());
|
||||
assert_eq!(
|
||||
batch.column(0).as_primitive::<Int32Type>().values(),
|
||||
&expected_values
|
||||
.iter()
|
||||
.map(|o| *o as i32)
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
assert_eq!(
|
||||
batch.column(1).as_primitive::<UInt64Type>().values(),
|
||||
&expected_values
|
||||
);
|
||||
};
|
||||
|
||||
check_batch(&mut stream, &row_ids[0..3]).await;
|
||||
check_batch(&mut stream, &row_ids[3..6]).await;
|
||||
check_batch(&mut stream, &row_ids[6..7]).await;
|
||||
assert!(stream.try_next().await.unwrap().is_none());
|
||||
|
||||
// Read split 1
|
||||
let reader = PermutationReader::try_from_tables(
|
||||
base_table.base_table().clone(),
|
||||
row_ids_table.base_table().clone(),
|
||||
1,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut stream = reader
|
||||
.read(
|
||||
Select::All,
|
||||
QueryExecutionOptions {
|
||||
max_batch_length: 3,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
check_batch(&mut stream, &row_ids[7..9]).await;
|
||||
assert!(stream.try_next().await.unwrap().is_none());
|
||||
}
|
||||
}
|
||||
@@ -22,7 +22,7 @@ use rand::{seq::SliceRandom, Rng, RngCore};
|
||||
|
||||
use crate::{
|
||||
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
|
||||
dataloader::util::{non_crypto_rng, TemporaryDirectory},
|
||||
dataloader::permutation::util::{non_crypto_rng, TemporaryDirectory},
|
||||
Error, Result,
|
||||
};
|
||||
|
||||
@@ -13,13 +13,13 @@ use arrow_array::{Array, BooleanArray, RecordBatch, UInt64Array};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use datafusion_common::hash_utils::create_hashes;
|
||||
use futures::{StreamExt, TryStreamExt};
|
||||
use lance::arrow::SchemaExt;
|
||||
use lance_arrow::SchemaExt;
|
||||
|
||||
use crate::{
|
||||
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
|
||||
dataloader::{
|
||||
shuffle::{Shuffler, ShufflerConfig},
|
||||
util::TemporaryDirectory,
|
||||
permutation::shuffle::{Shuffler, ShufflerConfig},
|
||||
permutation::util::TemporaryDirectory,
|
||||
},
|
||||
query::{Query, QueryBase, Select},
|
||||
Error, Result,
|
||||
@@ -10,7 +10,7 @@ pub mod sentence_transformers;
|
||||
#[cfg(feature = "bedrock")]
|
||||
pub mod bedrock;
|
||||
|
||||
use lance::arrow::RecordBatchExt;
|
||||
use lance_arrow::RecordBatchExt;
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
collections::{HashMap, HashSet},
|
||||
|
||||
@@ -6,6 +6,8 @@ use std::sync::PoisonError;
|
||||
use arrow_schema::ArrowError;
|
||||
use snafu::Snafu;
|
||||
|
||||
type BoxError = Box<dyn std::error::Error + Send + Sync>;
|
||||
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(visibility(pub(crate)))]
|
||||
pub enum Error {
|
||||
@@ -14,7 +16,7 @@ pub enum Error {
|
||||
#[snafu(display("Invalid input, {message}"))]
|
||||
InvalidInput { message: String },
|
||||
#[snafu(display("Table '{name}' was not found"))]
|
||||
TableNotFound { name: String },
|
||||
TableNotFound { name: String, source: BoxError },
|
||||
#[snafu(display("Database '{name}' was not found"))]
|
||||
DatabaseNotFound { name: String },
|
||||
#[snafu(display("Database '{name}' already exists."))]
|
||||
|
||||
@@ -207,7 +207,8 @@ pub mod query;
|
||||
pub mod remote;
|
||||
pub mod rerankers;
|
||||
pub mod table;
|
||||
pub mod test_connection;
|
||||
#[cfg(test)]
|
||||
pub mod test_utils;
|
||||
pub mod utils;
|
||||
|
||||
use std::fmt::Display;
|
||||
|
||||
@@ -6,15 +6,13 @@ use std::{future::Future, time::Duration};
|
||||
|
||||
use arrow::compute::concat_batches;
|
||||
use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array};
|
||||
use arrow_schema::DataType;
|
||||
use arrow_schema::{DataType, SchemaRef};
|
||||
use datafusion_expr::Expr;
|
||||
use datafusion_physical_plan::ExecutionPlan;
|
||||
use futures::{stream, try_join, FutureExt, TryStreamExt};
|
||||
use futures::{stream, try_join, FutureExt, TryFutureExt, TryStreamExt};
|
||||
use half::f16;
|
||||
use lance::{
|
||||
arrow::RecordBatchExt,
|
||||
dataset::{scanner::DatasetRecordBatchStream, ROW_ID},
|
||||
};
|
||||
use lance::dataset::{scanner::DatasetRecordBatchStream, ROW_ID};
|
||||
use lance_arrow::RecordBatchExt;
|
||||
use lance_datafusion::exec::execute_plan;
|
||||
use lance_index::scalar::inverted::SCORE_COL;
|
||||
use lance_index::scalar::FullTextSearchQuery;
|
||||
@@ -36,7 +34,7 @@ pub(crate) const DEFAULT_TOP_K: usize = 10;
|
||||
/// Which columns should be retrieved from the database
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Select {
|
||||
/// Select all columns
|
||||
/// Select all non-system columns
|
||||
///
|
||||
/// Warning: This will always be slower than selecting only the columns you need.
|
||||
All,
|
||||
@@ -582,16 +580,40 @@ pub trait ExecutableQuery {
|
||||
options: QueryExecutionOptions,
|
||||
) -> impl Future<Output = Result<SendableRecordBatchStream>> + Send;
|
||||
|
||||
/// Explain the plan for a query
|
||||
///
|
||||
/// This will create a string representation of the plan that will be used to
|
||||
/// execute the query. This will not execute the query.
|
||||
///
|
||||
/// This function can be used to get an understanding of what work will be done by the query
|
||||
/// and is useful for debugging query performance.
|
||||
fn explain_plan(&self, verbose: bool) -> impl Future<Output = Result<String>> + Send;
|
||||
|
||||
/// Execute the query and display the runtime metrics
|
||||
///
|
||||
/// This shows the same plan as [`ExecutableQuery::explain_plan`] but includes runtime metrics.
|
||||
///
|
||||
/// This function will actually execute the query in order to get the runtime metrics.
|
||||
fn analyze_plan(&self) -> impl Future<Output = Result<String>> + Send {
|
||||
self.analyze_plan_with_options(QueryExecutionOptions::default())
|
||||
}
|
||||
|
||||
/// Execute the query and display the runtime metrics
|
||||
///
|
||||
/// This is the same as [`ExecutableQuery::analyze_plan`] but allows for specifying the execution options.
|
||||
fn analyze_plan_with_options(
|
||||
&self,
|
||||
options: QueryExecutionOptions,
|
||||
) -> impl Future<Output = Result<String>> + Send;
|
||||
|
||||
/// Return the output schema for data returned by the query without actually executing the query
|
||||
///
|
||||
/// This can be useful when the selection for a query is built dynamically as it is not always
|
||||
/// obvious what the output schema will be.
|
||||
fn output_schema(&self) -> impl Future<Output = Result<SchemaRef>> + Send {
|
||||
self.create_plan(QueryExecutionOptions::default())
|
||||
.and_then(|plan| std::future::ready(Ok(plan.schema())))
|
||||
}
|
||||
}
|
||||
|
||||
/// A query filter that can be applied to a query
|
||||
@@ -645,6 +667,12 @@ pub struct QueryRequest {
|
||||
|
||||
/// Configure how query results are normalized when doing hybrid search
|
||||
pub norm: Option<NormalizeMethod>,
|
||||
|
||||
/// If set to true, disables automatic projection of scoring columns (_score, _distance).
|
||||
/// When disabled, these columns are only included if explicitly requested in the projection.
|
||||
///
|
||||
/// By default, this is false (scoring columns are auto-projected for backward compatibility).
|
||||
pub disable_scoring_autoprojection: bool,
|
||||
}
|
||||
|
||||
impl Default for QueryRequest {
|
||||
@@ -660,6 +688,7 @@ impl Default for QueryRequest {
|
||||
prefilter: true,
|
||||
reranker: None,
|
||||
norm: None,
|
||||
disable_scoring_autoprojection: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1505,6 +1534,16 @@ mod tests {
|
||||
.query()
|
||||
.limit(10)
|
||||
.select(Select::dynamic(&[("id2", "id * 2"), ("id", "id")]));
|
||||
|
||||
let schema = query.output_schema().await.unwrap();
|
||||
assert_eq!(
|
||||
schema,
|
||||
Arc::new(ArrowSchema::new(vec![
|
||||
ArrowField::new("id2", DataType::Int32, true),
|
||||
ArrowField::new("id", DataType::Int32, true),
|
||||
]))
|
||||
);
|
||||
|
||||
let result = query.execute().await;
|
||||
let mut batches = result
|
||||
.expect("should have result")
|
||||
|
||||
@@ -515,11 +515,8 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/describe/", identifier));
|
||||
let (request_id, rsp) = self.client.send_with_retry(req, None, true).await?;
|
||||
if rsp.status() == StatusCode::NOT_FOUND {
|
||||
return Err(crate::Error::TableNotFound {
|
||||
name: identifier.clone(),
|
||||
});
|
||||
}
|
||||
let rsp =
|
||||
RemoteTable::<S>::handle_table_not_found(&request.name, rsp, &request_id).await?;
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let version = parse_server_version(&request_id, &rsp)?;
|
||||
let table_identifier = build_table_identifier(
|
||||
|
||||
@@ -336,16 +336,33 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
pub(super) async fn handle_table_not_found(
|
||||
table_name: &str,
|
||||
response: reqwest::Response,
|
||||
request_id: &str,
|
||||
) -> Result<reqwest::Response> {
|
||||
let status = response.status();
|
||||
if status == StatusCode::NOT_FOUND {
|
||||
let body = response.text().await.ok().unwrap_or_default();
|
||||
let request_error = Error::Http {
|
||||
source: body.into(),
|
||||
request_id: request_id.into(),
|
||||
status_code: Some(status),
|
||||
};
|
||||
return Err(Error::TableNotFound {
|
||||
name: table_name.to_string(),
|
||||
source: Box::new(request_error),
|
||||
});
|
||||
}
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn check_table_response(
|
||||
&self,
|
||||
request_id: &str,
|
||||
response: reqwest::Response,
|
||||
) -> Result<reqwest::Response> {
|
||||
if response.status() == StatusCode::NOT_FOUND {
|
||||
return Err(Error::TableNotFound {
|
||||
name: self.identifier.clone(),
|
||||
});
|
||||
}
|
||||
let response = Self::handle_table_not_found(&self.name, response, request_id).await?;
|
||||
|
||||
self.client.check_response(request_id, response).await
|
||||
}
|
||||
@@ -681,8 +698,9 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
.map_err(|e| match e {
|
||||
// try to map the error to a more user-friendly error telling them
|
||||
// specifically that the version does not exist
|
||||
Error::TableNotFound { name } => Error::TableNotFound {
|
||||
Error::TableNotFound { name, source } => Error::TableNotFound {
|
||||
name: format!("{} (version: {})", name, version),
|
||||
source,
|
||||
},
|
||||
e => e,
|
||||
})?;
|
||||
@@ -1427,6 +1445,10 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
"NOT_SUPPORTED"
|
||||
}
|
||||
|
||||
async fn storage_options(&self) -> Option<HashMap<String, String>> {
|
||||
None
|
||||
}
|
||||
|
||||
async fn stats(&self) -> Result<TableStatistics> {
|
||||
let request = self
|
||||
.client
|
||||
@@ -1571,7 +1593,11 @@ mod tests {
|
||||
for result in results {
|
||||
let result = result.await;
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result, Err(Error::TableNotFound { name }) if name == "my_table"));
|
||||
assert!(
|
||||
matches!(&result, &Err(Error::TableNotFound { ref name, .. }) if name == "my_table")
|
||||
);
|
||||
let full_error_report = snafu::Report::from_error(result.unwrap_err()).to_string();
|
||||
assert!(full_error_report.contains("table my_table not found"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2880,7 +2906,7 @@ mod tests {
|
||||
let res = table.checkout(43).await;
|
||||
println!("{:?}", res);
|
||||
assert!(
|
||||
matches!(res, Err(Error::TableNotFound { name }) if name == "my_table (version: 43)")
|
||||
matches!(res, Err(Error::TableNotFound { name, .. }) if name == "my_table (version: 43)")
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -511,6 +511,9 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
|
||||
/// Get the namespace of the table.
|
||||
fn namespace(&self) -> &[String];
|
||||
/// Get the id of the table
|
||||
///
|
||||
/// This is the namespace of the table concatenated with the name
|
||||
/// separated by a dot (".")
|
||||
fn id(&self) -> &str;
|
||||
/// Get the arrow [Schema] of the table.
|
||||
async fn schema(&self) -> Result<SchemaRef>;
|
||||
@@ -598,6 +601,8 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
|
||||
async fn table_definition(&self) -> Result<TableDefinition>;
|
||||
/// Get the table URI
|
||||
fn dataset_uri(&self) -> &str;
|
||||
/// Get the storage options used when opening this table, if any.
|
||||
async fn storage_options(&self) -> Option<HashMap<String, String>>;
|
||||
/// Poll until the columns are fully indexed. Will return Error::Timeout if the columns
|
||||
/// are not fully indexed within the timeout.
|
||||
async fn wait_for_index(
|
||||
@@ -615,7 +620,7 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Table {
|
||||
inner: Arc<dyn BaseTable>,
|
||||
database: Arc<dyn Database>,
|
||||
database: Option<Arc<dyn Database>>,
|
||||
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
||||
}
|
||||
|
||||
@@ -639,7 +644,7 @@ mod test_utils {
|
||||
let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
|
||||
Self {
|
||||
inner,
|
||||
database,
|
||||
database: Some(database),
|
||||
// Registry is unused.
|
||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||
}
|
||||
@@ -661,7 +666,7 @@ mod test_utils {
|
||||
let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
|
||||
Self {
|
||||
inner,
|
||||
database,
|
||||
database: Some(database),
|
||||
// Registry is unused.
|
||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||
}
|
||||
@@ -675,11 +680,21 @@ impl std::fmt::Display for Table {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Arc<dyn BaseTable>> for Table {
|
||||
fn from(inner: Arc<dyn BaseTable>) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
database: None,
|
||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Table {
|
||||
pub fn new(inner: Arc<dyn BaseTable>, database: Arc<dyn Database>) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
database,
|
||||
database: Some(database),
|
||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||
}
|
||||
}
|
||||
@@ -689,7 +704,7 @@ impl Table {
|
||||
}
|
||||
|
||||
pub fn database(&self) -> &Arc<dyn Database> {
|
||||
&self.database
|
||||
self.database.as_ref().unwrap()
|
||||
}
|
||||
|
||||
pub fn embedding_registry(&self) -> &Arc<dyn EmbeddingRegistry> {
|
||||
@@ -703,7 +718,7 @@ impl Table {
|
||||
) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
database,
|
||||
database: Some(database),
|
||||
embedding_registry,
|
||||
}
|
||||
}
|
||||
@@ -1290,6 +1305,13 @@ impl Table {
|
||||
self.inner.dataset_uri()
|
||||
}
|
||||
|
||||
/// Get the storage options used when opening this table, if any.
|
||||
///
|
||||
/// Warning: This is an internal API and the return value is subject to change.
|
||||
pub async fn storage_options(&self) -> Option<HashMap<String, String>> {
|
||||
self.inner.storage_options().await
|
||||
}
|
||||
|
||||
/// Get statistics about an index.
|
||||
/// Returns None if the index does not exist.
|
||||
pub async fn index_stats(
|
||||
@@ -1522,6 +1544,7 @@ impl NativeTable {
|
||||
.map_err(|e| match e {
|
||||
lance::Error::DatasetNotFound { .. } => Error::TableNotFound {
|
||||
name: name.to_string(),
|
||||
source: Box::new(e),
|
||||
},
|
||||
source => Error::Lance { source },
|
||||
})?;
|
||||
@@ -1542,6 +1565,7 @@ impl NativeTable {
|
||||
.file_stem()
|
||||
.ok_or(Error::TableNotFound {
|
||||
name: uri.to_string(),
|
||||
source: format!("Could not extract table name from URI: '{}'", uri).into(),
|
||||
})?
|
||||
.to_str()
|
||||
.ok_or(Error::InvalidTableName {
|
||||
@@ -2379,6 +2403,10 @@ impl BaseTable for NativeTable {
|
||||
scanner.distance_metric(distance_type.into());
|
||||
}
|
||||
|
||||
if query.base.disable_scoring_autoprojection {
|
||||
scanner.disable_scoring_autoprojection();
|
||||
}
|
||||
|
||||
Ok(scanner.create_plan().await?)
|
||||
}
|
||||
|
||||
@@ -2614,6 +2642,14 @@ impl BaseTable for NativeTable {
|
||||
self.uri.as_str()
|
||||
}
|
||||
|
||||
async fn storage_options(&self) -> Option<HashMap<String, String>> {
|
||||
self.dataset
|
||||
.get()
|
||||
.await
|
||||
.ok()
|
||||
.and_then(|dataset| dataset.storage_options().cloned())
|
||||
}
|
||||
|
||||
async fn index_stats(&self, index_name: &str) -> Result<Option<IndexStatistics>> {
|
||||
let stats = match self
|
||||
.dataset
|
||||
@@ -2623,7 +2659,7 @@ impl BaseTable for NativeTable {
|
||||
.await
|
||||
{
|
||||
Ok(stats) => stats,
|
||||
Err(lance::error::Error::IndexNotFound { .. }) => return Ok(None),
|
||||
Err(lance_core::Error::IndexNotFound { .. }) => return Ok(None),
|
||||
Err(e) => return Err(Error::from(e)),
|
||||
};
|
||||
|
||||
|
||||
@@ -2,6 +2,9 @@
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! This module contains adapters to allow LanceDB tables to be used as DataFusion table providers.
|
||||
|
||||
pub mod udtf;
|
||||
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use arrow_array::RecordBatch;
|
||||
@@ -21,6 +24,8 @@ use crate::{
|
||||
query::{QueryExecutionOptions, QueryFilter, QueryRequest, Select},
|
||||
Result,
|
||||
};
|
||||
use arrow_schema::{DataType, Field};
|
||||
use lance_index::scalar::FullTextSearchQuery;
|
||||
|
||||
/// Datafusion attempts to maintain batch metadata
|
||||
///
|
||||
@@ -135,19 +140,38 @@ impl ExecutionPlan for MetadataEraserExec {
|
||||
pub struct BaseTableAdapter {
|
||||
table: Arc<dyn BaseTable>,
|
||||
schema: Arc<ArrowSchema>,
|
||||
fts_query: Option<FullTextSearchQuery>,
|
||||
}
|
||||
|
||||
impl BaseTableAdapter {
|
||||
pub async fn try_new(table: Arc<dyn BaseTable>) -> Result<Self> {
|
||||
let schema = Arc::new(
|
||||
table
|
||||
.schema()
|
||||
.await?
|
||||
.as_ref()
|
||||
.clone()
|
||||
.with_metadata(HashMap::default()),
|
||||
);
|
||||
Ok(Self { table, schema })
|
||||
let schema = table
|
||||
.schema()
|
||||
.await?
|
||||
.as_ref()
|
||||
.clone()
|
||||
.with_metadata(HashMap::default());
|
||||
|
||||
Ok(Self {
|
||||
table,
|
||||
schema: Arc::new(schema),
|
||||
fts_query: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a new adapter with an FTS query applied.
|
||||
pub fn with_fts_query(&self, fts_query: FullTextSearchQuery) -> Self {
|
||||
// Add _score column to the schema
|
||||
let score_field = Field::new("_score", DataType::Float32, true);
|
||||
let mut fields = self.schema.fields().to_vec();
|
||||
fields.push(Arc::new(score_field));
|
||||
let schema = Arc::new(ArrowSchema::new(fields));
|
||||
|
||||
Self {
|
||||
table: self.table.clone(),
|
||||
schema,
|
||||
fts_query: Some(fts_query),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,7 +196,15 @@ impl TableProvider for BaseTableAdapter {
|
||||
filters: &[Expr],
|
||||
limit: Option<usize>,
|
||||
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
|
||||
let mut query = QueryRequest::default();
|
||||
// For FTS queries, disable auto-projection of _score to match DataFusion expectations
|
||||
let disable_scoring = self.fts_query.is_some() && projection.is_some();
|
||||
|
||||
let mut query = QueryRequest {
|
||||
full_text_search: self.fts_query.clone(),
|
||||
disable_scoring_autoprojection: disable_scoring,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
if let Some(projection) = projection {
|
||||
let field_names = projection
|
||||
.iter()
|
||||
|
||||
6
rust/lancedb/src/table/datafusion/udtf.rs
Normal file
6
rust/lancedb/src/table/datafusion/udtf.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! User-Defined Table Functions (UDTFs) for DataFusion integration
|
||||
|
||||
pub mod fts;
|
||||
2028
rust/lancedb/src/table/datafusion/udtf/fts.rs
Normal file
2028
rust/lancedb/src/table/datafusion/udtf/fts.rs
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user