mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 06:39:57 +00:00
Compare commits
48 Commits
python-v0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4f7b24d1a9 | ||
|
|
f9540724b7 | ||
|
|
aeac9c7644 | ||
|
|
6ddd271627 | ||
|
|
f0d7520bdf | ||
|
|
7ef8bafd51 | ||
|
|
aed4a7c98e | ||
|
|
273ba18426 | ||
|
|
8b94308cf2 | ||
|
|
0b7b27481e | ||
|
|
e1f9b011f8 | ||
|
|
d664b8739f | ||
|
|
20bec61ecb | ||
|
|
45255be42c | ||
|
|
93c2cf2f59 | ||
|
|
9d29c83f81 | ||
|
|
2a6143b5bd | ||
|
|
b2242886e0 | ||
|
|
199904ab35 | ||
|
|
1fa888615f | ||
|
|
40967f3baa | ||
|
|
0bfc7de32c | ||
|
|
d43880a585 | ||
|
|
59a886958b | ||
|
|
c36f6746d1 | ||
|
|
25ce6d311f | ||
|
|
92a4e46f9f | ||
|
|
845641c480 | ||
|
|
d96404c635 | ||
|
|
02d31ee412 | ||
|
|
308623577d | ||
|
|
8ee3ae378f | ||
|
|
3372a2aae0 | ||
|
|
4cfcd95320 | ||
|
|
a70ff04bc9 | ||
|
|
a9daa18be9 | ||
|
|
3f2e3986e9 | ||
|
|
bf55feb9b6 | ||
|
|
8f8e06a2da | ||
|
|
03eab0f091 | ||
|
|
143184c0ae | ||
|
|
dadb042978 | ||
|
|
5a19cf15a6 | ||
|
|
3dcec724b7 | ||
|
|
86a6bb9fcb | ||
|
|
b59d1007d3 | ||
|
|
56a16b1728 | ||
|
|
b7afed9beb |
@@ -1,5 +1,5 @@
|
|||||||
[tool.bumpversion]
|
[tool.bumpversion]
|
||||||
current_version = "0.22.2-beta.2"
|
current_version = "0.22.3-beta.5"
|
||||||
parse = """(?x)
|
parse = """(?x)
|
||||||
(?P<major>0|[1-9]\\d*)\\.
|
(?P<major>0|[1-9]\\d*)\\.
|
||||||
(?P<minor>0|[1-9]\\d*)\\.
|
(?P<minor>0|[1-9]\\d*)\\.
|
||||||
|
|||||||
107
.github/workflows/codex-update-lance-dependency.yml
vendored
Normal file
107
.github/workflows/codex-update-lance-dependency.yml
vendored
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
name: Codex Update Lance Dependency
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_call:
|
||||||
|
inputs:
|
||||||
|
tag:
|
||||||
|
description: "Tag name from Lance"
|
||||||
|
required: true
|
||||||
|
type: string
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
tag:
|
||||||
|
description: "Tag name from Lance"
|
||||||
|
required: true
|
||||||
|
type: string
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
pull-requests: write
|
||||||
|
actions: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
update:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Show inputs
|
||||||
|
run: |
|
||||||
|
echo "tag = ${{ inputs.tag }}"
|
||||||
|
|
||||||
|
- name: Checkout Repo LanceDB
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
persist-credentials: true
|
||||||
|
|
||||||
|
- name: Set up Node.js
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: 20
|
||||||
|
|
||||||
|
- name: Install Codex CLI
|
||||||
|
run: npm install -g @openai/codex
|
||||||
|
|
||||||
|
- name: Install Rust toolchain
|
||||||
|
uses: dtolnay/rust-toolchain@stable
|
||||||
|
with:
|
||||||
|
toolchain: stable
|
||||||
|
components: clippy, rustfmt
|
||||||
|
|
||||||
|
- name: Install system dependencies
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y protobuf-compiler libssl-dev
|
||||||
|
|
||||||
|
- name: Install cargo-info
|
||||||
|
run: cargo install cargo-info
|
||||||
|
|
||||||
|
- name: Install Python dependencies
|
||||||
|
run: python3 -m pip install --upgrade pip packaging
|
||||||
|
|
||||||
|
- name: Configure git user
|
||||||
|
run: |
|
||||||
|
git config user.name "lancedb automation"
|
||||||
|
git config user.email "robot@lancedb.com"
|
||||||
|
|
||||||
|
- name: Configure Codex authentication
|
||||||
|
env:
|
||||||
|
CODEX_TOKEN_B64: ${{ secrets.CODEX_TOKEN }}
|
||||||
|
run: |
|
||||||
|
if [ -z "${CODEX_TOKEN_B64}" ]; then
|
||||||
|
echo "Repository secret CODEX_TOKEN is not defined; skipping Codex execution."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
mkdir -p ~/.codex
|
||||||
|
echo "${CODEX_TOKEN_B64}" | base64 --decode > ~/.codex/auth.json
|
||||||
|
|
||||||
|
- name: Run Codex to update Lance dependency
|
||||||
|
env:
|
||||||
|
TAG: ${{ inputs.tag }}
|
||||||
|
GITHUB_TOKEN: ${{ secrets.ROBOT_TOKEN }}
|
||||||
|
GH_TOKEN: ${{ secrets.ROBOT_TOKEN }}
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
VERSION="${TAG#refs/tags/}"
|
||||||
|
VERSION="${VERSION#v}"
|
||||||
|
BRANCH_NAME="codex/update-lance-${VERSION//[^a-zA-Z0-9]/-}"
|
||||||
|
cat <<EOF >/tmp/codex-prompt.txt
|
||||||
|
You are running inside the lancedb repository on a GitHub Actions runner. Update the Lance dependency to version ${VERSION} and prepare a pull request for maintainers to review.
|
||||||
|
|
||||||
|
Follow these steps exactly:
|
||||||
|
1. Use script "ci/set_lance_version.py" to update Lance dependencies. The script already refreshes Cargo metadata, so allow it to finish even if it takes time.
|
||||||
|
2. Run "cargo clippy --workspace --tests --all-features -- -D warnings". If diagnostics appear, fix them yourself and rerun clippy until it exits cleanly. Do not skip any warnings.
|
||||||
|
3. After clippy succeeds, run "cargo fmt --all" to format the workspace.
|
||||||
|
4. Ensure the repository is clean except for intentional changes. Inspect "git status --short" and "git diff" to confirm the dependency update and any required fixes.
|
||||||
|
5. Create and switch to a new branch named "${BRANCH_NAME}" (replace any duplicated hyphens if necessary).
|
||||||
|
6. Stage all relevant files with "git add -A". Commit using the message "chore: update lance dependency to v${VERSION}".
|
||||||
|
7. Push the branch to origin. If the branch already exists, force-push your changes.
|
||||||
|
8. env "GH_TOKEN" is available, use "gh" tools for github related operations like creating pull request.
|
||||||
|
9. Create a pull request targeting "main" with title "chore: update lance dependency to v${VERSION}". In the body, summarize the dependency bump, clippy/fmt verification, and link the triggering tag (${TAG}).
|
||||||
|
10. After creating the PR, display the PR URL, "git status --short", and a concise summary of the commands run and their results.
|
||||||
|
|
||||||
|
Constraints:
|
||||||
|
- Use bash commands; avoid modifying GitHub workflow files other than through the scripted task above.
|
||||||
|
- Do not merge the PR.
|
||||||
|
- If any command fails, diagnose and fix the issue instead of aborting.
|
||||||
|
EOF
|
||||||
|
codex --config shell_environment_policy.ignore_default_excludes=true exec --dangerously-bypass-approvals-and-sandbox "$(cat /tmp/codex-prompt.txt)"
|
||||||
101
AGENTS.md
Normal file
101
AGENTS.md
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
LanceDB is a database designed for retrieval, including vector, full-text, and hybrid search.
|
||||||
|
It is a wrapper around Lance. There are two backends: local (in-process like SQLite) and
|
||||||
|
remote (against LanceDB Cloud).
|
||||||
|
|
||||||
|
The core of LanceDB is written in Rust. There are bindings in Python, Typescript, and Java.
|
||||||
|
|
||||||
|
Project layout:
|
||||||
|
|
||||||
|
* `rust/lancedb`: The LanceDB core Rust implementation.
|
||||||
|
* `python`: The Python bindings, using PyO3.
|
||||||
|
* `nodejs`: The Typescript bindings, using napi-rs
|
||||||
|
* `java`: The Java bindings
|
||||||
|
|
||||||
|
Common commands:
|
||||||
|
|
||||||
|
* Check for compiler errors: `cargo check --quiet --features remote --tests --examples`
|
||||||
|
* Run tests: `cargo test --quiet --features remote --tests`
|
||||||
|
* Run specific test: `cargo test --quiet --features remote -p <package_name> --test <test_name>`
|
||||||
|
* Lint: `cargo clippy --quiet --features remote --tests --examples`
|
||||||
|
* Format: `cargo fmt --all`
|
||||||
|
|
||||||
|
Before committing changes, run formatting.
|
||||||
|
|
||||||
|
## Coding tips
|
||||||
|
|
||||||
|
* When writing Rust doctests for things that require a connection or table reference,
|
||||||
|
write them as a function instead of a fully executable test. This allows type checking
|
||||||
|
to run but avoids needing a full test environment. For example:
|
||||||
|
```rust
|
||||||
|
/// ```
|
||||||
|
/// use lance_index::scalar::FullTextSearchQuery;
|
||||||
|
/// use lancedb::query::{QueryBase, ExecutableQuery};
|
||||||
|
///
|
||||||
|
/// # use lancedb::Table;
|
||||||
|
/// # async fn query(table: &Table) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
/// let results = table.query()
|
||||||
|
/// .full_text_search(FullTextSearchQuery::new("hello world".into()))
|
||||||
|
/// .execute()
|
||||||
|
/// .await?;
|
||||||
|
/// # Ok(())
|
||||||
|
/// # }
|
||||||
|
/// ```
|
||||||
|
```
|
||||||
|
|
||||||
|
## Example plan: adding a new method on Table
|
||||||
|
|
||||||
|
Adding a new method involves first adding it to the Rust core, then exposing it
|
||||||
|
in the Python and TypeScript bindings. There are both local and remote tables.
|
||||||
|
Remote tables are implemented via a HTTP API and require the `remote` cargo
|
||||||
|
feature flag to be enabled. Python has both sync and async methods.
|
||||||
|
|
||||||
|
Rust core changes:
|
||||||
|
|
||||||
|
1. Add method on `Table` struct in `rust/lancedb/src/table.rs` (calls `BaseTable` trait).
|
||||||
|
2. Add method to `BaseTable` trait in `rust/lancedb/src/table.rs`.
|
||||||
|
3. Implement new trait method on `NativeTable` in `rust/lancedb/src/table.rs`.
|
||||||
|
* Test with unit test in `rust/lancedb/src/table.rs`.
|
||||||
|
4. Implement new trait method on `RemoteTable` in `rust/lancedb/src/remote/table.rs`.
|
||||||
|
* Test with unit test in `rust/lancedb/src/remote/table.rs` against mocked endpoint.
|
||||||
|
|
||||||
|
Python bindings changes:
|
||||||
|
|
||||||
|
1. Add PyO3 method binding in `python/src/table.rs`. Run `make develop` to compile bindings.
|
||||||
|
2. Add types for PyO3 method in `python/python/lancedb/_lancedb.pyi`.
|
||||||
|
3. Add method to `AsyncTable` class in `python/python/lancedb/table.py`.
|
||||||
|
4. Add abstract method to `Table` abstract base class in `python/python/lancedb/table.py`.
|
||||||
|
5. Add concrete sync method to `LanceTable` class in `python/python/lancedb/table.py`.
|
||||||
|
* Should use `LOOP.run()` to call the corresponding `AsyncTable` method.
|
||||||
|
6. Add concrete sync method to `RemoteTable` class in `python/python/lancedb/remote/table.py`.
|
||||||
|
7. Add unit test in `python/tests/test_table.py`.
|
||||||
|
|
||||||
|
TypeScript bindings changes:
|
||||||
|
|
||||||
|
1. Add napi-rs method binding on `Table` in `nodejs/src/table.rs`.
|
||||||
|
2. Run `npm run build` to generate TypeScript definitions.
|
||||||
|
3. Add typescript method on abstract class `Table` in `nodejs/src/table.ts`.
|
||||||
|
4. Add concrete method on `LocalTable` class in `nodejs/src/native_table.ts`.
|
||||||
|
* Note: despite the name, this class is also used for remote tables.
|
||||||
|
5. Add test in `nodejs/__test__/table.test.ts`.
|
||||||
|
6. Run `npm run docs` to generate TypeScript documentation.
|
||||||
|
|
||||||
|
## Review Guidelines
|
||||||
|
|
||||||
|
Please consider the following when reviewing code contributions.
|
||||||
|
|
||||||
|
### Rust API design
|
||||||
|
* Design public APIs so they can be evolved easily in the future without breaking
|
||||||
|
changes. Often this means using builder patterns or options structs instead of
|
||||||
|
long argument lists.
|
||||||
|
* For public APIs, prefer inputs that use `Into<T>` or `AsRef<T>` traits to allow
|
||||||
|
more flexible inputs. For example, use `name: Into<String>` instead of `name: String`,
|
||||||
|
so we don't have to write `func("my_string".to_string())`.
|
||||||
|
|
||||||
|
### Testing
|
||||||
|
* Ensure all new public APIs have documentation and examples.
|
||||||
|
* Ensure that all bugfixes and features have corresponding tests. **We do not merge
|
||||||
|
code without tests.**
|
||||||
|
|
||||||
|
### Documentation
|
||||||
|
* New features must include updates to the rust documentation comments. Link to
|
||||||
|
relevant structs and methods to increase the value of documentation.
|
||||||
80
CLAUDE.md
80
CLAUDE.md
@@ -1,80 +0,0 @@
|
|||||||
LanceDB is a database designed for retrieval, including vector, full-text, and hybrid search.
|
|
||||||
It is a wrapper around Lance. There are two backends: local (in-process like SQLite) and
|
|
||||||
remote (against LanceDB Cloud).
|
|
||||||
|
|
||||||
The core of LanceDB is written in Rust. There are bindings in Python, Typescript, and Java.
|
|
||||||
|
|
||||||
Project layout:
|
|
||||||
|
|
||||||
* `rust/lancedb`: The LanceDB core Rust implementation.
|
|
||||||
* `python`: The Python bindings, using PyO3.
|
|
||||||
* `nodejs`: The Typescript bindings, using napi-rs
|
|
||||||
* `java`: The Java bindings
|
|
||||||
|
|
||||||
Common commands:
|
|
||||||
|
|
||||||
* Check for compiler errors: `cargo check --quiet --features remote --tests --examples`
|
|
||||||
* Run tests: `cargo test --quiet --features remote --tests`
|
|
||||||
* Run specific test: `cargo test --quiet --features remote -p <package_name> --test <test_name>`
|
|
||||||
* Lint: `cargo clippy --quiet --features remote --tests --examples`
|
|
||||||
* Format: `cargo fmt --all`
|
|
||||||
|
|
||||||
Before committing changes, run formatting.
|
|
||||||
|
|
||||||
## Coding tips
|
|
||||||
|
|
||||||
* When writing Rust doctests for things that require a connection or table reference,
|
|
||||||
write them as a function instead of a fully executable test. This allows type checking
|
|
||||||
to run but avoids needing a full test environment. For example:
|
|
||||||
```rust
|
|
||||||
/// ```
|
|
||||||
/// use lance_index::scalar::FullTextSearchQuery;
|
|
||||||
/// use lancedb::query::{QueryBase, ExecutableQuery};
|
|
||||||
///
|
|
||||||
/// # use lancedb::Table;
|
|
||||||
/// # async fn query(table: &Table) -> Result<(), Box<dyn std::error::Error>> {
|
|
||||||
/// let results = table.query()
|
|
||||||
/// .full_text_search(FullTextSearchQuery::new("hello world".into()))
|
|
||||||
/// .execute()
|
|
||||||
/// .await?;
|
|
||||||
/// # Ok(())
|
|
||||||
/// # }
|
|
||||||
/// ```
|
|
||||||
```
|
|
||||||
|
|
||||||
## Example plan: adding a new method on Table
|
|
||||||
|
|
||||||
Adding a new method involves first adding it to the Rust core, then exposing it
|
|
||||||
in the Python and TypeScript bindings. There are both local and remote tables.
|
|
||||||
Remote tables are implemented via a HTTP API and require the `remote` cargo
|
|
||||||
feature flag to be enabled. Python has both sync and async methods.
|
|
||||||
|
|
||||||
Rust core changes:
|
|
||||||
|
|
||||||
1. Add method on `Table` struct in `rust/lancedb/src/table.rs` (calls `BaseTable` trait).
|
|
||||||
2. Add method to `BaseTable` trait in `rust/lancedb/src/table.rs`.
|
|
||||||
3. Implement new trait method on `NativeTable` in `rust/lancedb/src/table.rs`.
|
|
||||||
* Test with unit test in `rust/lancedb/src/table.rs`.
|
|
||||||
4. Implement new trait method on `RemoteTable` in `rust/lancedb/src/remote/table.rs`.
|
|
||||||
* Test with unit test in `rust/lancedb/src/remote/table.rs` against mocked endpoint.
|
|
||||||
|
|
||||||
Python bindings changes:
|
|
||||||
|
|
||||||
1. Add PyO3 method binding in `python/src/table.rs`. Run `make develop` to compile bindings.
|
|
||||||
2. Add types for PyO3 method in `python/python/lancedb/_lancedb.pyi`.
|
|
||||||
3. Add method to `AsyncTable` class in `python/python/lancedb/table.py`.
|
|
||||||
4. Add abstract method to `Table` abstract base class in `python/python/lancedb/table.py`.
|
|
||||||
5. Add concrete sync method to `LanceTable` class in `python/python/lancedb/table.py`.
|
|
||||||
* Should use `LOOP.run()` to call the corresponding `AsyncTable` method.
|
|
||||||
6. Add concrete sync method to `RemoteTable` class in `python/python/lancedb/remote/table.py`.
|
|
||||||
7. Add unit test in `python/tests/test_table.py`.
|
|
||||||
|
|
||||||
TypeScript bindings changes:
|
|
||||||
|
|
||||||
1. Add napi-rs method binding on `Table` in `nodejs/src/table.rs`.
|
|
||||||
2. Run `npm run build` to generate TypeScript definitions.
|
|
||||||
3. Add typescript method on abstract class `Table` in `nodejs/src/table.ts`.
|
|
||||||
4. Add concrete method on `LocalTable` class in `nodejs/src/native_table.ts`.
|
|
||||||
* Note: despite the name, this class is also used for remote tables.
|
|
||||||
5. Add test in `nodejs/__test__/table.test.ts`.
|
|
||||||
6. Run `npm run docs` to generate TypeScript documentation.
|
|
||||||
225
Cargo.lock
generated
225
Cargo.lock
generated
@@ -72,12 +72,6 @@ version = "0.2.21"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923"
|
checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "android-tzdata"
|
|
||||||
version = "0.1.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "android_system_properties"
|
name = "android_system_properties"
|
||||||
version = "0.1.5"
|
version = "0.1.5"
|
||||||
@@ -1474,17 +1468,16 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "chrono"
|
name = "chrono"
|
||||||
version = "0.4.41"
|
version = "0.4.42"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d"
|
checksum = "145052bdd345b87320e369255277e3fb5152762ad123a901ef5c262dd38fe8d2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"android-tzdata",
|
|
||||||
"iana-time-zone",
|
"iana-time-zone",
|
||||||
"js-sys",
|
"js-sys",
|
||||||
"num-traits",
|
"num-traits",
|
||||||
"serde",
|
"serde",
|
||||||
"wasm-bindgen",
|
"wasm-bindgen",
|
||||||
"windows-link 0.1.3",
|
"windows-link 0.2.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1573,7 +1566,7 @@ checksum = "e0d05af1e006a2407bedef5af410552494ce5be9090444dbbcb57258c1af3d56"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"crossterm 0.27.0",
|
"crossterm 0.27.0",
|
||||||
"crossterm 0.28.1",
|
"crossterm 0.28.1",
|
||||||
"strum 0.26.3",
|
"strum",
|
||||||
"strum_macros 0.26.4",
|
"strum_macros 0.26.4",
|
||||||
"unicode-width",
|
"unicode-width",
|
||||||
]
|
]
|
||||||
@@ -2940,18 +2933,6 @@ version = "0.2.3"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f8eb564c5c7423d25c886fb561d1e4ee69f72354d16918afa32c08811f6b6a55"
|
checksum = "f8eb564c5c7423d25c886fb561d1e4ee69f72354d16918afa32c08811f6b6a55"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "fastbloom"
|
|
||||||
version = "0.14.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "18c1ddb9231d8554c2d6bdf4cfaabf0c59251658c68b6c95cd52dd0c513a912a"
|
|
||||||
dependencies = [
|
|
||||||
"getrandom 0.3.3",
|
|
||||||
"libm",
|
|
||||||
"rand 0.9.2",
|
|
||||||
"siphasher",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fastdivide"
|
name = "fastdivide"
|
||||||
version = "0.4.2"
|
version = "0.4.2"
|
||||||
@@ -3051,9 +3032,9 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fsst"
|
name = "fsst"
|
||||||
version = "0.38.2"
|
version = "0.39.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "480fc4f47567da549ab44bb2f37f6db1570c9eff7200e50334b69fa1daa74339"
|
checksum = "1d2475ce218217196b161b025598f77e2b405d5e729f7c37bfff145f5df00a41"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow-array",
|
"arrow-array",
|
||||||
"rand 0.9.2",
|
"rand 0.9.2",
|
||||||
@@ -3850,7 +3831,7 @@ dependencies = [
|
|||||||
"js-sys",
|
"js-sys",
|
||||||
"log",
|
"log",
|
||||||
"wasm-bindgen",
|
"wasm-bindgen",
|
||||||
"windows-core 0.62.2",
|
"windows-core 0.61.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -4237,9 +4218,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance"
|
name = "lance"
|
||||||
version = "0.38.2"
|
version = "0.39.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "5e2d2472f58d01894bc5f0a9f9d28dfca4649c9e28faf467c47e87f788ef322b"
|
checksum = "a2f0ca022d0424d991933a62d2898864cf5621873962bd84e65e7d1f023f9c36"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow",
|
"arrow",
|
||||||
"arrow-arith",
|
"arrow-arith",
|
||||||
@@ -4278,6 +4259,7 @@ dependencies = [
|
|||||||
"lance-index",
|
"lance-index",
|
||||||
"lance-io",
|
"lance-io",
|
||||||
"lance-linalg",
|
"lance-linalg",
|
||||||
|
"lance-namespace",
|
||||||
"lance-table",
|
"lance-table",
|
||||||
"log",
|
"log",
|
||||||
"moka",
|
"moka",
|
||||||
@@ -4288,6 +4270,7 @@ dependencies = [
|
|||||||
"prost-types",
|
"prost-types",
|
||||||
"rand 0.9.2",
|
"rand 0.9.2",
|
||||||
"roaring",
|
"roaring",
|
||||||
|
"semver",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"snafu",
|
"snafu",
|
||||||
@@ -4301,9 +4284,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-arrow"
|
name = "lance-arrow"
|
||||||
version = "0.38.2"
|
version = "0.39.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a2abba8770c4217fbdc8b517cdfb7183639b02dc5c2bcad1e7c69ffdcf4fbe1a"
|
checksum = "7552f8d528775bf0ab21e1f75dcb70bdb2a828eeae58024a803b5a4655fd9a11"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow-array",
|
"arrow-array",
|
||||||
"arrow-buffer",
|
"arrow-buffer",
|
||||||
@@ -4321,9 +4304,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-bitpacking"
|
name = "lance-bitpacking"
|
||||||
version = "0.38.2"
|
version = "0.39.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "efb7af69bff8d8499999684f961b0a4dc6e159065c773041545d19bc158f0814"
|
checksum = "a2ea14583cc6fa0bb190bcc2d3bc364b0aa545b345702976025f810e4740e8ce"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrayref",
|
"arrayref",
|
||||||
"paste",
|
"paste",
|
||||||
@@ -4332,9 +4315,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-core"
|
name = "lance-core"
|
||||||
version = "0.38.2"
|
version = "0.39.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "356a5df5f9cd7cb4aedaf78a4e346190ae50ba574b828316caed7d1df3b6dcd8"
|
checksum = "69c752dedd207384892006c40930f898d6634e05e3d489e89763abfe4b9307e7"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow-array",
|
"arrow-array",
|
||||||
"arrow-buffer",
|
"arrow-buffer",
|
||||||
@@ -4370,9 +4353,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-datafusion"
|
name = "lance-datafusion"
|
||||||
version = "0.38.2"
|
version = "0.39.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b8e8ec07021bdaba6a441563d8fbcb0431350aae6842910ae3622557765f218f"
|
checksum = "21e1e98ca6e5cd337bdda2d9fb66063f295c0c2852d2bc6831366fea833ee608"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow",
|
"arrow",
|
||||||
"arrow-array",
|
"arrow-array",
|
||||||
@@ -4381,6 +4364,7 @@ dependencies = [
|
|||||||
"arrow-schema",
|
"arrow-schema",
|
||||||
"arrow-select",
|
"arrow-select",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
|
"chrono",
|
||||||
"datafusion",
|
"datafusion",
|
||||||
"datafusion-common",
|
"datafusion-common",
|
||||||
"datafusion-functions",
|
"datafusion-functions",
|
||||||
@@ -4400,9 +4384,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-datagen"
|
name = "lance-datagen"
|
||||||
version = "0.38.2"
|
version = "0.39.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d4fe98730cd5297dc68b22f6ad7e1e27cf34e2db05586b64d3540ca74a519a61"
|
checksum = "483c643fc2806ed1a2766edf4d180511bbd1d549bcc60373e33f4785c6185891"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow",
|
"arrow",
|
||||||
"arrow-array",
|
"arrow-array",
|
||||||
@@ -4419,9 +4403,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-encoding"
|
name = "lance-encoding"
|
||||||
version = "0.38.2"
|
version = "0.39.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "ef073d419cc00ef41dd95cb25203b333118b224151ae397145530b1d559769c9"
|
checksum = "a199d1fa3487529c5ffc433fbd1721231330b9350c2ff9b0c7b7dbdb98f0806a"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow-arith",
|
"arrow-arith",
|
||||||
"arrow-array",
|
"arrow-array",
|
||||||
@@ -4449,7 +4433,7 @@ dependencies = [
|
|||||||
"prost-types",
|
"prost-types",
|
||||||
"rand 0.9.2",
|
"rand 0.9.2",
|
||||||
"snafu",
|
"snafu",
|
||||||
"strum 0.25.0",
|
"strum",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tracing",
|
"tracing",
|
||||||
"xxhash-rust",
|
"xxhash-rust",
|
||||||
@@ -4458,9 +4442,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-file"
|
name = "lance-file"
|
||||||
version = "0.38.2"
|
version = "0.39.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0e34aba3a41f119188da997730560e4a6915ee5a38b672bbf721fdc99121aa1e"
|
checksum = "b57def2279465232cf5a8cd996300c632442e368745768bbed661c7f0a35334b"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow-arith",
|
"arrow-arith",
|
||||||
"arrow-array",
|
"arrow-array",
|
||||||
@@ -4485,7 +4469,6 @@ dependencies = [
|
|||||||
"prost",
|
"prost",
|
||||||
"prost-build",
|
"prost-build",
|
||||||
"prost-types",
|
"prost-types",
|
||||||
"roaring",
|
|
||||||
"snafu",
|
"snafu",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tracing",
|
"tracing",
|
||||||
@@ -4493,9 +4476,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-index"
|
name = "lance-index"
|
||||||
version = "0.38.2"
|
version = "0.39.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c5f480f801c8efb41a6dedc48a5cacff6044a10f82c6f9764b8dac7194a7754e"
|
checksum = "a75938c61e986aef8c615dc44c92e4c19e393160a59e2b57402ccfe08c5e63af"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow",
|
"arrow",
|
||||||
"arrow-arith",
|
"arrow-arith",
|
||||||
@@ -4517,7 +4500,6 @@ dependencies = [
|
|||||||
"datafusion-sql",
|
"datafusion-sql",
|
||||||
"deepsize",
|
"deepsize",
|
||||||
"dirs",
|
"dirs",
|
||||||
"fastbloom",
|
|
||||||
"fst",
|
"fst",
|
||||||
"futures",
|
"futures",
|
||||||
"half",
|
"half",
|
||||||
@@ -4557,9 +4539,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-io"
|
name = "lance-io"
|
||||||
version = "0.38.2"
|
version = "0.39.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0708125c74965b2b7e5e0c4fe2d8e6bd8346a7031484f8844cf06c08bfa29a72"
|
checksum = "fa6c3b5b28570d6c951206c5b043f1b35c936928af14fca6f2ac25b0097e4c32"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow",
|
"arrow",
|
||||||
"arrow-arith",
|
"arrow-arith",
|
||||||
@@ -4580,6 +4562,7 @@ dependencies = [
|
|||||||
"futures",
|
"futures",
|
||||||
"lance-arrow",
|
"lance-arrow",
|
||||||
"lance-core",
|
"lance-core",
|
||||||
|
"lance-namespace",
|
||||||
"log",
|
"log",
|
||||||
"object_store",
|
"object_store",
|
||||||
"object_store_opendal",
|
"object_store_opendal",
|
||||||
@@ -4598,44 +4581,55 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-linalg"
|
name = "lance-linalg"
|
||||||
version = "0.38.2"
|
version = "0.39.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "da9d1c22deed92420a1869e4b89188ccecc7e1aee2ea4e5bca92eae861511d60"
|
checksum = "b3cbc7e85a89ff9cb3a4627559dea3fd1c1fb16c0d8bc46ede75eefef51eec06"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow-array",
|
"arrow-array",
|
||||||
"arrow-buffer",
|
"arrow-buffer",
|
||||||
"arrow-ord",
|
|
||||||
"arrow-schema",
|
"arrow-schema",
|
||||||
"bitvec",
|
|
||||||
"cc",
|
"cc",
|
||||||
"deepsize",
|
"deepsize",
|
||||||
"futures",
|
|
||||||
"half",
|
"half",
|
||||||
"lance-arrow",
|
"lance-arrow",
|
||||||
"lance-core",
|
"lance-core",
|
||||||
"log",
|
|
||||||
"num-traits",
|
"num-traits",
|
||||||
"rand 0.9.2",
|
"rand 0.9.2",
|
||||||
"rayon",
|
|
||||||
"tokio",
|
|
||||||
"tracing",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-namespace"
|
name = "lance-namespace"
|
||||||
version = "0.0.18"
|
version = "0.39.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7c0629165b5d85ff305f2de8833dcee507e899b36b098864c59f14f3b8b8e62d"
|
checksum = "897dd6726816515bb70a698ce7cda44670dca5761637696d7905b45f405a8cd9"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow",
|
"arrow",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"bytes",
|
"bytes",
|
||||||
"lance",
|
"lance-core",
|
||||||
"lance-namespace-reqwest-client",
|
"lance-namespace-reqwest-client",
|
||||||
"opendal",
|
"snafu",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "lance-namespace-impls"
|
||||||
|
version = "0.39.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5e3cfcd3ba369de2719abf6fb6233f69cda639eb5cbcb328487a790e745ab988"
|
||||||
|
dependencies = [
|
||||||
|
"arrow",
|
||||||
|
"arrow-ipc",
|
||||||
|
"arrow-schema",
|
||||||
|
"async-trait",
|
||||||
|
"bytes",
|
||||||
|
"lance",
|
||||||
|
"lance-core",
|
||||||
|
"lance-io",
|
||||||
|
"lance-namespace",
|
||||||
|
"object_store",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"thiserror 1.0.69",
|
"snafu",
|
||||||
"url",
|
"url",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -4654,9 +4648,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-table"
|
name = "lance-table"
|
||||||
version = "0.38.2"
|
version = "0.39.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "805e6c64efbb3295f74714668c9033121ffdfa6c868f067024e65ade700b8b8b"
|
checksum = "c8facc13760ba034b6c38767b16adba85e44cbcbea8124dc0c63c43865c60630"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow",
|
"arrow",
|
||||||
"arrow-array",
|
"arrow-array",
|
||||||
@@ -4683,6 +4677,7 @@ dependencies = [
|
|||||||
"rand 0.9.2",
|
"rand 0.9.2",
|
||||||
"rangemap",
|
"rangemap",
|
||||||
"roaring",
|
"roaring",
|
||||||
|
"semver",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"snafu",
|
"snafu",
|
||||||
@@ -4694,9 +4689,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lance-testing"
|
name = "lance-testing"
|
||||||
version = "0.38.2"
|
version = "0.39.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8ac735b5eb153a6ac841ce0206e4c30df941610c812cc89c8ae20006f8d0b018"
|
checksum = "b05052ef86188d6ae6339bdd9f2c5d77190e8ad1158f3dc8a42fa91bde9e5246"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow-array",
|
"arrow-array",
|
||||||
"arrow-schema",
|
"arrow-schema",
|
||||||
@@ -4707,8 +4702,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
version = "0.22.2-beta.2"
|
version = "0.22.3-beta.5"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"ahash",
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"arrow",
|
"arrow",
|
||||||
"arrow-array",
|
"arrow-array",
|
||||||
@@ -4717,6 +4713,7 @@ dependencies = [
|
|||||||
"arrow-ipc",
|
"arrow-ipc",
|
||||||
"arrow-ord",
|
"arrow-ord",
|
||||||
"arrow-schema",
|
"arrow-schema",
|
||||||
|
"arrow-select",
|
||||||
"async-openai",
|
"async-openai",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"aws-config",
|
"aws-config",
|
||||||
@@ -4725,13 +4722,11 @@ dependencies = [
|
|||||||
"aws-sdk-kms",
|
"aws-sdk-kms",
|
||||||
"aws-sdk-s3",
|
"aws-sdk-s3",
|
||||||
"aws-smithy-runtime",
|
"aws-smithy-runtime",
|
||||||
"bytemuck_derive",
|
|
||||||
"bytes",
|
"bytes",
|
||||||
"candle-core",
|
"candle-core",
|
||||||
"candle-nn",
|
"candle-nn",
|
||||||
"candle-transformers",
|
"candle-transformers",
|
||||||
"chrono",
|
"chrono",
|
||||||
"crunchy",
|
|
||||||
"datafusion",
|
"datafusion",
|
||||||
"datafusion-catalog",
|
"datafusion-catalog",
|
||||||
"datafusion-common",
|
"datafusion-common",
|
||||||
@@ -4744,12 +4739,17 @@ dependencies = [
|
|||||||
"http 1.3.1",
|
"http 1.3.1",
|
||||||
"http-body 1.0.1",
|
"http-body 1.0.1",
|
||||||
"lance",
|
"lance",
|
||||||
|
"lance-arrow",
|
||||||
|
"lance-core",
|
||||||
"lance-datafusion",
|
"lance-datafusion",
|
||||||
|
"lance-datagen",
|
||||||
"lance-encoding",
|
"lance-encoding",
|
||||||
|
"lance-file",
|
||||||
"lance-index",
|
"lance-index",
|
||||||
"lance-io",
|
"lance-io",
|
||||||
"lance-linalg",
|
"lance-linalg",
|
||||||
"lance-namespace",
|
"lance-namespace",
|
||||||
|
"lance-namespace-impls",
|
||||||
"lance-table",
|
"lance-table",
|
||||||
"lance-testing",
|
"lance-testing",
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
@@ -4771,6 +4771,7 @@ dependencies = [
|
|||||||
"serde_with",
|
"serde_with",
|
||||||
"snafu",
|
"snafu",
|
||||||
"tempfile",
|
"tempfile",
|
||||||
|
"test-log",
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
"url",
|
"url",
|
||||||
@@ -4796,7 +4797,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lancedb-nodejs"
|
name = "lancedb-nodejs"
|
||||||
version = "0.22.2-beta.2"
|
version = "0.22.3-beta.5"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow-array",
|
"arrow-array",
|
||||||
"arrow-ipc",
|
"arrow-ipc",
|
||||||
@@ -4816,7 +4817,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lancedb-python"
|
name = "lancedb-python"
|
||||||
version = "0.25.2-beta.2"
|
version = "0.25.3-beta.5"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow",
|
"arrow",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
@@ -5176,12 +5177,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mock_instant"
|
name = "mock_instant"
|
||||||
version = "0.3.2"
|
version = "0.6.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9366861eb2a2c436c20b12c8dbec5f798cea6b47ad99216be0282942e2c81ea0"
|
checksum = "dce6dd36094cac388f119d2e9dc82dc730ef91c32a6222170d630e5414b956e6"
|
||||||
dependencies = [
|
|
||||||
"once_cell",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "moka"
|
name = "moka"
|
||||||
@@ -7909,20 +7907,14 @@ version = "0.11.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
|
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "strum"
|
|
||||||
version = "0.25.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125"
|
|
||||||
dependencies = [
|
|
||||||
"strum_macros 0.25.3",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "strum"
|
name = "strum"
|
||||||
version = "0.26.3"
|
version = "0.26.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06"
|
checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06"
|
||||||
|
dependencies = [
|
||||||
|
"strum_macros 0.26.4",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "strum_macros"
|
name = "strum_macros"
|
||||||
@@ -8244,6 +8236,28 @@ dependencies = [
|
|||||||
"windows-sys 0.61.2",
|
"windows-sys 0.61.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "test-log"
|
||||||
|
version = "0.2.18"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1e33b98a582ea0be1168eba097538ee8dd4bbe0f2b01b22ac92ea30054e5be7b"
|
||||||
|
dependencies = [
|
||||||
|
"env_logger",
|
||||||
|
"test-log-macros",
|
||||||
|
"tracing-subscriber",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "test-log-macros"
|
||||||
|
version = "0.2.18"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "451b374529930d7601b1eef8d32bc79ae870b6079b069401709c2a8bf9e75f36"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.106",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thiserror"
|
name = "thiserror"
|
||||||
version = "1.0.69"
|
version = "1.0.69"
|
||||||
@@ -9061,21 +9075,8 @@ dependencies = [
|
|||||||
"windows-implement",
|
"windows-implement",
|
||||||
"windows-interface",
|
"windows-interface",
|
||||||
"windows-link 0.1.3",
|
"windows-link 0.1.3",
|
||||||
"windows-result 0.3.4",
|
"windows-result",
|
||||||
"windows-strings 0.4.2",
|
"windows-strings",
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "windows-core"
|
|
||||||
version = "0.62.2"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb"
|
|
||||||
dependencies = [
|
|
||||||
"windows-implement",
|
|
||||||
"windows-interface",
|
|
||||||
"windows-link 0.2.1",
|
|
||||||
"windows-result 0.4.1",
|
|
||||||
"windows-strings 0.5.1",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -9140,8 +9141,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "5b8a9ed28765efc97bbc954883f4e6796c33a06546ebafacbabee9696967499e"
|
checksum = "5b8a9ed28765efc97bbc954883f4e6796c33a06546ebafacbabee9696967499e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"windows-link 0.1.3",
|
"windows-link 0.1.3",
|
||||||
"windows-result 0.3.4",
|
"windows-result",
|
||||||
"windows-strings 0.4.2",
|
"windows-strings",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -9153,15 +9154,6 @@ dependencies = [
|
|||||||
"windows-link 0.1.3",
|
"windows-link 0.1.3",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "windows-result"
|
|
||||||
version = "0.4.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5"
|
|
||||||
dependencies = [
|
|
||||||
"windows-link 0.2.1",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows-strings"
|
name = "windows-strings"
|
||||||
version = "0.4.2"
|
version = "0.4.2"
|
||||||
@@ -9171,15 +9163,6 @@ dependencies = [
|
|||||||
"windows-link 0.1.3",
|
"windows-link 0.1.3",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "windows-strings"
|
|
||||||
version = "0.5.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091"
|
|
||||||
dependencies = [
|
|
||||||
"windows-link 0.2.1",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows-sys"
|
name = "windows-sys"
|
||||||
version = "0.45.0"
|
version = "0.45.0"
|
||||||
|
|||||||
45
Cargo.toml
45
Cargo.toml
@@ -15,15 +15,21 @@ categories = ["database-implementations"]
|
|||||||
rust-version = "1.78.0"
|
rust-version = "1.78.0"
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
lance = { "version" = "=0.38.2", default-features = false, "features" = ["dynamodb"] }
|
lance = { "version" = "=0.39.0", default-features = false }
|
||||||
lance-io = { "version" = "=0.38.2", default-features = false }
|
lance-core = "=0.39.0"
|
||||||
lance-index = "=0.38.2"
|
lance-datagen = "=0.39.0"
|
||||||
lance-linalg = "=0.38.2"
|
lance-file = "=0.39.0"
|
||||||
lance-table = "=0.38.2"
|
lance-io = { "version" = "=0.39.0", default-features = false }
|
||||||
lance-testing = "=0.38.2"
|
lance-index = "=0.39.0"
|
||||||
lance-datafusion = "=0.38.2"
|
lance-linalg = "=0.39.0"
|
||||||
lance-encoding = "=0.38.2"
|
lance-namespace = "=0.39.0"
|
||||||
lance-namespace = "0.0.18"
|
lance-namespace-impls = { "version" = "=0.39.0", "features" = ["dir-aws", "dir-gcp", "dir-azure", "dir-oss", "rest"] }
|
||||||
|
lance-table = "=0.39.0"
|
||||||
|
lance-testing = "=0.39.0"
|
||||||
|
lance-datafusion = "=0.39.0"
|
||||||
|
lance-encoding = "=0.39.0"
|
||||||
|
lance-arrow = "=0.39.0"
|
||||||
|
ahash = "0.8"
|
||||||
# Note that this one does not include pyarrow
|
# Note that this one does not include pyarrow
|
||||||
arrow = { version = "56.2", optional = false }
|
arrow = { version = "56.2", optional = false }
|
||||||
arrow-array = "56.2"
|
arrow-array = "56.2"
|
||||||
@@ -31,6 +37,7 @@ arrow-data = "56.2"
|
|||||||
arrow-ipc = "56.2"
|
arrow-ipc = "56.2"
|
||||||
arrow-ord = "56.2"
|
arrow-ord = "56.2"
|
||||||
arrow-schema = "56.2"
|
arrow-schema = "56.2"
|
||||||
|
arrow-select = "56.2"
|
||||||
arrow-cast = "56.2"
|
arrow-cast = "56.2"
|
||||||
async-trait = "0"
|
async-trait = "0"
|
||||||
datafusion = { version = "50.1", default-features = false }
|
datafusion = { version = "50.1", default-features = false }
|
||||||
@@ -48,27 +55,11 @@ log = "0.4"
|
|||||||
moka = { version = "0.12", features = ["future"] }
|
moka = { version = "0.12", features = ["future"] }
|
||||||
object_store = "0.12.0"
|
object_store = "0.12.0"
|
||||||
pin-project = "1.0.7"
|
pin-project = "1.0.7"
|
||||||
|
rand = "0.9"
|
||||||
snafu = "0.8"
|
snafu = "0.8"
|
||||||
url = "2"
|
url = "2"
|
||||||
num-traits = "0.2"
|
num-traits = "0.2"
|
||||||
regex = "1.10"
|
regex = "1.10"
|
||||||
lazy_static = "1"
|
lazy_static = "1"
|
||||||
semver = "1.0.25"
|
semver = "1.0.25"
|
||||||
crunchy = "0.2.4"
|
chrono = "0.4"
|
||||||
# Temporary pins to work around downstream issues
|
|
||||||
# https://github.com/apache/arrow-rs/commit/2fddf85afcd20110ce783ed5b4cdeb82293da30b
|
|
||||||
chrono = "=0.4.41"
|
|
||||||
# Workaround for: https://github.com/Lokathor/bytemuck/issues/306
|
|
||||||
bytemuck_derive = ">=1.8.1, <1.9.0"
|
|
||||||
|
|
||||||
# This is only needed when we reference preview releases of lance
|
|
||||||
# [patch.crates-io]
|
|
||||||
# # Force to use the same lance version as the rest of the project to avoid duplicate dependencies
|
|
||||||
# lance = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
|
|
||||||
# lance-io = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
|
|
||||||
# lance-index = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
|
|
||||||
# lance-linalg = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
|
|
||||||
# lance-table = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
|
|
||||||
# lance-testing = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
|
|
||||||
# lance-datafusion = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
|
|
||||||
# lance-encoding = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
|
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ def extract_features(line: str) -> list:
|
|||||||
match = re.search(r'"features"\s*=\s*\[\s*(.*?)\s*\]', line, re.DOTALL)
|
match = re.search(r'"features"\s*=\s*\[\s*(.*?)\s*\]', line, re.DOTALL)
|
||||||
if match:
|
if match:
|
||||||
features_str = match.group(1)
|
features_str = match.group(1)
|
||||||
return [f.strip('"') for f in features_str.split(",") if len(f) > 0]
|
return [f.strip().strip('"') for f in features_str.split(",") if f.strip()]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
@@ -117,7 +117,7 @@ def update_cargo_toml(line_updater):
|
|||||||
lance_line = ""
|
lance_line = ""
|
||||||
is_parsing_lance_line = False
|
is_parsing_lance_line = False
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if line.startswith("lance") and not line.startswith("lance-namespace"):
|
if line.startswith("lance"):
|
||||||
# Check if this is a single-line or multi-line entry
|
# Check if this is a single-line or multi-line entry
|
||||||
# Single-line entries either:
|
# Single-line entries either:
|
||||||
# 1. End with } (complete inline table)
|
# 1. End with } (complete inline table)
|
||||||
@@ -183,10 +183,8 @@ def set_preview_version(version: str):
|
|||||||
|
|
||||||
def line_updater(line: str) -> str:
|
def line_updater(line: str) -> str:
|
||||||
package_name = line.split("=", maxsplit=1)[0].strip()
|
package_name = line.split("=", maxsplit=1)[0].strip()
|
||||||
base_version = version.split("-")[0] # Get the base version without beta suffix
|
|
||||||
|
|
||||||
# Build config in desired order: version, default-features, features, tag, git
|
# Build config in desired order: version, default-features, features, tag, git
|
||||||
config = {"version": f"={base_version}"}
|
config = {"version": f"={version}"}
|
||||||
|
|
||||||
if extract_default_features(line):
|
if extract_default_features(line):
|
||||||
config["default-features"] = False
|
config["default-features"] = False
|
||||||
|
|||||||
@@ -0,0 +1,97 @@
|
|||||||
|
# VoyageAI Embeddings : Multimodal
|
||||||
|
|
||||||
|
VoyageAI embeddings can also be used to embed both text and image data, only some of the models support image data and you can check the list
|
||||||
|
under [https://docs.voyageai.com/docs/multimodal-embeddings](https://docs.voyageai.com/docs/multimodal-embeddings)
|
||||||
|
|
||||||
|
Supported parameters (to be passed in `create` method) are:
|
||||||
|
|
||||||
|
| Parameter | Type | Default Value | Description |
|
||||||
|
|---|---|-------------------------|-------------------------------------------|
|
||||||
|
| `name` | `str` | `"voyage-multimodal-3"` | The model ID of the VoyageAI model to use |
|
||||||
|
|
||||||
|
Usage Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import lancedb
|
||||||
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
from lancedb.embeddings import get_registry
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
os.environ['VOYAGE_API_KEY'] = 'YOUR_VOYAGE_API_KEY'
|
||||||
|
|
||||||
|
db = lancedb.connect(".lancedb")
|
||||||
|
func = get_registry().get("voyageai").create(name="voyage-multimodal-3")
|
||||||
|
|
||||||
|
|
||||||
|
def image_to_base64(image_bytes: bytes):
|
||||||
|
buffered = BytesIO(image_bytes)
|
||||||
|
img_str = base64.b64encode(buffered.getvalue())
|
||||||
|
return img_str.decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
class Images(LanceModel):
|
||||||
|
label: str
|
||||||
|
image_uri: str = func.SourceField() # image uri as the source
|
||||||
|
image_bytes: str = func.SourceField() # image bytes base64 encoded as the source
|
||||||
|
vector: Vector(func.ndims()) = func.VectorField() # vector column
|
||||||
|
vec_from_bytes: Vector(func.ndims()) = func.VectorField() # Another vector column
|
||||||
|
|
||||||
|
|
||||||
|
if "images" in db.table_names():
|
||||||
|
db.drop_table("images")
|
||||||
|
table = db.create_table("images", schema=Images)
|
||||||
|
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
|
||||||
|
uris = [
|
||||||
|
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||||
|
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
||||||
|
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
||||||
|
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
||||||
|
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
||||||
|
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
||||||
|
]
|
||||||
|
# get each uri as bytes
|
||||||
|
images_bytes = [image_to_base64(requests.get(uri).content) for uri in uris]
|
||||||
|
table.add(
|
||||||
|
pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": images_bytes})
|
||||||
|
)
|
||||||
|
```
|
||||||
|
Now we can search using text from both the default vector column and the custom vector column
|
||||||
|
```python
|
||||||
|
|
||||||
|
# text search
|
||||||
|
actual = table.search("man's best friend", "vec_from_bytes").limit(1).to_pydantic(Images)[0]
|
||||||
|
print(actual.label) # prints "dog"
|
||||||
|
|
||||||
|
frombytes = (
|
||||||
|
table.search("man's best friend", vector_column_name="vec_from_bytes")
|
||||||
|
.limit(1)
|
||||||
|
.to_pydantic(Images)[0]
|
||||||
|
)
|
||||||
|
print(frombytes.label)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
Because we're using a multi-modal embedding function, we can also search using images
|
||||||
|
|
||||||
|
```python
|
||||||
|
# image search
|
||||||
|
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
|
||||||
|
image_bytes = requests.get(query_image_uri).content
|
||||||
|
query_image = Image.open(BytesIO(image_bytes))
|
||||||
|
actual = table.search(query_image, "vec_from_bytes").limit(1).to_pydantic(Images)[0]
|
||||||
|
print(actual.label == "dog")
|
||||||
|
|
||||||
|
# image search using a custom vector column
|
||||||
|
other = (
|
||||||
|
table.search(query_image, vector_column_name="vec_from_bytes")
|
||||||
|
.limit(1)
|
||||||
|
.to_pydantic(Images)[0]
|
||||||
|
)
|
||||||
|
print(actual.label)
|
||||||
|
|
||||||
|
```
|
||||||
@@ -397,117 +397,6 @@ For **read-only access**, LanceDB will need a policy such as:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
#### DynamoDB Commit Store for concurrent writes
|
|
||||||
|
|
||||||
By default, S3 does not support concurrent writes. Having two or more processes
|
|
||||||
writing to the same table at the same time can lead to data corruption. This is
|
|
||||||
because S3, unlike other object stores, does not have any atomic put or copy
|
|
||||||
operation.
|
|
||||||
|
|
||||||
To enable concurrent writes, you can configure LanceDB to use a DynamoDB table
|
|
||||||
as a commit store. This table will be used to coordinate writes between
|
|
||||||
different processes. To enable this feature, you must modify your connection
|
|
||||||
URI to use the `s3+ddb` scheme and add a query parameter `ddbTableName` with the
|
|
||||||
name of the table to use.
|
|
||||||
|
|
||||||
=== "Python"
|
|
||||||
|
|
||||||
=== "Sync API"
|
|
||||||
|
|
||||||
```python
|
|
||||||
import lancedb
|
|
||||||
db = lancedb.connect(
|
|
||||||
"s3+ddb://bucket/path?ddbTableName=my-dynamodb-table",
|
|
||||||
)
|
|
||||||
```
|
|
||||||
=== "Async API"
|
|
||||||
|
|
||||||
```python
|
|
||||||
import lancedb
|
|
||||||
async_db = await lancedb.connect_async(
|
|
||||||
"s3+ddb://bucket/path?ddbTableName=my-dynamodb-table",
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
=== "JavaScript"
|
|
||||||
|
|
||||||
```javascript
|
|
||||||
const lancedb = require("lancedb");
|
|
||||||
|
|
||||||
const db = await lancedb.connect(
|
|
||||||
"s3+ddb://bucket/path?ddbTableName=my-dynamodb-table",
|
|
||||||
);
|
|
||||||
```
|
|
||||||
|
|
||||||
The DynamoDB table must be created with the following schema:
|
|
||||||
|
|
||||||
- Hash key: `base_uri` (string)
|
|
||||||
- Range key: `version` (number)
|
|
||||||
|
|
||||||
You can create this programmatically with:
|
|
||||||
|
|
||||||
=== "Python"
|
|
||||||
|
|
||||||
<!-- skip-test -->
|
|
||||||
```python
|
|
||||||
import boto3
|
|
||||||
|
|
||||||
dynamodb = boto3.client("dynamodb")
|
|
||||||
table = dynamodb.create_table(
|
|
||||||
TableName=table_name,
|
|
||||||
KeySchema=[
|
|
||||||
{"AttributeName": "base_uri", "KeyType": "HASH"},
|
|
||||||
{"AttributeName": "version", "KeyType": "RANGE"},
|
|
||||||
],
|
|
||||||
AttributeDefinitions=[
|
|
||||||
{"AttributeName": "base_uri", "AttributeType": "S"},
|
|
||||||
{"AttributeName": "version", "AttributeType": "N"},
|
|
||||||
],
|
|
||||||
ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1},
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
=== "JavaScript"
|
|
||||||
|
|
||||||
<!-- skip-test -->
|
|
||||||
```javascript
|
|
||||||
import {
|
|
||||||
CreateTableCommand,
|
|
||||||
DynamoDBClient,
|
|
||||||
} from "@aws-sdk/client-dynamodb";
|
|
||||||
|
|
||||||
const dynamodb = new DynamoDBClient({
|
|
||||||
region: CONFIG.awsRegion,
|
|
||||||
credentials: {
|
|
||||||
accessKeyId: CONFIG.awsAccessKeyId,
|
|
||||||
secretAccessKey: CONFIG.awsSecretAccessKey,
|
|
||||||
},
|
|
||||||
endpoint: CONFIG.awsEndpoint,
|
|
||||||
});
|
|
||||||
const command = new CreateTableCommand({
|
|
||||||
TableName: table_name,
|
|
||||||
AttributeDefinitions: [
|
|
||||||
{
|
|
||||||
AttributeName: "base_uri",
|
|
||||||
AttributeType: "S",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
AttributeName: "version",
|
|
||||||
AttributeType: "N",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
KeySchema: [
|
|
||||||
{ AttributeName: "base_uri", KeyType: "HASH" },
|
|
||||||
{ AttributeName: "version", KeyType: "RANGE" },
|
|
||||||
],
|
|
||||||
ProvisionedThroughput: {
|
|
||||||
ReadCapacityUnits: 1,
|
|
||||||
WriteCapacityUnits: 1,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
await client.send(command);
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
#### S3-compatible stores
|
#### S3-compatible stores
|
||||||
|
|
||||||
|
|||||||
@@ -194,6 +194,37 @@ currently is also a memory intensive operation.
|
|||||||
|
|
||||||
***
|
***
|
||||||
|
|
||||||
|
### ivfRq()
|
||||||
|
|
||||||
|
```ts
|
||||||
|
static ivfRq(options?): Index
|
||||||
|
```
|
||||||
|
|
||||||
|
Create an IvfRq index
|
||||||
|
|
||||||
|
IVF-RQ (RabitQ Quantization) compresses vectors using RabitQ quantization
|
||||||
|
and organizes them into IVF partitions.
|
||||||
|
|
||||||
|
The compression scheme is called RabitQ quantization. Each dimension is quantized into a small number of bits.
|
||||||
|
The parameters `num_bits` and `num_partitions` control this process, providing a tradeoff
|
||||||
|
between index size (and thus search speed) and index accuracy.
|
||||||
|
|
||||||
|
The partitioning process is called IVF and the `num_partitions` parameter controls how
|
||||||
|
many groups to create.
|
||||||
|
|
||||||
|
Note that training an IVF RQ index on a large dataset is a slow operation and
|
||||||
|
currently is also a memory intensive operation.
|
||||||
|
|
||||||
|
#### Parameters
|
||||||
|
|
||||||
|
* **options?**: `Partial`<[`IvfRqOptions`](../interfaces/IvfRqOptions.md)>
|
||||||
|
|
||||||
|
#### Returns
|
||||||
|
|
||||||
|
[`Index`](Index.md)
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
### labelList()
|
### labelList()
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
|
|||||||
250
docs/src/js/classes/PermutationBuilder.md
Normal file
250
docs/src/js/classes/PermutationBuilder.md
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
[@lancedb/lancedb](../globals.md) / PermutationBuilder
|
||||||
|
|
||||||
|
# Class: PermutationBuilder
|
||||||
|
|
||||||
|
A PermutationBuilder for creating data permutations with splits, shuffling, and filtering.
|
||||||
|
|
||||||
|
This class provides a TypeScript wrapper around the native Rust PermutationBuilder,
|
||||||
|
offering methods to configure data splits, shuffling, and filtering before executing
|
||||||
|
the permutation to create a new table.
|
||||||
|
|
||||||
|
## Methods
|
||||||
|
|
||||||
|
### execute()
|
||||||
|
|
||||||
|
```ts
|
||||||
|
execute(): Promise<Table>
|
||||||
|
```
|
||||||
|
|
||||||
|
Execute the permutation and create the destination table.
|
||||||
|
|
||||||
|
#### Returns
|
||||||
|
|
||||||
|
`Promise`<[`Table`](Table.md)>
|
||||||
|
|
||||||
|
A Promise that resolves to the new Table instance
|
||||||
|
|
||||||
|
#### Example
|
||||||
|
|
||||||
|
```ts
|
||||||
|
const permutationTable = await builder.execute();
|
||||||
|
console.log(`Created table: ${permutationTable.name}`);
|
||||||
|
```
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
### filter()
|
||||||
|
|
||||||
|
```ts
|
||||||
|
filter(filter): PermutationBuilder
|
||||||
|
```
|
||||||
|
|
||||||
|
Configure filtering for the permutation.
|
||||||
|
|
||||||
|
#### Parameters
|
||||||
|
|
||||||
|
* **filter**: `string`
|
||||||
|
SQL filter expression
|
||||||
|
|
||||||
|
#### Returns
|
||||||
|
|
||||||
|
[`PermutationBuilder`](PermutationBuilder.md)
|
||||||
|
|
||||||
|
A new PermutationBuilder instance
|
||||||
|
|
||||||
|
#### Example
|
||||||
|
|
||||||
|
```ts
|
||||||
|
builder.filter("age > 18 AND status = 'active'");
|
||||||
|
```
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
### persist()
|
||||||
|
|
||||||
|
```ts
|
||||||
|
persist(connection, tableName): PermutationBuilder
|
||||||
|
```
|
||||||
|
|
||||||
|
Configure the permutation to be persisted.
|
||||||
|
|
||||||
|
#### Parameters
|
||||||
|
|
||||||
|
* **connection**: [`Connection`](Connection.md)
|
||||||
|
The connection to persist the permutation to
|
||||||
|
|
||||||
|
* **tableName**: `string`
|
||||||
|
The name of the table to create
|
||||||
|
|
||||||
|
#### Returns
|
||||||
|
|
||||||
|
[`PermutationBuilder`](PermutationBuilder.md)
|
||||||
|
|
||||||
|
A new PermutationBuilder instance
|
||||||
|
|
||||||
|
#### Example
|
||||||
|
|
||||||
|
```ts
|
||||||
|
builder.persist(connection, "permutation_table");
|
||||||
|
```
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
### shuffle()
|
||||||
|
|
||||||
|
```ts
|
||||||
|
shuffle(options): PermutationBuilder
|
||||||
|
```
|
||||||
|
|
||||||
|
Configure shuffling for the permutation.
|
||||||
|
|
||||||
|
#### Parameters
|
||||||
|
|
||||||
|
* **options**: [`ShuffleOptions`](../interfaces/ShuffleOptions.md)
|
||||||
|
Configuration for shuffling
|
||||||
|
|
||||||
|
#### Returns
|
||||||
|
|
||||||
|
[`PermutationBuilder`](PermutationBuilder.md)
|
||||||
|
|
||||||
|
A new PermutationBuilder instance
|
||||||
|
|
||||||
|
#### Example
|
||||||
|
|
||||||
|
```ts
|
||||||
|
// Basic shuffle
|
||||||
|
builder.shuffle({ seed: 42 });
|
||||||
|
|
||||||
|
// Shuffle with clump size
|
||||||
|
builder.shuffle({ seed: 42, clumpSize: 10 });
|
||||||
|
```
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
### splitCalculated()
|
||||||
|
|
||||||
|
```ts
|
||||||
|
splitCalculated(options): PermutationBuilder
|
||||||
|
```
|
||||||
|
|
||||||
|
Configure calculated splits for the permutation.
|
||||||
|
|
||||||
|
#### Parameters
|
||||||
|
|
||||||
|
* **options**: [`SplitCalculatedOptions`](../interfaces/SplitCalculatedOptions.md)
|
||||||
|
Configuration for calculated splitting
|
||||||
|
|
||||||
|
#### Returns
|
||||||
|
|
||||||
|
[`PermutationBuilder`](PermutationBuilder.md)
|
||||||
|
|
||||||
|
A new PermutationBuilder instance
|
||||||
|
|
||||||
|
#### Example
|
||||||
|
|
||||||
|
```ts
|
||||||
|
builder.splitCalculated("user_id % 3");
|
||||||
|
```
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
### splitHash()
|
||||||
|
|
||||||
|
```ts
|
||||||
|
splitHash(options): PermutationBuilder
|
||||||
|
```
|
||||||
|
|
||||||
|
Configure hash-based splits for the permutation.
|
||||||
|
|
||||||
|
#### Parameters
|
||||||
|
|
||||||
|
* **options**: [`SplitHashOptions`](../interfaces/SplitHashOptions.md)
|
||||||
|
Configuration for hash-based splitting
|
||||||
|
|
||||||
|
#### Returns
|
||||||
|
|
||||||
|
[`PermutationBuilder`](PermutationBuilder.md)
|
||||||
|
|
||||||
|
A new PermutationBuilder instance
|
||||||
|
|
||||||
|
#### Example
|
||||||
|
|
||||||
|
```ts
|
||||||
|
builder.splitHash({
|
||||||
|
columns: ["user_id"],
|
||||||
|
splitWeights: [70, 30],
|
||||||
|
discardWeight: 0
|
||||||
|
});
|
||||||
|
```
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
### splitRandom()
|
||||||
|
|
||||||
|
```ts
|
||||||
|
splitRandom(options): PermutationBuilder
|
||||||
|
```
|
||||||
|
|
||||||
|
Configure random splits for the permutation.
|
||||||
|
|
||||||
|
#### Parameters
|
||||||
|
|
||||||
|
* **options**: [`SplitRandomOptions`](../interfaces/SplitRandomOptions.md)
|
||||||
|
Configuration for random splitting
|
||||||
|
|
||||||
|
#### Returns
|
||||||
|
|
||||||
|
[`PermutationBuilder`](PermutationBuilder.md)
|
||||||
|
|
||||||
|
A new PermutationBuilder instance
|
||||||
|
|
||||||
|
#### Example
|
||||||
|
|
||||||
|
```ts
|
||||||
|
// Split by ratios
|
||||||
|
builder.splitRandom({ ratios: [0.7, 0.3], seed: 42 });
|
||||||
|
|
||||||
|
// Split by counts
|
||||||
|
builder.splitRandom({ counts: [1000, 500], seed: 42 });
|
||||||
|
|
||||||
|
// Split with fixed size
|
||||||
|
builder.splitRandom({ fixed: 100, seed: 42 });
|
||||||
|
```
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
### splitSequential()
|
||||||
|
|
||||||
|
```ts
|
||||||
|
splitSequential(options): PermutationBuilder
|
||||||
|
```
|
||||||
|
|
||||||
|
Configure sequential splits for the permutation.
|
||||||
|
|
||||||
|
#### Parameters
|
||||||
|
|
||||||
|
* **options**: [`SplitSequentialOptions`](../interfaces/SplitSequentialOptions.md)
|
||||||
|
Configuration for sequential splitting
|
||||||
|
|
||||||
|
#### Returns
|
||||||
|
|
||||||
|
[`PermutationBuilder`](PermutationBuilder.md)
|
||||||
|
|
||||||
|
A new PermutationBuilder instance
|
||||||
|
|
||||||
|
#### Example
|
||||||
|
|
||||||
|
```ts
|
||||||
|
// Split by ratios
|
||||||
|
builder.splitSequential({ ratios: [0.8, 0.2] });
|
||||||
|
|
||||||
|
// Split by counts
|
||||||
|
builder.splitSequential({ counts: [800, 200] });
|
||||||
|
|
||||||
|
// Split with fixed size
|
||||||
|
builder.splitSequential({ fixed: 1000 });
|
||||||
|
```
|
||||||
@@ -80,7 +80,7 @@ AnalyzeExec verbose=true, metrics=[]
|
|||||||
### execute()
|
### execute()
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
protected execute(options?): RecordBatchIterator
|
protected execute(options?): AsyncGenerator<RecordBatch<any>, void, unknown>
|
||||||
```
|
```
|
||||||
|
|
||||||
Execute the query and return the results as an
|
Execute the query and return the results as an
|
||||||
@@ -91,7 +91,7 @@ Execute the query and return the results as an
|
|||||||
|
|
||||||
#### Returns
|
#### Returns
|
||||||
|
|
||||||
[`RecordBatchIterator`](RecordBatchIterator.md)
|
`AsyncGenerator`<`RecordBatch`<`any`>, `void`, `unknown`>
|
||||||
|
|
||||||
#### See
|
#### See
|
||||||
|
|
||||||
@@ -343,6 +343,29 @@ This is useful for pagination.
|
|||||||
|
|
||||||
***
|
***
|
||||||
|
|
||||||
|
### outputSchema()
|
||||||
|
|
||||||
|
```ts
|
||||||
|
outputSchema(): Promise<Schema<any>>
|
||||||
|
```
|
||||||
|
|
||||||
|
Returns the schema of the output that will be returned by this query.
|
||||||
|
|
||||||
|
This can be used to inspect the types and names of the columns that will be
|
||||||
|
returned by the query before executing it.
|
||||||
|
|
||||||
|
#### Returns
|
||||||
|
|
||||||
|
`Promise`<`Schema`<`any`>>
|
||||||
|
|
||||||
|
An Arrow Schema describing the output columns.
|
||||||
|
|
||||||
|
#### Inherited from
|
||||||
|
|
||||||
|
`StandardQueryBase.outputSchema`
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
### select()
|
### select()
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ AnalyzeExec verbose=true, metrics=[]
|
|||||||
### execute()
|
### execute()
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
protected execute(options?): RecordBatchIterator
|
protected execute(options?): AsyncGenerator<RecordBatch<any>, void, unknown>
|
||||||
```
|
```
|
||||||
|
|
||||||
Execute the query and return the results as an
|
Execute the query and return the results as an
|
||||||
@@ -92,7 +92,7 @@ Execute the query and return the results as an
|
|||||||
|
|
||||||
#### Returns
|
#### Returns
|
||||||
|
|
||||||
[`RecordBatchIterator`](RecordBatchIterator.md)
|
`AsyncGenerator`<`RecordBatch`<`any`>, `void`, `unknown`>
|
||||||
|
|
||||||
#### See
|
#### See
|
||||||
|
|
||||||
@@ -140,6 +140,25 @@ const plan = await table.query().nearestTo([0.5, 0.2]).explainPlan();
|
|||||||
|
|
||||||
***
|
***
|
||||||
|
|
||||||
|
### outputSchema()
|
||||||
|
|
||||||
|
```ts
|
||||||
|
outputSchema(): Promise<Schema<any>>
|
||||||
|
```
|
||||||
|
|
||||||
|
Returns the schema of the output that will be returned by this query.
|
||||||
|
|
||||||
|
This can be used to inspect the types and names of the columns that will be
|
||||||
|
returned by the query before executing it.
|
||||||
|
|
||||||
|
#### Returns
|
||||||
|
|
||||||
|
`Promise`<`Schema`<`any`>>
|
||||||
|
|
||||||
|
An Arrow Schema describing the output columns.
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
### select()
|
### select()
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
|
|||||||
@@ -1,43 +0,0 @@
|
|||||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
|
||||||
|
|
||||||
***
|
|
||||||
|
|
||||||
[@lancedb/lancedb](../globals.md) / RecordBatchIterator
|
|
||||||
|
|
||||||
# Class: RecordBatchIterator
|
|
||||||
|
|
||||||
## Implements
|
|
||||||
|
|
||||||
- `AsyncIterator`<`RecordBatch`>
|
|
||||||
|
|
||||||
## Constructors
|
|
||||||
|
|
||||||
### new RecordBatchIterator()
|
|
||||||
|
|
||||||
```ts
|
|
||||||
new RecordBatchIterator(promise?): RecordBatchIterator
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Parameters
|
|
||||||
|
|
||||||
* **promise?**: `Promise`<`RecordBatchIterator`>
|
|
||||||
|
|
||||||
#### Returns
|
|
||||||
|
|
||||||
[`RecordBatchIterator`](RecordBatchIterator.md)
|
|
||||||
|
|
||||||
## Methods
|
|
||||||
|
|
||||||
### next()
|
|
||||||
|
|
||||||
```ts
|
|
||||||
next(): Promise<IteratorResult<RecordBatch<any>, any>>
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Returns
|
|
||||||
|
|
||||||
`Promise`<`IteratorResult`<`RecordBatch`<`any`>, `any`>>
|
|
||||||
|
|
||||||
#### Implementation of
|
|
||||||
|
|
||||||
`AsyncIterator.next`
|
|
||||||
@@ -76,7 +76,7 @@ AnalyzeExec verbose=true, metrics=[]
|
|||||||
### execute()
|
### execute()
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
protected execute(options?): RecordBatchIterator
|
protected execute(options?): AsyncGenerator<RecordBatch<any>, void, unknown>
|
||||||
```
|
```
|
||||||
|
|
||||||
Execute the query and return the results as an
|
Execute the query and return the results as an
|
||||||
@@ -87,7 +87,7 @@ Execute the query and return the results as an
|
|||||||
|
|
||||||
#### Returns
|
#### Returns
|
||||||
|
|
||||||
[`RecordBatchIterator`](RecordBatchIterator.md)
|
`AsyncGenerator`<`RecordBatch`<`any`>, `void`, `unknown`>
|
||||||
|
|
||||||
#### See
|
#### See
|
||||||
|
|
||||||
@@ -143,6 +143,29 @@ const plan = await table.query().nearestTo([0.5, 0.2]).explainPlan();
|
|||||||
|
|
||||||
***
|
***
|
||||||
|
|
||||||
|
### outputSchema()
|
||||||
|
|
||||||
|
```ts
|
||||||
|
outputSchema(): Promise<Schema<any>>
|
||||||
|
```
|
||||||
|
|
||||||
|
Returns the schema of the output that will be returned by this query.
|
||||||
|
|
||||||
|
This can be used to inspect the types and names of the columns that will be
|
||||||
|
returned by the query before executing it.
|
||||||
|
|
||||||
|
#### Returns
|
||||||
|
|
||||||
|
`Promise`<`Schema`<`any`>>
|
||||||
|
|
||||||
|
An Arrow Schema describing the output columns.
|
||||||
|
|
||||||
|
#### Inherited from
|
||||||
|
|
||||||
|
[`QueryBase`](QueryBase.md).[`outputSchema`](QueryBase.md#outputschema)
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
### select()
|
### select()
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
|
|||||||
@@ -221,7 +221,7 @@ also increase the latency of your query. The default value is 1.5*limit.
|
|||||||
### execute()
|
### execute()
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
protected execute(options?): RecordBatchIterator
|
protected execute(options?): AsyncGenerator<RecordBatch<any>, void, unknown>
|
||||||
```
|
```
|
||||||
|
|
||||||
Execute the query and return the results as an
|
Execute the query and return the results as an
|
||||||
@@ -232,7 +232,7 @@ Execute the query and return the results as an
|
|||||||
|
|
||||||
#### Returns
|
#### Returns
|
||||||
|
|
||||||
[`RecordBatchIterator`](RecordBatchIterator.md)
|
`AsyncGenerator`<`RecordBatch`<`any`>, `void`, `unknown`>
|
||||||
|
|
||||||
#### See
|
#### See
|
||||||
|
|
||||||
@@ -498,6 +498,29 @@ This is useful for pagination.
|
|||||||
|
|
||||||
***
|
***
|
||||||
|
|
||||||
|
### outputSchema()
|
||||||
|
|
||||||
|
```ts
|
||||||
|
outputSchema(): Promise<Schema<any>>
|
||||||
|
```
|
||||||
|
|
||||||
|
Returns the schema of the output that will be returned by this query.
|
||||||
|
|
||||||
|
This can be used to inspect the types and names of the columns that will be
|
||||||
|
returned by the query before executing it.
|
||||||
|
|
||||||
|
#### Returns
|
||||||
|
|
||||||
|
`Promise`<`Schema`<`any`>>
|
||||||
|
|
||||||
|
An Arrow Schema describing the output columns.
|
||||||
|
|
||||||
|
#### Inherited from
|
||||||
|
|
||||||
|
`StandardQueryBase.outputSchema`
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
### postfilter()
|
### postfilter()
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
|
|||||||
19
docs/src/js/functions/RecordBatchIterator.md
Normal file
19
docs/src/js/functions/RecordBatchIterator.md
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
[@lancedb/lancedb](../globals.md) / RecordBatchIterator
|
||||||
|
|
||||||
|
# Function: RecordBatchIterator()
|
||||||
|
|
||||||
|
```ts
|
||||||
|
function RecordBatchIterator(promisedInner): AsyncGenerator<RecordBatch<any>, void, unknown>
|
||||||
|
```
|
||||||
|
|
||||||
|
## Parameters
|
||||||
|
|
||||||
|
* **promisedInner**: `Promise`<`RecordBatchIterator`>
|
||||||
|
|
||||||
|
## Returns
|
||||||
|
|
||||||
|
`AsyncGenerator`<`RecordBatch`<`any`>, `void`, `unknown`>
|
||||||
34
docs/src/js/functions/permutationBuilder.md
Normal file
34
docs/src/js/functions/permutationBuilder.md
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
[@lancedb/lancedb](../globals.md) / permutationBuilder
|
||||||
|
|
||||||
|
# Function: permutationBuilder()
|
||||||
|
|
||||||
|
```ts
|
||||||
|
function permutationBuilder(table): PermutationBuilder
|
||||||
|
```
|
||||||
|
|
||||||
|
Create a permutation builder for the given table.
|
||||||
|
|
||||||
|
## Parameters
|
||||||
|
|
||||||
|
* **table**: [`Table`](../classes/Table.md)
|
||||||
|
The source table to create a permutation from
|
||||||
|
|
||||||
|
## Returns
|
||||||
|
|
||||||
|
[`PermutationBuilder`](../classes/PermutationBuilder.md)
|
||||||
|
|
||||||
|
A PermutationBuilder instance
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
```ts
|
||||||
|
const builder = permutationBuilder(sourceTable, "training_data")
|
||||||
|
.splitRandom({ ratios: [0.8, 0.2], seed: 42 })
|
||||||
|
.shuffle({ seed: 123 });
|
||||||
|
|
||||||
|
const trainingTable = await builder.execute();
|
||||||
|
```
|
||||||
@@ -28,10 +28,10 @@
|
|||||||
- [MultiMatchQuery](classes/MultiMatchQuery.md)
|
- [MultiMatchQuery](classes/MultiMatchQuery.md)
|
||||||
- [NativeJsHeaderProvider](classes/NativeJsHeaderProvider.md)
|
- [NativeJsHeaderProvider](classes/NativeJsHeaderProvider.md)
|
||||||
- [OAuthHeaderProvider](classes/OAuthHeaderProvider.md)
|
- [OAuthHeaderProvider](classes/OAuthHeaderProvider.md)
|
||||||
|
- [PermutationBuilder](classes/PermutationBuilder.md)
|
||||||
- [PhraseQuery](classes/PhraseQuery.md)
|
- [PhraseQuery](classes/PhraseQuery.md)
|
||||||
- [Query](classes/Query.md)
|
- [Query](classes/Query.md)
|
||||||
- [QueryBase](classes/QueryBase.md)
|
- [QueryBase](classes/QueryBase.md)
|
||||||
- [RecordBatchIterator](classes/RecordBatchIterator.md)
|
|
||||||
- [Session](classes/Session.md)
|
- [Session](classes/Session.md)
|
||||||
- [StaticHeaderProvider](classes/StaticHeaderProvider.md)
|
- [StaticHeaderProvider](classes/StaticHeaderProvider.md)
|
||||||
- [Table](classes/Table.md)
|
- [Table](classes/Table.md)
|
||||||
@@ -68,6 +68,7 @@
|
|||||||
- [IndexStatistics](interfaces/IndexStatistics.md)
|
- [IndexStatistics](interfaces/IndexStatistics.md)
|
||||||
- [IvfFlatOptions](interfaces/IvfFlatOptions.md)
|
- [IvfFlatOptions](interfaces/IvfFlatOptions.md)
|
||||||
- [IvfPqOptions](interfaces/IvfPqOptions.md)
|
- [IvfPqOptions](interfaces/IvfPqOptions.md)
|
||||||
|
- [IvfRqOptions](interfaces/IvfRqOptions.md)
|
||||||
- [MergeResult](interfaces/MergeResult.md)
|
- [MergeResult](interfaces/MergeResult.md)
|
||||||
- [OpenTableOptions](interfaces/OpenTableOptions.md)
|
- [OpenTableOptions](interfaces/OpenTableOptions.md)
|
||||||
- [OptimizeOptions](interfaces/OptimizeOptions.md)
|
- [OptimizeOptions](interfaces/OptimizeOptions.md)
|
||||||
@@ -75,6 +76,11 @@
|
|||||||
- [QueryExecutionOptions](interfaces/QueryExecutionOptions.md)
|
- [QueryExecutionOptions](interfaces/QueryExecutionOptions.md)
|
||||||
- [RemovalStats](interfaces/RemovalStats.md)
|
- [RemovalStats](interfaces/RemovalStats.md)
|
||||||
- [RetryConfig](interfaces/RetryConfig.md)
|
- [RetryConfig](interfaces/RetryConfig.md)
|
||||||
|
- [ShuffleOptions](interfaces/ShuffleOptions.md)
|
||||||
|
- [SplitCalculatedOptions](interfaces/SplitCalculatedOptions.md)
|
||||||
|
- [SplitHashOptions](interfaces/SplitHashOptions.md)
|
||||||
|
- [SplitRandomOptions](interfaces/SplitRandomOptions.md)
|
||||||
|
- [SplitSequentialOptions](interfaces/SplitSequentialOptions.md)
|
||||||
- [TableNamesOptions](interfaces/TableNamesOptions.md)
|
- [TableNamesOptions](interfaces/TableNamesOptions.md)
|
||||||
- [TableStatistics](interfaces/TableStatistics.md)
|
- [TableStatistics](interfaces/TableStatistics.md)
|
||||||
- [TimeoutConfig](interfaces/TimeoutConfig.md)
|
- [TimeoutConfig](interfaces/TimeoutConfig.md)
|
||||||
@@ -99,6 +105,8 @@
|
|||||||
|
|
||||||
## Functions
|
## Functions
|
||||||
|
|
||||||
|
- [RecordBatchIterator](functions/RecordBatchIterator.md)
|
||||||
- [connect](functions/connect.md)
|
- [connect](functions/connect.md)
|
||||||
- [makeArrowTable](functions/makeArrowTable.md)
|
- [makeArrowTable](functions/makeArrowTable.md)
|
||||||
- [packBits](functions/packBits.md)
|
- [packBits](functions/packBits.md)
|
||||||
|
- [permutationBuilder](functions/permutationBuilder.md)
|
||||||
|
|||||||
101
docs/src/js/interfaces/IvfRqOptions.md
Normal file
101
docs/src/js/interfaces/IvfRqOptions.md
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
[@lancedb/lancedb](../globals.md) / IvfRqOptions
|
||||||
|
|
||||||
|
# Interface: IvfRqOptions
|
||||||
|
|
||||||
|
## Properties
|
||||||
|
|
||||||
|
### distanceType?
|
||||||
|
|
||||||
|
```ts
|
||||||
|
optional distanceType: "l2" | "cosine" | "dot";
|
||||||
|
```
|
||||||
|
|
||||||
|
Distance type to use to build the index.
|
||||||
|
|
||||||
|
Default value is "l2".
|
||||||
|
|
||||||
|
This is used when training the index to calculate the IVF partitions
|
||||||
|
(vectors are grouped in partitions with similar vectors according to this
|
||||||
|
distance type) and during quantization.
|
||||||
|
|
||||||
|
The distance type used to train an index MUST match the distance type used
|
||||||
|
to search the index. Failure to do so will yield inaccurate results.
|
||||||
|
|
||||||
|
The following distance types are available:
|
||||||
|
|
||||||
|
"l2" - Euclidean distance.
|
||||||
|
"cosine" - Cosine distance.
|
||||||
|
"dot" - Dot product.
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
### maxIterations?
|
||||||
|
|
||||||
|
```ts
|
||||||
|
optional maxIterations: number;
|
||||||
|
```
|
||||||
|
|
||||||
|
Max iterations to train IVF kmeans.
|
||||||
|
|
||||||
|
When training an IVF index we use kmeans to calculate the partitions. This parameter
|
||||||
|
controls how many iterations of kmeans to run.
|
||||||
|
|
||||||
|
The default value is 50.
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
### numBits?
|
||||||
|
|
||||||
|
```ts
|
||||||
|
optional numBits: number;
|
||||||
|
```
|
||||||
|
|
||||||
|
Number of bits per dimension for residual quantization.
|
||||||
|
|
||||||
|
This value controls how much each residual component is compressed. The more
|
||||||
|
bits, the more accurate the index will be but the slower search. Typical values
|
||||||
|
are small integers; the default is 1 bit per dimension.
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
### numPartitions?
|
||||||
|
|
||||||
|
```ts
|
||||||
|
optional numPartitions: number;
|
||||||
|
```
|
||||||
|
|
||||||
|
The number of IVF partitions to create.
|
||||||
|
|
||||||
|
This value should generally scale with the number of rows in the dataset.
|
||||||
|
By default the number of partitions is the square root of the number of
|
||||||
|
rows.
|
||||||
|
|
||||||
|
If this value is too large then the first part of the search (picking the
|
||||||
|
right partition) will be slow. If this value is too small then the second
|
||||||
|
part of the search (searching within a partition) will be slow.
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
### sampleRate?
|
||||||
|
|
||||||
|
```ts
|
||||||
|
optional sampleRate: number;
|
||||||
|
```
|
||||||
|
|
||||||
|
The number of vectors, per partition, to sample when training IVF kmeans.
|
||||||
|
|
||||||
|
When an IVF index is trained, we need to calculate partitions. These are groups
|
||||||
|
of vectors that are similar to each other. To do this we use an algorithm called kmeans.
|
||||||
|
|
||||||
|
Running kmeans on a large dataset can be slow. To speed this up we run kmeans on a
|
||||||
|
random sample of the data. This parameter controls the size of the sample. The total
|
||||||
|
number of vectors used to train the index is `sample_rate * num_partitions`.
|
||||||
|
|
||||||
|
Increasing this value might improve the quality of the index but in most cases the
|
||||||
|
default should be sufficient.
|
||||||
|
|
||||||
|
The default value is 256.
|
||||||
23
docs/src/js/interfaces/ShuffleOptions.md
Normal file
23
docs/src/js/interfaces/ShuffleOptions.md
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
[@lancedb/lancedb](../globals.md) / ShuffleOptions
|
||||||
|
|
||||||
|
# Interface: ShuffleOptions
|
||||||
|
|
||||||
|
## Properties
|
||||||
|
|
||||||
|
### clumpSize?
|
||||||
|
|
||||||
|
```ts
|
||||||
|
optional clumpSize: number;
|
||||||
|
```
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
### seed?
|
||||||
|
|
||||||
|
```ts
|
||||||
|
optional seed: number;
|
||||||
|
```
|
||||||
23
docs/src/js/interfaces/SplitCalculatedOptions.md
Normal file
23
docs/src/js/interfaces/SplitCalculatedOptions.md
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
[@lancedb/lancedb](../globals.md) / SplitCalculatedOptions
|
||||||
|
|
||||||
|
# Interface: SplitCalculatedOptions
|
||||||
|
|
||||||
|
## Properties
|
||||||
|
|
||||||
|
### calculation
|
||||||
|
|
||||||
|
```ts
|
||||||
|
calculation: string;
|
||||||
|
```
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
### splitNames?
|
||||||
|
|
||||||
|
```ts
|
||||||
|
optional splitNames: string[];
|
||||||
|
```
|
||||||
39
docs/src/js/interfaces/SplitHashOptions.md
Normal file
39
docs/src/js/interfaces/SplitHashOptions.md
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
[@lancedb/lancedb](../globals.md) / SplitHashOptions
|
||||||
|
|
||||||
|
# Interface: SplitHashOptions
|
||||||
|
|
||||||
|
## Properties
|
||||||
|
|
||||||
|
### columns
|
||||||
|
|
||||||
|
```ts
|
||||||
|
columns: string[];
|
||||||
|
```
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
### discardWeight?
|
||||||
|
|
||||||
|
```ts
|
||||||
|
optional discardWeight: number;
|
||||||
|
```
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
### splitNames?
|
||||||
|
|
||||||
|
```ts
|
||||||
|
optional splitNames: string[];
|
||||||
|
```
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
### splitWeights
|
||||||
|
|
||||||
|
```ts
|
||||||
|
splitWeights: number[];
|
||||||
|
```
|
||||||
47
docs/src/js/interfaces/SplitRandomOptions.md
Normal file
47
docs/src/js/interfaces/SplitRandomOptions.md
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
[@lancedb/lancedb](../globals.md) / SplitRandomOptions
|
||||||
|
|
||||||
|
# Interface: SplitRandomOptions
|
||||||
|
|
||||||
|
## Properties
|
||||||
|
|
||||||
|
### counts?
|
||||||
|
|
||||||
|
```ts
|
||||||
|
optional counts: number[];
|
||||||
|
```
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
### fixed?
|
||||||
|
|
||||||
|
```ts
|
||||||
|
optional fixed: number;
|
||||||
|
```
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
### ratios?
|
||||||
|
|
||||||
|
```ts
|
||||||
|
optional ratios: number[];
|
||||||
|
```
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
### seed?
|
||||||
|
|
||||||
|
```ts
|
||||||
|
optional seed: number;
|
||||||
|
```
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
### splitNames?
|
||||||
|
|
||||||
|
```ts
|
||||||
|
optional splitNames: string[];
|
||||||
|
```
|
||||||
39
docs/src/js/interfaces/SplitSequentialOptions.md
Normal file
39
docs/src/js/interfaces/SplitSequentialOptions.md
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
[@lancedb/lancedb](../globals.md) / SplitSequentialOptions
|
||||||
|
|
||||||
|
# Interface: SplitSequentialOptions
|
||||||
|
|
||||||
|
## Properties
|
||||||
|
|
||||||
|
### counts?
|
||||||
|
|
||||||
|
```ts
|
||||||
|
optional counts: number[];
|
||||||
|
```
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
### fixed?
|
||||||
|
|
||||||
|
```ts
|
||||||
|
optional fixed: number;
|
||||||
|
```
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
### ratios?
|
||||||
|
|
||||||
|
```ts
|
||||||
|
optional ratios: number[];
|
||||||
|
```
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
### splitNames?
|
||||||
|
|
||||||
|
```ts
|
||||||
|
optional splitNames: string[];
|
||||||
|
```
|
||||||
@@ -51,8 +51,11 @@ pub enum Error {
|
|||||||
DatasetAlreadyExists { uri: String, location: Location },
|
DatasetAlreadyExists { uri: String, location: Location },
|
||||||
#[snafu(display("Table '{name}' already exists"))]
|
#[snafu(display("Table '{name}' already exists"))]
|
||||||
TableAlreadyExists { name: String },
|
TableAlreadyExists { name: String },
|
||||||
#[snafu(display("Table '{name}' was not found"))]
|
#[snafu(display("Table '{name}' was not found: {source}"))]
|
||||||
TableNotFound { name: String },
|
TableNotFound {
|
||||||
|
name: String,
|
||||||
|
source: Box<dyn std::error::Error + Send + Sync>,
|
||||||
|
},
|
||||||
#[snafu(display("Invalid table name '{name}': {reason}"))]
|
#[snafu(display("Invalid table name '{name}': {reason}"))]
|
||||||
InvalidTableName { name: String, reason: String },
|
InvalidTableName { name: String, reason: String },
|
||||||
#[snafu(display("Embedding function '{name}' was not found: {reason}, {location}"))]
|
#[snafu(display("Embedding function '{name}' was not found: {reason}, {location}"))]
|
||||||
@@ -191,7 +194,7 @@ impl From<lancedb::Error> for Error {
|
|||||||
message,
|
message,
|
||||||
location: std::panic::Location::caller().to_snafu_location(),
|
location: std::panic::Location::caller().to_snafu_location(),
|
||||||
},
|
},
|
||||||
lancedb::Error::TableNotFound { name } => Self::TableNotFound { name },
|
lancedb::Error::TableNotFound { name, source } => Self::TableNotFound { name, source },
|
||||||
lancedb::Error::TableAlreadyExists { name } => Self::TableAlreadyExists { name },
|
lancedb::Error::TableAlreadyExists { name } => Self::TableAlreadyExists { name },
|
||||||
lancedb::Error::EmbeddingFunctionNotFound { name, reason } => {
|
lancedb::Error::EmbeddingFunctionNotFound { name, reason } => {
|
||||||
Self::EmbeddingFunctionNotFound {
|
Self::EmbeddingFunctionNotFound {
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
<parent>
|
<parent>
|
||||||
<groupId>com.lancedb</groupId>
|
<groupId>com.lancedb</groupId>
|
||||||
<artifactId>lancedb-parent</artifactId>
|
<artifactId>lancedb-parent</artifactId>
|
||||||
<version>0.22.2-beta.2</version>
|
<version>0.22.3-beta.5</version>
|
||||||
<relativePath>../pom.xml</relativePath>
|
<relativePath>../pom.xml</relativePath>
|
||||||
</parent>
|
</parent>
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
<parent>
|
<parent>
|
||||||
<groupId>com.lancedb</groupId>
|
<groupId>com.lancedb</groupId>
|
||||||
<artifactId>lancedb-parent</artifactId>
|
<artifactId>lancedb-parent</artifactId>
|
||||||
<version>0.22.2-beta.2</version>
|
<version>0.22.3-beta.5</version>
|
||||||
<relativePath>../pom.xml</relativePath>
|
<relativePath>../pom.xml</relativePath>
|
||||||
</parent>
|
</parent>
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
<groupId>com.lancedb</groupId>
|
<groupId>com.lancedb</groupId>
|
||||||
<artifactId>lancedb-parent</artifactId>
|
<artifactId>lancedb-parent</artifactId>
|
||||||
<version>0.22.2-beta.2</version>
|
<version>0.22.3-beta.5</version>
|
||||||
<packaging>pom</packaging>
|
<packaging>pom</packaging>
|
||||||
<name>${project.artifactId}</name>
|
<name>${project.artifactId}</name>
|
||||||
<description>LanceDB Java SDK Parent POM</description>
|
<description>LanceDB Java SDK Parent POM</description>
|
||||||
|
|||||||
13
nodejs/AGENTS.md
Normal file
13
nodejs/AGENTS.md
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
These are the typescript bindings of LanceDB.
|
||||||
|
The core Rust library is in the `../rust/lancedb` directory, the rust binding
|
||||||
|
code is in the `src/` directory and the typescript bindings are in
|
||||||
|
the `lancedb/` directory.
|
||||||
|
|
||||||
|
Whenever you change the Rust code, you will need to recompile: `npm run build`.
|
||||||
|
|
||||||
|
Common commands:
|
||||||
|
* Build: `npm run build`
|
||||||
|
* Lint: `npm run lint`
|
||||||
|
* Fix lints: `npm run lint-fix`
|
||||||
|
* Test: `npm test`
|
||||||
|
* Run single test file: `npm test __test__/arrow.test.ts`
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
These are the typescript bindings of LanceDB.
|
|
||||||
The core Rust library is in the `../rust/lancedb` directory, the rust binding
|
|
||||||
code is in the `src/` directory and the typescript bindings are in
|
|
||||||
the `lancedb/` directory.
|
|
||||||
|
|
||||||
Whenever you change the Rust code, you will need to recompile: `npm run build`.
|
|
||||||
|
|
||||||
Common commands:
|
|
||||||
* Build: `npm run build`
|
|
||||||
* Lint: `npm run lint`
|
|
||||||
* Fix lints: `npm run lint-fix`
|
|
||||||
* Test: `npm test`
|
|
||||||
* Run single test file: `npm test __test__/arrow.test.ts`
|
|
||||||
1
nodejs/CLAUDE.md
Symbolic link
1
nodejs/CLAUDE.md
Symbolic link
@@ -0,0 +1 @@
|
|||||||
|
AGENTS.md
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb-nodejs"
|
name = "lancedb-nodejs"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
version = "0.22.2-beta.2"
|
version = "0.22.3-beta.5"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
description.workspace = true
|
description.workspace = true
|
||||||
repository.workspace = true
|
repository.workspace = true
|
||||||
|
|||||||
371
nodejs/__test__/permutation.test.ts
Normal file
371
nodejs/__test__/permutation.test.ts
Normal file
@@ -0,0 +1,371 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
import * as tmp from "tmp";
|
||||||
|
import { Table, connect, permutationBuilder } from "../lancedb";
|
||||||
|
import { makeArrowTable } from "../lancedb/arrow";
|
||||||
|
|
||||||
|
describe("PermutationBuilder", () => {
|
||||||
|
let tmpDir: tmp.DirResult;
|
||||||
|
let table: Table;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
tmpDir = tmp.dirSync({ unsafeCleanup: true });
|
||||||
|
const db = await connect(tmpDir.name);
|
||||||
|
|
||||||
|
// Create test data
|
||||||
|
const data = makeArrowTable(
|
||||||
|
[
|
||||||
|
{ id: 1, value: 10 },
|
||||||
|
{ id: 2, value: 20 },
|
||||||
|
{ id: 3, value: 30 },
|
||||||
|
{ id: 4, value: 40 },
|
||||||
|
{ id: 5, value: 50 },
|
||||||
|
{ id: 6, value: 60 },
|
||||||
|
{ id: 7, value: 70 },
|
||||||
|
{ id: 8, value: 80 },
|
||||||
|
{ id: 9, value: 90 },
|
||||||
|
{ id: 10, value: 100 },
|
||||||
|
],
|
||||||
|
{ vectorColumns: {} },
|
||||||
|
);
|
||||||
|
|
||||||
|
table = await db.createTable("test_table", data);
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
tmpDir.removeCallback();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should create permutation builder", () => {
|
||||||
|
const builder = permutationBuilder(table);
|
||||||
|
expect(builder).toBeDefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should execute basic permutation", async () => {
|
||||||
|
const builder = permutationBuilder(table);
|
||||||
|
const permutationTable = await builder.execute();
|
||||||
|
|
||||||
|
expect(permutationTable).toBeDefined();
|
||||||
|
|
||||||
|
const rowCount = await permutationTable.countRows();
|
||||||
|
expect(rowCount).toBe(10);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should create permutation with random splits", async () => {
|
||||||
|
const builder = permutationBuilder(table).splitRandom({
|
||||||
|
ratios: [1.0],
|
||||||
|
seed: 42,
|
||||||
|
});
|
||||||
|
|
||||||
|
const permutationTable = await builder.execute();
|
||||||
|
const rowCount = await permutationTable.countRows();
|
||||||
|
expect(rowCount).toBe(10);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should create permutation with percentage splits", async () => {
|
||||||
|
const builder = permutationBuilder(table).splitRandom({
|
||||||
|
ratios: [0.3, 0.7],
|
||||||
|
seed: 42,
|
||||||
|
});
|
||||||
|
|
||||||
|
const permutationTable = await builder.execute();
|
||||||
|
const rowCount = await permutationTable.countRows();
|
||||||
|
expect(rowCount).toBe(10);
|
||||||
|
|
||||||
|
// Check split distribution
|
||||||
|
const split0Count = await permutationTable.countRows("split_id = 0");
|
||||||
|
const split1Count = await permutationTable.countRows("split_id = 1");
|
||||||
|
|
||||||
|
expect(split0Count).toBeGreaterThan(0);
|
||||||
|
expect(split1Count).toBeGreaterThan(0);
|
||||||
|
expect(split0Count + split1Count).toBe(10);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should create permutation with count splits", async () => {
|
||||||
|
const builder = permutationBuilder(table).splitRandom({
|
||||||
|
counts: [3, 7],
|
||||||
|
seed: 42,
|
||||||
|
});
|
||||||
|
|
||||||
|
const permutationTable = await builder.execute();
|
||||||
|
const rowCount = await permutationTable.countRows();
|
||||||
|
expect(rowCount).toBe(10);
|
||||||
|
|
||||||
|
// Check split distribution
|
||||||
|
const split0Count = await permutationTable.countRows("split_id = 0");
|
||||||
|
const split1Count = await permutationTable.countRows("split_id = 1");
|
||||||
|
|
||||||
|
expect(split0Count).toBe(3);
|
||||||
|
expect(split1Count).toBe(7);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should create permutation with hash splits", async () => {
|
||||||
|
const builder = permutationBuilder(table).splitHash({
|
||||||
|
columns: ["id"],
|
||||||
|
splitWeights: [50, 50],
|
||||||
|
discardWeight: 0,
|
||||||
|
});
|
||||||
|
|
||||||
|
const permutationTable = await builder.execute();
|
||||||
|
const rowCount = await permutationTable.countRows();
|
||||||
|
expect(rowCount).toBe(10);
|
||||||
|
|
||||||
|
// Check that splits exist
|
||||||
|
const split0Count = await permutationTable.countRows("split_id = 0");
|
||||||
|
const split1Count = await permutationTable.countRows("split_id = 1");
|
||||||
|
|
||||||
|
expect(split0Count).toBeGreaterThan(0);
|
||||||
|
expect(split1Count).toBeGreaterThan(0);
|
||||||
|
expect(split0Count + split1Count).toBe(10);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should create permutation with sequential splits", async () => {
|
||||||
|
const builder = permutationBuilder(table).splitSequential({
|
||||||
|
ratios: [0.5, 0.5],
|
||||||
|
});
|
||||||
|
|
||||||
|
const permutationTable = await builder.execute();
|
||||||
|
const rowCount = await permutationTable.countRows();
|
||||||
|
expect(rowCount).toBe(10);
|
||||||
|
|
||||||
|
// Check split distribution - sequential should give exactly 5 and 5
|
||||||
|
const split0Count = await permutationTable.countRows("split_id = 0");
|
||||||
|
const split1Count = await permutationTable.countRows("split_id = 1");
|
||||||
|
|
||||||
|
expect(split0Count).toBe(5);
|
||||||
|
expect(split1Count).toBe(5);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should create permutation with calculated splits", async () => {
|
||||||
|
const builder = permutationBuilder(table).splitCalculated({
|
||||||
|
calculation: "id % 2",
|
||||||
|
});
|
||||||
|
|
||||||
|
const permutationTable = await builder.execute();
|
||||||
|
const rowCount = await permutationTable.countRows();
|
||||||
|
expect(rowCount).toBe(10);
|
||||||
|
|
||||||
|
// Check split distribution
|
||||||
|
const split0Count = await permutationTable.countRows("split_id = 0");
|
||||||
|
const split1Count = await permutationTable.countRows("split_id = 1");
|
||||||
|
|
||||||
|
expect(split0Count).toBeGreaterThan(0);
|
||||||
|
expect(split1Count).toBeGreaterThan(0);
|
||||||
|
expect(split0Count + split1Count).toBe(10);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should create permutation with shuffle", async () => {
|
||||||
|
const builder = permutationBuilder(table).shuffle({
|
||||||
|
seed: 42,
|
||||||
|
});
|
||||||
|
|
||||||
|
const permutationTable = await builder.execute();
|
||||||
|
const rowCount = await permutationTable.countRows();
|
||||||
|
expect(rowCount).toBe(10);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should create permutation with shuffle and clump size", async () => {
|
||||||
|
const builder = permutationBuilder(table).shuffle({
|
||||||
|
seed: 42,
|
||||||
|
clumpSize: 2,
|
||||||
|
});
|
||||||
|
|
||||||
|
const permutationTable = await builder.execute();
|
||||||
|
const rowCount = await permutationTable.countRows();
|
||||||
|
expect(rowCount).toBe(10);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should create permutation with filter", async () => {
|
||||||
|
const builder = permutationBuilder(table).filter("value > 50");
|
||||||
|
|
||||||
|
const permutationTable = await builder.execute();
|
||||||
|
const rowCount = await permutationTable.countRows();
|
||||||
|
expect(rowCount).toBe(5); // Values 60, 70, 80, 90, 100
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should chain multiple operations", async () => {
|
||||||
|
const builder = permutationBuilder(table)
|
||||||
|
.filter("value <= 80")
|
||||||
|
.splitRandom({ ratios: [0.5, 0.5], seed: 42 })
|
||||||
|
.shuffle({ seed: 123 });
|
||||||
|
|
||||||
|
const permutationTable = await builder.execute();
|
||||||
|
const rowCount = await permutationTable.countRows();
|
||||||
|
expect(rowCount).toBe(8); // Values 10, 20, 30, 40, 50, 60, 70, 80
|
||||||
|
|
||||||
|
// Check split distribution
|
||||||
|
const split0Count = await permutationTable.countRows("split_id = 0");
|
||||||
|
const split1Count = await permutationTable.countRows("split_id = 1");
|
||||||
|
|
||||||
|
expect(split0Count).toBeGreaterThan(0);
|
||||||
|
expect(split1Count).toBeGreaterThan(0);
|
||||||
|
expect(split0Count + split1Count).toBe(8);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should throw error for invalid split arguments", () => {
|
||||||
|
const builder = permutationBuilder(table);
|
||||||
|
|
||||||
|
// Test no arguments provided
|
||||||
|
expect(() => builder.splitRandom({})).toThrow(
|
||||||
|
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
||||||
|
);
|
||||||
|
|
||||||
|
// Test multiple arguments provided
|
||||||
|
expect(() =>
|
||||||
|
builder.splitRandom({ ratios: [0.5, 0.5], counts: [3, 7], seed: 42 }),
|
||||||
|
).toThrow("Exactly one of 'ratios', 'counts', or 'fixed' must be provided");
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should throw error when builder is consumed", async () => {
|
||||||
|
const builder = permutationBuilder(table);
|
||||||
|
|
||||||
|
// Execute once
|
||||||
|
await builder.execute();
|
||||||
|
|
||||||
|
// Should throw error on second execution
|
||||||
|
await expect(builder.execute()).rejects.toThrow("Builder already consumed");
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should accept custom split names with random splits", async () => {
|
||||||
|
const builder = permutationBuilder(table).splitRandom({
|
||||||
|
ratios: [0.3, 0.7],
|
||||||
|
seed: 42,
|
||||||
|
splitNames: ["train", "test"],
|
||||||
|
});
|
||||||
|
|
||||||
|
const permutationTable = await builder.execute();
|
||||||
|
const rowCount = await permutationTable.countRows();
|
||||||
|
expect(rowCount).toBe(10);
|
||||||
|
|
||||||
|
// Split names are provided but split_id is still numeric (0, 1, etc.)
|
||||||
|
// The names are metadata that can be used by higher-level APIs
|
||||||
|
const split0Count = await permutationTable.countRows("split_id = 0");
|
||||||
|
const split1Count = await permutationTable.countRows("split_id = 1");
|
||||||
|
|
||||||
|
expect(split0Count).toBeGreaterThan(0);
|
||||||
|
expect(split1Count).toBeGreaterThan(0);
|
||||||
|
expect(split0Count + split1Count).toBe(10);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should accept custom split names with hash splits", async () => {
|
||||||
|
const builder = permutationBuilder(table).splitHash({
|
||||||
|
columns: ["id"],
|
||||||
|
splitWeights: [50, 50],
|
||||||
|
discardWeight: 0,
|
||||||
|
splitNames: ["set_a", "set_b"],
|
||||||
|
});
|
||||||
|
|
||||||
|
const permutationTable = await builder.execute();
|
||||||
|
const rowCount = await permutationTable.countRows();
|
||||||
|
expect(rowCount).toBe(10);
|
||||||
|
|
||||||
|
// Split names are provided but split_id is still numeric
|
||||||
|
const split0Count = await permutationTable.countRows("split_id = 0");
|
||||||
|
const split1Count = await permutationTable.countRows("split_id = 1");
|
||||||
|
|
||||||
|
expect(split0Count).toBeGreaterThan(0);
|
||||||
|
expect(split1Count).toBeGreaterThan(0);
|
||||||
|
expect(split0Count + split1Count).toBe(10);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should accept custom split names with sequential splits", async () => {
|
||||||
|
const builder = permutationBuilder(table).splitSequential({
|
||||||
|
ratios: [0.5, 0.5],
|
||||||
|
splitNames: ["first", "second"],
|
||||||
|
});
|
||||||
|
|
||||||
|
const permutationTable = await builder.execute();
|
||||||
|
const rowCount = await permutationTable.countRows();
|
||||||
|
expect(rowCount).toBe(10);
|
||||||
|
|
||||||
|
// Split names are provided but split_id is still numeric
|
||||||
|
const split0Count = await permutationTable.countRows("split_id = 0");
|
||||||
|
const split1Count = await permutationTable.countRows("split_id = 1");
|
||||||
|
|
||||||
|
expect(split0Count).toBe(5);
|
||||||
|
expect(split1Count).toBe(5);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should accept custom split names with calculated splits", async () => {
|
||||||
|
const builder = permutationBuilder(table).splitCalculated({
|
||||||
|
calculation: "id % 2",
|
||||||
|
splitNames: ["even", "odd"],
|
||||||
|
});
|
||||||
|
|
||||||
|
const permutationTable = await builder.execute();
|
||||||
|
const rowCount = await permutationTable.countRows();
|
||||||
|
expect(rowCount).toBe(10);
|
||||||
|
|
||||||
|
// Split names are provided but split_id is still numeric
|
||||||
|
const split0Count = await permutationTable.countRows("split_id = 0");
|
||||||
|
const split1Count = await permutationTable.countRows("split_id = 1");
|
||||||
|
|
||||||
|
expect(split0Count).toBeGreaterThan(0);
|
||||||
|
expect(split1Count).toBeGreaterThan(0);
|
||||||
|
expect(split0Count + split1Count).toBe(10);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should persist permutation to a new table", async () => {
|
||||||
|
const db = await connect(tmpDir.name);
|
||||||
|
const builder = permutationBuilder(table)
|
||||||
|
.splitRandom({
|
||||||
|
ratios: [0.7, 0.3],
|
||||||
|
seed: 42,
|
||||||
|
splitNames: ["train", "validation"],
|
||||||
|
})
|
||||||
|
.persist(db, "my_permutation");
|
||||||
|
|
||||||
|
// Execute the builder which will persist the table
|
||||||
|
const permutationTable = await builder.execute();
|
||||||
|
|
||||||
|
// Verify the persisted table exists and can be opened
|
||||||
|
const persistedTable = await db.openTable("my_permutation");
|
||||||
|
expect(persistedTable).toBeDefined();
|
||||||
|
|
||||||
|
// Verify the persisted table has the correct number of rows
|
||||||
|
const rowCount = await persistedTable.countRows();
|
||||||
|
expect(rowCount).toBe(10);
|
||||||
|
|
||||||
|
// Verify splits exist (numeric split_id values)
|
||||||
|
const split0Count = await persistedTable.countRows("split_id = 0");
|
||||||
|
const split1Count = await persistedTable.countRows("split_id = 1");
|
||||||
|
|
||||||
|
expect(split0Count).toBeGreaterThan(0);
|
||||||
|
expect(split1Count).toBeGreaterThan(0);
|
||||||
|
expect(split0Count + split1Count).toBe(10);
|
||||||
|
|
||||||
|
// Verify the table returned by execute is the same as the persisted one
|
||||||
|
const executedRowCount = await permutationTable.countRows();
|
||||||
|
expect(executedRowCount).toBe(10);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("should persist permutation with multiple operations", async () => {
|
||||||
|
const db = await connect(tmpDir.name);
|
||||||
|
const builder = permutationBuilder(table)
|
||||||
|
.filter("value > 30")
|
||||||
|
.splitRandom({ ratios: [0.5, 0.5], seed: 123, splitNames: ["a", "b"] })
|
||||||
|
.shuffle({ seed: 456 })
|
||||||
|
.persist(db, "filtered_permutation");
|
||||||
|
|
||||||
|
// Execute the builder
|
||||||
|
const permutationTable = await builder.execute();
|
||||||
|
|
||||||
|
// Verify the persisted table
|
||||||
|
const persistedTable = await db.openTable("filtered_permutation");
|
||||||
|
const rowCount = await persistedTable.countRows();
|
||||||
|
expect(rowCount).toBe(7); // Values 40, 50, 60, 70, 80, 90, 100
|
||||||
|
|
||||||
|
// Verify splits exist (numeric split_id values)
|
||||||
|
const split0Count = await persistedTable.countRows("split_id = 0");
|
||||||
|
const split1Count = await persistedTable.countRows("split_id = 1");
|
||||||
|
|
||||||
|
expect(split0Count).toBeGreaterThan(0);
|
||||||
|
expect(split1Count).toBeGreaterThan(0);
|
||||||
|
expect(split0Count + split1Count).toBe(7);
|
||||||
|
|
||||||
|
// Verify the executed table matches
|
||||||
|
const executedRowCount = await permutationTable.countRows();
|
||||||
|
expect(executedRowCount).toBe(7);
|
||||||
|
});
|
||||||
|
});
|
||||||
111
nodejs/__test__/query.test.ts
Normal file
111
nodejs/__test__/query.test.ts
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
import * as tmp from "tmp";
|
||||||
|
|
||||||
|
import { type Table, connect } from "../lancedb";
|
||||||
|
import {
|
||||||
|
Field,
|
||||||
|
FixedSizeList,
|
||||||
|
Float32,
|
||||||
|
Int64,
|
||||||
|
Schema,
|
||||||
|
Utf8,
|
||||||
|
makeArrowTable,
|
||||||
|
} from "../lancedb/arrow";
|
||||||
|
import { Index } from "../lancedb/indices";
|
||||||
|
|
||||||
|
describe("Query outputSchema", () => {
|
||||||
|
let tmpDir: tmp.DirResult;
|
||||||
|
let table: Table;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
tmpDir = tmp.dirSync({ unsafeCleanup: true });
|
||||||
|
const db = await connect(tmpDir.name);
|
||||||
|
|
||||||
|
// Create table with explicit schema to ensure proper types
|
||||||
|
const schema = new Schema([
|
||||||
|
new Field("a", new Int64(), true),
|
||||||
|
new Field("text", new Utf8(), true),
|
||||||
|
new Field(
|
||||||
|
"vec",
|
||||||
|
new FixedSizeList(2, new Field("item", new Float32())),
|
||||||
|
true,
|
||||||
|
),
|
||||||
|
]);
|
||||||
|
|
||||||
|
const data = makeArrowTable(
|
||||||
|
[
|
||||||
|
{ a: 1n, text: "foo", vec: [1, 2] },
|
||||||
|
{ a: 2n, text: "bar", vec: [3, 4] },
|
||||||
|
{ a: 3n, text: "baz", vec: [5, 6] },
|
||||||
|
],
|
||||||
|
{ schema },
|
||||||
|
);
|
||||||
|
table = await db.createTable("test", data);
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
tmpDir.removeCallback();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should return schema for plain query", async () => {
|
||||||
|
const schema = await table.query().outputSchema();
|
||||||
|
|
||||||
|
expect(schema.fields.length).toBe(3);
|
||||||
|
expect(schema.fields.map((f) => f.name)).toEqual(["a", "text", "vec"]);
|
||||||
|
expect(schema.fields[0].type.toString()).toBe("Int64");
|
||||||
|
expect(schema.fields[1].type.toString()).toBe("Utf8");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should return schema with dynamic projection", async () => {
|
||||||
|
const schema = await table.query().select({ bl: "a * 2" }).outputSchema();
|
||||||
|
|
||||||
|
expect(schema.fields.length).toBe(1);
|
||||||
|
expect(schema.fields[0].name).toBe("bl");
|
||||||
|
expect(schema.fields[0].type.toString()).toBe("Int64");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should return schema for vector search with _distance column", async () => {
|
||||||
|
const schema = await table
|
||||||
|
.vectorSearch([1, 2])
|
||||||
|
.select(["a"])
|
||||||
|
.outputSchema();
|
||||||
|
|
||||||
|
expect(schema.fields.length).toBe(2);
|
||||||
|
expect(schema.fields.map((f) => f.name)).toEqual(["a", "_distance"]);
|
||||||
|
expect(schema.fields[0].type.toString()).toBe("Int64");
|
||||||
|
expect(schema.fields[1].type.toString()).toBe("Float32");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should return schema for FTS search", async () => {
|
||||||
|
await table.createIndex("text", { config: Index.fts() });
|
||||||
|
|
||||||
|
const schema = await table
|
||||||
|
.search("foo", "fts")
|
||||||
|
.select(["a"])
|
||||||
|
.outputSchema();
|
||||||
|
|
||||||
|
// FTS search includes _score column in addition to selected columns
|
||||||
|
expect(schema.fields.length).toBe(2);
|
||||||
|
expect(schema.fields.map((f) => f.name)).toContain("a");
|
||||||
|
expect(schema.fields.map((f) => f.name)).toContain("_score");
|
||||||
|
const aField = schema.fields.find((f) => f.name === "a");
|
||||||
|
expect(aField?.type.toString()).toBe("Int64");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should return schema for take query", async () => {
|
||||||
|
const schema = await table.takeOffsets([0]).select(["text"]).outputSchema();
|
||||||
|
|
||||||
|
expect(schema.fields.length).toBe(1);
|
||||||
|
expect(schema.fields[0].name).toBe("text");
|
||||||
|
expect(schema.fields[0].type.toString()).toBe("Utf8");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should return full schema when no select is specified", async () => {
|
||||||
|
const schema = await table.query().outputSchema();
|
||||||
|
|
||||||
|
// Should return all columns
|
||||||
|
expect(schema.fields.length).toBe(3);
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -861,6 +861,15 @@ describe("When creating an index", () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("should be able to create IVF_RQ", async () => {
|
||||||
|
await tbl.createIndex("vec", {
|
||||||
|
config: Index.ivfRq({
|
||||||
|
numPartitions: 10,
|
||||||
|
numBits: 1,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
it("should allow me to replace (or not) an existing index", async () => {
|
it("should allow me to replace (or not) an existing index", async () => {
|
||||||
await tbl.createIndex("id");
|
await tbl.createIndex("id");
|
||||||
// Default is replace=true
|
// Default is replace=true
|
||||||
|
|||||||
@@ -43,6 +43,11 @@ export {
|
|||||||
DeleteResult,
|
DeleteResult,
|
||||||
DropColumnsResult,
|
DropColumnsResult,
|
||||||
UpdateResult,
|
UpdateResult,
|
||||||
|
SplitCalculatedOptions,
|
||||||
|
SplitRandomOptions,
|
||||||
|
SplitHashOptions,
|
||||||
|
SplitSequentialOptions,
|
||||||
|
ShuffleOptions,
|
||||||
} from "./native.js";
|
} from "./native.js";
|
||||||
|
|
||||||
export {
|
export {
|
||||||
@@ -85,6 +90,7 @@ export {
|
|||||||
Index,
|
Index,
|
||||||
IndexOptions,
|
IndexOptions,
|
||||||
IvfPqOptions,
|
IvfPqOptions,
|
||||||
|
IvfRqOptions,
|
||||||
IvfFlatOptions,
|
IvfFlatOptions,
|
||||||
HnswPqOptions,
|
HnswPqOptions,
|
||||||
HnswSqOptions,
|
HnswSqOptions,
|
||||||
@@ -110,6 +116,7 @@ export {
|
|||||||
export { MergeInsertBuilder, WriteExecutionOptions } from "./merge";
|
export { MergeInsertBuilder, WriteExecutionOptions } from "./merge";
|
||||||
|
|
||||||
export * as embedding from "./embedding";
|
export * as embedding from "./embedding";
|
||||||
|
export { permutationBuilder, PermutationBuilder } from "./permutation";
|
||||||
export * as rerankers from "./rerankers";
|
export * as rerankers from "./rerankers";
|
||||||
export {
|
export {
|
||||||
SchemaLike,
|
SchemaLike,
|
||||||
|
|||||||
@@ -112,6 +112,77 @@ export interface IvfPqOptions {
|
|||||||
sampleRate?: number;
|
sampleRate?: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface IvfRqOptions {
|
||||||
|
/**
|
||||||
|
* The number of IVF partitions to create.
|
||||||
|
*
|
||||||
|
* This value should generally scale with the number of rows in the dataset.
|
||||||
|
* By default the number of partitions is the square root of the number of
|
||||||
|
* rows.
|
||||||
|
*
|
||||||
|
* If this value is too large then the first part of the search (picking the
|
||||||
|
* right partition) will be slow. If this value is too small then the second
|
||||||
|
* part of the search (searching within a partition) will be slow.
|
||||||
|
*/
|
||||||
|
numPartitions?: number;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Number of bits per dimension for residual quantization.
|
||||||
|
*
|
||||||
|
* This value controls how much each residual component is compressed. The more
|
||||||
|
* bits, the more accurate the index will be but the slower search. Typical values
|
||||||
|
* are small integers; the default is 1 bit per dimension.
|
||||||
|
*/
|
||||||
|
numBits?: number;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Distance type to use to build the index.
|
||||||
|
*
|
||||||
|
* Default value is "l2".
|
||||||
|
*
|
||||||
|
* This is used when training the index to calculate the IVF partitions
|
||||||
|
* (vectors are grouped in partitions with similar vectors according to this
|
||||||
|
* distance type) and during quantization.
|
||||||
|
*
|
||||||
|
* The distance type used to train an index MUST match the distance type used
|
||||||
|
* to search the index. Failure to do so will yield inaccurate results.
|
||||||
|
*
|
||||||
|
* The following distance types are available:
|
||||||
|
*
|
||||||
|
* "l2" - Euclidean distance.
|
||||||
|
* "cosine" - Cosine distance.
|
||||||
|
* "dot" - Dot product.
|
||||||
|
*/
|
||||||
|
distanceType?: "l2" | "cosine" | "dot";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Max iterations to train IVF kmeans.
|
||||||
|
*
|
||||||
|
* When training an IVF index we use kmeans to calculate the partitions. This parameter
|
||||||
|
* controls how many iterations of kmeans to run.
|
||||||
|
*
|
||||||
|
* The default value is 50.
|
||||||
|
*/
|
||||||
|
maxIterations?: number;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The number of vectors, per partition, to sample when training IVF kmeans.
|
||||||
|
*
|
||||||
|
* When an IVF index is trained, we need to calculate partitions. These are groups
|
||||||
|
* of vectors that are similar to each other. To do this we use an algorithm called kmeans.
|
||||||
|
*
|
||||||
|
* Running kmeans on a large dataset can be slow. To speed this up we run kmeans on a
|
||||||
|
* random sample of the data. This parameter controls the size of the sample. The total
|
||||||
|
* number of vectors used to train the index is `sample_rate * num_partitions`.
|
||||||
|
*
|
||||||
|
* Increasing this value might improve the quality of the index but in most cases the
|
||||||
|
* default should be sufficient.
|
||||||
|
*
|
||||||
|
* The default value is 256.
|
||||||
|
*/
|
||||||
|
sampleRate?: number;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Options to create an `HNSW_PQ` index
|
* Options to create an `HNSW_PQ` index
|
||||||
*/
|
*/
|
||||||
@@ -523,6 +594,35 @@ export class Index {
|
|||||||
options?.distanceType,
|
options?.distanceType,
|
||||||
options?.numPartitions,
|
options?.numPartitions,
|
||||||
options?.numSubVectors,
|
options?.numSubVectors,
|
||||||
|
options?.numBits,
|
||||||
|
options?.maxIterations,
|
||||||
|
options?.sampleRate,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create an IvfRq index
|
||||||
|
*
|
||||||
|
* IVF-RQ (RabitQ Quantization) compresses vectors using RabitQ quantization
|
||||||
|
* and organizes them into IVF partitions.
|
||||||
|
*
|
||||||
|
* The compression scheme is called RabitQ quantization. Each dimension is quantized into a small number of bits.
|
||||||
|
* The parameters `num_bits` and `num_partitions` control this process, providing a tradeoff
|
||||||
|
* between index size (and thus search speed) and index accuracy.
|
||||||
|
*
|
||||||
|
* The partitioning process is called IVF and the `num_partitions` parameter controls how
|
||||||
|
* many groups to create.
|
||||||
|
*
|
||||||
|
* Note that training an IVF RQ index on a large dataset is a slow operation and
|
||||||
|
* currently is also a memory intensive operation.
|
||||||
|
*/
|
||||||
|
static ivfRq(options?: Partial<IvfRqOptions>) {
|
||||||
|
return new Index(
|
||||||
|
LanceDbIndex.ivfRq(
|
||||||
|
options?.distanceType,
|
||||||
|
options?.numPartitions,
|
||||||
|
options?.numBits,
|
||||||
options?.maxIterations,
|
options?.maxIterations,
|
||||||
options?.sampleRate,
|
options?.sampleRate,
|
||||||
),
|
),
|
||||||
|
|||||||
202
nodejs/lancedb/permutation.ts
Normal file
202
nodejs/lancedb/permutation.ts
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
import { Connection, LocalConnection } from "./connection.js";
|
||||||
|
import {
|
||||||
|
PermutationBuilder as NativePermutationBuilder,
|
||||||
|
Table as NativeTable,
|
||||||
|
ShuffleOptions,
|
||||||
|
SplitCalculatedOptions,
|
||||||
|
SplitHashOptions,
|
||||||
|
SplitRandomOptions,
|
||||||
|
SplitSequentialOptions,
|
||||||
|
permutationBuilder as nativePermutationBuilder,
|
||||||
|
} from "./native.js";
|
||||||
|
import { LocalTable, Table } from "./table";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A PermutationBuilder for creating data permutations with splits, shuffling, and filtering.
|
||||||
|
*
|
||||||
|
* This class provides a TypeScript wrapper around the native Rust PermutationBuilder,
|
||||||
|
* offering methods to configure data splits, shuffling, and filtering before executing
|
||||||
|
* the permutation to create a new table.
|
||||||
|
*/
|
||||||
|
export class PermutationBuilder {
|
||||||
|
private inner: NativePermutationBuilder;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @hidden
|
||||||
|
*/
|
||||||
|
constructor(inner: NativePermutationBuilder) {
|
||||||
|
this.inner = inner;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configure the permutation to be persisted.
|
||||||
|
*
|
||||||
|
* @param connection - The connection to persist the permutation to
|
||||||
|
* @param tableName - The name of the table to create
|
||||||
|
* @returns A new PermutationBuilder instance
|
||||||
|
* @example
|
||||||
|
* ```ts
|
||||||
|
* builder.persist(connection, "permutation_table");
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
persist(connection: Connection, tableName: string): PermutationBuilder {
|
||||||
|
const localConnection = connection as LocalConnection;
|
||||||
|
const newInner = this.inner.persist(localConnection.inner, tableName);
|
||||||
|
return new PermutationBuilder(newInner);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configure random splits for the permutation.
|
||||||
|
*
|
||||||
|
* @param options - Configuration for random splitting
|
||||||
|
* @returns A new PermutationBuilder instance
|
||||||
|
* @example
|
||||||
|
* ```ts
|
||||||
|
* // Split by ratios
|
||||||
|
* builder.splitRandom({ ratios: [0.7, 0.3], seed: 42 });
|
||||||
|
*
|
||||||
|
* // Split by counts
|
||||||
|
* builder.splitRandom({ counts: [1000, 500], seed: 42 });
|
||||||
|
*
|
||||||
|
* // Split with fixed size
|
||||||
|
* builder.splitRandom({ fixed: 100, seed: 42 });
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
splitRandom(options: SplitRandomOptions): PermutationBuilder {
|
||||||
|
const newInner = this.inner.splitRandom(options);
|
||||||
|
return new PermutationBuilder(newInner);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configure hash-based splits for the permutation.
|
||||||
|
*
|
||||||
|
* @param options - Configuration for hash-based splitting
|
||||||
|
* @returns A new PermutationBuilder instance
|
||||||
|
* @example
|
||||||
|
* ```ts
|
||||||
|
* builder.splitHash({
|
||||||
|
* columns: ["user_id"],
|
||||||
|
* splitWeights: [70, 30],
|
||||||
|
* discardWeight: 0
|
||||||
|
* });
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
splitHash(options: SplitHashOptions): PermutationBuilder {
|
||||||
|
const newInner = this.inner.splitHash(options);
|
||||||
|
return new PermutationBuilder(newInner);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configure sequential splits for the permutation.
|
||||||
|
*
|
||||||
|
* @param options - Configuration for sequential splitting
|
||||||
|
* @returns A new PermutationBuilder instance
|
||||||
|
* @example
|
||||||
|
* ```ts
|
||||||
|
* // Split by ratios
|
||||||
|
* builder.splitSequential({ ratios: [0.8, 0.2] });
|
||||||
|
*
|
||||||
|
* // Split by counts
|
||||||
|
* builder.splitSequential({ counts: [800, 200] });
|
||||||
|
*
|
||||||
|
* // Split with fixed size
|
||||||
|
* builder.splitSequential({ fixed: 1000 });
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
splitSequential(options: SplitSequentialOptions): PermutationBuilder {
|
||||||
|
const newInner = this.inner.splitSequential(options);
|
||||||
|
return new PermutationBuilder(newInner);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configure calculated splits for the permutation.
|
||||||
|
*
|
||||||
|
* @param options - Configuration for calculated splitting
|
||||||
|
* @returns A new PermutationBuilder instance
|
||||||
|
* @example
|
||||||
|
* ```ts
|
||||||
|
* builder.splitCalculated("user_id % 3");
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
splitCalculated(options: SplitCalculatedOptions): PermutationBuilder {
|
||||||
|
const newInner = this.inner.splitCalculated(options);
|
||||||
|
return new PermutationBuilder(newInner);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configure shuffling for the permutation.
|
||||||
|
*
|
||||||
|
* @param options - Configuration for shuffling
|
||||||
|
* @returns A new PermutationBuilder instance
|
||||||
|
* @example
|
||||||
|
* ```ts
|
||||||
|
* // Basic shuffle
|
||||||
|
* builder.shuffle({ seed: 42 });
|
||||||
|
*
|
||||||
|
* // Shuffle with clump size
|
||||||
|
* builder.shuffle({ seed: 42, clumpSize: 10 });
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
shuffle(options: ShuffleOptions): PermutationBuilder {
|
||||||
|
const newInner = this.inner.shuffle(options);
|
||||||
|
return new PermutationBuilder(newInner);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configure filtering for the permutation.
|
||||||
|
*
|
||||||
|
* @param filter - SQL filter expression
|
||||||
|
* @returns A new PermutationBuilder instance
|
||||||
|
* @example
|
||||||
|
* ```ts
|
||||||
|
* builder.filter("age > 18 AND status = 'active'");
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
filter(filter: string): PermutationBuilder {
|
||||||
|
const newInner = this.inner.filter(filter);
|
||||||
|
return new PermutationBuilder(newInner);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Execute the permutation and create the destination table.
|
||||||
|
*
|
||||||
|
* @returns A Promise that resolves to the new Table instance
|
||||||
|
* @example
|
||||||
|
* ```ts
|
||||||
|
* const permutationTable = await builder.execute();
|
||||||
|
* console.log(`Created table: ${permutationTable.name}`);
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
async execute(): Promise<Table> {
|
||||||
|
const nativeTable: NativeTable = await this.inner.execute();
|
||||||
|
return new LocalTable(nativeTable);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a permutation builder for the given table.
|
||||||
|
*
|
||||||
|
* @param table - The source table to create a permutation from
|
||||||
|
* @returns A PermutationBuilder instance
|
||||||
|
* @example
|
||||||
|
* ```ts
|
||||||
|
* const builder = permutationBuilder(sourceTable, "training_data")
|
||||||
|
* .splitRandom({ ratios: [0.8, 0.2], seed: 42 })
|
||||||
|
* .shuffle({ seed: 123 });
|
||||||
|
*
|
||||||
|
* const trainingTable = await builder.execute();
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
export function permutationBuilder(table: Table): PermutationBuilder {
|
||||||
|
// Extract the inner native table from the TypeScript wrapper
|
||||||
|
const localTable = table as LocalTable;
|
||||||
|
// Access inner through type assertion since it's private
|
||||||
|
const nativeBuilder = nativePermutationBuilder(
|
||||||
|
// biome-ignore lint/suspicious/noExplicitAny: need access to private variable
|
||||||
|
(localTable as any).inner,
|
||||||
|
);
|
||||||
|
return new PermutationBuilder(nativeBuilder);
|
||||||
|
}
|
||||||
@@ -20,35 +20,25 @@ import {
|
|||||||
} from "./native";
|
} from "./native";
|
||||||
import { Reranker } from "./rerankers";
|
import { Reranker } from "./rerankers";
|
||||||
|
|
||||||
export class RecordBatchIterator implements AsyncIterator<RecordBatch> {
|
export async function* RecordBatchIterator(
|
||||||
private promisedInner?: Promise<NativeBatchIterator>;
|
promisedInner: Promise<NativeBatchIterator>,
|
||||||
private inner?: NativeBatchIterator;
|
) {
|
||||||
|
const inner = await promisedInner;
|
||||||
|
|
||||||
constructor(promise?: Promise<NativeBatchIterator>) {
|
if (inner === undefined) {
|
||||||
// TODO: check promise reliably so we dont need to pass two arguments.
|
throw new Error("Invalid iterator state");
|
||||||
this.promisedInner = promise;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// biome-ignore lint/suspicious/noExplicitAny: skip
|
for (let buffer = await inner.next(); buffer; buffer = await inner.next()) {
|
||||||
async next(): Promise<IteratorResult<RecordBatch<any>>> {
|
const { batches } = tableFromIPC(buffer);
|
||||||
if (this.inner === undefined) {
|
|
||||||
this.inner = await this.promisedInner;
|
if (batches.length !== 1) {
|
||||||
}
|
|
||||||
if (this.inner === undefined) {
|
|
||||||
throw new Error("Invalid iterator state state");
|
|
||||||
}
|
|
||||||
const n = await this.inner.next();
|
|
||||||
if (n == null) {
|
|
||||||
return Promise.resolve({ done: true, value: null });
|
|
||||||
}
|
|
||||||
const tbl = tableFromIPC(n);
|
|
||||||
if (tbl.batches.length != 1) {
|
|
||||||
throw new Error("Expected only one batch");
|
throw new Error("Expected only one batch");
|
||||||
}
|
}
|
||||||
return Promise.resolve({ done: false, value: tbl.batches[0] });
|
|
||||||
|
yield batches[0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/* eslint-enable */
|
|
||||||
|
|
||||||
class RecordBatchIterable<
|
class RecordBatchIterable<
|
||||||
NativeQueryType extends NativeQuery | NativeVectorQuery | NativeTakeQuery,
|
NativeQueryType extends NativeQuery | NativeVectorQuery | NativeTakeQuery,
|
||||||
@@ -64,7 +54,7 @@ class RecordBatchIterable<
|
|||||||
|
|
||||||
// biome-ignore lint/suspicious/noExplicitAny: skip
|
// biome-ignore lint/suspicious/noExplicitAny: skip
|
||||||
[Symbol.asyncIterator](): AsyncIterator<RecordBatch<any>, any, undefined> {
|
[Symbol.asyncIterator](): AsyncIterator<RecordBatch<any>, any, undefined> {
|
||||||
return new RecordBatchIterator(
|
return RecordBatchIterator(
|
||||||
this.inner.execute(this.options?.maxBatchLength, this.options?.timeoutMs),
|
this.inner.execute(this.options?.maxBatchLength, this.options?.timeoutMs),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -231,10 +221,8 @@ export class QueryBase<
|
|||||||
* single query)
|
* single query)
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
protected execute(
|
protected execute(options?: Partial<QueryExecutionOptions>) {
|
||||||
options?: Partial<QueryExecutionOptions>,
|
return RecordBatchIterator(this.nativeExecute(options));
|
||||||
): RecordBatchIterator {
|
|
||||||
return new RecordBatchIterator(this.nativeExecute(options));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -242,8 +230,7 @@ export class QueryBase<
|
|||||||
*/
|
*/
|
||||||
// biome-ignore lint/suspicious/noExplicitAny: skip
|
// biome-ignore lint/suspicious/noExplicitAny: skip
|
||||||
[Symbol.asyncIterator](): AsyncIterator<RecordBatch<any>> {
|
[Symbol.asyncIterator](): AsyncIterator<RecordBatch<any>> {
|
||||||
const promise = this.nativeExecute();
|
return RecordBatchIterator(this.nativeExecute());
|
||||||
return new RecordBatchIterator(promise);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Collect the results as an Arrow @see {@link ArrowTable}. */
|
/** Collect the results as an Arrow @see {@link ArrowTable}. */
|
||||||
@@ -326,6 +313,25 @@ export class QueryBase<
|
|||||||
return this.inner.analyzePlan();
|
return this.inner.analyzePlan();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the schema of the output that will be returned by this query.
|
||||||
|
*
|
||||||
|
* This can be used to inspect the types and names of the columns that will be
|
||||||
|
* returned by the query before executing it.
|
||||||
|
*
|
||||||
|
* @returns An Arrow Schema describing the output columns.
|
||||||
|
*/
|
||||||
|
async outputSchema(): Promise<import("./arrow").Schema> {
|
||||||
|
let schemaBuffer: Buffer;
|
||||||
|
if (this.inner instanceof Promise) {
|
||||||
|
schemaBuffer = await this.inner.then((inner) => inner.outputSchema());
|
||||||
|
} else {
|
||||||
|
schemaBuffer = await this.inner.outputSchema();
|
||||||
|
}
|
||||||
|
const schema = tableFromIPC(schemaBuffer).schema;
|
||||||
|
return schema;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export class StandardQueryBase<
|
export class StandardQueryBase<
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-darwin-arm64",
|
"name": "@lancedb/lancedb-darwin-arm64",
|
||||||
"version": "0.22.2-beta.2",
|
"version": "0.22.3-beta.5",
|
||||||
"os": ["darwin"],
|
"os": ["darwin"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.darwin-arm64.node",
|
"main": "lancedb.darwin-arm64.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-darwin-x64",
|
"name": "@lancedb/lancedb-darwin-x64",
|
||||||
"version": "0.22.2-beta.2",
|
"version": "0.22.3-beta.5",
|
||||||
"os": ["darwin"],
|
"os": ["darwin"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.darwin-x64.node",
|
"main": "lancedb.darwin-x64.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||||
"version": "0.22.2-beta.2",
|
"version": "0.22.3-beta.5",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.linux-arm64-gnu.node",
|
"main": "lancedb.linux-arm64-gnu.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||||
"version": "0.22.2-beta.2",
|
"version": "0.22.3-beta.5",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.linux-arm64-musl.node",
|
"main": "lancedb.linux-arm64-musl.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||||
"version": "0.22.2-beta.2",
|
"version": "0.22.3-beta.5",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.linux-x64-gnu.node",
|
"main": "lancedb.linux-x64-gnu.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||||
"version": "0.22.2-beta.2",
|
"version": "0.22.3-beta.5",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.linux-x64-musl.node",
|
"main": "lancedb.linux-x64-musl.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||||
"version": "0.22.2-beta.2",
|
"version": "0.22.3-beta.5",
|
||||||
"os": [
|
"os": [
|
||||||
"win32"
|
"win32"
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||||
"version": "0.22.2-beta.2",
|
"version": "0.22.3-beta.5",
|
||||||
"os": ["win32"],
|
"os": ["win32"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.win32-x64-msvc.node",
|
"main": "lancedb.win32-x64-msvc.node",
|
||||||
|
|||||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb",
|
"name": "@lancedb/lancedb",
|
||||||
"version": "0.22.2-beta.2",
|
"version": "0.22.3-beta.5",
|
||||||
"lockfileVersion": 3,
|
"lockfileVersion": 3,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
"": {
|
"": {
|
||||||
"name": "@lancedb/lancedb",
|
"name": "@lancedb/lancedb",
|
||||||
"version": "0.22.2-beta.2",
|
"version": "0.22.3-beta.5",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64",
|
"x64",
|
||||||
"arm64"
|
"arm64"
|
||||||
|
|||||||
@@ -11,7 +11,7 @@
|
|||||||
"ann"
|
"ann"
|
||||||
],
|
],
|
||||||
"private": false,
|
"private": false,
|
||||||
"version": "0.22.2-beta.2",
|
"version": "0.22.3-beta.5",
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
"exports": {
|
"exports": {
|
||||||
".": "./dist/index.js",
|
".": "./dist/index.js",
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use lancedb::database::CreateTableMode;
|
use lancedb::database::{CreateTableMode, Database};
|
||||||
use napi::bindgen_prelude::*;
|
use napi::bindgen_prelude::*;
|
||||||
use napi_derive::*;
|
use napi_derive::*;
|
||||||
|
|
||||||
@@ -41,6 +41,10 @@ impl Connection {
|
|||||||
_ => Err(napi::Error::from_reason(format!("Invalid mode {}", mode))),
|
_ => Err(napi::Error::from_reason(format!("Invalid mode {}", mode))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn database(&self) -> napi::Result<Arc<dyn Database>> {
|
||||||
|
Ok(self.get_inner()?.database().clone())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ use std::sync::Mutex;
|
|||||||
use lancedb::index::scalar::{BTreeIndexBuilder, FtsIndexBuilder};
|
use lancedb::index::scalar::{BTreeIndexBuilder, FtsIndexBuilder};
|
||||||
use lancedb::index::vector::{
|
use lancedb::index::vector::{
|
||||||
IvfFlatIndexBuilder, IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder,
|
IvfFlatIndexBuilder, IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder,
|
||||||
|
IvfRqIndexBuilder,
|
||||||
};
|
};
|
||||||
use lancedb::index::Index as LanceDbIndex;
|
use lancedb::index::Index as LanceDbIndex;
|
||||||
use napi_derive::napi;
|
use napi_derive::napi;
|
||||||
@@ -65,6 +66,36 @@ impl Index {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[napi(factory)]
|
||||||
|
pub fn ivf_rq(
|
||||||
|
distance_type: Option<String>,
|
||||||
|
num_partitions: Option<u32>,
|
||||||
|
num_bits: Option<u32>,
|
||||||
|
max_iterations: Option<u32>,
|
||||||
|
sample_rate: Option<u32>,
|
||||||
|
) -> napi::Result<Self> {
|
||||||
|
let mut ivf_rq_builder = IvfRqIndexBuilder::default();
|
||||||
|
if let Some(distance_type) = distance_type {
|
||||||
|
let distance_type = parse_distance_type(distance_type)?;
|
||||||
|
ivf_rq_builder = ivf_rq_builder.distance_type(distance_type);
|
||||||
|
}
|
||||||
|
if let Some(num_partitions) = num_partitions {
|
||||||
|
ivf_rq_builder = ivf_rq_builder.num_partitions(num_partitions);
|
||||||
|
}
|
||||||
|
if let Some(num_bits) = num_bits {
|
||||||
|
ivf_rq_builder = ivf_rq_builder.num_bits(num_bits);
|
||||||
|
}
|
||||||
|
if let Some(max_iterations) = max_iterations {
|
||||||
|
ivf_rq_builder = ivf_rq_builder.max_iterations(max_iterations);
|
||||||
|
}
|
||||||
|
if let Some(sample_rate) = sample_rate {
|
||||||
|
ivf_rq_builder = ivf_rq_builder.sample_rate(sample_rate);
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
inner: Mutex::new(Some(LanceDbIndex::IvfRq(ivf_rq_builder))),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
#[napi(factory)]
|
#[napi(factory)]
|
||||||
pub fn ivf_flat(
|
pub fn ivf_flat(
|
||||||
distance_type: Option<String>,
|
distance_type: Option<String>,
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ mod header;
|
|||||||
mod index;
|
mod index;
|
||||||
mod iterator;
|
mod iterator;
|
||||||
pub mod merge;
|
pub mod merge;
|
||||||
|
pub mod permutation;
|
||||||
mod query;
|
mod query;
|
||||||
pub mod remote;
|
pub mod remote;
|
||||||
mod rerankers;
|
mod rerankers;
|
||||||
|
|||||||
248
nodejs/src/permutation.rs
Normal file
248
nodejs/src/permutation.rs
Normal file
@@ -0,0 +1,248 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
use crate::{error::NapiErrorExt, table::Table};
|
||||||
|
use lancedb::dataloader::{
|
||||||
|
permutation::builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
|
||||||
|
permutation::split::{SplitSizes, SplitStrategy},
|
||||||
|
};
|
||||||
|
use napi_derive::napi;
|
||||||
|
|
||||||
|
#[napi(object)]
|
||||||
|
pub struct SplitRandomOptions {
|
||||||
|
pub ratios: Option<Vec<f64>>,
|
||||||
|
pub counts: Option<Vec<i64>>,
|
||||||
|
pub fixed: Option<i64>,
|
||||||
|
pub seed: Option<i64>,
|
||||||
|
pub split_names: Option<Vec<String>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi(object)]
|
||||||
|
pub struct SplitHashOptions {
|
||||||
|
pub columns: Vec<String>,
|
||||||
|
pub split_weights: Vec<i64>,
|
||||||
|
pub discard_weight: Option<i64>,
|
||||||
|
pub split_names: Option<Vec<String>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi(object)]
|
||||||
|
pub struct SplitSequentialOptions {
|
||||||
|
pub ratios: Option<Vec<f64>>,
|
||||||
|
pub counts: Option<Vec<i64>>,
|
||||||
|
pub fixed: Option<i64>,
|
||||||
|
pub split_names: Option<Vec<String>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi(object)]
|
||||||
|
pub struct SplitCalculatedOptions {
|
||||||
|
pub calculation: String,
|
||||||
|
pub split_names: Option<Vec<String>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi(object)]
|
||||||
|
pub struct ShuffleOptions {
|
||||||
|
pub seed: Option<i64>,
|
||||||
|
pub clump_size: Option<i64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct PermutationBuilderState {
|
||||||
|
pub builder: Option<LancePermutationBuilder>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub struct PermutationBuilder {
|
||||||
|
state: Arc<Mutex<PermutationBuilderState>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PermutationBuilder {
|
||||||
|
pub fn new(builder: LancePermutationBuilder) -> Self {
|
||||||
|
Self {
|
||||||
|
state: Arc::new(Mutex::new(PermutationBuilderState {
|
||||||
|
builder: Some(builder),
|
||||||
|
})),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PermutationBuilder {
|
||||||
|
fn modify(
|
||||||
|
&self,
|
||||||
|
func: impl FnOnce(LancePermutationBuilder) -> LancePermutationBuilder,
|
||||||
|
) -> napi::Result<Self> {
|
||||||
|
let mut state = self.state.lock().unwrap();
|
||||||
|
let builder = state
|
||||||
|
.builder
|
||||||
|
.take()
|
||||||
|
.ok_or_else(|| napi::Error::from_reason("Builder already consumed"))?;
|
||||||
|
state.builder = Some(func(builder));
|
||||||
|
Ok(Self {
|
||||||
|
state: self.state.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
impl PermutationBuilder {
|
||||||
|
#[napi]
|
||||||
|
pub fn persist(
|
||||||
|
&self,
|
||||||
|
connection: &crate::connection::Connection,
|
||||||
|
table_name: String,
|
||||||
|
) -> napi::Result<Self> {
|
||||||
|
let database = connection.database()?;
|
||||||
|
self.modify(|builder| builder.persist(database, table_name))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Configure random splits
|
||||||
|
#[napi]
|
||||||
|
pub fn split_random(&self, options: SplitRandomOptions) -> napi::Result<Self> {
|
||||||
|
// Check that exactly one split type is provided
|
||||||
|
let split_args_count = [
|
||||||
|
options.ratios.is_some(),
|
||||||
|
options.counts.is_some(),
|
||||||
|
options.fixed.is_some(),
|
||||||
|
]
|
||||||
|
.iter()
|
||||||
|
.filter(|&&x| x)
|
||||||
|
.count();
|
||||||
|
|
||||||
|
if split_args_count != 1 {
|
||||||
|
return Err(napi::Error::from_reason(
|
||||||
|
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let sizes = if let Some(ratios) = options.ratios {
|
||||||
|
SplitSizes::Percentages(ratios)
|
||||||
|
} else if let Some(counts) = options.counts {
|
||||||
|
SplitSizes::Counts(counts.into_iter().map(|c| c as u64).collect())
|
||||||
|
} else if let Some(fixed) = options.fixed {
|
||||||
|
SplitSizes::Fixed(fixed as u64)
|
||||||
|
} else {
|
||||||
|
unreachable!("One of the split arguments must be provided");
|
||||||
|
};
|
||||||
|
|
||||||
|
let seed = options.seed.map(|s| s as u64);
|
||||||
|
|
||||||
|
self.modify(|builder| {
|
||||||
|
builder.with_split_strategy(
|
||||||
|
SplitStrategy::Random { seed, sizes },
|
||||||
|
options.split_names.clone(),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Configure hash-based splits
|
||||||
|
#[napi]
|
||||||
|
pub fn split_hash(&self, options: SplitHashOptions) -> napi::Result<Self> {
|
||||||
|
let split_weights = options
|
||||||
|
.split_weights
|
||||||
|
.into_iter()
|
||||||
|
.map(|w| w as u64)
|
||||||
|
.collect();
|
||||||
|
let discard_weight = options.discard_weight.unwrap_or(0) as u64;
|
||||||
|
|
||||||
|
self.modify(move |builder| {
|
||||||
|
builder.with_split_strategy(
|
||||||
|
SplitStrategy::Hash {
|
||||||
|
columns: options.columns,
|
||||||
|
split_weights,
|
||||||
|
discard_weight,
|
||||||
|
},
|
||||||
|
options.split_names,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Configure sequential splits
|
||||||
|
#[napi]
|
||||||
|
pub fn split_sequential(&self, options: SplitSequentialOptions) -> napi::Result<Self> {
|
||||||
|
// Check that exactly one split type is provided
|
||||||
|
let split_args_count = [
|
||||||
|
options.ratios.is_some(),
|
||||||
|
options.counts.is_some(),
|
||||||
|
options.fixed.is_some(),
|
||||||
|
]
|
||||||
|
.iter()
|
||||||
|
.filter(|&&x| x)
|
||||||
|
.count();
|
||||||
|
|
||||||
|
if split_args_count != 1 {
|
||||||
|
return Err(napi::Error::from_reason(
|
||||||
|
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let sizes = if let Some(ratios) = options.ratios {
|
||||||
|
SplitSizes::Percentages(ratios)
|
||||||
|
} else if let Some(counts) = options.counts {
|
||||||
|
SplitSizes::Counts(counts.into_iter().map(|c| c as u64).collect())
|
||||||
|
} else if let Some(fixed) = options.fixed {
|
||||||
|
SplitSizes::Fixed(fixed as u64)
|
||||||
|
} else {
|
||||||
|
unreachable!("One of the split arguments must be provided");
|
||||||
|
};
|
||||||
|
|
||||||
|
self.modify(move |builder| {
|
||||||
|
builder.with_split_strategy(SplitStrategy::Sequential { sizes }, options.split_names)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Configure calculated splits
|
||||||
|
#[napi]
|
||||||
|
pub fn split_calculated(&self, options: SplitCalculatedOptions) -> napi::Result<Self> {
|
||||||
|
self.modify(move |builder| {
|
||||||
|
builder.with_split_strategy(
|
||||||
|
SplitStrategy::Calculated {
|
||||||
|
calculation: options.calculation,
|
||||||
|
},
|
||||||
|
options.split_names,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Configure shuffling
|
||||||
|
#[napi]
|
||||||
|
pub fn shuffle(&self, options: ShuffleOptions) -> napi::Result<Self> {
|
||||||
|
let seed = options.seed.map(|s| s as u64);
|
||||||
|
let clump_size = options.clump_size.map(|c| c as u64);
|
||||||
|
|
||||||
|
self.modify(|builder| {
|
||||||
|
builder.with_shuffle_strategy(ShuffleStrategy::Random { seed, clump_size })
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Configure filtering
|
||||||
|
#[napi]
|
||||||
|
pub fn filter(&self, filter: String) -> napi::Result<Self> {
|
||||||
|
self.modify(|builder| builder.with_filter(filter))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Execute the permutation builder and create the table
|
||||||
|
#[napi]
|
||||||
|
pub async fn execute(&self) -> napi::Result<Table> {
|
||||||
|
let builder = {
|
||||||
|
let mut state = self.state.lock().unwrap();
|
||||||
|
state
|
||||||
|
.builder
|
||||||
|
.take()
|
||||||
|
.ok_or_else(|| napi::Error::from_reason("Builder already consumed"))?
|
||||||
|
};
|
||||||
|
|
||||||
|
let table = builder.build().await.default_error()?;
|
||||||
|
Ok(Table::new(table))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a permutation builder for the given table
|
||||||
|
#[napi]
|
||||||
|
pub fn permutation_builder(table: &crate::table::Table) -> napi::Result<PermutationBuilder> {
|
||||||
|
use lancedb::dataloader::permutation::builder::PermutationBuilder as LancePermutationBuilder;
|
||||||
|
|
||||||
|
let inner_table = table.inner_ref()?.clone();
|
||||||
|
let inner_builder = LancePermutationBuilder::new(inner_table);
|
||||||
|
|
||||||
|
Ok(PermutationBuilder::new(inner_builder))
|
||||||
|
}
|
||||||
@@ -22,7 +22,7 @@ use crate::error::NapiErrorExt;
|
|||||||
use crate::iterator::RecordBatchIterator;
|
use crate::iterator::RecordBatchIterator;
|
||||||
use crate::rerankers::Reranker;
|
use crate::rerankers::Reranker;
|
||||||
use crate::rerankers::RerankerCallbacks;
|
use crate::rerankers::RerankerCallbacks;
|
||||||
use crate::util::parse_distance_type;
|
use crate::util::{parse_distance_type, schema_to_buffer};
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
pub struct Query {
|
pub struct Query {
|
||||||
@@ -88,6 +88,12 @@ impl Query {
|
|||||||
self.inner = self.inner.clone().with_row_id();
|
self.inner = self.inner.clone().with_row_id();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[napi(catch_unwind)]
|
||||||
|
pub async fn output_schema(&self) -> napi::Result<Buffer> {
|
||||||
|
let schema = self.inner.output_schema().await.default_error()?;
|
||||||
|
schema_to_buffer(&schema)
|
||||||
|
}
|
||||||
|
|
||||||
#[napi(catch_unwind)]
|
#[napi(catch_unwind)]
|
||||||
pub async fn execute(
|
pub async fn execute(
|
||||||
&self,
|
&self,
|
||||||
@@ -273,6 +279,12 @@ impl VectorQuery {
|
|||||||
.rerank(Arc::new(Reranker::new(callbacks)));
|
.rerank(Arc::new(Reranker::new(callbacks)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[napi(catch_unwind)]
|
||||||
|
pub async fn output_schema(&self) -> napi::Result<Buffer> {
|
||||||
|
let schema = self.inner.output_schema().await.default_error()?;
|
||||||
|
schema_to_buffer(&schema)
|
||||||
|
}
|
||||||
|
|
||||||
#[napi(catch_unwind)]
|
#[napi(catch_unwind)]
|
||||||
pub async fn execute(
|
pub async fn execute(
|
||||||
&self,
|
&self,
|
||||||
@@ -346,6 +358,12 @@ impl TakeQuery {
|
|||||||
self.inner = self.inner.clone().with_row_id();
|
self.inner = self.inner.clone().with_row_id();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[napi(catch_unwind)]
|
||||||
|
pub async fn output_schema(&self) -> napi::Result<Buffer> {
|
||||||
|
let schema = self.inner.output_schema().await.default_error()?;
|
||||||
|
schema_to_buffer(&schema)
|
||||||
|
}
|
||||||
|
|
||||||
#[napi(catch_unwind)]
|
#[napi(catch_unwind)]
|
||||||
pub async fn execute(
|
pub async fn execute(
|
||||||
&self,
|
&self,
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use arrow_ipc::writer::FileWriter;
|
|
||||||
use lancedb::ipc::ipc_file_to_batches;
|
use lancedb::ipc::ipc_file_to_batches;
|
||||||
use lancedb::table::{
|
use lancedb::table::{
|
||||||
AddDataMode, ColumnAlteration as LanceColumnAlteration, Duration, NewColumnTransform,
|
AddDataMode, ColumnAlteration as LanceColumnAlteration, Duration, NewColumnTransform,
|
||||||
@@ -16,6 +15,7 @@ use crate::error::NapiErrorExt;
|
|||||||
use crate::index::Index;
|
use crate::index::Index;
|
||||||
use crate::merge::NativeMergeInsertBuilder;
|
use crate::merge::NativeMergeInsertBuilder;
|
||||||
use crate::query::{Query, TakeQuery, VectorQuery};
|
use crate::query::{Query, TakeQuery, VectorQuery};
|
||||||
|
use crate::util::schema_to_buffer;
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
pub struct Table {
|
pub struct Table {
|
||||||
@@ -26,7 +26,7 @@ pub struct Table {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Table {
|
impl Table {
|
||||||
fn inner_ref(&self) -> napi::Result<&LanceDbTable> {
|
pub(crate) fn inner_ref(&self) -> napi::Result<&LanceDbTable> {
|
||||||
self.inner
|
self.inner
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.ok_or_else(|| napi::Error::from_reason(format!("Table {} is closed", self.name)))
|
.ok_or_else(|| napi::Error::from_reason(format!("Table {} is closed", self.name)))
|
||||||
@@ -64,14 +64,7 @@ impl Table {
|
|||||||
#[napi(catch_unwind)]
|
#[napi(catch_unwind)]
|
||||||
pub async fn schema(&self) -> napi::Result<Buffer> {
|
pub async fn schema(&self) -> napi::Result<Buffer> {
|
||||||
let schema = self.inner_ref()?.schema().await.default_error()?;
|
let schema = self.inner_ref()?.schema().await.default_error()?;
|
||||||
let mut writer = FileWriter::try_new(vec![], &schema)
|
schema_to_buffer(&schema)
|
||||||
.map_err(|e| napi::Error::from_reason(format!("Failed to create IPC file: {}", e)))?;
|
|
||||||
writer
|
|
||||||
.finish()
|
|
||||||
.map_err(|e| napi::Error::from_reason(format!("Failed to finish IPC file: {}", e)))?;
|
|
||||||
Ok(Buffer::from(writer.into_inner().map_err(|e| {
|
|
||||||
napi::Error::from_reason(format!("Failed to get IPC file: {}", e))
|
|
||||||
})?))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi(catch_unwind)]
|
#[napi(catch_unwind)]
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
use arrow_ipc::writer::FileWriter;
|
||||||
|
use arrow_schema::Schema;
|
||||||
use lancedb::DistanceType;
|
use lancedb::DistanceType;
|
||||||
|
use napi::bindgen_prelude::Buffer;
|
||||||
|
|
||||||
pub fn parse_distance_type(distance_type: impl AsRef<str>) -> napi::Result<DistanceType> {
|
pub fn parse_distance_type(distance_type: impl AsRef<str>) -> napi::Result<DistanceType> {
|
||||||
match distance_type.as_ref().to_lowercase().as_str() {
|
match distance_type.as_ref().to_lowercase().as_str() {
|
||||||
@@ -15,3 +18,15 @@ pub fn parse_distance_type(distance_type: impl AsRef<str>) -> napi::Result<Dista
|
|||||||
))),
|
))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Convert an Arrow Schema to an Arrow IPC file buffer
|
||||||
|
pub fn schema_to_buffer(schema: &Schema) -> napi::Result<Buffer> {
|
||||||
|
let mut writer = FileWriter::try_new(vec![], schema)
|
||||||
|
.map_err(|e| napi::Error::from_reason(format!("Failed to create IPC file: {}", e)))?;
|
||||||
|
writer
|
||||||
|
.finish()
|
||||||
|
.map_err(|e| napi::Error::from_reason(format!("Failed to finish IPC file: {}", e)))?;
|
||||||
|
Ok(Buffer::from(writer.into_inner().map_err(|e| {
|
||||||
|
napi::Error::from_reason(format!("Failed to get IPC file: {}", e))
|
||||||
|
})?))
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[tool.bumpversion]
|
[tool.bumpversion]
|
||||||
current_version = "0.25.2"
|
current_version = "0.25.3"
|
||||||
parse = """(?x)
|
parse = """(?x)
|
||||||
(?P<major>0|[1-9]\\d*)\\.
|
(?P<major>0|[1-9]\\d*)\\.
|
||||||
(?P<minor>0|[1-9]\\d*)\\.
|
(?P<minor>0|[1-9]\\d*)\\.
|
||||||
@@ -24,6 +24,19 @@ commit = true
|
|||||||
message = "Bump version: {current_version} → {new_version}"
|
message = "Bump version: {current_version} → {new_version}"
|
||||||
commit_args = ""
|
commit_args = ""
|
||||||
|
|
||||||
|
# Update Cargo.lock after version bump
|
||||||
|
pre_commit_hooks = [
|
||||||
|
"""
|
||||||
|
cd python && cargo update -p lancedb-python
|
||||||
|
if git diff --quiet ../Cargo.lock; then
|
||||||
|
echo "Cargo.lock unchanged"
|
||||||
|
else
|
||||||
|
git add ../Cargo.lock
|
||||||
|
echo "Updated and staged Cargo.lock"
|
||||||
|
fi
|
||||||
|
""",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.bumpversion.parts.pre_l]
|
[tool.bumpversion.parts.pre_l]
|
||||||
values = ["beta", "final"]
|
values = ["beta", "final"]
|
||||||
optional_value = "final"
|
optional_value = "final"
|
||||||
|
|||||||
19
python/AGENTS.md
Normal file
19
python/AGENTS.md
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
These are the Python bindings of LanceDB.
|
||||||
|
The core Rust library is in the `../rust/lancedb` directory, the rust binding
|
||||||
|
code is in the `src/` directory and the Python bindings are in the `lancedb/` directory.
|
||||||
|
|
||||||
|
Common commands:
|
||||||
|
|
||||||
|
* Build: `make develop`
|
||||||
|
* Format: `make format`
|
||||||
|
* Lint: `make check`
|
||||||
|
* Fix lints: `make fix`
|
||||||
|
* Test: `make test`
|
||||||
|
* Doc test: `make doctest`
|
||||||
|
|
||||||
|
Before committing changes, run lints and then formatting.
|
||||||
|
|
||||||
|
When you change the Rust code, you will need to recompile the Python bindings: `make develop`.
|
||||||
|
|
||||||
|
When you export new types from Rust to Python, you must manually update `python/lancedb/_lancedb.pyi`
|
||||||
|
with the corresponding type hints. You can run `pyright` to check for type errors in the Python code.
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
These are the Python bindings of LanceDB.
|
|
||||||
The core Rust library is in the `../rust/lancedb` directory, the rust binding
|
|
||||||
code is in the `src/` directory and the Python bindings are in the `lancedb/` directory.
|
|
||||||
|
|
||||||
Common commands:
|
|
||||||
|
|
||||||
* Build: `make develop`
|
|
||||||
* Format: `make format`
|
|
||||||
* Lint: `make check`
|
|
||||||
* Fix lints: `make fix`
|
|
||||||
* Test: `make test`
|
|
||||||
* Doc test: `make doctest`
|
|
||||||
|
|
||||||
Before committing changes, run lints and then formatting.
|
|
||||||
|
|
||||||
When you change the Rust code, you will need to recompile the Python bindings: `make develop`.
|
|
||||||
|
|
||||||
When you export new types from Rust to Python, you must manually update `python/lancedb/_lancedb.pyi`
|
|
||||||
with the corresponding type hints. You can run `pyright` to check for type errors in the Python code.
|
|
||||||
1
python/CLAUDE.md
Symbolic link
1
python/CLAUDE.md
Symbolic link
@@ -0,0 +1 @@
|
|||||||
|
AGENTS.md
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb-python"
|
name = "lancedb-python"
|
||||||
version = "0.25.2"
|
version = "0.25.3"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
description = "Python bindings for LanceDB"
|
description = "Python bindings for LanceDB"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from .db import AsyncConnection, DBConnection, LanceDBConnection
|
|||||||
from .remote import ClientConfig
|
from .remote import ClientConfig
|
||||||
from .remote.db import RemoteDBConnection
|
from .remote.db import RemoteDBConnection
|
||||||
from .schema import vector
|
from .schema import vector
|
||||||
from .table import AsyncTable
|
from .table import AsyncTable, Table
|
||||||
from ._lancedb import Session
|
from ._lancedb import Session
|
||||||
from .namespace import connect_namespace, LanceNamespaceDBConnection
|
from .namespace import connect_namespace, LanceNamespaceDBConnection
|
||||||
|
|
||||||
@@ -233,6 +233,7 @@ __all__ = [
|
|||||||
"LanceNamespaceDBConnection",
|
"LanceNamespaceDBConnection",
|
||||||
"RemoteDBConnection",
|
"RemoteDBConnection",
|
||||||
"Session",
|
"Session",
|
||||||
|
"Table",
|
||||||
"__version__",
|
"__version__",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -123,6 +123,8 @@ class Table:
|
|||||||
@property
|
@property
|
||||||
def tags(self) -> Tags: ...
|
def tags(self) -> Tags: ...
|
||||||
def query(self) -> Query: ...
|
def query(self) -> Query: ...
|
||||||
|
def take_offsets(self, offsets: list[int]) -> TakeQuery: ...
|
||||||
|
def take_row_ids(self, row_ids: list[int]) -> TakeQuery: ...
|
||||||
def vector_search(self) -> VectorQuery: ...
|
def vector_search(self) -> VectorQuery: ...
|
||||||
|
|
||||||
class Tags:
|
class Tags:
|
||||||
@@ -165,6 +167,7 @@ class Query:
|
|||||||
def postfilter(self): ...
|
def postfilter(self): ...
|
||||||
def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ...
|
def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ...
|
||||||
def nearest_to_text(self, query: dict) -> FTSQuery: ...
|
def nearest_to_text(self, query: dict) -> FTSQuery: ...
|
||||||
|
async def output_schema(self) -> pa.Schema: ...
|
||||||
async def execute(
|
async def execute(
|
||||||
self, max_batch_length: Optional[int], timeout: Optional[timedelta]
|
self, max_batch_length: Optional[int], timeout: Optional[timedelta]
|
||||||
) -> RecordBatchStream: ...
|
) -> RecordBatchStream: ...
|
||||||
@@ -172,6 +175,13 @@ class Query:
|
|||||||
async def analyze_plan(self) -> str: ...
|
async def analyze_plan(self) -> str: ...
|
||||||
def to_query_request(self) -> PyQueryRequest: ...
|
def to_query_request(self) -> PyQueryRequest: ...
|
||||||
|
|
||||||
|
class TakeQuery:
|
||||||
|
def select(self, columns: List[str]): ...
|
||||||
|
def with_row_id(self): ...
|
||||||
|
async def output_schema(self) -> pa.Schema: ...
|
||||||
|
async def execute(self) -> RecordBatchStream: ...
|
||||||
|
def to_query_request(self) -> PyQueryRequest: ...
|
||||||
|
|
||||||
class FTSQuery:
|
class FTSQuery:
|
||||||
def where(self, filter: str): ...
|
def where(self, filter: str): ...
|
||||||
def select(self, columns: List[str]): ...
|
def select(self, columns: List[str]): ...
|
||||||
@@ -183,12 +193,14 @@ class FTSQuery:
|
|||||||
def get_query(self) -> str: ...
|
def get_query(self) -> str: ...
|
||||||
def add_query_vector(self, query_vec: pa.Array) -> None: ...
|
def add_query_vector(self, query_vec: pa.Array) -> None: ...
|
||||||
def nearest_to(self, query_vec: pa.Array) -> HybridQuery: ...
|
def nearest_to(self, query_vec: pa.Array) -> HybridQuery: ...
|
||||||
|
async def output_schema(self) -> pa.Schema: ...
|
||||||
async def execute(
|
async def execute(
|
||||||
self, max_batch_length: Optional[int], timeout: Optional[timedelta]
|
self, max_batch_length: Optional[int], timeout: Optional[timedelta]
|
||||||
) -> RecordBatchStream: ...
|
) -> RecordBatchStream: ...
|
||||||
def to_query_request(self) -> PyQueryRequest: ...
|
def to_query_request(self) -> PyQueryRequest: ...
|
||||||
|
|
||||||
class VectorQuery:
|
class VectorQuery:
|
||||||
|
async def output_schema(self) -> pa.Schema: ...
|
||||||
async def execute(self) -> RecordBatchStream: ...
|
async def execute(self) -> RecordBatchStream: ...
|
||||||
def where(self, filter: str): ...
|
def where(self, filter: str): ...
|
||||||
def select(self, columns: List[str]): ...
|
def select(self, columns: List[str]): ...
|
||||||
@@ -296,3 +308,38 @@ class AlterColumnsResult:
|
|||||||
|
|
||||||
class DropColumnsResult:
|
class DropColumnsResult:
|
||||||
version: int
|
version: int
|
||||||
|
|
||||||
|
class AsyncPermutationBuilder:
|
||||||
|
def select(self, projections: Dict[str, str]) -> "AsyncPermutationBuilder": ...
|
||||||
|
def split_random(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ratios: Optional[List[float]] = None,
|
||||||
|
counts: Optional[List[int]] = None,
|
||||||
|
fixed: Optional[int] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
) -> "AsyncPermutationBuilder": ...
|
||||||
|
def split_hash(
|
||||||
|
self, columns: List[str], split_weights: List[int], *, discard_weight: int = 0
|
||||||
|
) -> "AsyncPermutationBuilder": ...
|
||||||
|
def split_sequential(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ratios: Optional[List[float]] = None,
|
||||||
|
counts: Optional[List[int]] = None,
|
||||||
|
fixed: Optional[int] = None,
|
||||||
|
) -> "AsyncPermutationBuilder": ...
|
||||||
|
def split_calculated(self, calculation: str) -> "AsyncPermutationBuilder": ...
|
||||||
|
def shuffle(
|
||||||
|
self, seed: Optional[int], clump_size: Optional[int]
|
||||||
|
) -> "AsyncPermutationBuilder": ...
|
||||||
|
def filter(self, filter: str) -> "AsyncPermutationBuilder": ...
|
||||||
|
async def execute(self) -> Table: ...
|
||||||
|
|
||||||
|
def async_permutation_builder(
|
||||||
|
table: Table, dest_table_name: str
|
||||||
|
) -> AsyncPermutationBuilder: ...
|
||||||
|
def fts_query_to_json(query: Any) -> str: ...
|
||||||
|
|
||||||
|
class PermutationReader:
|
||||||
|
def __init__(self, base_table: Table, permutation_table: Table): ...
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
from datetime import timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import sys
|
import sys
|
||||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Union
|
from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Union
|
||||||
@@ -40,7 +41,6 @@ import deprecation
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
from .pydantic import LanceModel
|
from .pydantic import LanceModel
|
||||||
from datetime import timedelta
|
|
||||||
|
|
||||||
from ._lancedb import Connection as LanceDbConnection
|
from ._lancedb import Connection as LanceDbConnection
|
||||||
from .common import DATA, URI
|
from .common import DATA, URI
|
||||||
@@ -452,7 +452,12 @@ class LanceDBConnection(DBConnection):
|
|||||||
read_consistency_interval: Optional[timedelta] = None,
|
read_consistency_interval: Optional[timedelta] = None,
|
||||||
storage_options: Optional[Dict[str, str]] = None,
|
storage_options: Optional[Dict[str, str]] = None,
|
||||||
session: Optional[Session] = None,
|
session: Optional[Session] = None,
|
||||||
|
_inner: Optional[LanceDbConnection] = None,
|
||||||
):
|
):
|
||||||
|
if _inner is not None:
|
||||||
|
self._conn = _inner
|
||||||
|
return
|
||||||
|
|
||||||
if not isinstance(uri, Path):
|
if not isinstance(uri, Path):
|
||||||
scheme = get_uri_scheme(uri)
|
scheme = get_uri_scheme(uri)
|
||||||
is_local = isinstance(uri, Path) or scheme == "file"
|
is_local = isinstance(uri, Path) or scheme == "file"
|
||||||
@@ -461,11 +466,6 @@ class LanceDBConnection(DBConnection):
|
|||||||
uri = Path(uri)
|
uri = Path(uri)
|
||||||
uri = uri.expanduser().absolute()
|
uri = uri.expanduser().absolute()
|
||||||
Path(uri).mkdir(parents=True, exist_ok=True)
|
Path(uri).mkdir(parents=True, exist_ok=True)
|
||||||
self._uri = str(uri)
|
|
||||||
self._entered = False
|
|
||||||
self.read_consistency_interval = read_consistency_interval
|
|
||||||
self.storage_options = storage_options
|
|
||||||
self.session = session
|
|
||||||
|
|
||||||
if read_consistency_interval is not None:
|
if read_consistency_interval is not None:
|
||||||
read_consistency_interval_secs = read_consistency_interval.total_seconds()
|
read_consistency_interval_secs = read_consistency_interval.total_seconds()
|
||||||
@@ -484,10 +484,32 @@ class LanceDBConnection(DBConnection):
|
|||||||
session,
|
session,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: It would be nice if we didn't store self.storage_options but it is
|
||||||
|
# currently used by the LanceTable.to_lance method. This doesn't _really_
|
||||||
|
# work because some paths like LanceDBConnection.from_inner will lose the
|
||||||
|
# storage_options. Also, this class really shouldn't be holding any state
|
||||||
|
# beyond _conn.
|
||||||
|
self.storage_options = storage_options
|
||||||
self._conn = AsyncConnection(LOOP.run(do_connect()))
|
self._conn = AsyncConnection(LOOP.run(do_connect()))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def read_consistency_interval(self) -> Optional[timedelta]:
|
||||||
|
return LOOP.run(self._conn.get_read_consistency_interval())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def session(self) -> Optional[Session]:
|
||||||
|
return self._conn.session
|
||||||
|
|
||||||
|
@property
|
||||||
|
def uri(self) -> str:
|
||||||
|
return self._conn.uri
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_inner(cls, inner: LanceDbConnection):
|
||||||
|
return cls(None, _inner=inner)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
val = f"{self.__class__.__name__}(uri={self._uri!r}"
|
val = f"{self.__class__.__name__}(uri={self._conn.uri!r}"
|
||||||
if self.read_consistency_interval is not None:
|
if self.read_consistency_interval is not None:
|
||||||
val += f", read_consistency_interval={repr(self.read_consistency_interval)}"
|
val += f", read_consistency_interval={repr(self.read_consistency_interval)}"
|
||||||
val += ")"
|
val += ")"
|
||||||
@@ -497,6 +519,10 @@ class LanceDBConnection(DBConnection):
|
|||||||
conn = AsyncConnection(await lancedb_connect(self.uri))
|
conn = AsyncConnection(await lancedb_connect(self.uri))
|
||||||
return await conn.table_names(start_after=start_after, limit=limit)
|
return await conn.table_names(start_after=start_after, limit=limit)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _inner(self) -> LanceDbConnection:
|
||||||
|
return self._conn._inner
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def list_namespaces(
|
def list_namespaces(
|
||||||
self,
|
self,
|
||||||
@@ -856,6 +882,13 @@ class AsyncConnection(object):
|
|||||||
def uri(self) -> str:
|
def uri(self) -> str:
|
||||||
return self._inner.uri
|
return self._inner.uri
|
||||||
|
|
||||||
|
async def get_read_consistency_interval(self) -> Optional[timedelta]:
|
||||||
|
interval_secs = await self._inner.get_read_consistency_interval()
|
||||||
|
if interval_secs is not None:
|
||||||
|
return timedelta(seconds=interval_secs)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
async def list_namespaces(
|
async def list_namespaces(
|
||||||
self,
|
self,
|
||||||
namespace: List[str] = [],
|
namespace: List[str] = [],
|
||||||
|
|||||||
@@ -3,9 +3,11 @@
|
|||||||
|
|
||||||
|
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import List, Union, Optional, Any
|
from logging import warning
|
||||||
|
from typing import List, Union, Optional, Any, Callable
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import io
|
import io
|
||||||
|
import warnings
|
||||||
|
|
||||||
from ..util import attempt_import_or_raise
|
from ..util import attempt_import_or_raise
|
||||||
from .base import EmbeddingFunction
|
from .base import EmbeddingFunction
|
||||||
@@ -19,35 +21,52 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
|||||||
An embedding function that uses the ColPali engine for
|
An embedding function that uses the ColPali engine for
|
||||||
multimodal multi-vector embeddings.
|
multimodal multi-vector embeddings.
|
||||||
|
|
||||||
This embedding function supports ColQwen2.5 models, producing multivector outputs
|
This embedding function supports ColPali models, producing multivector outputs
|
||||||
for both text and image inputs. The output embeddings are lists of vectors, each
|
for both text and image inputs.
|
||||||
vector being 128-dimensional by default, represented as List[List[float]].
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
model_name : str
|
model_name : str
|
||||||
The name of the model to use (e.g., "Metric-AI/ColQwen2.5-3b-multilingual-v1.0")
|
The name of the model to use (e.g., "Metric-AI/ColQwen2.5-3b-multilingual-v1.0")
|
||||||
|
Supports models based on these engines:
|
||||||
|
- ColPali: "vidore/colpali-v1.3" and others
|
||||||
|
- ColQwen2.5: "Metric-AI/ColQwen2.5-3b-multilingual-v1.0" and others
|
||||||
|
- ColQwen2: "vidore/colqwen2-v1.0" and others
|
||||||
|
- ColSmol: "vidore/colSmol-256M" and others
|
||||||
|
|
||||||
device : str
|
device : str
|
||||||
The device for inference (default "cuda:0").
|
The device for inference (default "auto").
|
||||||
dtype : str
|
dtype : str
|
||||||
Data type for model weights (default "bfloat16").
|
Data type for model weights (default "bfloat16").
|
||||||
use_token_pooling : bool
|
use_token_pooling : bool
|
||||||
Whether to use token pooling to reduce embedding size (default True).
|
DEPRECATED. Whether to use token pooling. Use `pooling_strategy` instead.
|
||||||
|
pooling_strategy : str, optional
|
||||||
|
The token pooling strategy to use, by default "hierarchical".
|
||||||
|
- "hierarchical": Progressively pools tokens to reduce sequence length.
|
||||||
|
- "lambda": A simpler pooling that uses a custom `pooling_func`.
|
||||||
|
pooling_func: typing.Callable, optional
|
||||||
|
A function to use for pooling when `pooling_strategy` is "lambda".
|
||||||
pool_factor : int
|
pool_factor : int
|
||||||
Factor to reduce sequence length if token pooling is enabled (default 2).
|
Factor to reduce sequence length if token pooling is enabled (default 2).
|
||||||
quantization_config : Optional[BitsAndBytesConfig]
|
quantization_config : Optional[BitsAndBytesConfig]
|
||||||
Quantization configuration for the model. (default None, bitsandbytes needed)
|
Quantization configuration for the model. (default None, bitsandbytes needed)
|
||||||
batch_size : int
|
batch_size : int
|
||||||
Batch size for processing inputs (default 2).
|
Batch size for processing inputs (default 2).
|
||||||
|
offload_folder: str, optional
|
||||||
|
Folder to offload model weights if using CPU offloading (default None). This is
|
||||||
|
useful for large models that do not fit in memory.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_name: str = "Metric-AI/ColQwen2.5-3b-multilingual-v1.0"
|
model_name: str = "Metric-AI/ColQwen2.5-3b-multilingual-v1.0"
|
||||||
device: str = "auto"
|
device: str = "auto"
|
||||||
dtype: str = "bfloat16"
|
dtype: str = "bfloat16"
|
||||||
use_token_pooling: bool = True
|
use_token_pooling: bool = True
|
||||||
|
pooling_strategy: Optional[str] = "hierarchical"
|
||||||
|
pooling_func: Optional[Any] = None
|
||||||
pool_factor: int = 2
|
pool_factor: int = 2
|
||||||
quantization_config: Optional[Any] = None
|
quantization_config: Optional[Any] = None
|
||||||
batch_size: int = 2
|
batch_size: int = 2
|
||||||
|
offload_folder: Optional[str] = None
|
||||||
|
|
||||||
_model = None
|
_model = None
|
||||||
_processor = None
|
_processor = None
|
||||||
@@ -56,15 +75,43 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
torch = attempt_import_or_raise("torch", "torch")
|
||||||
|
|
||||||
|
if not self.use_token_pooling:
|
||||||
|
warnings.warn(
|
||||||
|
"use_token_pooling is deprecated, use pooling_strategy=None instead",
|
||||||
|
DeprecationWarning,
|
||||||
|
)
|
||||||
|
self.pooling_strategy = None
|
||||||
|
|
||||||
|
if self.pooling_strategy == "lambda" and self.pooling_func is None:
|
||||||
|
raise ValueError(
|
||||||
|
"pooling_func must be provided when pooling_strategy is 'lambda'"
|
||||||
|
)
|
||||||
|
|
||||||
|
device = self.device
|
||||||
|
if device == "auto":
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
device = "mps"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
dtype = self.dtype
|
||||||
|
if device == "mps" and dtype == "bfloat16":
|
||||||
|
dtype = "float32" # Avoid NaNs on MPS
|
||||||
|
|
||||||
(
|
(
|
||||||
self._model,
|
self._model,
|
||||||
self._processor,
|
self._processor,
|
||||||
self._token_pooler,
|
self._token_pooler,
|
||||||
) = self._load_model(
|
) = self._load_model(
|
||||||
self.model_name,
|
self.model_name,
|
||||||
self.dtype,
|
dtype,
|
||||||
self.device,
|
device,
|
||||||
self.use_token_pooling,
|
self.pooling_strategy,
|
||||||
|
self.pooling_func,
|
||||||
self.quantization_config,
|
self.quantization_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -74,16 +121,26 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
device: str,
|
device: str,
|
||||||
use_token_pooling: bool,
|
pooling_strategy: Optional[str],
|
||||||
|
pooling_func: Optional[Callable],
|
||||||
quantization_config: Optional[Any],
|
quantization_config: Optional[Any],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize and cache the ColPali model, processor, and token pooler.
|
Initialize and cache the ColPali model, processor, and token pooler.
|
||||||
"""
|
"""
|
||||||
|
if device.startswith("mps"):
|
||||||
|
# warn some torch ops in late interaction architecture result in nans on mps
|
||||||
|
warning(
|
||||||
|
"MPS device detected. Some operations may result in NaNs. "
|
||||||
|
"If you encounter issues, consider using 'cpu' or 'cuda' devices."
|
||||||
|
)
|
||||||
torch = attempt_import_or_raise("torch", "torch")
|
torch = attempt_import_or_raise("torch", "torch")
|
||||||
transformers = attempt_import_or_raise("transformers", "transformers")
|
transformers = attempt_import_or_raise("transformers", "transformers")
|
||||||
colpali_engine = attempt_import_or_raise("colpali_engine", "colpali_engine")
|
colpali_engine = attempt_import_or_raise("colpali_engine", "colpali_engine")
|
||||||
from colpali_engine.compression.token_pooling import HierarchicalTokenPooler
|
from colpali_engine.compression.token_pooling import (
|
||||||
|
HierarchicalTokenPooler,
|
||||||
|
LambdaTokenPooler,
|
||||||
|
)
|
||||||
|
|
||||||
if quantization_config is not None:
|
if quantization_config is not None:
|
||||||
if not isinstance(quantization_config, transformers.BitsAndBytesConfig):
|
if not isinstance(quantization_config, transformers.BitsAndBytesConfig):
|
||||||
@@ -98,21 +155,45 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
|||||||
else:
|
else:
|
||||||
torch_dtype = torch.float32
|
torch_dtype = torch.float32
|
||||||
|
|
||||||
model = colpali_engine.models.ColQwen2_5.from_pretrained(
|
model_class, processor_class = None, None
|
||||||
|
model_name_lower = model_name.lower()
|
||||||
|
if "colqwen2.5" in model_name_lower:
|
||||||
|
model_class = colpali_engine.models.ColQwen2_5
|
||||||
|
processor_class = colpali_engine.models.ColQwen2_5_Processor
|
||||||
|
elif "colsmol" in model_name_lower or "colidefics3" in model_name_lower:
|
||||||
|
model_class = colpali_engine.models.ColIdefics3
|
||||||
|
processor_class = colpali_engine.models.ColIdefics3Processor
|
||||||
|
elif "colqwen" in model_name_lower:
|
||||||
|
model_class = colpali_engine.models.ColQwen2
|
||||||
|
processor_class = colpali_engine.models.ColQwen2Processor
|
||||||
|
elif "colpali" in model_name_lower:
|
||||||
|
model_class = colpali_engine.models.ColPali
|
||||||
|
processor_class = colpali_engine.models.ColPaliProcessor
|
||||||
|
|
||||||
|
if model_class is None:
|
||||||
|
raise ValueError(f"Unsupported model: {model_name}")
|
||||||
|
|
||||||
|
model = model_class.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
device_map=device,
|
|
||||||
quantization_config=quantization_config
|
quantization_config=quantization_config
|
||||||
if quantization_config is not None
|
if quantization_config is not None
|
||||||
else None,
|
else None,
|
||||||
attn_implementation="flash_attention_2"
|
attn_implementation="flash_attention_2"
|
||||||
if is_flash_attn_2_available()
|
if is_flash_attn_2_available()
|
||||||
else None,
|
else None,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
).eval()
|
).eval()
|
||||||
processor = colpali_engine.models.ColQwen2_5_Processor.from_pretrained(
|
model = model.to(device)
|
||||||
model_name
|
model = model.to(torch_dtype) # Force cast after moving to device
|
||||||
)
|
processor = processor_class.from_pretrained(model_name)
|
||||||
token_pooler = HierarchicalTokenPooler() if use_token_pooling else None
|
|
||||||
|
token_pooler = None
|
||||||
|
if pooling_strategy == "hierarchical":
|
||||||
|
token_pooler = HierarchicalTokenPooler()
|
||||||
|
elif pooling_strategy == "lambda":
|
||||||
|
token_pooler = LambdaTokenPooler(pool_func=pooling_func)
|
||||||
|
|
||||||
return model, processor, token_pooler
|
return model, processor, token_pooler
|
||||||
|
|
||||||
def ndims(self):
|
def ndims(self):
|
||||||
@@ -128,7 +209,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
query_embeddings = self._model(**batch_queries)
|
query_embeddings = self._model(**batch_queries)
|
||||||
|
|
||||||
if self.use_token_pooling and self._token_pooler is not None:
|
if self.pooling_strategy and self._token_pooler is not None:
|
||||||
query_embeddings = self._token_pooler.pool_embeddings(
|
query_embeddings = self._token_pooler.pool_embeddings(
|
||||||
query_embeddings,
|
query_embeddings,
|
||||||
pool_factor=self.pool_factor,
|
pool_factor=self.pool_factor,
|
||||||
@@ -145,13 +226,20 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
|||||||
Use token pooling if enabled.
|
Use token pooling if enabled.
|
||||||
"""
|
"""
|
||||||
torch = attempt_import_or_raise("torch", "torch")
|
torch = attempt_import_or_raise("torch", "torch")
|
||||||
if self.use_token_pooling and self._token_pooler is not None:
|
if self.pooling_strategy and self._token_pooler is not None:
|
||||||
embeddings = self._token_pooler.pool_embeddings(
|
if self.pooling_strategy == "hierarchical":
|
||||||
embeddings,
|
embeddings = self._token_pooler.pool_embeddings(
|
||||||
pool_factor=self.pool_factor,
|
embeddings,
|
||||||
padding=True,
|
pool_factor=self.pool_factor,
|
||||||
padding_side=self._processor.tokenizer.padding_side,
|
padding=True,
|
||||||
)
|
padding_side=self._processor.tokenizer.padding_side,
|
||||||
|
)
|
||||||
|
elif self.pooling_strategy == "lambda":
|
||||||
|
embeddings = self._token_pooler.pool_embeddings(
|
||||||
|
embeddings,
|
||||||
|
padding=True,
|
||||||
|
padding_side=self._processor.tokenizer.padding_side,
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(embeddings, torch.Tensor):
|
if isinstance(embeddings, torch.Tensor):
|
||||||
tensors = embeddings.detach().cpu()
|
tensors = embeddings.detach().cpu()
|
||||||
@@ -179,6 +267,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
|||||||
)
|
)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
query_embeddings = self._model(**batch_queries)
|
query_embeddings = self._model(**batch_queries)
|
||||||
|
query_embeddings = torch.nan_to_num(query_embeddings)
|
||||||
all_embeddings.extend(self._process_embeddings(query_embeddings))
|
all_embeddings.extend(self._process_embeddings(query_embeddings))
|
||||||
return all_embeddings
|
return all_embeddings
|
||||||
|
|
||||||
@@ -225,6 +314,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
|||||||
)
|
)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
image_embeddings = self._model(**batch_images)
|
image_embeddings = self._model(**batch_images)
|
||||||
|
image_embeddings = torch.nan_to_num(image_embeddings)
|
||||||
all_embeddings.extend(self._process_embeddings(image_embeddings))
|
all_embeddings.extend(self._process_embeddings(image_embeddings))
|
||||||
return all_embeddings
|
return all_embeddings
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
import base64
|
import base64
|
||||||
import os
|
import os
|
||||||
from typing import ClassVar, TYPE_CHECKING, List, Union, Any
|
from typing import ClassVar, TYPE_CHECKING, List, Union, Any, Generator
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
@@ -19,6 +19,23 @@ from .utils import api_key_not_found_help, IMAGES, TEXT
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import PIL
|
import PIL
|
||||||
|
|
||||||
|
# Token limits for different VoyageAI models
|
||||||
|
VOYAGE_TOTAL_TOKEN_LIMITS = {
|
||||||
|
"voyage-context-3": 32_000,
|
||||||
|
"voyage-3.5-lite": 1_000_000,
|
||||||
|
"voyage-3.5": 320_000,
|
||||||
|
"voyage-3-lite": 120_000,
|
||||||
|
"voyage-3": 120_000,
|
||||||
|
"voyage-multimodal-3": 120_000,
|
||||||
|
"voyage-finance-2": 120_000,
|
||||||
|
"voyage-multilingual-2": 120_000,
|
||||||
|
"voyage-law-2": 120_000,
|
||||||
|
"voyage-code-2": 120_000,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Batch size for embedding requests (max number of items per batch)
|
||||||
|
BATCH_SIZE = 1000
|
||||||
|
|
||||||
|
|
||||||
def is_valid_url(text):
|
def is_valid_url(text):
|
||||||
try:
|
try:
|
||||||
@@ -120,6 +137,9 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
|||||||
name: str
|
name: str
|
||||||
The name of the model to use. List of acceptable models:
|
The name of the model to use. List of acceptable models:
|
||||||
|
|
||||||
|
* voyage-context-3
|
||||||
|
* voyage-3.5
|
||||||
|
* voyage-3.5-lite
|
||||||
* voyage-3
|
* voyage-3
|
||||||
* voyage-3-lite
|
* voyage-3-lite
|
||||||
* voyage-multimodal-3
|
* voyage-multimodal-3
|
||||||
@@ -157,25 +177,35 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
|||||||
name: str
|
name: str
|
||||||
client: ClassVar = None
|
client: ClassVar = None
|
||||||
text_embedding_models: list = [
|
text_embedding_models: list = [
|
||||||
|
"voyage-3.5",
|
||||||
|
"voyage-3.5-lite",
|
||||||
"voyage-3",
|
"voyage-3",
|
||||||
"voyage-3-lite",
|
"voyage-3-lite",
|
||||||
"voyage-finance-2",
|
"voyage-finance-2",
|
||||||
|
"voyage-multilingual-2",
|
||||||
"voyage-law-2",
|
"voyage-law-2",
|
||||||
"voyage-code-2",
|
"voyage-code-2",
|
||||||
]
|
]
|
||||||
multimodal_embedding_models: list = ["voyage-multimodal-3"]
|
multimodal_embedding_models: list = ["voyage-multimodal-3"]
|
||||||
|
contextual_embedding_models: list = ["voyage-context-3"]
|
||||||
|
|
||||||
def _is_multimodal_model(self, model_name: str):
|
def _is_multimodal_model(self, model_name: str):
|
||||||
return (
|
return (
|
||||||
model_name in self.multimodal_embedding_models or "multimodal" in model_name
|
model_name in self.multimodal_embedding_models or "multimodal" in model_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _is_contextual_model(self, model_name: str):
|
||||||
|
return model_name in self.contextual_embedding_models or "context" in model_name
|
||||||
|
|
||||||
def ndims(self):
|
def ndims(self):
|
||||||
if self.name == "voyage-3-lite":
|
if self.name == "voyage-3-lite":
|
||||||
return 512
|
return 512
|
||||||
elif self.name == "voyage-code-2":
|
elif self.name == "voyage-code-2":
|
||||||
return 1536
|
return 1536
|
||||||
elif self.name in [
|
elif self.name in [
|
||||||
|
"voyage-context-3",
|
||||||
|
"voyage-3.5",
|
||||||
|
"voyage-3.5-lite",
|
||||||
"voyage-3",
|
"voyage-3",
|
||||||
"voyage-multimodal-3",
|
"voyage-multimodal-3",
|
||||||
"voyage-finance-2",
|
"voyage-finance-2",
|
||||||
@@ -207,6 +237,11 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
|||||||
result = client.multimodal_embed(
|
result = client.multimodal_embed(
|
||||||
inputs=[[query]], model=self.name, input_type="query", **kwargs
|
inputs=[[query]], model=self.name, input_type="query", **kwargs
|
||||||
)
|
)
|
||||||
|
elif self._is_contextual_model(self.name):
|
||||||
|
result = client.contextualized_embed(
|
||||||
|
inputs=[[query]], model=self.name, input_type="query", **kwargs
|
||||||
|
)
|
||||||
|
result = result.results[0]
|
||||||
else:
|
else:
|
||||||
result = client.embed(
|
result = client.embed(
|
||||||
texts=[query], model=self.name, input_type="query", **kwargs
|
texts=[query], model=self.name, input_type="query", **kwargs
|
||||||
@@ -231,18 +266,164 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
|||||||
List[np.array]: the list of embeddings
|
List[np.array]: the list of embeddings
|
||||||
"""
|
"""
|
||||||
client = VoyageAIEmbeddingFunction._get_client()
|
client = VoyageAIEmbeddingFunction._get_client()
|
||||||
|
|
||||||
|
# For multimodal models, check if inputs contain images
|
||||||
if self._is_multimodal_model(self.name):
|
if self._is_multimodal_model(self.name):
|
||||||
inputs = sanitize_multimodal_input(inputs)
|
sanitized = sanitize_multimodal_input(inputs)
|
||||||
result = client.multimodal_embed(
|
has_images = any(
|
||||||
inputs=inputs, model=self.name, input_type="document", **kwargs
|
inp["content"][0].get("type") != "text" for inp in sanitized
|
||||||
)
|
)
|
||||||
|
if has_images:
|
||||||
|
# Use non-batched API for images
|
||||||
|
result = client.multimodal_embed(
|
||||||
|
inputs=sanitized, model=self.name, input_type="document", **kwargs
|
||||||
|
)
|
||||||
|
return result.embeddings
|
||||||
|
# Extract texts for batching
|
||||||
|
inputs = [inp["content"][0]["text"] for inp in sanitized]
|
||||||
else:
|
else:
|
||||||
inputs = sanitize_text_input(inputs)
|
inputs = sanitize_text_input(inputs)
|
||||||
result = client.embed(
|
|
||||||
texts=inputs, model=self.name, input_type="document", **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
return result.embeddings
|
# Use batching for all text inputs
|
||||||
|
return self._embed_with_batching(
|
||||||
|
client, inputs, input_type="document", **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_batches(
|
||||||
|
self, client, texts: List[str]
|
||||||
|
) -> Generator[List[str], None, None]:
|
||||||
|
"""
|
||||||
|
Generate batches of texts based on token limits using a generator.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
client : voyageai.Client
|
||||||
|
The VoyageAI client instance.
|
||||||
|
texts : List[str]
|
||||||
|
List of texts to batch.
|
||||||
|
|
||||||
|
Yields
|
||||||
|
------
|
||||||
|
List[str]: Batches of texts.
|
||||||
|
"""
|
||||||
|
if not texts:
|
||||||
|
return
|
||||||
|
|
||||||
|
max_tokens_per_batch = VOYAGE_TOTAL_TOKEN_LIMITS.get(self.name, 120_000)
|
||||||
|
current_batch: List[str] = []
|
||||||
|
current_batch_tokens = 0
|
||||||
|
|
||||||
|
# Tokenize all texts in one API call
|
||||||
|
token_lists = client.tokenize(texts, model=self.name)
|
||||||
|
token_counts = [len(token_list) for token_list in token_lists]
|
||||||
|
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
n_tokens = token_counts[i]
|
||||||
|
|
||||||
|
# Check if adding this text would exceed limits
|
||||||
|
if current_batch and (
|
||||||
|
len(current_batch) >= BATCH_SIZE
|
||||||
|
or (current_batch_tokens + n_tokens > max_tokens_per_batch)
|
||||||
|
):
|
||||||
|
# Yield the current batch and start a new one
|
||||||
|
yield current_batch
|
||||||
|
current_batch = []
|
||||||
|
current_batch_tokens = 0
|
||||||
|
|
||||||
|
current_batch.append(text)
|
||||||
|
current_batch_tokens += n_tokens
|
||||||
|
|
||||||
|
# Yield the last batch (always has at least one text)
|
||||||
|
if current_batch:
|
||||||
|
yield current_batch
|
||||||
|
|
||||||
|
def _get_embed_function(
|
||||||
|
self, client, input_type: str = "document", **kwargs
|
||||||
|
) -> callable:
|
||||||
|
"""
|
||||||
|
Get the appropriate embedding function based on model type.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
client : voyageai.Client
|
||||||
|
The VoyageAI client instance.
|
||||||
|
input_type : str
|
||||||
|
Either "query" or "document"
|
||||||
|
**kwargs
|
||||||
|
Additional arguments to pass to the embedding API
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
callable: A function that takes a batch of texts and returns embeddings.
|
||||||
|
"""
|
||||||
|
if self._is_multimodal_model(self.name):
|
||||||
|
|
||||||
|
def embed_batch(batch: List[str]) -> List[np.array]:
|
||||||
|
batch_inputs = sanitize_multimodal_input(batch)
|
||||||
|
result = client.multimodal_embed(
|
||||||
|
inputs=batch_inputs,
|
||||||
|
model=self.name,
|
||||||
|
input_type=input_type,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return result.embeddings
|
||||||
|
|
||||||
|
return embed_batch
|
||||||
|
|
||||||
|
elif self._is_contextual_model(self.name):
|
||||||
|
|
||||||
|
def embed_batch(batch: List[str]) -> List[np.array]:
|
||||||
|
result = client.contextualized_embed(
|
||||||
|
inputs=[batch], model=self.name, input_type=input_type, **kwargs
|
||||||
|
)
|
||||||
|
return result.results[0].embeddings
|
||||||
|
|
||||||
|
return embed_batch
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def embed_batch(batch: List[str]) -> List[np.array]:
|
||||||
|
result = client.embed(
|
||||||
|
texts=batch, model=self.name, input_type=input_type, **kwargs
|
||||||
|
)
|
||||||
|
return result.embeddings
|
||||||
|
|
||||||
|
return embed_batch
|
||||||
|
|
||||||
|
def _embed_with_batching(
|
||||||
|
self, client, texts: List[str], input_type: str = "document", **kwargs
|
||||||
|
) -> List[np.array]:
|
||||||
|
"""
|
||||||
|
Embed texts with automatic batching based on token limits.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
client : voyageai.Client
|
||||||
|
The VoyageAI client instance.
|
||||||
|
texts : List[str]
|
||||||
|
List of texts to embed.
|
||||||
|
input_type : str
|
||||||
|
Either "query" or "document"
|
||||||
|
**kwargs
|
||||||
|
Additional arguments to pass to the embedding API
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[np.array]: List of embeddings.
|
||||||
|
"""
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Get the appropriate embedding function for this model type
|
||||||
|
embed_fn = self._get_embed_function(client, input_type=input_type, **kwargs)
|
||||||
|
|
||||||
|
# Process each batch
|
||||||
|
all_embeddings = []
|
||||||
|
for batch in self._build_batches(client, texts):
|
||||||
|
batch_embeddings = embed_fn(batch)
|
||||||
|
all_embeddings.extend(batch_embeddings)
|
||||||
|
|
||||||
|
return all_embeddings
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_client():
|
def _get_client():
|
||||||
|
|||||||
@@ -605,9 +605,53 @@ class IvfPq:
|
|||||||
target_partition_size: Optional[int] = None
|
target_partition_size: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IvfRq:
|
||||||
|
"""Describes an IVF RQ Index
|
||||||
|
|
||||||
|
IVF-RQ (Residual Quantization) stores a compressed copy of each vector using
|
||||||
|
residual quantization and organizes them into IVF partitions. Parameters
|
||||||
|
largely mirror IVF-PQ for consistency.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
distance_type: str, default "l2"
|
||||||
|
Distance metric used to train the index and for quantization.
|
||||||
|
|
||||||
|
The following distance types are available:
|
||||||
|
|
||||||
|
"l2" - Euclidean distance.
|
||||||
|
"cosine" - Cosine distance.
|
||||||
|
"dot" - Dot product.
|
||||||
|
|
||||||
|
num_partitions: int, default sqrt(num_rows)
|
||||||
|
Number of IVF partitions to create.
|
||||||
|
|
||||||
|
num_bits: int, default 1
|
||||||
|
Number of bits to encode each dimension.
|
||||||
|
|
||||||
|
max_iterations: int, default 50
|
||||||
|
Max iterations to train kmeans when computing IVF partitions.
|
||||||
|
|
||||||
|
sample_rate: int, default 256
|
||||||
|
Controls the number of training vectors: sample_rate * num_partitions.
|
||||||
|
|
||||||
|
target_partition_size, default is 8192
|
||||||
|
Target size of each partition.
|
||||||
|
"""
|
||||||
|
|
||||||
|
distance_type: Literal["l2", "cosine", "dot"] = "l2"
|
||||||
|
num_partitions: Optional[int] = None
|
||||||
|
num_bits: int = 1
|
||||||
|
max_iterations: int = 50
|
||||||
|
sample_rate: int = 256
|
||||||
|
target_partition_size: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BTree",
|
"BTree",
|
||||||
"IvfPq",
|
"IvfPq",
|
||||||
|
"IvfRq",
|
||||||
"IvfFlat",
|
"IvfFlat",
|
||||||
"HnswPq",
|
"HnswPq",
|
||||||
"HnswSq",
|
"HnswSq",
|
||||||
|
|||||||
821
python/python/lancedb/permutation.py
Normal file
821
python/python/lancedb/permutation.py
Normal file
@@ -0,0 +1,821 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
from deprecation import deprecated
|
||||||
|
from lancedb import AsyncConnection, DBConnection
|
||||||
|
import pyarrow as pa
|
||||||
|
import json
|
||||||
|
|
||||||
|
from ._lancedb import async_permutation_builder, PermutationReader
|
||||||
|
from .table import LanceTable
|
||||||
|
from .background_loop import LOOP
|
||||||
|
from .util import batch_to_tensor
|
||||||
|
from typing import Any, Callable, Iterator, Literal, Optional, TYPE_CHECKING, Union
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from lancedb.dependencies import pandas as pd, numpy as np, polars as pl
|
||||||
|
|
||||||
|
|
||||||
|
class PermutationBuilder:
|
||||||
|
"""
|
||||||
|
A utility for creating a "permutation table" which is a table that defines an
|
||||||
|
ordering on a base table.
|
||||||
|
|
||||||
|
The permutation table does not store the actual data. It only stores row
|
||||||
|
ids and split ids to define the ordering. The [Permutation] class can be used to
|
||||||
|
read the data from the base table in the order defined by the permutation table.
|
||||||
|
|
||||||
|
Permutations can split, shuffle, and filter the data in the base table.
|
||||||
|
|
||||||
|
A filter limits the rows that are included in the permutation.
|
||||||
|
Splits divide the data into subsets (for example, a test/train split, or K
|
||||||
|
different splits for cross-validation).
|
||||||
|
Shuffling randomizes the order of the rows in the permutation.
|
||||||
|
|
||||||
|
Splits can optionally be named. If names are provided it will enable them to
|
||||||
|
be referenced by name in the future. If names are not provided then they can only
|
||||||
|
be referenced by their ordinal index. There is no requirement to name every split.
|
||||||
|
|
||||||
|
By default, the permutation will be stored in memory and will be lost when the
|
||||||
|
program exits. To persist the permutation (for very large datasets or to share
|
||||||
|
the permutation across multiple workers) use the [persist](#persist) method to
|
||||||
|
create a permanent table.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, table: LanceTable):
|
||||||
|
"""
|
||||||
|
Creates a new permutation builder for the given table.
|
||||||
|
|
||||||
|
By default, the permutation builder will create a single split that contains all
|
||||||
|
rows in the same order as the base table.
|
||||||
|
"""
|
||||||
|
self._async = async_permutation_builder(table)
|
||||||
|
|
||||||
|
def persist(
|
||||||
|
self, database: Union[DBConnection, AsyncConnection], table_name: str
|
||||||
|
) -> "PermutationBuilder":
|
||||||
|
"""
|
||||||
|
Persist the permutation to the given database.
|
||||||
|
"""
|
||||||
|
self._async.persist(database, table_name)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def split_random(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ratios: Optional[list[float]] = None,
|
||||||
|
counts: Optional[list[int]] = None,
|
||||||
|
fixed: Optional[int] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
split_names: Optional[list[str]] = None,
|
||||||
|
) -> "PermutationBuilder":
|
||||||
|
"""
|
||||||
|
Configure random splits for the permutation.
|
||||||
|
|
||||||
|
One of ratios, counts, or fixed must be provided.
|
||||||
|
|
||||||
|
If ratios are provided, they will be used to determine the relative size of each
|
||||||
|
split. For example, if ratios are [0.3, 0.7] then the first split will contain
|
||||||
|
30% of the rows and the second split will contain 70% of the rows.
|
||||||
|
|
||||||
|
If counts are provided, they will be used to determine the absolute number of
|
||||||
|
rows in each split. For example, if counts are [100, 200] then the first split
|
||||||
|
will contain 100 rows and the second split will contain 200 rows.
|
||||||
|
|
||||||
|
If fixed is provided, it will be used to determine the number of splits.
|
||||||
|
For example, if fixed is 3 then the permutation will be split evenly into 3
|
||||||
|
splits.
|
||||||
|
|
||||||
|
Rows will be randomly assigned to splits. The optional seed can be provided to
|
||||||
|
make the assignment deterministic.
|
||||||
|
|
||||||
|
The optional split_names can be provided to name the splits. If not provided,
|
||||||
|
the splits can only be referenced by their index.
|
||||||
|
"""
|
||||||
|
self._async.split_random(
|
||||||
|
ratios=ratios,
|
||||||
|
counts=counts,
|
||||||
|
fixed=fixed,
|
||||||
|
seed=seed,
|
||||||
|
split_names=split_names,
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def split_hash(
|
||||||
|
self,
|
||||||
|
columns: list[str],
|
||||||
|
split_weights: list[int],
|
||||||
|
*,
|
||||||
|
discard_weight: Optional[int] = None,
|
||||||
|
split_names: Optional[list[str]] = None,
|
||||||
|
) -> "PermutationBuilder":
|
||||||
|
"""
|
||||||
|
Configure hash-based splits for the permutation.
|
||||||
|
|
||||||
|
First, a hash will be calculated over the specified columns. The splits weights
|
||||||
|
are then used to determine how many rows to assign to each split. For example,
|
||||||
|
if split weights are [1, 2] then the first split will contain 1/3 of the rows
|
||||||
|
and the second split will contain 2/3 of the rows.
|
||||||
|
|
||||||
|
The optional discard weight can be provided to determine what percentage of rows
|
||||||
|
should be discarded. For example, if split weights are [1, 2] and discard
|
||||||
|
weight is 1 then 25% of the rows will be discarded.
|
||||||
|
|
||||||
|
Hash-based splits are useful if you want the split to be more or less random but
|
||||||
|
you don't want the split assignments to change if rows are added or removed
|
||||||
|
from the table.
|
||||||
|
|
||||||
|
The optional split_names can be provided to name the splits. If not provided,
|
||||||
|
the splits can only be referenced by their index.
|
||||||
|
"""
|
||||||
|
self._async.split_hash(
|
||||||
|
columns,
|
||||||
|
split_weights,
|
||||||
|
discard_weight=discard_weight,
|
||||||
|
split_names=split_names,
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def split_sequential(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ratios: Optional[list[float]] = None,
|
||||||
|
counts: Optional[list[int]] = None,
|
||||||
|
fixed: Optional[int] = None,
|
||||||
|
split_names: Optional[list[str]] = None,
|
||||||
|
) -> "PermutationBuilder":
|
||||||
|
"""
|
||||||
|
Configure sequential splits for the permutation.
|
||||||
|
|
||||||
|
One of ratios, counts, or fixed must be provided.
|
||||||
|
|
||||||
|
If ratios are provided, they will be used to determine the relative size of each
|
||||||
|
split. For example, if ratios are [0.3, 0.7] then the first split will contain
|
||||||
|
30% of the rows and the second split will contain 70% of the rows.
|
||||||
|
|
||||||
|
If counts are provided, they will be used to determine the absolute number of
|
||||||
|
rows in each split. For example, if counts are [100, 200] then the first split
|
||||||
|
will contain 100 rows and the second split will contain 200 rows.
|
||||||
|
|
||||||
|
If fixed is provided, it will be used to determine the number of splits.
|
||||||
|
For example, if fixed is 3 then the permutation will be split evenly into 3
|
||||||
|
splits.
|
||||||
|
|
||||||
|
Rows will be assigned to splits sequentially. The first N1 rows are assigned to
|
||||||
|
split 1, the next N2 rows are assigned to split 2, etc.
|
||||||
|
|
||||||
|
The optional split_names can be provided to name the splits. If not provided,
|
||||||
|
the splits can only be referenced by their index.
|
||||||
|
"""
|
||||||
|
self._async.split_sequential(
|
||||||
|
ratios=ratios, counts=counts, fixed=fixed, split_names=split_names
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def split_calculated(
|
||||||
|
self, calculation: str, split_names: Optional[list[str]] = None
|
||||||
|
) -> "PermutationBuilder":
|
||||||
|
"""
|
||||||
|
Use pre-calculated splits for the permutation.
|
||||||
|
|
||||||
|
The calculation should be an SQL statement that returns an integer value between
|
||||||
|
0 and the number of splits - 1. For example, if you have 3 splits then the
|
||||||
|
calculation should return 0 for the first split, 1 for the second split, and 2
|
||||||
|
for the third split.
|
||||||
|
|
||||||
|
This can be used to implement any kind of user-defined split strategy.
|
||||||
|
|
||||||
|
The optional split_names can be provided to name the splits. If not provided,
|
||||||
|
the splits can only be referenced by their index.
|
||||||
|
"""
|
||||||
|
self._async.split_calculated(calculation, split_names=split_names)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def shuffle(
|
||||||
|
self, *, seed: Optional[int] = None, clump_size: Optional[int] = None
|
||||||
|
) -> "PermutationBuilder":
|
||||||
|
"""
|
||||||
|
Randomly shuffle the rows in the permutation.
|
||||||
|
|
||||||
|
An optional seed can be provided to make the shuffle deterministic.
|
||||||
|
|
||||||
|
If a clump size is provided, then data will be shuffled as small "clumps"
|
||||||
|
of contiguous rows. This allows for a balance between randomization and
|
||||||
|
I/O performance. It can be useful when reading from cloud storage.
|
||||||
|
"""
|
||||||
|
self._async.shuffle(seed=seed, clump_size=clump_size)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def filter(self, filter: str) -> "PermutationBuilder":
|
||||||
|
"""
|
||||||
|
Configure a filter for the permutation.
|
||||||
|
|
||||||
|
The filter should be an SQL statement that returns a boolean value for each row.
|
||||||
|
Only rows where the filter is true will be included in the permutation.
|
||||||
|
"""
|
||||||
|
self._async.filter(filter)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def execute(self) -> LanceTable:
|
||||||
|
"""
|
||||||
|
Execute the configuration and create the permutation table.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def do_execute():
|
||||||
|
inner_tbl = await self._async.execute()
|
||||||
|
return LanceTable.from_inner(inner_tbl)
|
||||||
|
|
||||||
|
return LOOP.run(do_execute())
|
||||||
|
|
||||||
|
|
||||||
|
def permutation_builder(table: LanceTable) -> PermutationBuilder:
|
||||||
|
return PermutationBuilder(table)
|
||||||
|
|
||||||
|
|
||||||
|
class Permutations:
|
||||||
|
"""
|
||||||
|
A collection of permutations indexed by name or ordinal index.
|
||||||
|
|
||||||
|
Splits are defined when the permutation is created. Splits can always be referenced
|
||||||
|
by their ordinal index. If names were provided when the permutation was created
|
||||||
|
then they can also be referenced by name.
|
||||||
|
|
||||||
|
Each permutation or "split" is a view of a portion of the base table. For more
|
||||||
|
details see [Permutation].
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
base_table: LanceTable
|
||||||
|
The base table that the permutations are based on.
|
||||||
|
permutation_table: LanceTable
|
||||||
|
The permutation table that defines the splits.
|
||||||
|
split_names: list[str]
|
||||||
|
The names of the splits.
|
||||||
|
split_dict: dict[str, int]
|
||||||
|
A dictionary mapping split names to their ordinal index.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> # Initial data
|
||||||
|
>>> import lancedb
|
||||||
|
>>> db = lancedb.connect("memory:///")
|
||||||
|
>>> tbl = db.create_table("tbl", data=[{"x": x} for x in range(1000)])
|
||||||
|
>>> # Create a permutation
|
||||||
|
>>> perm_tbl = (
|
||||||
|
... permutation_builder(tbl)
|
||||||
|
... .split_random(ratios=[0.95, 0.05], split_names=["train", "test"])
|
||||||
|
... .shuffle()
|
||||||
|
... .execute()
|
||||||
|
... )
|
||||||
|
>>> # Read the permutations
|
||||||
|
>>> permutations = Permutations(tbl, perm_tbl)
|
||||||
|
>>> permutations["train"]
|
||||||
|
<lancedb.permutation.Permutation ...>
|
||||||
|
>>> permutations[0]
|
||||||
|
<lancedb.permutation.Permutation ...>
|
||||||
|
>>> permutations.split_names
|
||||||
|
['train', 'test']
|
||||||
|
>>> permutations.split_dict
|
||||||
|
{'train': 0, 'test': 1}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, base_table: LanceTable, permutation_table: LanceTable):
|
||||||
|
self.base_table = base_table
|
||||||
|
self.permutation_table = permutation_table
|
||||||
|
|
||||||
|
if permutation_table.schema.metadata is not None:
|
||||||
|
split_names = permutation_table.schema.metadata.get(
|
||||||
|
b"split_names", None
|
||||||
|
).decode("utf-8")
|
||||||
|
if split_names is not None:
|
||||||
|
self.split_names = json.loads(split_names)
|
||||||
|
self.split_dict = {
|
||||||
|
name: idx for idx, name in enumerate(self.split_names)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# No split names are defined in the permutation table
|
||||||
|
self.split_names = []
|
||||||
|
self.split_dict = {}
|
||||||
|
else:
|
||||||
|
# No metadata is defined in the permutation table
|
||||||
|
self.split_names = []
|
||||||
|
self.split_dict = {}
|
||||||
|
|
||||||
|
def get_by_name(self, name: str) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Get a permutation by name.
|
||||||
|
|
||||||
|
If no split named `name` is found then an error will be raised.
|
||||||
|
"""
|
||||||
|
idx = self.split_dict.get(name, None)
|
||||||
|
if idx is None:
|
||||||
|
raise ValueError(f"No split named `{name}` found")
|
||||||
|
return self.get_by_index(idx)
|
||||||
|
|
||||||
|
def get_by_index(self, index: int) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Get a permutation by index.
|
||||||
|
"""
|
||||||
|
return Permutation.from_tables(self.base_table, self.permutation_table, index)
|
||||||
|
|
||||||
|
def __getitem__(self, name: Union[str, int]) -> "Permutation":
|
||||||
|
if isinstance(name, str):
|
||||||
|
return self.get_by_name(name)
|
||||||
|
elif isinstance(name, int):
|
||||||
|
return self.get_by_index(name)
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Invalid split name or index: {name}")
|
||||||
|
|
||||||
|
|
||||||
|
class Transforms:
|
||||||
|
"""
|
||||||
|
Namespace for common transformation functions
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def arrow2python(batch: pa.RecordBatch) -> dict[str, list[Any]]:
|
||||||
|
return batch.to_pydict()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def arrow2arrow(batch: pa.RecordBatch) -> pa.RecordBatch:
|
||||||
|
return batch
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def arrow2numpy(batch: pa.RecordBatch) -> "np.ndarray":
|
||||||
|
return batch.to_pandas().to_numpy()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def arrow2pandas(batch: pa.RecordBatch) -> "pd.DataFrame":
|
||||||
|
return batch.to_pandas()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def arrow2polars() -> "pl.DataFrame":
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
def impl(batch: pa.RecordBatch) -> pl.DataFrame:
|
||||||
|
return pl.from_arrow(batch)
|
||||||
|
|
||||||
|
return impl
|
||||||
|
|
||||||
|
|
||||||
|
# HuggingFace uses 10 which is pretty small
|
||||||
|
DEFAULT_BATCH_SIZE = 100
|
||||||
|
|
||||||
|
|
||||||
|
class Permutation:
|
||||||
|
"""
|
||||||
|
A Permutation is a view of a dataset that can be used as input to model training
|
||||||
|
and evaluation.
|
||||||
|
|
||||||
|
A Permutation fulfills the pytorch Dataset contract and is loosely modeled after the
|
||||||
|
huggingface Dataset so it should be easy to use with existing code.
|
||||||
|
|
||||||
|
A permutation is not a "materialized view" or copy of the underlying data. It is
|
||||||
|
calculated on the fly from the base table. As a result, it is truly "lazy" and does
|
||||||
|
not require materializing the entire dataset in memory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
reader: PermutationReader,
|
||||||
|
selection: dict[str, str],
|
||||||
|
batch_size: int,
|
||||||
|
transform_fn: Callable[pa.RecordBatch, Any],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Internal constructor. Use [from_tables](#from_tables) instead.
|
||||||
|
"""
|
||||||
|
assert reader is not None, "reader is required"
|
||||||
|
assert selection is not None, "selection is required"
|
||||||
|
self.reader = reader
|
||||||
|
self.selection = selection
|
||||||
|
self.transform_fn = transform_fn
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
def _with_selection(self, selection: dict[str, str]) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Creates a new permutation with the given selection
|
||||||
|
|
||||||
|
Does not validation of the selection and it replaces it entirely. This is not
|
||||||
|
intended for public use.
|
||||||
|
"""
|
||||||
|
return Permutation(self.reader, selection, self.batch_size, self.transform_fn)
|
||||||
|
|
||||||
|
def _with_reader(self, reader: PermutationReader) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Creates a new permutation with the given reader
|
||||||
|
|
||||||
|
This is an internal method and should not be used directly.
|
||||||
|
"""
|
||||||
|
return Permutation(reader, self.selection, self.batch_size, self.transform_fn)
|
||||||
|
|
||||||
|
def with_batch_size(self, batch_size: int) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Creates a new permutation with the given batch size
|
||||||
|
"""
|
||||||
|
return Permutation(self.reader, self.selection, batch_size, self.transform_fn)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def identity(cls, table: LanceTable) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Creates an identity permutation for the given table.
|
||||||
|
"""
|
||||||
|
return Permutation.from_tables(table, None, None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_tables(
|
||||||
|
cls,
|
||||||
|
base_table: LanceTable,
|
||||||
|
permutation_table: Optional[LanceTable] = None,
|
||||||
|
split: Optional[Union[str, int]] = None,
|
||||||
|
) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Creates a permutation from the given base table and permutation table.
|
||||||
|
|
||||||
|
A permutation table identifies which rows, and in what order, the data should
|
||||||
|
be read from the base table. For more details see the [PermutationBuilder]
|
||||||
|
class.
|
||||||
|
|
||||||
|
If no permutation table is provided, then the identity permutation will be
|
||||||
|
created. An identity permutation is a permutation that reads all rows in the
|
||||||
|
base table in the order they are stored.
|
||||||
|
|
||||||
|
The split parameter identifies which split to use. If no split is provided
|
||||||
|
then the first split will be used.
|
||||||
|
"""
|
||||||
|
assert base_table is not None, "base_table is required"
|
||||||
|
if split is not None:
|
||||||
|
if permutation_table is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot create a permutation on split `{split}`"
|
||||||
|
" because no permutation table is provided"
|
||||||
|
)
|
||||||
|
if isinstance(split, str):
|
||||||
|
if permutation_table.schema.metadata is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot create a permutation on split `{split}`"
|
||||||
|
" because no split names are defined in the permutation table"
|
||||||
|
)
|
||||||
|
split_names = permutation_table.schema.metadata.get(
|
||||||
|
b"split_names", None
|
||||||
|
).decode("utf-8")
|
||||||
|
if split_names is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot create a permutation on split `{split}`"
|
||||||
|
" because no split names are defined in the permutation table"
|
||||||
|
)
|
||||||
|
split_names = json.loads(split_names)
|
||||||
|
try:
|
||||||
|
split = split_names.index(split)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot create a permutation on split `{split}`"
|
||||||
|
f" because split `{split}` is not defined in the "
|
||||||
|
"permutation table"
|
||||||
|
)
|
||||||
|
elif isinstance(split, int):
|
||||||
|
split = split
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Invalid split: {split}")
|
||||||
|
else:
|
||||||
|
split = 0
|
||||||
|
|
||||||
|
async def do_from_tables():
|
||||||
|
reader = await PermutationReader.from_tables(
|
||||||
|
base_table, permutation_table, split
|
||||||
|
)
|
||||||
|
schema = await reader.output_schema(None)
|
||||||
|
initial_selection = {name: name for name in schema.names}
|
||||||
|
return cls(
|
||||||
|
reader, initial_selection, DEFAULT_BATCH_SIZE, Transforms.arrow2python
|
||||||
|
)
|
||||||
|
|
||||||
|
return LOOP.run(do_from_tables())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def schema(self) -> pa.Schema:
|
||||||
|
async def do_output_schema():
|
||||||
|
return await self.reader.output_schema(self.selection)
|
||||||
|
|
||||||
|
return LOOP.run(do_output_schema())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_columns(self) -> int:
|
||||||
|
"""
|
||||||
|
The number of columns in the permutation
|
||||||
|
"""
|
||||||
|
return len(self.schema)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_rows(self) -> int:
|
||||||
|
"""
|
||||||
|
The number of rows in the permutation
|
||||||
|
"""
|
||||||
|
return self.reader.count_rows()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def column_names(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
The names of the columns in the permutation
|
||||||
|
"""
|
||||||
|
return self.schema.names
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
The shape of the permutation
|
||||||
|
|
||||||
|
This will return self.num_rows, self.num_columns
|
||||||
|
"""
|
||||||
|
return self.num_rows, self.num_columns
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
The number of rows in the permutation
|
||||||
|
|
||||||
|
This is an alias for [num_rows][lancedb.permutation.Permutation.num_rows]
|
||||||
|
"""
|
||||||
|
return self.num_rows
|
||||||
|
|
||||||
|
def unique(self, _column: str) -> list[Any]:
|
||||||
|
"""
|
||||||
|
Get the unique values in the given column
|
||||||
|
"""
|
||||||
|
raise Exception("unique is not yet implemented")
|
||||||
|
|
||||||
|
def flatten(self) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Flatten the permutation
|
||||||
|
|
||||||
|
Each column with a struct type will be flattened into multiple columns.
|
||||||
|
|
||||||
|
This flattening operation happens at read time as a post-processing step
|
||||||
|
so this call is cheap and no data is copied or modified in the underlying
|
||||||
|
dataset.
|
||||||
|
"""
|
||||||
|
raise Exception("flatten is not yet implemented")
|
||||||
|
|
||||||
|
def remove_columns(self, columns: list[str]) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Remove the given columns from the permutation
|
||||||
|
|
||||||
|
Note: this does not actually modify the underlying dataset. It only changes
|
||||||
|
which columns are visible from this permutation. Also, this does not introduce
|
||||||
|
a post-processing step. Instead, we simply do not read those columns in the
|
||||||
|
first place.
|
||||||
|
|
||||||
|
If any of the provided columns does not exist in the current permutation then it
|
||||||
|
will be ignored (no error is raised for missing columns)
|
||||||
|
|
||||||
|
Returns a new permutation with the given columns removed. This does not modify
|
||||||
|
self.
|
||||||
|
"""
|
||||||
|
assert columns is not None, "columns is required"
|
||||||
|
|
||||||
|
new_selection = {
|
||||||
|
name: value for name, value in self.selection.items() if name not in columns
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(new_selection) == 0:
|
||||||
|
raise ValueError("Cannot remove all columns")
|
||||||
|
|
||||||
|
return self._with_selection(new_selection)
|
||||||
|
|
||||||
|
def rename_column(self, old_name: str, new_name: str) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Rename a column in the permutation
|
||||||
|
|
||||||
|
If there is no column named old_name then an error will be raised
|
||||||
|
If there is already a column named new_name then an error will be raised
|
||||||
|
|
||||||
|
Note: this does not actually modify the underlying dataset. It only changes
|
||||||
|
the name of the column that is visible from this permutation. This is a
|
||||||
|
post-processing step but done at the batch level and so it is very cheap.
|
||||||
|
No data will be copied.
|
||||||
|
"""
|
||||||
|
assert old_name is not None, "old_name is required"
|
||||||
|
assert new_name is not None, "new_name is required"
|
||||||
|
if old_name not in self.selection:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot rename column `{old_name}` because it does not exist"
|
||||||
|
)
|
||||||
|
if new_name in self.selection:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot rename column `{old_name}` to `{new_name}` because a column "
|
||||||
|
"with that name already exists"
|
||||||
|
)
|
||||||
|
new_selection = self.selection.copy()
|
||||||
|
new_selection[new_name] = new_selection[old_name]
|
||||||
|
del new_selection[old_name]
|
||||||
|
return self._with_selection(new_selection)
|
||||||
|
|
||||||
|
def rename_columns(self, column_map: dict[str, str]) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Rename the given columns in the permutation
|
||||||
|
|
||||||
|
If any of the columns do not exist then an error will be raised
|
||||||
|
If any of the new names already exist then an error will be raised
|
||||||
|
|
||||||
|
Note: this does not actually modify the underlying dataset. It only changes
|
||||||
|
the name of the column that is visible from this permutation. This is a
|
||||||
|
post-processing step but done at the batch level and so it is very cheap.
|
||||||
|
No data will be copied.
|
||||||
|
"""
|
||||||
|
assert column_map is not None, "column_map is required"
|
||||||
|
|
||||||
|
new_permutation = self
|
||||||
|
for old_name, new_name in column_map.items():
|
||||||
|
new_permutation = new_permutation.rename_column(old_name, new_name)
|
||||||
|
return new_permutation
|
||||||
|
|
||||||
|
def select_columns(self, columns: list[str]) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Select the given columns from the permutation
|
||||||
|
|
||||||
|
This method refines the current selection, potentially removing columns. It
|
||||||
|
will not add back columns that were previously removed.
|
||||||
|
|
||||||
|
If any of the columns do not exist then an error will be raised
|
||||||
|
|
||||||
|
This does not introduce a post-processing step. It simply reduces the amount
|
||||||
|
of data we read.
|
||||||
|
"""
|
||||||
|
assert columns is not None, "columns is required"
|
||||||
|
if len(columns) == 0:
|
||||||
|
raise ValueError("Must select at least one column")
|
||||||
|
|
||||||
|
new_selection = {}
|
||||||
|
for name in columns:
|
||||||
|
value = self.selection.get(name, None)
|
||||||
|
if value is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot select column `{name}` because it does not exist"
|
||||||
|
)
|
||||||
|
new_selection[name] = value
|
||||||
|
return self._with_selection(new_selection)
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Iterate over the permutation
|
||||||
|
"""
|
||||||
|
return self.iter(self.batch_size, skip_last_batch=True)
|
||||||
|
|
||||||
|
def iter(
|
||||||
|
self, batch_size: int, skip_last_batch: bool = False
|
||||||
|
) -> Iterator[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Iterate over the permutation in batches
|
||||||
|
|
||||||
|
If skip_last_batch is True, the last batch will be skipped if it is not a
|
||||||
|
multiple of batch_size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def get_iter():
|
||||||
|
return await self.reader.read(self.selection, batch_size=batch_size)
|
||||||
|
|
||||||
|
async_iter = LOOP.run(get_iter())
|
||||||
|
|
||||||
|
async def get_next():
|
||||||
|
return await async_iter.__anext__()
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
batch = LOOP.run(get_next())
|
||||||
|
if batch.num_rows == batch_size or not skip_last_batch:
|
||||||
|
yield self.transform_fn(batch)
|
||||||
|
except StopAsyncIteration:
|
||||||
|
return
|
||||||
|
|
||||||
|
def with_format(
|
||||||
|
self, format: Literal["numpy", "python", "pandas", "arrow", "torch", "polars"]
|
||||||
|
) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Set the format for batches
|
||||||
|
|
||||||
|
If this method is not called, the "python" format will be used.
|
||||||
|
|
||||||
|
The format can be one of:
|
||||||
|
- "numpy" - the batch will be a dict of numpy arrays (one per column)
|
||||||
|
- "python" - the batch will be a dict of lists (one per column)
|
||||||
|
- "pandas" - the batch will be a pandas DataFrame
|
||||||
|
- "arrow" - the batch will be a pyarrow RecordBatch
|
||||||
|
- "torch" - the batch will be a two dimensional torch tensor
|
||||||
|
- "polars" - the batch will be a polars DataFrame
|
||||||
|
|
||||||
|
Conversion may or may not involve a data copy. Lance uses Arrow internally
|
||||||
|
and so it is able to zero-copy to the arrow and polars.
|
||||||
|
|
||||||
|
Conversion to torch will be zero-copy but will only support a subset of data
|
||||||
|
types (numeric types).
|
||||||
|
|
||||||
|
Conversion to numpy and/or pandas will typically be zero-copy for numeric
|
||||||
|
types. Conversion of strings, lists, and structs will require creating python
|
||||||
|
objects and this is not zero-copy.
|
||||||
|
|
||||||
|
For custom formatting, use [with_transform](#with_transform) which overrides
|
||||||
|
this method.
|
||||||
|
"""
|
||||||
|
assert format is not None, "format is required"
|
||||||
|
if format == "python":
|
||||||
|
return self.with_transform(Transforms.arrow2python)
|
||||||
|
elif format == "numpy":
|
||||||
|
return self.with_transform(Transforms.arrow2numpy)
|
||||||
|
elif format == "pandas":
|
||||||
|
return self.with_transform(Transforms.arrow2pandas)
|
||||||
|
elif format == "arrow":
|
||||||
|
return self.with_transform(Transforms.arrow2arrow)
|
||||||
|
elif format == "torch":
|
||||||
|
return self.with_transform(batch_to_tensor)
|
||||||
|
elif format == "polars":
|
||||||
|
return self.with_transform(Transforms.arrow2polars())
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid format: {format}")
|
||||||
|
|
||||||
|
def with_transform(self, transform: Callable[pa.RecordBatch, Any]) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Set a custom transform for the permutation
|
||||||
|
|
||||||
|
The transform is a callable that will be invoked with each record batch. The
|
||||||
|
return value will be used as the batch for iteration.
|
||||||
|
|
||||||
|
Note: transforms are not invoked in parallel. This method is not a good place
|
||||||
|
for expensive operations such as image decoding.
|
||||||
|
"""
|
||||||
|
assert transform is not None, "transform is required"
|
||||||
|
return Permutation(self.reader, self.selection, self.batch_size, transform)
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> Any:
|
||||||
|
"""
|
||||||
|
Return a single row from the permutation
|
||||||
|
|
||||||
|
The output will always be a python dictionary regardless of the format.
|
||||||
|
|
||||||
|
This method is mostly useful for debugging and exploration. For actual
|
||||||
|
processing use [iter](#iter) or a torch data loader to perform batched
|
||||||
|
processing.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@deprecated(details="Use with_skip instead")
|
||||||
|
def skip(self, skip: int) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Skip the first `skip` rows of the permutation
|
||||||
|
|
||||||
|
Note: this method returns a new permutation and does not modify `self`
|
||||||
|
It is provided for compatibility with the huggingface Dataset API.
|
||||||
|
|
||||||
|
Use [with_skip](#with_skip) instead to avoid confusion.
|
||||||
|
"""
|
||||||
|
return self.with_skip(skip)
|
||||||
|
|
||||||
|
def with_skip(self, skip: int) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Skip the first `skip` rows of the permutation
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def do_with_skip():
|
||||||
|
reader = await self.reader.with_offset(skip)
|
||||||
|
return self._with_reader(reader)
|
||||||
|
|
||||||
|
return LOOP.run(do_with_skip())
|
||||||
|
|
||||||
|
@deprecated(details="Use with_take instead")
|
||||||
|
def take(self, limit: int) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Limit the permutation to `limit` rows (following any `skip`)
|
||||||
|
|
||||||
|
Note: this method returns a new permutation and does not modify `self`
|
||||||
|
It is provided for compatibility with the huggingface Dataset API.
|
||||||
|
|
||||||
|
Use [with_take](#with_take) instead to avoid confusion.
|
||||||
|
"""
|
||||||
|
return self.with_take(limit)
|
||||||
|
|
||||||
|
def with_take(self, limit: int) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Limit the permutation to `limit` rows (following any `skip`)
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def do_with_take():
|
||||||
|
reader = await self.reader.with_limit(limit)
|
||||||
|
return self._with_reader(reader)
|
||||||
|
|
||||||
|
return LOOP.run(do_with_take())
|
||||||
|
|
||||||
|
@deprecated(details="Use with_repeat instead")
|
||||||
|
def repeat(self, times: int) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Repeat the permutation `times` times
|
||||||
|
|
||||||
|
Note: this method returns a new permutation and does not modify `self`
|
||||||
|
It is provided for compatibility with the huggingface Dataset API.
|
||||||
|
|
||||||
|
Use [with_repeat](#with_repeat) instead to avoid confusion.
|
||||||
|
"""
|
||||||
|
return self.with_repeat(times)
|
||||||
|
|
||||||
|
def with_repeat(self, times: int) -> "Permutation":
|
||||||
|
"""
|
||||||
|
Repeat the permutation `times` times
|
||||||
|
"""
|
||||||
|
raise Exception("with_repeat is not yet implemented")
|
||||||
@@ -37,7 +37,7 @@ from .rerankers.base import Reranker
|
|||||||
from .rerankers.rrf import RRFReranker
|
from .rerankers.rrf import RRFReranker
|
||||||
from .rerankers.util import check_reranker_result
|
from .rerankers.util import check_reranker_result
|
||||||
from .util import flatten_columns
|
from .util import flatten_columns
|
||||||
|
from lancedb._lancedb import fts_query_to_json
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -124,6 +124,24 @@ class FullTextQuery(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def to_json(self) -> str:
|
||||||
|
"""
|
||||||
|
Convert the query to a JSON string.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str
|
||||||
|
A JSON string representation of the query.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> from lancedb.query import MatchQuery
|
||||||
|
>>> query = MatchQuery("puppy", "text", fuzziness=2)
|
||||||
|
>>> query.to_json()
|
||||||
|
'{"match":{"column":"text","terms":"puppy","boost":1.0,"fuzziness":2,"max_expansions":50,"operator":"Or","prefix_length":0}}'
|
||||||
|
"""
|
||||||
|
return fts_query_to_json(self)
|
||||||
|
|
||||||
def __and__(self, other: "FullTextQuery") -> "FullTextQuery":
|
def __and__(self, other: "FullTextQuery") -> "FullTextQuery":
|
||||||
"""
|
"""
|
||||||
Combine two queries with a logical AND operation.
|
Combine two queries with a logical AND operation.
|
||||||
@@ -288,6 +306,8 @@ class BooleanQuery(FullTextQuery):
|
|||||||
----------
|
----------
|
||||||
queries : list[tuple(Occur, FullTextQuery)]
|
queries : list[tuple(Occur, FullTextQuery)]
|
||||||
The list of queries with their occurrence requirements.
|
The list of queries with their occurrence requirements.
|
||||||
|
Each tuple contains an Occur value (MUST, SHOULD, or MUST_NOT)
|
||||||
|
and a FullTextQuery to apply.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
queries: list[tuple[Occur, FullTextQuery]]
|
queries: list[tuple[Occur, FullTextQuery]]
|
||||||
@@ -1237,6 +1257,14 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
|||||||
self._refine_factor = refine_factor
|
self._refine_factor = refine_factor
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def output_schema(self) -> pa.Schema:
|
||||||
|
"""
|
||||||
|
Return the output schema for the query
|
||||||
|
|
||||||
|
This does not execute the query.
|
||||||
|
"""
|
||||||
|
return self._table._output_schema(self.to_query_object())
|
||||||
|
|
||||||
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
|
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
|
||||||
"""
|
"""
|
||||||
Execute the query and return the results as an
|
Execute the query and return the results as an
|
||||||
@@ -1452,6 +1480,14 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
|||||||
offset=self._offset,
|
offset=self._offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def output_schema(self) -> pa.Schema:
|
||||||
|
"""
|
||||||
|
Return the output schema for the query
|
||||||
|
|
||||||
|
This does not execute the query.
|
||||||
|
"""
|
||||||
|
return self._table._output_schema(self.to_query_object())
|
||||||
|
|
||||||
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
|
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
|
||||||
path, fs, exist = self._table._get_fts_index_path()
|
path, fs, exist = self._table._get_fts_index_path()
|
||||||
if exist:
|
if exist:
|
||||||
@@ -1595,6 +1631,10 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
|||||||
offset=self._offset,
|
offset=self._offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def output_schema(self) -> pa.Schema:
|
||||||
|
query = self.to_query_object()
|
||||||
|
return self._table._output_schema(query)
|
||||||
|
|
||||||
def to_batches(
|
def to_batches(
|
||||||
self, /, batch_size: Optional[int] = None, timeout: Optional[timedelta] = None
|
self, /, batch_size: Optional[int] = None, timeout: Optional[timedelta] = None
|
||||||
) -> pa.RecordBatchReader:
|
) -> pa.RecordBatchReader:
|
||||||
@@ -2238,6 +2278,14 @@ class AsyncQueryBase(object):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def output_schema(self) -> pa.Schema:
|
||||||
|
"""
|
||||||
|
Return the output schema for the query
|
||||||
|
|
||||||
|
This does not execute the query.
|
||||||
|
"""
|
||||||
|
return await self._inner.output_schema()
|
||||||
|
|
||||||
async def to_arrow(self, timeout: Optional[timedelta] = None) -> pa.Table:
|
async def to_arrow(self, timeout: Optional[timedelta] = None) -> pa.Table:
|
||||||
"""
|
"""
|
||||||
Execute the query and collect the results into an Apache Arrow Table.
|
Execute the query and collect the results into an Apache Arrow Table.
|
||||||
@@ -3193,6 +3241,14 @@ class BaseQueryBuilder(object):
|
|||||||
self._inner.with_row_id()
|
self._inner.with_row_id()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def output_schema(self) -> pa.Schema:
|
||||||
|
"""
|
||||||
|
Return the output schema for the query
|
||||||
|
|
||||||
|
This does not execute the query.
|
||||||
|
"""
|
||||||
|
return LOOP.run(self._inner.output_schema())
|
||||||
|
|
||||||
def to_batches(
|
def to_batches(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
|||||||
@@ -436,6 +436,9 @@ class RemoteTable(Table):
|
|||||||
def _analyze_plan(self, query: Query) -> str:
|
def _analyze_plan(self, query: Query) -> str:
|
||||||
return LOOP.run(self._table._analyze_plan(query))
|
return LOOP.run(self._table._analyze_plan(query))
|
||||||
|
|
||||||
|
def _output_schema(self, query: Query) -> pa.Schema:
|
||||||
|
return LOOP.run(self._table._output_schema(query))
|
||||||
|
|
||||||
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
||||||
"""Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder]
|
"""Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder]
|
||||||
that can be used to create a "merge insert" operation.
|
that can be used to create a "merge insert" operation.
|
||||||
|
|||||||
@@ -21,6 +21,8 @@ class VoyageAIReranker(Reranker):
|
|||||||
----------
|
----------
|
||||||
model_name : str, default "rerank-english-v2.0"
|
model_name : str, default "rerank-english-v2.0"
|
||||||
The name of the cross encoder model to use. Available voyageai models are:
|
The name of the cross encoder model to use. Available voyageai models are:
|
||||||
|
- rerank-2.5
|
||||||
|
- rerank-2.5-lite
|
||||||
- rerank-2
|
- rerank-2
|
||||||
- rerank-2-lite
|
- rerank-2-lite
|
||||||
column : str, default "text"
|
column : str, default "text"
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ import numpy as np
|
|||||||
|
|
||||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||||
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||||
from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
|
from .index import BTree, IvfFlat, IvfPq, Bitmap, IvfRq, LabelList, HnswPq, HnswSq, FTS
|
||||||
from .merge import LanceMergeInsertBuilder
|
from .merge import LanceMergeInsertBuilder
|
||||||
from .pydantic import LanceModel, model_to_dict
|
from .pydantic import LanceModel, model_to_dict
|
||||||
from .query import (
|
from .query import (
|
||||||
@@ -74,6 +74,7 @@ from .index import lang_mapping
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from .db import LanceDBConnection
|
||||||
from ._lancedb import (
|
from ._lancedb import (
|
||||||
Table as LanceDBTable,
|
Table as LanceDBTable,
|
||||||
OptimizeStats,
|
OptimizeStats,
|
||||||
@@ -88,7 +89,6 @@ if TYPE_CHECKING:
|
|||||||
MergeResult,
|
MergeResult,
|
||||||
UpdateResult,
|
UpdateResult,
|
||||||
)
|
)
|
||||||
from .db import LanceDBConnection
|
|
||||||
from .index import IndexConfig
|
from .index import IndexConfig
|
||||||
import pandas
|
import pandas
|
||||||
import PIL
|
import PIL
|
||||||
@@ -1248,6 +1248,9 @@ class Table(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _analyze_plan(self, query: Query) -> str: ...
|
def _analyze_plan(self, query: Query) -> str: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _output_schema(self, query: Query) -> pa.Schema: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _do_merge(
|
def _do_merge(
|
||||||
self,
|
self,
|
||||||
@@ -1707,22 +1710,38 @@ class LanceTable(Table):
|
|||||||
namespace: List[str] = [],
|
namespace: List[str] = [],
|
||||||
storage_options: Optional[Dict[str, str]] = None,
|
storage_options: Optional[Dict[str, str]] = None,
|
||||||
index_cache_size: Optional[int] = None,
|
index_cache_size: Optional[int] = None,
|
||||||
|
_async: AsyncTable = None,
|
||||||
):
|
):
|
||||||
self._conn = connection
|
self._conn = connection
|
||||||
self._namespace = namespace
|
self._namespace = namespace
|
||||||
self._table = LOOP.run(
|
if _async is not None:
|
||||||
connection._conn.open_table(
|
self._table = _async
|
||||||
name,
|
else:
|
||||||
namespace=namespace,
|
self._table = LOOP.run(
|
||||||
storage_options=storage_options,
|
connection._conn.open_table(
|
||||||
index_cache_size=index_cache_size,
|
name,
|
||||||
|
namespace=namespace,
|
||||||
|
storage_options=storage_options,
|
||||||
|
index_cache_size=index_cache_size,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return self._table.name
|
return self._table.name
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_inner(cls, tbl: LanceDBTable):
|
||||||
|
from .db import LanceDBConnection
|
||||||
|
|
||||||
|
async_tbl = AsyncTable(tbl)
|
||||||
|
conn = LanceDBConnection.from_inner(tbl.database())
|
||||||
|
return cls(
|
||||||
|
conn,
|
||||||
|
async_tbl.name,
|
||||||
|
_async=async_tbl,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def open(cls, db, name, *, namespace: List[str] = [], **kwargs):
|
def open(cls, db, name, *, namespace: List[str] = [], **kwargs):
|
||||||
tbl = cls(db, name, namespace=namespace, **kwargs)
|
tbl = cls(db, name, namespace=namespace, **kwargs)
|
||||||
@@ -1991,7 +2010,7 @@ class LanceTable(Table):
|
|||||||
index_cache_size: Optional[int] = None,
|
index_cache_size: Optional[int] = None,
|
||||||
num_bits: int = 8,
|
num_bits: int = 8,
|
||||||
index_type: Literal[
|
index_type: Literal[
|
||||||
"IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"
|
"IVF_FLAT", "IVF_PQ", "IVF_RQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"
|
||||||
] = "IVF_PQ",
|
] = "IVF_PQ",
|
||||||
max_iterations: int = 50,
|
max_iterations: int = 50,
|
||||||
sample_rate: int = 256,
|
sample_rate: int = 256,
|
||||||
@@ -2039,6 +2058,15 @@ class LanceTable(Table):
|
|||||||
sample_rate=sample_rate,
|
sample_rate=sample_rate,
|
||||||
target_partition_size=target_partition_size,
|
target_partition_size=target_partition_size,
|
||||||
)
|
)
|
||||||
|
elif index_type == "IVF_RQ":
|
||||||
|
config = IvfRq(
|
||||||
|
distance_type=metric,
|
||||||
|
num_partitions=num_partitions,
|
||||||
|
num_bits=num_bits,
|
||||||
|
max_iterations=max_iterations,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
target_partition_size=target_partition_size,
|
||||||
|
)
|
||||||
elif index_type == "IVF_HNSW_PQ":
|
elif index_type == "IVF_HNSW_PQ":
|
||||||
config = HnswPq(
|
config = HnswPq(
|
||||||
distance_type=metric,
|
distance_type=metric,
|
||||||
@@ -2736,6 +2764,9 @@ class LanceTable(Table):
|
|||||||
def _analyze_plan(self, query: Query) -> str:
|
def _analyze_plan(self, query: Query) -> str:
|
||||||
return LOOP.run(self._table._analyze_plan(query))
|
return LOOP.run(self._table._analyze_plan(query))
|
||||||
|
|
||||||
|
def _output_schema(self, query: Query) -> pa.Schema:
|
||||||
|
return LOOP.run(self._table._output_schema(query))
|
||||||
|
|
||||||
def _do_merge(
|
def _do_merge(
|
||||||
self,
|
self,
|
||||||
merge: LanceMergeInsertBuilder,
|
merge: LanceMergeInsertBuilder,
|
||||||
@@ -2747,6 +2778,10 @@ class LanceTable(Table):
|
|||||||
self._table._do_merge(merge, new_data, on_bad_vectors, fill_value)
|
self._table._do_merge(merge, new_data, on_bad_vectors, fill_value)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _inner(self) -> LanceDBTable:
|
||||||
|
return self._table._inner
|
||||||
|
|
||||||
@deprecation.deprecated(
|
@deprecation.deprecated(
|
||||||
deprecated_in="0.21.0",
|
deprecated_in="0.21.0",
|
||||||
current_version=__version__,
|
current_version=__version__,
|
||||||
@@ -3330,7 +3365,7 @@ class AsyncTable:
|
|||||||
*,
|
*,
|
||||||
replace: Optional[bool] = None,
|
replace: Optional[bool] = None,
|
||||||
config: Optional[
|
config: Optional[
|
||||||
Union[IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS]
|
Union[IvfFlat, IvfPq, IvfRq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS]
|
||||||
] = None,
|
] = None,
|
||||||
wait_timeout: Optional[timedelta] = None,
|
wait_timeout: Optional[timedelta] = None,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
@@ -3369,11 +3404,12 @@ class AsyncTable:
|
|||||||
"""
|
"""
|
||||||
if config is not None:
|
if config is not None:
|
||||||
if not isinstance(
|
if not isinstance(
|
||||||
config, (IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS)
|
config,
|
||||||
|
(IvfFlat, IvfPq, IvfRq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS),
|
||||||
):
|
):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"config must be an instance of IvfPq, HnswPq, HnswSq, BTree,"
|
"config must be an instance of IvfPq, IvfRq, HnswPq, HnswSq, BTree,"
|
||||||
" Bitmap, LabelList, or FTS"
|
" Bitmap, LabelList, or FTS, but got " + str(type(config))
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
await self._inner.create_index(
|
await self._inner.create_index(
|
||||||
@@ -3888,6 +3924,10 @@ class AsyncTable:
|
|||||||
async_query = self._sync_query_to_async(query)
|
async_query = self._sync_query_to_async(query)
|
||||||
return await async_query.analyze_plan()
|
return await async_query.analyze_plan()
|
||||||
|
|
||||||
|
async def _output_schema(self, query: Query) -> pa.Schema:
|
||||||
|
async_query = self._sync_query_to_async(query)
|
||||||
|
return await async_query.output_schema()
|
||||||
|
|
||||||
async def _do_merge(
|
async def _do_merge(
|
||||||
self,
|
self,
|
||||||
merge: LanceMergeInsertBuilder,
|
merge: LanceMergeInsertBuilder,
|
||||||
|
|||||||
@@ -18,10 +18,17 @@ AddMode = Literal["append", "overwrite"]
|
|||||||
CreateMode = Literal["create", "overwrite"]
|
CreateMode = Literal["create", "overwrite"]
|
||||||
|
|
||||||
# Index type literals
|
# Index type literals
|
||||||
VectorIndexType = Literal["IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"]
|
VectorIndexType = Literal["IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ", "IVF_RQ"]
|
||||||
ScalarIndexType = Literal["BTREE", "BITMAP", "LABEL_LIST"]
|
ScalarIndexType = Literal["BTREE", "BITMAP", "LABEL_LIST"]
|
||||||
IndexType = Literal[
|
IndexType = Literal[
|
||||||
"IVF_PQ", "IVF_HNSW_PQ", "IVF_HNSW_SQ", "FTS", "BTREE", "BITMAP", "LABEL_LIST"
|
"IVF_PQ",
|
||||||
|
"IVF_HNSW_PQ",
|
||||||
|
"IVF_HNSW_SQ",
|
||||||
|
"FTS",
|
||||||
|
"BTREE",
|
||||||
|
"BITMAP",
|
||||||
|
"LABEL_LIST",
|
||||||
|
"IVF_RQ",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Tokenizer literals
|
# Tokenizer literals
|
||||||
|
|||||||
@@ -366,3 +366,56 @@ def add_note(base_exception: BaseException, note: str):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Cannot add note to exception")
|
raise ValueError("Cannot add note to exception")
|
||||||
|
|
||||||
|
|
||||||
|
def tbl_to_tensor(tbl: pa.Table):
|
||||||
|
"""
|
||||||
|
Convert a PyArrow Table to a PyTorch Tensor.
|
||||||
|
|
||||||
|
Each column is converted to a tensor (using zero-copy via DLPack)
|
||||||
|
and the columns are then stacked into a single tensor.
|
||||||
|
|
||||||
|
Fails if torch is not installed.
|
||||||
|
Fails if any column is more than one chunk.
|
||||||
|
Fails if a column's data type is not supported by PyTorch.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
tbl : pa.Table or pa.RecordBatch
|
||||||
|
The table or record batch to convert to a tensor.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor: The tensor containing the columns of the table.
|
||||||
|
"""
|
||||||
|
torch = attempt_import_or_raise("torch", "torch")
|
||||||
|
|
||||||
|
def to_tensor(col: pa.ChunkedArray):
|
||||||
|
if col.num_chunks > 1:
|
||||||
|
raise Exception("Single batch was too large to fit into a one-chunk table")
|
||||||
|
return torch.from_dlpack(col.chunk(0))
|
||||||
|
|
||||||
|
return torch.stack([to_tensor(tbl.column(i)) for i in range(tbl.num_columns)])
|
||||||
|
|
||||||
|
|
||||||
|
def batch_to_tensor(batch: pa.RecordBatch):
|
||||||
|
"""
|
||||||
|
Convert a PyArrow RecordBatch to a PyTorch Tensor.
|
||||||
|
|
||||||
|
Each column is converted to a tensor (using zero-copy via DLPack)
|
||||||
|
and the columns are then stacked into a single tensor.
|
||||||
|
|
||||||
|
Fails if torch is not installed.
|
||||||
|
Fails if a column's data type is not supported by PyTorch.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
batch : pa.RecordBatch
|
||||||
|
The record batch to convert to a tensor.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor: The tensor containing the columns of the record batch.
|
||||||
|
"""
|
||||||
|
torch = attempt_import_or_raise("torch", "torch")
|
||||||
|
return torch.stack([torch.from_dlpack(col) for col in batch.columns])
|
||||||
|
|||||||
@@ -532,6 +532,27 @@ def test_voyageai_embedding_function():
|
|||||||
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
|
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||||
|
)
|
||||||
|
def test_voyageai_embedding_function_contextual_model():
|
||||||
|
voyageai = (
|
||||||
|
get_registry().get("voyageai").create(name="voyage-context-3", max_retries=0)
|
||||||
|
)
|
||||||
|
|
||||||
|
class TextModel(LanceModel):
|
||||||
|
text: str = voyageai.SourceField()
|
||||||
|
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
||||||
|
|
||||||
|
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||||
|
db = lancedb.connect("~/lancedb")
|
||||||
|
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||||
|
|
||||||
|
tbl.add(df)
|
||||||
|
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||||
@@ -656,6 +677,106 @@ def test_colpali(tmp_path):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
importlib.util.find_spec("colpali_engine") is None,
|
||||||
|
reason="colpali_engine not installed",
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name",
|
||||||
|
[
|
||||||
|
"vidore/colSmol-256M",
|
||||||
|
"vidore/colqwen2.5-v0.2",
|
||||||
|
"vidore/colpali-v1.3",
|
||||||
|
"vidore/colqwen2-v1.0",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_colpali_models(tmp_path, model_name):
|
||||||
|
import requests
|
||||||
|
from lancedb.pydantic import LanceModel
|
||||||
|
|
||||||
|
db = lancedb.connect(tmp_path)
|
||||||
|
registry = get_registry()
|
||||||
|
func = registry.get("colpali").create(model_name=model_name)
|
||||||
|
|
||||||
|
class MediaItems(LanceModel):
|
||||||
|
text: str
|
||||||
|
image_uri: str = func.SourceField()
|
||||||
|
image_bytes: bytes = func.SourceField()
|
||||||
|
image_vectors: MultiVector(func.ndims()) = func.VectorField()
|
||||||
|
|
||||||
|
table = db.create_table(f"media_{model_name.replace('/', '_')}", schema=MediaItems)
|
||||||
|
|
||||||
|
texts = [
|
||||||
|
"a cute cat playing with yarn",
|
||||||
|
]
|
||||||
|
|
||||||
|
uris = [
|
||||||
|
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||||
|
]
|
||||||
|
|
||||||
|
image_bytes = [requests.get(uri).content for uri in uris]
|
||||||
|
|
||||||
|
table.add(
|
||||||
|
pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes})
|
||||||
|
)
|
||||||
|
|
||||||
|
image_results = (
|
||||||
|
table.search("fluffy companion", vector_column_name="image_vectors")
|
||||||
|
.limit(1)
|
||||||
|
.to_pydantic(MediaItems)[0]
|
||||||
|
)
|
||||||
|
assert "cat" in image_results.text.lower() or "puppy" in image_results.text.lower()
|
||||||
|
|
||||||
|
first_row = table.to_arrow().to_pylist()[0]
|
||||||
|
assert len(first_row["image_vectors"]) > 1, "Should have multiple image vectors"
|
||||||
|
assert len(first_row["image_vectors"][0]) == func.ndims(), (
|
||||||
|
"Vector dimension mismatch"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
importlib.util.find_spec("colpali_engine") is None,
|
||||||
|
reason="colpali_engine not installed",
|
||||||
|
)
|
||||||
|
def test_colpali_pooling(tmp_path):
|
||||||
|
registry = get_registry()
|
||||||
|
model_name = "vidore/colSmol-256M"
|
||||||
|
test_sentence = "a test sentence for pooling"
|
||||||
|
|
||||||
|
# 1. Get embeddings with no pooling
|
||||||
|
func_no_pool = registry.get("colpali").create(
|
||||||
|
model_name=model_name, pooling_strategy=None
|
||||||
|
)
|
||||||
|
unpooled_embeddings = func_no_pool.generate_text_embeddings([test_sentence])[0]
|
||||||
|
original_length = len(unpooled_embeddings)
|
||||||
|
assert original_length > 1
|
||||||
|
|
||||||
|
# 2. Test hierarchical pooling
|
||||||
|
func_hierarchical = registry.get("colpali").create(
|
||||||
|
model_name=model_name, pooling_strategy="hierarchical", pool_factor=2
|
||||||
|
)
|
||||||
|
hierarchical_embeddings = func_hierarchical.generate_text_embeddings(
|
||||||
|
[test_sentence]
|
||||||
|
)[0]
|
||||||
|
expected_hierarchical_length = (original_length + 1) // 2
|
||||||
|
assert len(hierarchical_embeddings) == expected_hierarchical_length
|
||||||
|
|
||||||
|
# 3. Test lambda pooling
|
||||||
|
def simple_pool_func(tensor):
|
||||||
|
return tensor[::2]
|
||||||
|
|
||||||
|
func_lambda = registry.get("colpali").create(
|
||||||
|
model_name=model_name,
|
||||||
|
pooling_strategy="lambda",
|
||||||
|
pooling_func=simple_pool_func,
|
||||||
|
)
|
||||||
|
lambda_embeddings = func_lambda.generate_text_embeddings([test_sentence])[0]
|
||||||
|
expected_lambda_length = (original_length + 1) // 2
|
||||||
|
assert len(lambda_embeddings) == expected_lambda_length
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
def test_siglip(tmp_path, test_images, query_image_bytes):
|
def test_siglip(tmp_path, test_images, query_image_bytes):
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|||||||
@@ -20,7 +20,14 @@ from unittest import mock
|
|||||||
import lancedb as ldb
|
import lancedb as ldb
|
||||||
from lancedb.db import DBConnection
|
from lancedb.db import DBConnection
|
||||||
from lancedb.index import FTS
|
from lancedb.index import FTS
|
||||||
from lancedb.query import BoostQuery, MatchQuery, MultiMatchQuery, PhraseQuery
|
from lancedb.query import (
|
||||||
|
BoostQuery,
|
||||||
|
MatchQuery,
|
||||||
|
MultiMatchQuery,
|
||||||
|
PhraseQuery,
|
||||||
|
BooleanQuery,
|
||||||
|
Occur,
|
||||||
|
)
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@@ -727,3 +734,146 @@ def test_fts_ngram(mem_db: DBConnection):
|
|||||||
results = table.search("la", query_type="fts").limit(10).to_list()
|
results = table.search("la", query_type="fts").limit(10).to_list()
|
||||||
assert len(results) == 2
|
assert len(results) == 2
|
||||||
assert set(r["text"] for r in results) == {"lance database", "lance is cool"}
|
assert set(r["text"] for r in results) == {"lance database", "lance is cool"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_fts_query_to_json():
|
||||||
|
"""Test that FTS query to_json() produces valid JSON strings with exact format."""
|
||||||
|
|
||||||
|
# Test MatchQuery - basic
|
||||||
|
match_query = MatchQuery("hello world", "text")
|
||||||
|
json_str = match_query.to_json()
|
||||||
|
expected = (
|
||||||
|
'{"match":{"column":"text","terms":"hello world","boost":1.0,'
|
||||||
|
'"fuzziness":0,"max_expansions":50,"operator":"Or","prefix_length":0}}'
|
||||||
|
)
|
||||||
|
assert json_str == expected
|
||||||
|
|
||||||
|
# Test MatchQuery with options
|
||||||
|
match_query = MatchQuery("puppy", "text", fuzziness=2, boost=1.5, prefix_length=3)
|
||||||
|
json_str = match_query.to_json()
|
||||||
|
expected = (
|
||||||
|
'{"match":{"column":"text","terms":"puppy","boost":1.5,"fuzziness":2,'
|
||||||
|
'"max_expansions":50,"operator":"Or","prefix_length":3}}'
|
||||||
|
)
|
||||||
|
assert json_str == expected
|
||||||
|
|
||||||
|
# Test PhraseQuery
|
||||||
|
phrase_query = PhraseQuery("quick brown fox", "title")
|
||||||
|
json_str = phrase_query.to_json()
|
||||||
|
expected = '{"phrase":{"column":"title","terms":"quick brown fox","slop":0}}'
|
||||||
|
assert json_str == expected
|
||||||
|
|
||||||
|
# Test PhraseQuery with slop
|
||||||
|
phrase_query = PhraseQuery("quick brown", "title", slop=2)
|
||||||
|
json_str = phrase_query.to_json()
|
||||||
|
expected = '{"phrase":{"column":"title","terms":"quick brown","slop":2}}'
|
||||||
|
assert json_str == expected
|
||||||
|
|
||||||
|
# Test BooleanQuery with MUST
|
||||||
|
must_query = BooleanQuery(
|
||||||
|
[
|
||||||
|
(Occur.MUST, MatchQuery("puppy", "text")),
|
||||||
|
(Occur.MUST, MatchQuery("runs", "text")),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
json_str = must_query.to_json()
|
||||||
|
expected = (
|
||||||
|
'{"boolean":{"should":[],"must":[{"match":{"column":"text","terms":"puppy",'
|
||||||
|
'"boost":1.0,"fuzziness":0,"max_expansions":50,"operator":"Or",'
|
||||||
|
'"prefix_length":0}},{"match":{"column":"text","terms":"runs","boost":1.0,'
|
||||||
|
'"fuzziness":0,"max_expansions":50,"operator":"Or","prefix_length":0}}],'
|
||||||
|
'"must_not":[]}}'
|
||||||
|
)
|
||||||
|
assert json_str == expected
|
||||||
|
|
||||||
|
# Test BooleanQuery with SHOULD
|
||||||
|
should_query = BooleanQuery(
|
||||||
|
[
|
||||||
|
(Occur.SHOULD, MatchQuery("cat", "text")),
|
||||||
|
(Occur.SHOULD, MatchQuery("dog", "text")),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
json_str = should_query.to_json()
|
||||||
|
expected = (
|
||||||
|
'{"boolean":{"should":[{"match":{"column":"text","terms":"cat","boost":1.0,'
|
||||||
|
'"fuzziness":0,"max_expansions":50,"operator":"Or","prefix_length":0}},'
|
||||||
|
'{"match":{"column":"text","terms":"dog","boost":1.0,"fuzziness":0,'
|
||||||
|
'"max_expansions":50,"operator":"Or","prefix_length":0}}],"must":[],'
|
||||||
|
'"must_not":[]}}'
|
||||||
|
)
|
||||||
|
assert json_str == expected
|
||||||
|
|
||||||
|
# Test BooleanQuery with MUST_NOT
|
||||||
|
must_not_query = BooleanQuery(
|
||||||
|
[
|
||||||
|
(Occur.MUST, MatchQuery("puppy", "text")),
|
||||||
|
(Occur.MUST_NOT, MatchQuery("training", "text")),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
json_str = must_not_query.to_json()
|
||||||
|
expected = (
|
||||||
|
'{"boolean":{"should":[],"must":[{"match":{"column":"text","terms":"puppy",'
|
||||||
|
'"boost":1.0,"fuzziness":0,"max_expansions":50,"operator":"Or",'
|
||||||
|
'"prefix_length":0}}],"must_not":[{"match":{"column":"text",'
|
||||||
|
'"terms":"training","boost":1.0,"fuzziness":0,"max_expansions":50,'
|
||||||
|
'"operator":"Or","prefix_length":0}}]}}'
|
||||||
|
)
|
||||||
|
assert json_str == expected
|
||||||
|
|
||||||
|
# Test BoostQuery
|
||||||
|
positive = MatchQuery("puppy", "text")
|
||||||
|
negative = MatchQuery("training", "text")
|
||||||
|
boost_query = BoostQuery(positive, negative, negative_boost=0.3)
|
||||||
|
json_str = boost_query.to_json()
|
||||||
|
expected = (
|
||||||
|
'{"boost":{"positive":{"match":{"column":"text","terms":"puppy",'
|
||||||
|
'"boost":1.0,"fuzziness":0,"max_expansions":50,"operator":"Or",'
|
||||||
|
'"prefix_length":0}},"negative":{"match":{"column":"text",'
|
||||||
|
'"terms":"training","boost":1.0,"fuzziness":0,"max_expansions":50,'
|
||||||
|
'"operator":"Or","prefix_length":0}},"negative_boost":0.3}}'
|
||||||
|
)
|
||||||
|
assert json_str == expected
|
||||||
|
|
||||||
|
# Test MultiMatchQuery
|
||||||
|
multi_match = MultiMatchQuery("python", ["tags", "title"])
|
||||||
|
json_str = multi_match.to_json()
|
||||||
|
expected = (
|
||||||
|
'{"multi_match":{"query":"python","columns":["tags","title"],'
|
||||||
|
'"boost":[1.0,1.0]}}'
|
||||||
|
)
|
||||||
|
assert json_str == expected
|
||||||
|
|
||||||
|
# Test complex nested BooleanQuery
|
||||||
|
inner1 = BooleanQuery(
|
||||||
|
[
|
||||||
|
(Occur.MUST, MatchQuery("python", "tags")),
|
||||||
|
(Occur.MUST, MatchQuery("tutorial", "title")),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
inner2 = BooleanQuery(
|
||||||
|
[
|
||||||
|
(Occur.MUST, MatchQuery("rust", "tags")),
|
||||||
|
(Occur.MUST, MatchQuery("guide", "title")),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
complex_query = BooleanQuery(
|
||||||
|
[
|
||||||
|
(Occur.SHOULD, inner1),
|
||||||
|
(Occur.SHOULD, inner2),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
json_str = complex_query.to_json()
|
||||||
|
expected = (
|
||||||
|
'{"boolean":{"should":[{"boolean":{"should":[],"must":[{"match":'
|
||||||
|
'{"column":"tags","terms":"python","boost":1.0,"fuzziness":0,'
|
||||||
|
'"max_expansions":50,"operator":"Or","prefix_length":0}},{"match":'
|
||||||
|
'{"column":"title","terms":"tutorial","boost":1.0,"fuzziness":0,'
|
||||||
|
'"max_expansions":50,"operator":"Or","prefix_length":0}}],"must_not":[]}}'
|
||||||
|
',{"boolean":{"should":[],"must":[{"match":{"column":"tags",'
|
||||||
|
'"terms":"rust","boost":1.0,"fuzziness":0,"max_expansions":50,'
|
||||||
|
'"operator":"Or","prefix_length":0}},{"match":{"column":"title",'
|
||||||
|
'"terms":"guide","boost":1.0,"fuzziness":0,"max_expansions":50,'
|
||||||
|
'"operator":"Or","prefix_length":0}}],"must_not":[]}}],"must":[],'
|
||||||
|
'"must_not":[]}}'
|
||||||
|
)
|
||||||
|
assert json_str == expected
|
||||||
|
|||||||
@@ -8,7 +8,17 @@ import pyarrow as pa
|
|||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from lancedb import AsyncConnection, AsyncTable, connect_async
|
from lancedb import AsyncConnection, AsyncTable, connect_async
|
||||||
from lancedb.index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
|
from lancedb.index import (
|
||||||
|
BTree,
|
||||||
|
IvfFlat,
|
||||||
|
IvfPq,
|
||||||
|
IvfRq,
|
||||||
|
Bitmap,
|
||||||
|
LabelList,
|
||||||
|
HnswPq,
|
||||||
|
HnswSq,
|
||||||
|
FTS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest_asyncio.fixture
|
||||||
@@ -195,6 +205,16 @@ async def test_create_4bit_ivfpq_index(some_table: AsyncTable):
|
|||||||
assert stats.loss >= 0.0
|
assert stats.loss >= 0.0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_ivfrq_index(some_table: AsyncTable):
|
||||||
|
await some_table.create_index("vector", config=IvfRq(num_bits=1))
|
||||||
|
indices = await some_table.list_indices()
|
||||||
|
assert len(indices) == 1
|
||||||
|
assert indices[0].index_type == "IvfRq"
|
||||||
|
assert indices[0].columns == ["vector"]
|
||||||
|
assert indices[0].name == "vector_idx"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_hnswpq_index(some_table: AsyncTable):
|
async def test_create_hnswpq_index(some_table: AsyncTable):
|
||||||
await some_table.create_index("vector", config=HnswPq(num_partitions=10))
|
await some_table.create_index("vector", config=HnswPq(num_partitions=10))
|
||||||
|
|||||||
@@ -59,6 +59,14 @@ class TempNamespace(LanceNamespace):
|
|||||||
root
|
root
|
||||||
] # Reference to shared namespaces
|
] # Reference to shared namespaces
|
||||||
|
|
||||||
|
def namespace_id(self) -> str:
|
||||||
|
"""Return a human-readable unique identifier for this namespace instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A unique identifier string based on the root directory
|
||||||
|
"""
|
||||||
|
return f"TempNamespace {{ root: '{self.config.root}' }}"
|
||||||
|
|
||||||
def list_tables(self, request: ListTablesRequest) -> ListTablesResponse:
|
def list_tables(self, request: ListTablesRequest) -> ListTablesResponse:
|
||||||
"""List all tables in the namespace."""
|
"""List all tables in the namespace."""
|
||||||
if not request.id:
|
if not request.id:
|
||||||
|
|||||||
943
python/python/tests/test_permutation.py
Normal file
943
python/python/tests/test_permutation.py
Normal file
@@ -0,0 +1,943 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
import math
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from lancedb import DBConnection, Table, connect
|
||||||
|
from lancedb.permutation import Permutation, Permutations, permutation_builder
|
||||||
|
|
||||||
|
|
||||||
|
def test_permutation_persistence(tmp_path):
|
||||||
|
db = connect(tmp_path)
|
||||||
|
tbl = db.create_table("test_table", pa.table({"x": range(100), "y": range(100)}))
|
||||||
|
|
||||||
|
permutation_tbl = (
|
||||||
|
permutation_builder(tbl).shuffle().persist(db, "test_permutation").execute()
|
||||||
|
)
|
||||||
|
assert permutation_tbl.count_rows() == 100
|
||||||
|
|
||||||
|
re_open = db.open_table("test_permutation")
|
||||||
|
assert re_open.count_rows() == 100
|
||||||
|
|
||||||
|
assert permutation_tbl.to_arrow() == re_open.to_arrow()
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_random_ratios(mem_db):
|
||||||
|
"""Test random splitting with ratios."""
|
||||||
|
tbl = mem_db.create_table(
|
||||||
|
"test_table", pa.table({"x": range(100), "y": range(100)})
|
||||||
|
)
|
||||||
|
permutation_tbl = permutation_builder(tbl).split_random(ratios=[0.3, 0.7]).execute()
|
||||||
|
|
||||||
|
# Check that the table was created and has data
|
||||||
|
assert permutation_tbl.count_rows() == 100
|
||||||
|
|
||||||
|
# Check that split_id column exists and has correct values
|
||||||
|
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||||
|
split_ids = data["split_id"]
|
||||||
|
assert set(split_ids) == {0, 1}
|
||||||
|
|
||||||
|
# Check approximate split sizes (allowing for rounding)
|
||||||
|
split_0_count = split_ids.count(0)
|
||||||
|
split_1_count = split_ids.count(1)
|
||||||
|
assert 25 <= split_0_count <= 35 # ~30% ± tolerance
|
||||||
|
assert 65 <= split_1_count <= 75 # ~70% ± tolerance
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_random_counts(mem_db):
|
||||||
|
"""Test random splitting with absolute counts."""
|
||||||
|
tbl = mem_db.create_table(
|
||||||
|
"test_table", pa.table({"x": range(100), "y": range(100)})
|
||||||
|
)
|
||||||
|
permutation_tbl = permutation_builder(tbl).split_random(counts=[20, 30]).execute()
|
||||||
|
|
||||||
|
# Check that we have exactly the requested counts
|
||||||
|
assert permutation_tbl.count_rows() == 50
|
||||||
|
|
||||||
|
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||||
|
split_ids = data["split_id"]
|
||||||
|
assert split_ids.count(0) == 20
|
||||||
|
assert split_ids.count(1) == 30
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_random_fixed(mem_db):
|
||||||
|
"""Test random splitting with fixed number of splits."""
|
||||||
|
tbl = mem_db.create_table(
|
||||||
|
"test_table", pa.table({"x": range(100), "y": range(100)})
|
||||||
|
)
|
||||||
|
permutation_tbl = permutation_builder(tbl).split_random(fixed=4).execute()
|
||||||
|
|
||||||
|
# Check that we have 4 splits with 25 rows each
|
||||||
|
assert permutation_tbl.count_rows() == 100
|
||||||
|
|
||||||
|
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||||
|
split_ids = data["split_id"]
|
||||||
|
assert set(split_ids) == {0, 1, 2, 3}
|
||||||
|
|
||||||
|
for split_id in range(4):
|
||||||
|
assert split_ids.count(split_id) == 25
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_random_with_seed(mem_db):
|
||||||
|
"""Test that seeded random splits are reproducible."""
|
||||||
|
tbl = mem_db.create_table("test_table", pa.table({"x": range(50), "y": range(50)}))
|
||||||
|
|
||||||
|
# Create two identical permutations with same seed
|
||||||
|
perm1 = permutation_builder(tbl).split_random(ratios=[0.6, 0.4], seed=42).execute()
|
||||||
|
|
||||||
|
perm2 = permutation_builder(tbl).split_random(ratios=[0.6, 0.4], seed=42).execute()
|
||||||
|
|
||||||
|
# Results should be identical
|
||||||
|
data1 = perm1.search(None).to_arrow().to_pydict()
|
||||||
|
data2 = perm2.search(None).to_arrow().to_pydict()
|
||||||
|
|
||||||
|
assert data1["row_id"] == data2["row_id"]
|
||||||
|
assert data1["split_id"] == data2["split_id"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_hash(mem_db):
|
||||||
|
"""Test hash-based splitting."""
|
||||||
|
tbl = mem_db.create_table(
|
||||||
|
"test_table",
|
||||||
|
pa.table(
|
||||||
|
{
|
||||||
|
"id": range(100),
|
||||||
|
"category": (["A", "B", "C"] * 34)[:100], # Repeating pattern
|
||||||
|
"value": range(100),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
permutation_tbl = (
|
||||||
|
permutation_builder(tbl)
|
||||||
|
.split_hash(["category"], [1, 1], discard_weight=0)
|
||||||
|
.execute()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have all 100 rows (no discard)
|
||||||
|
assert permutation_tbl.count_rows() == 100
|
||||||
|
|
||||||
|
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||||
|
split_ids = data["split_id"]
|
||||||
|
assert set(split_ids) == {0, 1}
|
||||||
|
|
||||||
|
# Verify that each split has roughly 50 rows (allowing for hash variance)
|
||||||
|
split_0_count = split_ids.count(0)
|
||||||
|
split_1_count = split_ids.count(1)
|
||||||
|
assert 30 <= split_0_count <= 70 # ~50 ± 20 tolerance for hash distribution
|
||||||
|
assert 30 <= split_1_count <= 70 # ~50 ± 20 tolerance for hash distribution
|
||||||
|
|
||||||
|
# Hash splits should be deterministic - same category should go to same split
|
||||||
|
# Let's verify by creating another permutation and checking consistency
|
||||||
|
perm2 = (
|
||||||
|
permutation_builder(tbl)
|
||||||
|
.split_hash(["category"], [1, 1], discard_weight=0)
|
||||||
|
.execute()
|
||||||
|
)
|
||||||
|
|
||||||
|
data2 = perm2.search(None).to_arrow().to_pydict()
|
||||||
|
assert data["split_id"] == data2["split_id"] # Should be identical
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_hash_with_discard(mem_db):
|
||||||
|
"""Test hash-based splitting with discard weight."""
|
||||||
|
tbl = mem_db.create_table(
|
||||||
|
"test_table",
|
||||||
|
pa.table({"id": range(100), "category": ["A", "B"] * 50, "value": range(100)}),
|
||||||
|
)
|
||||||
|
|
||||||
|
permutation_tbl = (
|
||||||
|
permutation_builder(tbl)
|
||||||
|
.split_hash(["category"], [1, 1], discard_weight=2) # Should discard ~50%
|
||||||
|
.execute()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have fewer than 100 rows due to discard
|
||||||
|
row_count = permutation_tbl.count_rows()
|
||||||
|
assert row_count < 100
|
||||||
|
assert row_count > 0 # But not empty
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_sequential(mem_db):
|
||||||
|
"""Test sequential splitting."""
|
||||||
|
tbl = mem_db.create_table(
|
||||||
|
"test_table", pa.table({"x": range(100), "y": range(100)})
|
||||||
|
)
|
||||||
|
|
||||||
|
permutation_tbl = (
|
||||||
|
permutation_builder(tbl).split_sequential(counts=[30, 40]).execute()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert permutation_tbl.count_rows() == 70
|
||||||
|
|
||||||
|
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||||
|
row_ids = data["row_id"]
|
||||||
|
split_ids = data["split_id"]
|
||||||
|
|
||||||
|
# Sequential should maintain order
|
||||||
|
assert row_ids == sorted(row_ids)
|
||||||
|
|
||||||
|
# First 30 should be split 0, next 40 should be split 1
|
||||||
|
assert split_ids[:30] == [0] * 30
|
||||||
|
assert split_ids[30:] == [1] * 40
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_calculated(mem_db):
|
||||||
|
"""Test calculated splitting."""
|
||||||
|
tbl = mem_db.create_table(
|
||||||
|
"test_table", pa.table({"id": range(100), "value": range(100)})
|
||||||
|
)
|
||||||
|
|
||||||
|
permutation_tbl = (
|
||||||
|
permutation_builder(tbl)
|
||||||
|
.split_calculated("id % 3") # Split based on id modulo 3
|
||||||
|
.execute()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert permutation_tbl.count_rows() == 100
|
||||||
|
|
||||||
|
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||||
|
row_ids = data["row_id"]
|
||||||
|
split_ids = data["split_id"]
|
||||||
|
|
||||||
|
# Verify the calculation: each row's split_id should equal row_id % 3
|
||||||
|
for i, (row_id, split_id) in enumerate(zip(row_ids, split_ids)):
|
||||||
|
assert split_id == row_id % 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_error_cases(mem_db):
|
||||||
|
"""Test error handling for invalid split parameters."""
|
||||||
|
tbl = mem_db.create_table("test_table", pa.table({"x": range(10), "y": range(10)}))
|
||||||
|
|
||||||
|
# Test split_random with no parameters
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
||||||
|
):
|
||||||
|
permutation_builder(tbl).split_random().execute()
|
||||||
|
|
||||||
|
# Test split_random with multiple parameters
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
||||||
|
):
|
||||||
|
permutation_builder(tbl).split_random(
|
||||||
|
ratios=[0.5, 0.5], counts=[5, 5]
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
# Test split_sequential with no parameters
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
||||||
|
):
|
||||||
|
permutation_builder(tbl).split_sequential().execute()
|
||||||
|
|
||||||
|
# Test split_sequential with multiple parameters
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
||||||
|
):
|
||||||
|
permutation_builder(tbl).split_sequential(ratios=[0.5, 0.5], fixed=2).execute()
|
||||||
|
|
||||||
|
|
||||||
|
def test_shuffle_no_seed(mem_db):
|
||||||
|
"""Test shuffling without a seed."""
|
||||||
|
tbl = mem_db.create_table(
|
||||||
|
"test_table", pa.table({"id": range(100), "value": range(100)})
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a permutation with shuffling (no seed)
|
||||||
|
permutation_tbl = permutation_builder(tbl).shuffle().execute()
|
||||||
|
|
||||||
|
assert permutation_tbl.count_rows() == 100
|
||||||
|
|
||||||
|
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||||
|
row_ids = data["row_id"]
|
||||||
|
|
||||||
|
# Row IDs should not be in sequential order due to shuffling
|
||||||
|
# This is probabilistic but with 100 rows, it's extremely unlikely they'd stay
|
||||||
|
# in order
|
||||||
|
assert row_ids != list(range(100))
|
||||||
|
|
||||||
|
|
||||||
|
def test_shuffle_with_seed(mem_db):
|
||||||
|
"""Test that shuffling with a seed is reproducible."""
|
||||||
|
tbl = mem_db.create_table(
|
||||||
|
"test_table", pa.table({"id": range(50), "value": range(50)})
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create two identical permutations with same shuffle seed
|
||||||
|
perm1 = permutation_builder(tbl).shuffle(seed=42).execute()
|
||||||
|
|
||||||
|
perm2 = permutation_builder(tbl).shuffle(seed=42).execute()
|
||||||
|
|
||||||
|
# Results should be identical due to same seed
|
||||||
|
data1 = perm1.search(None).to_arrow().to_pydict()
|
||||||
|
data2 = perm2.search(None).to_arrow().to_pydict()
|
||||||
|
|
||||||
|
assert data1["row_id"] == data2["row_id"]
|
||||||
|
assert data1["split_id"] == data2["split_id"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_shuffle_with_clump_size(mem_db):
|
||||||
|
"""Test shuffling with clump size."""
|
||||||
|
tbl = mem_db.create_table(
|
||||||
|
"test_table", pa.table({"id": range(100), "value": range(100)})
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a permutation with shuffling using clumps
|
||||||
|
permutation_tbl = (
|
||||||
|
permutation_builder(tbl)
|
||||||
|
.shuffle(clump_size=10) # 10-row clumps
|
||||||
|
.execute()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert permutation_tbl.count_rows() == 100
|
||||||
|
|
||||||
|
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||||
|
row_ids = data["row_id"]
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
start = row_ids[i * 10]
|
||||||
|
assert row_ids[i * 10 : (i + 1) * 10] == list(range(start, start + 10))
|
||||||
|
|
||||||
|
|
||||||
|
def test_shuffle_different_seeds(mem_db):
|
||||||
|
"""Test that different seeds produce different shuffle orders."""
|
||||||
|
tbl = mem_db.create_table(
|
||||||
|
"test_table", pa.table({"id": range(50), "value": range(50)})
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create two permutations with different shuffle seeds
|
||||||
|
perm1 = permutation_builder(tbl).split_random(fixed=2).shuffle(seed=42).execute()
|
||||||
|
|
||||||
|
perm2 = permutation_builder(tbl).split_random(fixed=2).shuffle(seed=123).execute()
|
||||||
|
|
||||||
|
# Results should be different due to different seeds
|
||||||
|
data1 = perm1.search(None).to_arrow().to_pydict()
|
||||||
|
data2 = perm2.search(None).to_arrow().to_pydict()
|
||||||
|
|
||||||
|
# Row order should be different
|
||||||
|
assert data1["row_id"] != data2["row_id"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_shuffle_combined_with_splits(mem_db):
|
||||||
|
"""Test shuffling combined with different split strategies."""
|
||||||
|
tbl = mem_db.create_table(
|
||||||
|
"test_table",
|
||||||
|
pa.table(
|
||||||
|
{
|
||||||
|
"id": range(100),
|
||||||
|
"category": (["A", "B", "C"] * 34)[:100],
|
||||||
|
"value": range(100),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test shuffle with random splits
|
||||||
|
perm_random = (
|
||||||
|
permutation_builder(tbl)
|
||||||
|
.split_random(ratios=[0.6, 0.4], seed=42)
|
||||||
|
.shuffle(seed=123, clump_size=None)
|
||||||
|
.execute()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test shuffle with hash splits
|
||||||
|
perm_hash = (
|
||||||
|
permutation_builder(tbl)
|
||||||
|
.split_hash(["category"], [1, 1], discard_weight=0)
|
||||||
|
.shuffle(seed=456, clump_size=5)
|
||||||
|
.execute()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test shuffle with sequential splits
|
||||||
|
perm_sequential = (
|
||||||
|
permutation_builder(tbl)
|
||||||
|
.split_sequential(counts=[40, 35])
|
||||||
|
.shuffle(seed=789, clump_size=None)
|
||||||
|
.execute()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify all permutations work and have expected properties
|
||||||
|
assert perm_random.count_rows() == 100
|
||||||
|
assert perm_hash.count_rows() == 100
|
||||||
|
assert perm_sequential.count_rows() == 75
|
||||||
|
|
||||||
|
# Verify shuffle affected the order
|
||||||
|
data_random = perm_random.search(None).to_arrow().to_pydict()
|
||||||
|
data_sequential = perm_sequential.search(None).to_arrow().to_pydict()
|
||||||
|
|
||||||
|
assert data_random["row_id"] != list(range(100))
|
||||||
|
assert data_sequential["row_id"] != list(range(75))
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_shuffle_maintains_order(mem_db):
|
||||||
|
"""Test that not calling shuffle maintains the original order."""
|
||||||
|
tbl = mem_db.create_table(
|
||||||
|
"test_table", pa.table({"id": range(50), "value": range(50)})
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create permutation without shuffle (should maintain some order)
|
||||||
|
permutation_tbl = (
|
||||||
|
permutation_builder(tbl)
|
||||||
|
.split_sequential(counts=[25, 25]) # Sequential maintains order
|
||||||
|
.execute()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert permutation_tbl.count_rows() == 50
|
||||||
|
|
||||||
|
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||||
|
row_ids = data["row_id"]
|
||||||
|
|
||||||
|
# With sequential splits and no shuffle, should maintain order
|
||||||
|
assert row_ids == list(range(50))
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_basic(mem_db):
|
||||||
|
"""Test basic filtering functionality."""
|
||||||
|
tbl = mem_db.create_table(
|
||||||
|
"test_table", pa.table({"id": range(100), "value": range(100, 200)})
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter to only include rows where id < 50
|
||||||
|
permutation_tbl = permutation_builder(tbl).filter("id < 50").execute()
|
||||||
|
|
||||||
|
assert permutation_tbl.count_rows() == 50
|
||||||
|
|
||||||
|
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||||
|
row_ids = data["row_id"]
|
||||||
|
|
||||||
|
# All row_ids should be less than 50
|
||||||
|
assert all(row_id < 50 for row_id in row_ids)
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_with_splits(mem_db):
|
||||||
|
"""Test filtering combined with split strategies."""
|
||||||
|
tbl = mem_db.create_table(
|
||||||
|
"test_table",
|
||||||
|
pa.table(
|
||||||
|
{
|
||||||
|
"id": range(100),
|
||||||
|
"category": (["A", "B", "C"] * 34)[:100],
|
||||||
|
"value": range(100),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter to only category A and B, then split
|
||||||
|
permutation_tbl = (
|
||||||
|
permutation_builder(tbl)
|
||||||
|
.filter("category IN ('A', 'B')")
|
||||||
|
.split_random(ratios=[0.5, 0.5])
|
||||||
|
.execute()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have fewer than 100 rows due to filtering
|
||||||
|
row_count = permutation_tbl.count_rows()
|
||||||
|
assert row_count == 67
|
||||||
|
|
||||||
|
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||||
|
categories = data["category"]
|
||||||
|
|
||||||
|
# All categories should be A or B
|
||||||
|
assert all(cat in ["A", "B"] for cat in categories)
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_with_shuffle(mem_db):
|
||||||
|
"""Test filtering combined with shuffling."""
|
||||||
|
tbl = mem_db.create_table(
|
||||||
|
"test_table",
|
||||||
|
pa.table(
|
||||||
|
{
|
||||||
|
"id": range(100),
|
||||||
|
"category": (["A", "B", "C", "D"] * 25)[:100],
|
||||||
|
"value": range(100),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter and shuffle
|
||||||
|
permutation_tbl = (
|
||||||
|
permutation_builder(tbl)
|
||||||
|
.filter("category IN ('A', 'C')")
|
||||||
|
.shuffle(seed=42)
|
||||||
|
.execute()
|
||||||
|
)
|
||||||
|
|
||||||
|
row_count = permutation_tbl.count_rows()
|
||||||
|
assert row_count == 50 # Should have 50 rows (A and C categories)
|
||||||
|
|
||||||
|
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
||||||
|
row_ids = data["row_id"]
|
||||||
|
|
||||||
|
assert row_ids != sorted(row_ids)
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_empty_result(mem_db):
|
||||||
|
"""Test filtering that results in empty set."""
|
||||||
|
tbl = mem_db.create_table(
|
||||||
|
"test_table", pa.table({"id": range(10), "value": range(10)})
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter that matches nothing
|
||||||
|
permutation_tbl = (
|
||||||
|
permutation_builder(tbl)
|
||||||
|
.filter("value > 100") # No values > 100 in our data
|
||||||
|
.execute()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert permutation_tbl.count_rows() == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mem_db() -> DBConnection:
|
||||||
|
return connect("memory:///")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def some_table(mem_db: DBConnection) -> Table:
|
||||||
|
data = pa.table(
|
||||||
|
{
|
||||||
|
"id": range(1000),
|
||||||
|
"value": range(1000),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return mem_db.create_table("some_table", data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_split_names(some_table: Table):
|
||||||
|
perm_tbl = (
|
||||||
|
permutation_builder(some_table).split_sequential(counts=[500, 500]).execute()
|
||||||
|
)
|
||||||
|
permutations = Permutations(some_table, perm_tbl)
|
||||||
|
assert permutations.split_names == []
|
||||||
|
assert permutations.split_dict == {}
|
||||||
|
assert permutations[0].num_rows == 500
|
||||||
|
assert permutations[1].num_rows == 500
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def some_perm_table(some_table: Table) -> Table:
|
||||||
|
return (
|
||||||
|
permutation_builder(some_table)
|
||||||
|
.split_random(ratios=[0.95, 0.05], seed=42, split_names=["train", "test"])
|
||||||
|
.shuffle(seed=42)
|
||||||
|
.execute()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_nonexistent_split(some_table: Table, some_perm_table: Table):
|
||||||
|
# Reference by name and name does not exist
|
||||||
|
with pytest.raises(ValueError, match="split `nonexistent` is not defined"):
|
||||||
|
Permutation.from_tables(some_table, some_perm_table, "nonexistent")
|
||||||
|
|
||||||
|
# Reference by ordinal and there are no rows
|
||||||
|
with pytest.raises(ValueError, match="No rows found"):
|
||||||
|
Permutation.from_tables(some_table, some_perm_table, 5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_permutations(some_table: Table, some_perm_table: Table):
|
||||||
|
permutations = Permutations(some_table, some_perm_table)
|
||||||
|
assert permutations.split_names == ["train", "test"]
|
||||||
|
assert permutations.split_dict == {"train": 0, "test": 1}
|
||||||
|
assert permutations["train"].num_rows == 950
|
||||||
|
assert permutations[0].num_rows == 950
|
||||||
|
assert permutations["test"].num_rows == 50
|
||||||
|
assert permutations[1].num_rows == 50
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="No split named `nonexistent` found"):
|
||||||
|
permutations["nonexistent"]
|
||||||
|
with pytest.raises(ValueError, match="No rows found"):
|
||||||
|
permutations[5]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def some_permutation(some_table: Table, some_perm_table: Table) -> Permutation:
|
||||||
|
return Permutation.from_tables(some_table, some_perm_table)
|
||||||
|
|
||||||
|
|
||||||
|
def test_num_rows(some_permutation: Permutation):
|
||||||
|
assert some_permutation.num_rows == 950
|
||||||
|
|
||||||
|
|
||||||
|
def test_num_columns(some_permutation: Permutation):
|
||||||
|
assert some_permutation.num_columns == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_column_names(some_permutation: Permutation):
|
||||||
|
assert some_permutation.column_names == ["id", "value"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_shape(some_permutation: Permutation):
|
||||||
|
assert some_permutation.shape == (950, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_schema(some_permutation: Permutation):
|
||||||
|
assert some_permutation.schema == pa.schema(
|
||||||
|
[("id", pa.int64()), ("value", pa.int64())]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_limit_offset(some_permutation: Permutation):
|
||||||
|
assert some_permutation.with_take(100).num_rows == 100
|
||||||
|
assert some_permutation.with_skip(100).num_rows == 850
|
||||||
|
assert some_permutation.with_take(100).with_skip(100).num_rows == 100
|
||||||
|
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
some_permutation.with_take(1000000).num_rows
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
some_permutation.with_skip(1000000).num_rows
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
some_permutation.with_take(500).with_skip(500).num_rows
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
some_permutation.with_skip(500).with_take(500).num_rows
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_columns(some_permutation: Permutation):
|
||||||
|
assert some_permutation.remove_columns(["value"]).schema == pa.schema(
|
||||||
|
[("id", pa.int64())]
|
||||||
|
)
|
||||||
|
# Should not modify the original permutation
|
||||||
|
assert some_permutation.schema.names == ["id", "value"]
|
||||||
|
# Cannot remove all columns
|
||||||
|
with pytest.raises(ValueError, match="Cannot remove all columns"):
|
||||||
|
some_permutation.remove_columns(["id", "value"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_rename_column(some_permutation: Permutation):
|
||||||
|
assert some_permutation.rename_column("value", "new_value").schema == pa.schema(
|
||||||
|
[("id", pa.int64()), ("new_value", pa.int64())]
|
||||||
|
)
|
||||||
|
# Should not modify the original permutation
|
||||||
|
assert some_permutation.schema.names == ["id", "value"]
|
||||||
|
# Cannot rename to an existing column
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="a column with that name already exists",
|
||||||
|
):
|
||||||
|
some_permutation.rename_column("value", "id")
|
||||||
|
# Cannot rename a non-existent column
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="does not exist",
|
||||||
|
):
|
||||||
|
some_permutation.rename_column("non_existent", "new_value")
|
||||||
|
|
||||||
|
|
||||||
|
def test_rename_columns(some_permutation: Permutation):
|
||||||
|
assert some_permutation.rename_columns({"value": "new_value"}).schema == pa.schema(
|
||||||
|
[("id", pa.int64()), ("new_value", pa.int64())]
|
||||||
|
)
|
||||||
|
# Should not modify the original permutation
|
||||||
|
assert some_permutation.schema.names == ["id", "value"]
|
||||||
|
# Cannot rename to an existing column
|
||||||
|
with pytest.raises(ValueError, match="a column with that name already exists"):
|
||||||
|
some_permutation.rename_columns({"value": "id"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_select_columns(some_permutation: Permutation):
|
||||||
|
assert some_permutation.select_columns(["id"]).schema == pa.schema(
|
||||||
|
[("id", pa.int64())]
|
||||||
|
)
|
||||||
|
# Should not modify the original permutation
|
||||||
|
assert some_permutation.schema.names == ["id", "value"]
|
||||||
|
# Cannot select a non-existent column
|
||||||
|
with pytest.raises(ValueError, match="does not exist"):
|
||||||
|
some_permutation.select_columns(["non_existent"])
|
||||||
|
# Empty selection is not allowed
|
||||||
|
with pytest.raises(ValueError, match="select at least one column"):
|
||||||
|
some_permutation.select_columns([])
|
||||||
|
|
||||||
|
|
||||||
|
def test_iter_basic(some_permutation: Permutation):
|
||||||
|
"""Test basic iteration with custom batch size."""
|
||||||
|
batch_size = 100
|
||||||
|
batches = list(some_permutation.iter(batch_size, skip_last_batch=False))
|
||||||
|
|
||||||
|
# Check that we got the expected number of batches
|
||||||
|
expected_batches = (950 + batch_size - 1) // batch_size # ceiling division
|
||||||
|
assert len(batches) == expected_batches
|
||||||
|
|
||||||
|
# Check that all batches are dicts (default python format)
|
||||||
|
assert all(isinstance(batch, dict) for batch in batches)
|
||||||
|
|
||||||
|
# Check that batches have the correct structure
|
||||||
|
for batch in batches:
|
||||||
|
assert "id" in batch
|
||||||
|
assert "value" in batch
|
||||||
|
assert isinstance(batch["id"], list)
|
||||||
|
assert isinstance(batch["value"], list)
|
||||||
|
|
||||||
|
# Check that all batches except the last have the correct size
|
||||||
|
for batch in batches[:-1]:
|
||||||
|
assert len(batch["id"]) == batch_size
|
||||||
|
assert len(batch["value"]) == batch_size
|
||||||
|
|
||||||
|
# Last batch might be smaller
|
||||||
|
assert len(batches[-1]["id"]) <= batch_size
|
||||||
|
|
||||||
|
|
||||||
|
def test_iter_skip_last_batch(some_permutation: Permutation):
|
||||||
|
"""Test iteration with skip_last_batch=True."""
|
||||||
|
batch_size = 300
|
||||||
|
batches_with_skip = list(some_permutation.iter(batch_size, skip_last_batch=True))
|
||||||
|
batches_without_skip = list(
|
||||||
|
some_permutation.iter(batch_size, skip_last_batch=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# With skip_last_batch=True, we should have fewer batches if the last one is partial
|
||||||
|
num_full_batches = 950 // batch_size
|
||||||
|
assert len(batches_with_skip) == num_full_batches
|
||||||
|
|
||||||
|
# Without skip_last_batch, we should have one more batch if there's a remainder
|
||||||
|
if 950 % batch_size != 0:
|
||||||
|
assert len(batches_without_skip) == num_full_batches + 1
|
||||||
|
# Last batch should be smaller
|
||||||
|
assert len(batches_without_skip[-1]["id"]) == 950 % batch_size
|
||||||
|
|
||||||
|
# All batches with skip_last_batch should be full size
|
||||||
|
for batch in batches_with_skip:
|
||||||
|
assert len(batch["id"]) == batch_size
|
||||||
|
|
||||||
|
|
||||||
|
def test_iter_different_batch_sizes(some_permutation: Permutation):
|
||||||
|
"""Test iteration with different batch sizes."""
|
||||||
|
|
||||||
|
# Test with small batch size
|
||||||
|
small_batches = list(some_permutation.iter(100, skip_last_batch=False))
|
||||||
|
assert len(small_batches) == 10 # ceiling(950 / 100)
|
||||||
|
|
||||||
|
# Test with large batch size
|
||||||
|
large_batches = list(some_permutation.iter(400, skip_last_batch=False))
|
||||||
|
assert len(large_batches) == 3 # ceiling(950 / 400)
|
||||||
|
|
||||||
|
# Test with batch size equal to total rows
|
||||||
|
single_batch = list(some_permutation.iter(950, skip_last_batch=False))
|
||||||
|
assert len(single_batch) == 1
|
||||||
|
assert len(single_batch[0]["id"]) == 950
|
||||||
|
|
||||||
|
# Test with batch size larger than total rows
|
||||||
|
oversized_batch = list(some_permutation.iter(10000, skip_last_batch=False))
|
||||||
|
assert len(oversized_batch) == 1
|
||||||
|
assert len(oversized_batch[0]["id"]) == 950
|
||||||
|
|
||||||
|
|
||||||
|
def test_dunder_iter(some_permutation: Permutation):
|
||||||
|
"""Test the __iter__ method."""
|
||||||
|
# __iter__ should use DEFAULT_BATCH_SIZE (100) and skip_last_batch=True
|
||||||
|
batches = list(some_permutation)
|
||||||
|
|
||||||
|
# With DEFAULT_BATCH_SIZE=100 and skip_last_batch=True, we should get 9 batches
|
||||||
|
assert len(batches) == 9 # ceiling(950 / 100)
|
||||||
|
|
||||||
|
# All batches should be full size
|
||||||
|
for batch in batches:
|
||||||
|
assert len(batch["id"]) == 100
|
||||||
|
assert len(batch["value"]) == 100
|
||||||
|
|
||||||
|
some_permutation = some_permutation.with_batch_size(400)
|
||||||
|
batches = list(some_permutation)
|
||||||
|
assert len(batches) == 2 # floor(950 / 400) since skip_last_batch=True
|
||||||
|
for batch in batches:
|
||||||
|
assert len(batch["id"]) == 400
|
||||||
|
assert len(batch["value"]) == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_iter_with_different_formats(some_permutation: Permutation):
|
||||||
|
"""Test iteration with different output formats."""
|
||||||
|
batch_size = 100
|
||||||
|
|
||||||
|
# Test with arrow format
|
||||||
|
arrow_perm = some_permutation.with_format("arrow")
|
||||||
|
arrow_batches = list(arrow_perm.iter(batch_size, skip_last_batch=False))
|
||||||
|
assert all(isinstance(batch, pa.RecordBatch) for batch in arrow_batches)
|
||||||
|
|
||||||
|
# Test with python format (default)
|
||||||
|
python_perm = some_permutation.with_format("python")
|
||||||
|
python_batches = list(python_perm.iter(batch_size, skip_last_batch=False))
|
||||||
|
assert all(isinstance(batch, dict) for batch in python_batches)
|
||||||
|
|
||||||
|
# Test with pandas format
|
||||||
|
pandas_perm = some_permutation.with_format("pandas")
|
||||||
|
pandas_batches = list(pandas_perm.iter(batch_size, skip_last_batch=False))
|
||||||
|
# Import pandas to check the type
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
assert all(isinstance(batch, pd.DataFrame) for batch in pandas_batches)
|
||||||
|
|
||||||
|
|
||||||
|
def test_iter_with_column_selection(some_permutation: Permutation):
|
||||||
|
"""Test iteration after column selection."""
|
||||||
|
# Select only the id column
|
||||||
|
id_only = some_permutation.select_columns(["id"])
|
||||||
|
batches = list(id_only.iter(100, skip_last_batch=False))
|
||||||
|
|
||||||
|
# Check that batches only contain the id column
|
||||||
|
for batch in batches:
|
||||||
|
assert "id" in batch
|
||||||
|
assert "value" not in batch
|
||||||
|
|
||||||
|
|
||||||
|
def test_iter_with_column_rename(some_permutation: Permutation):
|
||||||
|
"""Test iteration after renaming columns."""
|
||||||
|
renamed = some_permutation.rename_column("value", "data")
|
||||||
|
batches = list(renamed.iter(100, skip_last_batch=False))
|
||||||
|
|
||||||
|
# Check that batches have the renamed column
|
||||||
|
for batch in batches:
|
||||||
|
assert "id" in batch
|
||||||
|
assert "data" in batch
|
||||||
|
assert "value" not in batch
|
||||||
|
|
||||||
|
|
||||||
|
def test_iter_with_limit_offset(some_permutation: Permutation):
|
||||||
|
"""Test iteration with limit and offset."""
|
||||||
|
# Test with offset
|
||||||
|
offset_perm = some_permutation.with_skip(100)
|
||||||
|
offset_batches = list(offset_perm.iter(100, skip_last_batch=False))
|
||||||
|
# Should have 850 rows (950 - 100)
|
||||||
|
expected_batches = math.ceil(850 / 100)
|
||||||
|
assert len(offset_batches) == expected_batches
|
||||||
|
|
||||||
|
# Test with limit
|
||||||
|
limit_perm = some_permutation.with_take(500)
|
||||||
|
limit_batches = list(limit_perm.iter(100, skip_last_batch=False))
|
||||||
|
# Should have 5 batches (500 / 100)
|
||||||
|
assert len(limit_batches) == 5
|
||||||
|
|
||||||
|
no_skip = some_permutation.iter(101, skip_last_batch=False)
|
||||||
|
row_100 = next(no_skip)["id"][100]
|
||||||
|
|
||||||
|
# Test with both limit and offset
|
||||||
|
limited_perm = some_permutation.with_skip(100).with_take(300)
|
||||||
|
limited_batches = list(limited_perm.iter(100, skip_last_batch=False))
|
||||||
|
# Should have 3 batches (300 / 100)
|
||||||
|
assert len(limited_batches) == 3
|
||||||
|
assert limited_batches[0]["id"][0] == row_100
|
||||||
|
|
||||||
|
|
||||||
|
def test_iter_empty_permutation(mem_db):
|
||||||
|
"""Test iteration over an empty permutation."""
|
||||||
|
# Create a table and filter it to be empty
|
||||||
|
tbl = mem_db.create_table(
|
||||||
|
"test_table", pa.table({"id": range(10), "value": range(10)})
|
||||||
|
)
|
||||||
|
permutation_tbl = permutation_builder(tbl).filter("value > 100").execute()
|
||||||
|
with pytest.raises(ValueError, match="No rows found"):
|
||||||
|
Permutation.from_tables(tbl, permutation_tbl)
|
||||||
|
|
||||||
|
|
||||||
|
def test_iter_single_row(mem_db):
|
||||||
|
"""Test iteration over a permutation with a single row."""
|
||||||
|
tbl = mem_db.create_table("test_table", pa.table({"id": [42], "value": [100]}))
|
||||||
|
permutation_tbl = permutation_builder(tbl).execute()
|
||||||
|
perm = Permutation.from_tables(tbl, permutation_tbl)
|
||||||
|
|
||||||
|
# With skip_last_batch=False, should get one batch
|
||||||
|
batches = list(perm.iter(10, skip_last_batch=False))
|
||||||
|
assert len(batches) == 1
|
||||||
|
assert len(batches[0]["id"]) == 1
|
||||||
|
|
||||||
|
# With skip_last_batch=True, should skip the single row (since it's < batch_size)
|
||||||
|
batches_skip = list(perm.iter(10, skip_last_batch=True))
|
||||||
|
assert len(batches_skip) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_identity_permutation(mem_db):
|
||||||
|
tbl = mem_db.create_table(
|
||||||
|
"test_table", pa.table({"id": range(10), "value": range(10)})
|
||||||
|
)
|
||||||
|
permutation = Permutation.identity(tbl)
|
||||||
|
|
||||||
|
assert permutation.num_rows == 10
|
||||||
|
assert permutation.num_columns == 2
|
||||||
|
|
||||||
|
batches = list(permutation.iter(10, skip_last_batch=False))
|
||||||
|
assert len(batches) == 1
|
||||||
|
assert len(batches[0]["id"]) == 10
|
||||||
|
assert len(batches[0]["value"]) == 10
|
||||||
|
|
||||||
|
permutation = permutation.remove_columns(["value"])
|
||||||
|
assert permutation.num_columns == 1
|
||||||
|
assert permutation.schema == pa.schema([("id", pa.int64())])
|
||||||
|
assert permutation.column_names == ["id"]
|
||||||
|
assert permutation.shape == (10, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_transform_fn(mem_db):
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
tbl = mem_db.create_table(
|
||||||
|
"test_table", pa.table({"id": range(10), "value": range(10)})
|
||||||
|
)
|
||||||
|
permutation = Permutation.identity(tbl)
|
||||||
|
|
||||||
|
np_result = list(permutation.with_format("numpy").iter(10, skip_last_batch=False))[
|
||||||
|
0
|
||||||
|
]
|
||||||
|
assert np_result.shape == (10, 2)
|
||||||
|
assert np_result.dtype == np.int64
|
||||||
|
assert isinstance(np_result, np.ndarray)
|
||||||
|
|
||||||
|
pd_result = list(permutation.with_format("pandas").iter(10, skip_last_batch=False))[
|
||||||
|
0
|
||||||
|
]
|
||||||
|
assert pd_result.shape == (10, 2)
|
||||||
|
assert pd_result.dtypes.tolist() == [np.int64, np.int64]
|
||||||
|
assert isinstance(pd_result, pd.DataFrame)
|
||||||
|
|
||||||
|
pl_result = list(permutation.with_format("polars").iter(10, skip_last_batch=False))[
|
||||||
|
0
|
||||||
|
]
|
||||||
|
assert pl_result.shape == (10, 2)
|
||||||
|
assert pl_result.dtypes == [pl.Int64, pl.Int64]
|
||||||
|
assert isinstance(pl_result, pl.DataFrame)
|
||||||
|
|
||||||
|
py_result = list(permutation.with_format("python").iter(10, skip_last_batch=False))[
|
||||||
|
0
|
||||||
|
]
|
||||||
|
assert len(py_result) == 2
|
||||||
|
assert len(py_result["id"]) == 10
|
||||||
|
assert len(py_result["value"]) == 10
|
||||||
|
assert isinstance(py_result, dict)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
torch_result = list(
|
||||||
|
permutation.with_format("torch").iter(10, skip_last_batch=False)
|
||||||
|
)[0]
|
||||||
|
assert torch_result.shape == (2, 10)
|
||||||
|
assert torch_result.dtype == torch.int64
|
||||||
|
assert isinstance(torch_result, torch.Tensor)
|
||||||
|
except ImportError:
|
||||||
|
# Skip check if torch is not installed
|
||||||
|
pass
|
||||||
|
|
||||||
|
arrow_result = list(
|
||||||
|
permutation.with_format("arrow").iter(10, skip_last_batch=False)
|
||||||
|
)[0]
|
||||||
|
assert arrow_result.shape == (10, 2)
|
||||||
|
assert arrow_result.schema == pa.schema([("id", pa.int64()), ("value", pa.int64())])
|
||||||
|
assert isinstance(arrow_result, pa.RecordBatch)
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_transform(mem_db):
|
||||||
|
tbl = mem_db.create_table(
|
||||||
|
"test_table", pa.table({"id": range(10), "value": range(10)})
|
||||||
|
)
|
||||||
|
permutation = Permutation.identity(tbl)
|
||||||
|
|
||||||
|
def transform(batch: pa.RecordBatch) -> pa.RecordBatch:
|
||||||
|
return batch.select(["id"])
|
||||||
|
|
||||||
|
transformed = permutation.with_transform(transform)
|
||||||
|
batches = list(transformed.iter(10, skip_last_batch=False))
|
||||||
|
assert len(batches) == 1
|
||||||
|
batch = batches[0]
|
||||||
|
|
||||||
|
assert batch == pa.record_batch([range(10)], ["id"])
|
||||||
@@ -1298,6 +1298,79 @@ async def test_query_serialization_async(table_async: AsyncTable):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_schema(tmp_path):
|
||||||
|
db = lancedb.connect(tmp_path)
|
||||||
|
tbl = db.create_table(
|
||||||
|
"test",
|
||||||
|
pa.table(
|
||||||
|
{
|
||||||
|
"a": [1, 2, 3],
|
||||||
|
"text": ["a", "b", "c"],
|
||||||
|
"vec": pa.array(
|
||||||
|
[[1, 2], [3, 4], [5, 6]], pa.list_(pa.float32(), list_size=2)
|
||||||
|
),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tbl.search(None).output_schema() == pa.schema(
|
||||||
|
{
|
||||||
|
"a": pa.int64(),
|
||||||
|
"text": pa.string(),
|
||||||
|
"vec": pa.list_(pa.float32(), list_size=2),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert tbl.search(None).select({"bl": "a * 2"}).output_schema() == pa.schema(
|
||||||
|
{"bl": pa.int64()}
|
||||||
|
)
|
||||||
|
assert tbl.search([1, 2]).select(["a"]).output_schema() == pa.schema(
|
||||||
|
{"a": pa.int64(), "_distance": pa.float32()}
|
||||||
|
)
|
||||||
|
assert tbl.search("blah").select(["a"]).output_schema() == pa.schema(
|
||||||
|
{"a": pa.int64()}
|
||||||
|
)
|
||||||
|
assert tbl.take_offsets([0]).select(["text"]).output_schema() == pa.schema(
|
||||||
|
{"text": pa.string()}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_schema_async(tmp_path):
|
||||||
|
db = await lancedb.connect_async(tmp_path)
|
||||||
|
tbl = await db.create_table(
|
||||||
|
"test",
|
||||||
|
pa.table(
|
||||||
|
{
|
||||||
|
"a": [1, 2, 3],
|
||||||
|
"text": ["a", "b", "c"],
|
||||||
|
"vec": pa.array(
|
||||||
|
[[1, 2], [3, 4], [5, 6]], pa.list_(pa.float32(), list_size=2)
|
||||||
|
),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert await tbl.query().output_schema() == pa.schema(
|
||||||
|
{
|
||||||
|
"a": pa.int64(),
|
||||||
|
"text": pa.string(),
|
||||||
|
"vec": pa.list_(pa.float32(), list_size=2),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert await tbl.query().select({"bl": "a * 2"}).output_schema() == pa.schema(
|
||||||
|
{"bl": pa.int64()}
|
||||||
|
)
|
||||||
|
assert await tbl.vector_search([1, 2]).select(["a"]).output_schema() == pa.schema(
|
||||||
|
{"a": pa.int64(), "_distance": pa.float32()}
|
||||||
|
)
|
||||||
|
assert await (await tbl.search("blah")).select(["a"]).output_schema() == pa.schema(
|
||||||
|
{"a": pa.int64()}
|
||||||
|
)
|
||||||
|
assert await tbl.take_offsets([0]).select(["text"]).output_schema() == pa.schema(
|
||||||
|
{"text": pa.string()}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_query_timeout(tmp_path):
|
def test_query_timeout(tmp_path):
|
||||||
# Use local directory instead of memory:// to add a bit of latency to
|
# Use local directory instead of memory:// to add a bit of latency to
|
||||||
# operations so a timeout of zero will trigger exceptions.
|
# operations so a timeout of zero will trigger exceptions.
|
||||||
|
|||||||
@@ -484,7 +484,7 @@ def test_jina_reranker(tmp_path, use_tantivy):
|
|||||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||||
def test_voyageai_reranker(tmp_path, use_tantivy):
|
def test_voyageai_reranker(tmp_path, use_tantivy):
|
||||||
pytest.importorskip("voyageai")
|
pytest.importorskip("voyageai")
|
||||||
reranker = VoyageAIReranker(model_name="rerank-2")
|
reranker = VoyageAIReranker(model_name="rerank-2.5")
|
||||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||||
|
|
||||||
|
|||||||
@@ -3,19 +3,11 @@
|
|||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import pytest
|
import pytest
|
||||||
|
from lancedb.util import tbl_to_tensor
|
||||||
|
|
||||||
torch = pytest.importorskip("torch")
|
torch = pytest.importorskip("torch")
|
||||||
|
|
||||||
|
|
||||||
def tbl_to_tensor(tbl):
|
|
||||||
def to_tensor(col: pa.ChunkedArray):
|
|
||||||
if col.num_chunks > 1:
|
|
||||||
raise Exception("Single batch was too large to fit into a one-chunk table")
|
|
||||||
return torch.from_dlpack(col.chunk(0))
|
|
||||||
|
|
||||||
return torch.stack([to_tensor(tbl.column(i)) for i in range(tbl.num_columns)])
|
|
||||||
|
|
||||||
|
|
||||||
def test_table_dataloader(mem_db):
|
def test_table_dataloader(mem_db):
|
||||||
table = mem_db.create_table("test_table", pa.table({"a": range(1000)}))
|
table = mem_db.create_table("test_table", pa.table({"a": range(1000)}))
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
|||||||
@@ -4,7 +4,10 @@
|
|||||||
use std::{collections::HashMap, sync::Arc, time::Duration};
|
use std::{collections::HashMap, sync::Arc, time::Duration};
|
||||||
|
|
||||||
use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow};
|
use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow};
|
||||||
use lancedb::{connection::Connection as LanceConnection, database::CreateTableMode};
|
use lancedb::{
|
||||||
|
connection::Connection as LanceConnection,
|
||||||
|
database::{CreateTableMode, Database, ReadConsistency},
|
||||||
|
};
|
||||||
use pyo3::{
|
use pyo3::{
|
||||||
exceptions::{PyRuntimeError, PyValueError},
|
exceptions::{PyRuntimeError, PyValueError},
|
||||||
pyclass, pyfunction, pymethods, Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
|
pyclass, pyfunction, pymethods, Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
|
||||||
@@ -23,7 +26,7 @@ impl Connection {
|
|||||||
Self { inner: Some(inner) }
|
Self { inner: Some(inner) }
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_inner(&self) -> PyResult<&LanceConnection> {
|
pub(crate) fn get_inner(&self) -> PyResult<&LanceConnection> {
|
||||||
self.inner
|
self.inner
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.ok_or_else(|| PyRuntimeError::new_err("Connection is closed"))
|
.ok_or_else(|| PyRuntimeError::new_err("Connection is closed"))
|
||||||
@@ -39,6 +42,10 @@ impl Connection {
|
|||||||
_ => Err(PyValueError::new_err(format!("Invalid mode {}", mode))),
|
_ => Err(PyValueError::new_err(format!("Invalid mode {}", mode))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn database(&self) -> PyResult<Arc<dyn Database>> {
|
||||||
|
Ok(self.get_inner()?.database().clone())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
@@ -63,6 +70,18 @@ impl Connection {
|
|||||||
self.get_inner().map(|inner| inner.uri().to_string())
|
self.get_inner().map(|inner| inner.uri().to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[pyo3(signature = ())]
|
||||||
|
pub fn get_read_consistency_interval(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||||
|
let inner = self_.get_inner()?.clone();
|
||||||
|
future_into_py(self_.py(), async move {
|
||||||
|
Ok(match inner.read_consistency().await.infer_error()? {
|
||||||
|
ReadConsistency::Manual => None,
|
||||||
|
ReadConsistency::Eventual(duration) => Some(duration.as_secs_f64()),
|
||||||
|
ReadConsistency::Strong => Some(0.0_f64),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
#[pyo3(signature = (namespace=vec![], start_after=None, limit=None))]
|
#[pyo3(signature = (namespace=vec![], start_after=None, limit=None))]
|
||||||
pub fn table_names(
|
pub fn table_names(
|
||||||
self_: PyRef<'_, Self>,
|
self_: PyRef<'_, Self>,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
use lancedb::index::vector::IvfFlatIndexBuilder;
|
use lancedb::index::vector::{IvfFlatIndexBuilder, IvfRqIndexBuilder};
|
||||||
use lancedb::index::{
|
use lancedb::index::{
|
||||||
scalar::{BTreeIndexBuilder, FtsIndexBuilder},
|
scalar::{BTreeIndexBuilder, FtsIndexBuilder},
|
||||||
vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
|
vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
|
||||||
@@ -87,6 +87,22 @@ pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<Lance
|
|||||||
}
|
}
|
||||||
Ok(LanceDbIndex::IvfPq(ivf_pq_builder))
|
Ok(LanceDbIndex::IvfPq(ivf_pq_builder))
|
||||||
},
|
},
|
||||||
|
"IvfRq" => {
|
||||||
|
let params = source.extract::<IvfRqParams>()?;
|
||||||
|
let distance_type = parse_distance_type(params.distance_type)?;
|
||||||
|
let mut ivf_rq_builder = IvfRqIndexBuilder::default()
|
||||||
|
.distance_type(distance_type)
|
||||||
|
.max_iterations(params.max_iterations)
|
||||||
|
.sample_rate(params.sample_rate)
|
||||||
|
.num_bits(params.num_bits);
|
||||||
|
if let Some(num_partitions) = params.num_partitions {
|
||||||
|
ivf_rq_builder = ivf_rq_builder.num_partitions(num_partitions);
|
||||||
|
}
|
||||||
|
if let Some(target_partition_size) = params.target_partition_size {
|
||||||
|
ivf_rq_builder = ivf_rq_builder.target_partition_size(target_partition_size);
|
||||||
|
}
|
||||||
|
Ok(LanceDbIndex::IvfRq(ivf_rq_builder))
|
||||||
|
},
|
||||||
"HnswPq" => {
|
"HnswPq" => {
|
||||||
let params = source.extract::<IvfHnswPqParams>()?;
|
let params = source.extract::<IvfHnswPqParams>()?;
|
||||||
let distance_type = parse_distance_type(params.distance_type)?;
|
let distance_type = parse_distance_type(params.distance_type)?;
|
||||||
@@ -170,6 +186,16 @@ struct IvfPqParams {
|
|||||||
target_partition_size: Option<u32>,
|
target_partition_size: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(FromPyObject)]
|
||||||
|
struct IvfRqParams {
|
||||||
|
distance_type: String,
|
||||||
|
num_partitions: Option<u32>,
|
||||||
|
num_bits: u32,
|
||||||
|
max_iterations: u32,
|
||||||
|
sample_rate: u32,
|
||||||
|
target_partition_size: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(FromPyObject)]
|
#[derive(FromPyObject)]
|
||||||
struct IvfHnswPqParams {
|
struct IvfHnswPqParams {
|
||||||
distance_type: String,
|
distance_type: String,
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ use arrow::RecordBatchStream;
|
|||||||
use connection::{connect, Connection};
|
use connection::{connect, Connection};
|
||||||
use env_logger::Env;
|
use env_logger::Env;
|
||||||
use index::IndexConfig;
|
use index::IndexConfig;
|
||||||
|
use permutation::{PyAsyncPermutationBuilder, PyPermutationReader};
|
||||||
use pyo3::{
|
use pyo3::{
|
||||||
pymodule,
|
pymodule,
|
||||||
types::{PyModule, PyModuleMethods},
|
types::{PyModule, PyModuleMethods},
|
||||||
@@ -22,6 +23,7 @@ pub mod connection;
|
|||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod header;
|
pub mod header;
|
||||||
pub mod index;
|
pub mod index;
|
||||||
|
pub mod permutation;
|
||||||
pub mod query;
|
pub mod query;
|
||||||
pub mod session;
|
pub mod session;
|
||||||
pub mod table;
|
pub mod table;
|
||||||
@@ -49,8 +51,12 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
|||||||
m.add_class::<DeleteResult>()?;
|
m.add_class::<DeleteResult>()?;
|
||||||
m.add_class::<DropColumnsResult>()?;
|
m.add_class::<DropColumnsResult>()?;
|
||||||
m.add_class::<UpdateResult>()?;
|
m.add_class::<UpdateResult>()?;
|
||||||
|
m.add_class::<PyAsyncPermutationBuilder>()?;
|
||||||
|
m.add_class::<PyPermutationReader>()?;
|
||||||
m.add_function(wrap_pyfunction!(connect, m)?)?;
|
m.add_function(wrap_pyfunction!(connect, m)?)?;
|
||||||
|
m.add_function(wrap_pyfunction!(permutation::async_permutation_builder, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;
|
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;
|
||||||
|
m.add_function(wrap_pyfunction!(query::fts_query_to_json, m)?)?;
|
||||||
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
331
python/src/permutation.rs
Normal file
331
python/src/permutation.rs
Normal file
@@ -0,0 +1,331 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
arrow::RecordBatchStream, connection::Connection, error::PythonErrorExt, table::Table,
|
||||||
|
};
|
||||||
|
use arrow::pyarrow::ToPyArrow;
|
||||||
|
use lancedb::{
|
||||||
|
dataloader::permutation::{
|
||||||
|
builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
|
||||||
|
reader::PermutationReader,
|
||||||
|
split::{SplitSizes, SplitStrategy},
|
||||||
|
},
|
||||||
|
query::Select,
|
||||||
|
};
|
||||||
|
use pyo3::{
|
||||||
|
exceptions::PyRuntimeError,
|
||||||
|
pyclass, pymethods,
|
||||||
|
types::{PyAnyMethods, PyDict, PyDictMethods, PyType},
|
||||||
|
Bound, PyAny, PyRef, PyRefMut, PyResult, Python,
|
||||||
|
};
|
||||||
|
use pyo3_async_runtimes::tokio::future_into_py;
|
||||||
|
|
||||||
|
/// Create a permutation builder for the given table
|
||||||
|
#[pyo3::pyfunction]
|
||||||
|
pub fn async_permutation_builder(table: Bound<'_, PyAny>) -> PyResult<PyAsyncPermutationBuilder> {
|
||||||
|
let table = table.getattr("_inner")?.downcast_into::<Table>()?;
|
||||||
|
let inner_table = table.borrow().inner_ref()?.clone();
|
||||||
|
let inner_builder = LancePermutationBuilder::new(inner_table);
|
||||||
|
|
||||||
|
Ok(PyAsyncPermutationBuilder {
|
||||||
|
state: Arc::new(Mutex::new(PyAsyncPermutationBuilderState {
|
||||||
|
builder: Some(inner_builder),
|
||||||
|
})),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
struct PyAsyncPermutationBuilderState {
|
||||||
|
builder: Option<LancePermutationBuilder>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyclass(name = "AsyncPermutationBuilder")]
|
||||||
|
pub struct PyAsyncPermutationBuilder {
|
||||||
|
state: Arc<Mutex<PyAsyncPermutationBuilderState>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PyAsyncPermutationBuilder {
|
||||||
|
fn modify(
|
||||||
|
&self,
|
||||||
|
func: impl FnOnce(LancePermutationBuilder) -> LancePermutationBuilder,
|
||||||
|
) -> PyResult<Self> {
|
||||||
|
let mut state = self.state.lock().unwrap();
|
||||||
|
let builder = state
|
||||||
|
.builder
|
||||||
|
.take()
|
||||||
|
.ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?;
|
||||||
|
state.builder = Some(func(builder));
|
||||||
|
Ok(Self {
|
||||||
|
state: self.state.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pymethods]
|
||||||
|
impl PyAsyncPermutationBuilder {
|
||||||
|
#[pyo3(signature = (database, table_name))]
|
||||||
|
pub fn persist(
|
||||||
|
slf: PyRefMut<'_, Self>,
|
||||||
|
database: Bound<'_, PyAny>,
|
||||||
|
table_name: String,
|
||||||
|
) -> PyResult<Self> {
|
||||||
|
let conn = if database.hasattr("_conn")? {
|
||||||
|
database
|
||||||
|
.getattr("_conn")?
|
||||||
|
.getattr("_inner")?
|
||||||
|
.downcast_into::<Connection>()?
|
||||||
|
} else {
|
||||||
|
database.getattr("_inner")?.downcast_into::<Connection>()?
|
||||||
|
};
|
||||||
|
let database = conn.borrow().database()?;
|
||||||
|
slf.modify(|builder| builder.persist(database, table_name))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None, seed=None, split_names=None))]
|
||||||
|
pub fn split_random(
|
||||||
|
slf: PyRefMut<'_, Self>,
|
||||||
|
ratios: Option<Vec<f64>>,
|
||||||
|
counts: Option<Vec<u64>>,
|
||||||
|
fixed: Option<u64>,
|
||||||
|
seed: Option<u64>,
|
||||||
|
split_names: Option<Vec<String>>,
|
||||||
|
) -> PyResult<Self> {
|
||||||
|
// Check that exactly one split type is provided
|
||||||
|
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
|
||||||
|
.iter()
|
||||||
|
.filter(|&&x| x)
|
||||||
|
.count();
|
||||||
|
|
||||||
|
if split_args_count != 1 {
|
||||||
|
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||||
|
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let sizes = if let Some(ratios) = ratios {
|
||||||
|
SplitSizes::Percentages(ratios)
|
||||||
|
} else if let Some(counts) = counts {
|
||||||
|
SplitSizes::Counts(counts)
|
||||||
|
} else if let Some(fixed) = fixed {
|
||||||
|
SplitSizes::Fixed(fixed)
|
||||||
|
} else {
|
||||||
|
unreachable!("One of the split arguments must be provided");
|
||||||
|
};
|
||||||
|
|
||||||
|
slf.modify(|builder| {
|
||||||
|
builder.with_split_strategy(SplitStrategy::Random { seed, sizes }, split_names)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyo3(signature = (columns, split_weights, *, discard_weight=0, split_names=None))]
|
||||||
|
pub fn split_hash(
|
||||||
|
slf: PyRefMut<'_, Self>,
|
||||||
|
columns: Vec<String>,
|
||||||
|
split_weights: Vec<u64>,
|
||||||
|
discard_weight: u64,
|
||||||
|
split_names: Option<Vec<String>>,
|
||||||
|
) -> PyResult<Self> {
|
||||||
|
slf.modify(|builder| {
|
||||||
|
builder.with_split_strategy(
|
||||||
|
SplitStrategy::Hash {
|
||||||
|
columns,
|
||||||
|
split_weights,
|
||||||
|
discard_weight,
|
||||||
|
},
|
||||||
|
split_names,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None, split_names=None))]
|
||||||
|
pub fn split_sequential(
|
||||||
|
slf: PyRefMut<'_, Self>,
|
||||||
|
ratios: Option<Vec<f64>>,
|
||||||
|
counts: Option<Vec<u64>>,
|
||||||
|
fixed: Option<u64>,
|
||||||
|
split_names: Option<Vec<String>>,
|
||||||
|
) -> PyResult<Self> {
|
||||||
|
// Check that exactly one split type is provided
|
||||||
|
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
|
||||||
|
.iter()
|
||||||
|
.filter(|&&x| x)
|
||||||
|
.count();
|
||||||
|
|
||||||
|
if split_args_count != 1 {
|
||||||
|
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||||
|
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let sizes = if let Some(ratios) = ratios {
|
||||||
|
SplitSizes::Percentages(ratios)
|
||||||
|
} else if let Some(counts) = counts {
|
||||||
|
SplitSizes::Counts(counts)
|
||||||
|
} else if let Some(fixed) = fixed {
|
||||||
|
SplitSizes::Fixed(fixed)
|
||||||
|
} else {
|
||||||
|
unreachable!("One of the split arguments must be provided");
|
||||||
|
};
|
||||||
|
|
||||||
|
slf.modify(|builder| {
|
||||||
|
builder.with_split_strategy(SplitStrategy::Sequential { sizes }, split_names)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn split_calculated(
|
||||||
|
slf: PyRefMut<'_, Self>,
|
||||||
|
calculation: String,
|
||||||
|
split_names: Option<Vec<String>>,
|
||||||
|
) -> PyResult<Self> {
|
||||||
|
slf.modify(|builder| {
|
||||||
|
builder.with_split_strategy(SplitStrategy::Calculated { calculation }, split_names)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn shuffle(
|
||||||
|
slf: PyRefMut<'_, Self>,
|
||||||
|
seed: Option<u64>,
|
||||||
|
clump_size: Option<u64>,
|
||||||
|
) -> PyResult<Self> {
|
||||||
|
slf.modify(|builder| {
|
||||||
|
builder.with_shuffle_strategy(ShuffleStrategy::Random { seed, clump_size })
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn filter(slf: PyRefMut<'_, Self>, filter: String) -> PyResult<Self> {
|
||||||
|
slf.modify(|builder| builder.with_filter(filter))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn execute(slf: PyRefMut<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||||
|
let mut state = slf.state.lock().unwrap();
|
||||||
|
let builder = state
|
||||||
|
.builder
|
||||||
|
.take()
|
||||||
|
.ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?;
|
||||||
|
|
||||||
|
future_into_py(slf.py(), async move {
|
||||||
|
let table = builder.build().await.infer_error()?;
|
||||||
|
Ok(Table::new(table))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyclass(name = "PermutationReader")]
|
||||||
|
pub struct PyPermutationReader {
|
||||||
|
reader: Arc<PermutationReader>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PyPermutationReader {
|
||||||
|
fn from_reader(reader: PermutationReader) -> Self {
|
||||||
|
Self {
|
||||||
|
reader: Arc::new(reader),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_selection(selection: Option<Bound<'_, PyAny>>) -> PyResult<Select> {
|
||||||
|
let Some(selection) = selection else {
|
||||||
|
return Ok(Select::All);
|
||||||
|
};
|
||||||
|
let selection = selection.downcast_into::<PyDict>()?;
|
||||||
|
let selection = selection
|
||||||
|
.iter()
|
||||||
|
.map(|(key, value)| {
|
||||||
|
let key = key.extract::<String>()?;
|
||||||
|
let value = value.extract::<String>()?;
|
||||||
|
Ok((key, value))
|
||||||
|
})
|
||||||
|
.collect::<PyResult<Vec<_>>>()?;
|
||||||
|
Ok(Select::dynamic(&selection))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pymethods]
|
||||||
|
impl PyPermutationReader {
|
||||||
|
#[classmethod]
|
||||||
|
pub fn from_tables<'py>(
|
||||||
|
cls: &Bound<'py, PyType>,
|
||||||
|
base_table: Bound<'py, PyAny>,
|
||||||
|
permutation_table: Option<Bound<'py, PyAny>>,
|
||||||
|
split: u64,
|
||||||
|
) -> PyResult<Bound<'py, PyAny>> {
|
||||||
|
let base_table = base_table.getattr("_inner")?.downcast_into::<Table>()?;
|
||||||
|
let permutation_table = permutation_table
|
||||||
|
.map(|p| PyResult::Ok(p.getattr("_inner")?.downcast_into::<Table>()?))
|
||||||
|
.transpose()?;
|
||||||
|
|
||||||
|
let base_table = base_table.borrow().inner_ref()?.base_table().clone();
|
||||||
|
let permutation_table = permutation_table
|
||||||
|
.map(|p| PyResult::Ok(p.borrow().inner_ref()?.base_table().clone()))
|
||||||
|
.transpose()?;
|
||||||
|
|
||||||
|
future_into_py(cls.py(), async move {
|
||||||
|
let reader = if let Some(permutation_table) = permutation_table {
|
||||||
|
PermutationReader::try_from_tables(base_table, permutation_table, split)
|
||||||
|
.await
|
||||||
|
.infer_error()?
|
||||||
|
} else {
|
||||||
|
PermutationReader::identity(base_table).await
|
||||||
|
};
|
||||||
|
Ok(Self::from_reader(reader))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyo3(signature = (selection=None))]
|
||||||
|
pub fn output_schema<'py>(
|
||||||
|
slf: PyRef<'py, Self>,
|
||||||
|
selection: Option<Bound<'py, PyAny>>,
|
||||||
|
) -> PyResult<Bound<'py, PyAny>> {
|
||||||
|
let selection = Self::parse_selection(selection)?;
|
||||||
|
let reader = slf.reader.clone();
|
||||||
|
future_into_py(slf.py(), async move {
|
||||||
|
let schema = reader.output_schema(selection).await.infer_error()?;
|
||||||
|
Python::with_gil(|py| schema.to_pyarrow(py))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyo3(signature = ())]
|
||||||
|
pub fn count_rows<'py>(slf: PyRef<'py, Self>) -> u64 {
|
||||||
|
slf.reader.count_rows()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyo3(signature = (offset))]
|
||||||
|
pub fn with_offset<'py>(slf: PyRef<'py, Self>, offset: u64) -> PyResult<Bound<'py, PyAny>> {
|
||||||
|
let reader = slf.reader.as_ref().clone();
|
||||||
|
future_into_py(slf.py(), async move {
|
||||||
|
let reader = reader.with_offset(offset).await.infer_error()?;
|
||||||
|
Ok(Self::from_reader(reader))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyo3(signature = (limit))]
|
||||||
|
pub fn with_limit<'py>(slf: PyRef<'py, Self>, limit: u64) -> PyResult<Bound<'py, PyAny>> {
|
||||||
|
let reader = slf.reader.as_ref().clone();
|
||||||
|
future_into_py(slf.py(), async move {
|
||||||
|
let reader = reader.with_limit(limit).await.infer_error()?;
|
||||||
|
Ok(Self::from_reader(reader))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyo3(signature = (selection=None, *, batch_size=None))]
|
||||||
|
pub fn read<'py>(
|
||||||
|
slf: PyRef<'py, Self>,
|
||||||
|
selection: Option<Bound<'py, PyAny>>,
|
||||||
|
batch_size: Option<u32>,
|
||||||
|
) -> PyResult<Bound<'py, PyAny>> {
|
||||||
|
let selection = Self::parse_selection(selection)?;
|
||||||
|
let reader = slf.reader.clone();
|
||||||
|
let batch_size = batch_size.unwrap_or(1024);
|
||||||
|
future_into_py(slf.py(), async move {
|
||||||
|
use lancedb::query::QueryExecutionOptions;
|
||||||
|
let mut execution_options = QueryExecutionOptions::default();
|
||||||
|
execution_options.max_batch_length = batch_size;
|
||||||
|
let stream = reader
|
||||||
|
.read(selection, execution_options)
|
||||||
|
.await
|
||||||
|
.infer_error()?;
|
||||||
|
Ok(RecordBatchStream::new(stream))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -9,6 +9,7 @@ use arrow::array::Array;
|
|||||||
use arrow::array::ArrayData;
|
use arrow::array::ArrayData;
|
||||||
use arrow::pyarrow::FromPyArrow;
|
use arrow::pyarrow::FromPyArrow;
|
||||||
use arrow::pyarrow::IntoPyArrow;
|
use arrow::pyarrow::IntoPyArrow;
|
||||||
|
use arrow::pyarrow::ToPyArrow;
|
||||||
use lancedb::index::scalar::{
|
use lancedb::index::scalar::{
|
||||||
BooleanQuery, BoostQuery, FtsQuery, FullTextSearchQuery, MatchQuery, MultiMatchQuery, Occur,
|
BooleanQuery, BoostQuery, FtsQuery, FullTextSearchQuery, MatchQuery, MultiMatchQuery, Occur,
|
||||||
Operator, PhraseQuery,
|
Operator, PhraseQuery,
|
||||||
@@ -22,6 +23,7 @@ use lancedb::query::{
|
|||||||
};
|
};
|
||||||
use lancedb::table::AnyQuery;
|
use lancedb::table::AnyQuery;
|
||||||
use pyo3::prelude::{PyAnyMethods, PyDictMethods};
|
use pyo3::prelude::{PyAnyMethods, PyDictMethods};
|
||||||
|
use pyo3::pyfunction;
|
||||||
use pyo3::pymethods;
|
use pyo3::pymethods;
|
||||||
use pyo3::types::PyList;
|
use pyo3::types::PyList;
|
||||||
use pyo3::types::{PyDict, PyString};
|
use pyo3::types::{PyDict, PyString};
|
||||||
@@ -30,6 +32,7 @@ use pyo3::IntoPyObject;
|
|||||||
use pyo3::PyAny;
|
use pyo3::PyAny;
|
||||||
use pyo3::PyRef;
|
use pyo3::PyRef;
|
||||||
use pyo3::PyResult;
|
use pyo3::PyResult;
|
||||||
|
use pyo3::Python;
|
||||||
use pyo3::{exceptions::PyRuntimeError, FromPyObject};
|
use pyo3::{exceptions::PyRuntimeError, FromPyObject};
|
||||||
use pyo3::{
|
use pyo3::{
|
||||||
exceptions::{PyNotImplementedError, PyValueError},
|
exceptions::{PyNotImplementedError, PyValueError},
|
||||||
@@ -445,6 +448,15 @@ impl Query {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[pyo3(signature = ())]
|
||||||
|
pub fn output_schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||||
|
let inner = self_.inner.clone();
|
||||||
|
future_into_py(self_.py(), async move {
|
||||||
|
let schema = inner.output_schema().await.infer_error()?;
|
||||||
|
Python::with_gil(|py| schema.to_pyarrow(py))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
#[pyo3(signature = (max_batch_length=None, timeout=None))]
|
#[pyo3(signature = (max_batch_length=None, timeout=None))]
|
||||||
pub fn execute(
|
pub fn execute(
|
||||||
self_: PyRef<'_, Self>,
|
self_: PyRef<'_, Self>,
|
||||||
@@ -515,6 +527,15 @@ impl TakeQuery {
|
|||||||
self.inner = self.inner.clone().with_row_id();
|
self.inner = self.inner.clone().with_row_id();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[pyo3(signature = ())]
|
||||||
|
pub fn output_schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||||
|
let inner = self_.inner.clone();
|
||||||
|
future_into_py(self_.py(), async move {
|
||||||
|
let schema = inner.output_schema().await.infer_error()?;
|
||||||
|
Python::with_gil(|py| schema.to_pyarrow(py))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
#[pyo3(signature = (max_batch_length=None, timeout=None))]
|
#[pyo3(signature = (max_batch_length=None, timeout=None))]
|
||||||
pub fn execute(
|
pub fn execute(
|
||||||
self_: PyRef<'_, Self>,
|
self_: PyRef<'_, Self>,
|
||||||
@@ -601,6 +622,15 @@ impl FTSQuery {
|
|||||||
self.inner = self.inner.clone().postfilter();
|
self.inner = self.inner.clone().postfilter();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[pyo3(signature = ())]
|
||||||
|
pub fn output_schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||||
|
let inner = self_.inner.clone();
|
||||||
|
future_into_py(self_.py(), async move {
|
||||||
|
let schema = inner.output_schema().await.infer_error()?;
|
||||||
|
Python::with_gil(|py| schema.to_pyarrow(py))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
#[pyo3(signature = (max_batch_length=None, timeout=None))]
|
#[pyo3(signature = (max_batch_length=None, timeout=None))]
|
||||||
pub fn execute(
|
pub fn execute(
|
||||||
self_: PyRef<'_, Self>,
|
self_: PyRef<'_, Self>,
|
||||||
@@ -771,6 +801,15 @@ impl VectorQuery {
|
|||||||
self.inner = self.inner.clone().bypass_vector_index()
|
self.inner = self.inner.clone().bypass_vector_index()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[pyo3(signature = ())]
|
||||||
|
pub fn output_schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||||
|
let inner = self_.inner.clone();
|
||||||
|
future_into_py(self_.py(), async move {
|
||||||
|
let schema = inner.output_schema().await.infer_error()?;
|
||||||
|
Python::with_gil(|py| schema.to_pyarrow(py))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
#[pyo3(signature = (max_batch_length=None, timeout=None))]
|
#[pyo3(signature = (max_batch_length=None, timeout=None))]
|
||||||
pub fn execute(
|
pub fn execute(
|
||||||
self_: PyRef<'_, Self>,
|
self_: PyRef<'_, Self>,
|
||||||
@@ -944,3 +983,15 @@ impl HybridQuery {
|
|||||||
req
|
req
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Convert a Python FTS query to JSON string
|
||||||
|
#[pyfunction]
|
||||||
|
pub fn fts_query_to_json(query_obj: &Bound<'_, PyAny>) -> PyResult<String> {
|
||||||
|
let wrapped: PyLanceDB<FtsQuery> = query_obj.extract()?;
|
||||||
|
lancedb::table::datafusion::udtf::fts::to_json(&wrapped.0).map_err(|e| {
|
||||||
|
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
|
||||||
|
"Failed to serialize FTS query to JSON: {}",
|
||||||
|
e
|
||||||
|
))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
use std::{collections::HashMap, sync::Arc};
|
use std::{collections::HashMap, sync::Arc};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
connection::Connection,
|
||||||
error::PythonErrorExt,
|
error::PythonErrorExt,
|
||||||
index::{extract_index_params, IndexConfig},
|
index::{extract_index_params, IndexConfig},
|
||||||
query::{Query, TakeQuery},
|
query::{Query, TakeQuery},
|
||||||
@@ -249,7 +250,7 @@ impl Table {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Table {
|
impl Table {
|
||||||
fn inner_ref(&self) -> PyResult<&LanceDbTable> {
|
pub(crate) fn inner_ref(&self) -> PyResult<&LanceDbTable> {
|
||||||
self.inner
|
self.inner
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.ok_or_else(|| PyRuntimeError::new_err(format!("Table {} is closed", self.name)))
|
.ok_or_else(|| PyRuntimeError::new_err(format!("Table {} is closed", self.name)))
|
||||||
@@ -272,6 +273,13 @@ impl Table {
|
|||||||
self.inner.take();
|
self.inner.take();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn database(&self) -> PyResult<Connection> {
|
||||||
|
let inner = self.inner_ref()?.clone();
|
||||||
|
let inner_connection =
|
||||||
|
lancedb::Connection::new(inner.database().clone(), inner.embedding_registry().clone());
|
||||||
|
Ok(Connection::new(inner_connection))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
pub fn schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||||
let inner = self_.inner_ref()?.clone();
|
let inner = self_.inner_ref()?.clone();
|
||||||
future_into_py(self_.py(), async move {
|
future_into_py(self_.py(), async move {
|
||||||
|
|||||||
@@ -1,2 +1,2 @@
|
|||||||
[toolchain]
|
[toolchain]
|
||||||
channel = "1.86.0"
|
channel = "1.90.0"
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
version = "0.22.2-beta.2"
|
version = "0.22.3-beta.5"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
@@ -11,10 +11,12 @@ rust-version.workspace = true
|
|||||||
|
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
ahash = { workspace = true }
|
||||||
arrow = { workspace = true }
|
arrow = { workspace = true }
|
||||||
arrow-array = { workspace = true }
|
arrow-array = { workspace = true }
|
||||||
arrow-data = { workspace = true }
|
arrow-data = { workspace = true }
|
||||||
arrow-schema = { workspace = true }
|
arrow-schema = { workspace = true }
|
||||||
|
arrow-select = { workspace = true }
|
||||||
arrow-ord = { workspace = true }
|
arrow-ord = { workspace = true }
|
||||||
arrow-cast = { workspace = true }
|
arrow-cast = { workspace = true }
|
||||||
arrow-ipc.workspace = true
|
arrow-ipc.workspace = true
|
||||||
@@ -24,19 +26,25 @@ datafusion-common.workspace = true
|
|||||||
datafusion-execution.workspace = true
|
datafusion-execution.workspace = true
|
||||||
datafusion-expr.workspace = true
|
datafusion-expr.workspace = true
|
||||||
datafusion-physical-plan.workspace = true
|
datafusion-physical-plan.workspace = true
|
||||||
|
datafusion.workspace = true
|
||||||
object_store = { workspace = true }
|
object_store = { workspace = true }
|
||||||
snafu = { workspace = true }
|
snafu = { workspace = true }
|
||||||
half = { workspace = true }
|
half = { workspace = true }
|
||||||
lazy_static.workspace = true
|
lazy_static.workspace = true
|
||||||
lance = { workspace = true }
|
lance = { workspace = true }
|
||||||
|
lance-core = { workspace = true }
|
||||||
lance-datafusion.workspace = true
|
lance-datafusion.workspace = true
|
||||||
|
lance-datagen = { workspace = true }
|
||||||
|
lance-file = { workspace = true }
|
||||||
lance-io = { workspace = true }
|
lance-io = { workspace = true }
|
||||||
lance-index = { workspace = true }
|
lance-index = { workspace = true }
|
||||||
lance-table = { workspace = true }
|
lance-table = { workspace = true }
|
||||||
lance-linalg = { workspace = true }
|
lance-linalg = { workspace = true }
|
||||||
lance-testing = { workspace = true }
|
lance-testing = { workspace = true }
|
||||||
lance-encoding = { workspace = true }
|
lance-encoding = { workspace = true }
|
||||||
|
lance-arrow = { workspace = true }
|
||||||
lance-namespace = { workspace = true }
|
lance-namespace = { workspace = true }
|
||||||
|
lance-namespace-impls = { workspace = true }
|
||||||
moka = { workspace = true }
|
moka = { workspace = true }
|
||||||
pin-project = { workspace = true }
|
pin-project = { workspace = true }
|
||||||
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
||||||
@@ -46,11 +54,13 @@ bytes = "1"
|
|||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
num-traits.workspace = true
|
num-traits.workspace = true
|
||||||
url.workspace = true
|
url.workspace = true
|
||||||
|
rand.workspace = true
|
||||||
regex.workspace = true
|
regex.workspace = true
|
||||||
serde = { version = "^1" }
|
serde = { version = "^1" }
|
||||||
serde_json = { version = "1" }
|
serde_json = { version = "1" }
|
||||||
async-openai = { version = "0.20.0", optional = true }
|
async-openai = { version = "0.20.0", optional = true }
|
||||||
serde_with = { version = "3.8.1" }
|
serde_with = { version = "3.8.1" }
|
||||||
|
tempfile = "3.5.0"
|
||||||
aws-sdk-bedrockruntime = { version = "1.27.0", optional = true }
|
aws-sdk-bedrockruntime = { version = "1.27.0", optional = true }
|
||||||
# For remote feature
|
# For remote feature
|
||||||
reqwest = { version = "0.12.0", default-features = false, features = [
|
reqwest = { version = "0.12.0", default-features = false, features = [
|
||||||
@@ -61,9 +71,8 @@ reqwest = { version = "0.12.0", default-features = false, features = [
|
|||||||
"macos-system-configuration",
|
"macos-system-configuration",
|
||||||
"stream",
|
"stream",
|
||||||
], optional = true }
|
], optional = true }
|
||||||
rand = { version = "0.9", features = ["small_rng"], optional = true }
|
|
||||||
http = { version = "1", optional = true } # Matching what is in reqwest
|
http = { version = "1", optional = true } # Matching what is in reqwest
|
||||||
uuid = { version = "1.7.0", features = ["v4"], optional = true }
|
uuid = { version = "1.7.0", features = ["v4"] }
|
||||||
polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
|
polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
|
||||||
polars = { version = ">=0.37,<0.40.0", optional = true }
|
polars = { version = ">=0.37,<0.40.0", optional = true }
|
||||||
hf-hub = { version = "0.4.1", optional = true, default-features = false, features = [
|
hf-hub = { version = "0.4.1", optional = true, default-features = false, features = [
|
||||||
@@ -77,14 +86,9 @@ candle-nn = { version = "0.9.1", optional = true }
|
|||||||
tokenizers = { version = "0.19.1", optional = true }
|
tokenizers = { version = "0.19.1", optional = true }
|
||||||
semver = { workspace = true }
|
semver = { workspace = true }
|
||||||
|
|
||||||
# For a workaround, see workspace Cargo.toml
|
|
||||||
crunchy.workspace = true
|
|
||||||
bytemuck_derive.workspace = true
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = "1"
|
anyhow = "1"
|
||||||
tempfile = "3.5.0"
|
tempfile = "3.5.0"
|
||||||
rand = { version = "0.9", features = ["small_rng"] }
|
|
||||||
random_word = { version = "0.4.3", features = ["en"] }
|
random_word = { version = "0.4.3", features = ["en"] }
|
||||||
uuid = { version = "1.7.0", features = ["v4"] }
|
uuid = { version = "1.7.0", features = ["v4"] }
|
||||||
walkdir = "2"
|
walkdir = "2"
|
||||||
@@ -96,6 +100,7 @@ aws-smithy-runtime = { version = "1.9.1" }
|
|||||||
datafusion.workspace = true
|
datafusion.workspace = true
|
||||||
http-body = "1" # Matching reqwest
|
http-body = "1" # Matching reqwest
|
||||||
rstest = "0.23.0"
|
rstest = "0.23.0"
|
||||||
|
test-log = "0.2"
|
||||||
|
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
@@ -105,7 +110,7 @@ oss = ["lance/oss", "lance-io/oss"]
|
|||||||
gcs = ["lance/gcp", "lance-io/gcp"]
|
gcs = ["lance/gcp", "lance-io/gcp"]
|
||||||
azure = ["lance/azure", "lance-io/azure"]
|
azure = ["lance/azure", "lance-io/azure"]
|
||||||
dynamodb = ["lance/dynamodb", "aws"]
|
dynamodb = ["lance/dynamodb", "aws"]
|
||||||
remote = ["dep:reqwest", "dep:http", "dep:rand", "dep:uuid"]
|
remote = ["dep:reqwest", "dep:http"]
|
||||||
fp16kernels = ["lance-linalg/fp16kernels"]
|
fp16kernels = ["lance-linalg/fp16kernels"]
|
||||||
s3-test = []
|
s3-test = []
|
||||||
bedrock = ["dep:aws-sdk-bedrockruntime"]
|
bedrock = ["dep:aws-sdk-bedrockruntime"]
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ pub use arrow_schema;
|
|||||||
use datafusion_common::DataFusionError;
|
use datafusion_common::DataFusionError;
|
||||||
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
|
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
|
||||||
use futures::{Stream, StreamExt, TryStreamExt};
|
use futures::{Stream, StreamExt, TryStreamExt};
|
||||||
|
use lance_datagen::{BatchCount, BatchGeneratorBuilder, RowCount};
|
||||||
|
|
||||||
#[cfg(feature = "polars")]
|
#[cfg(feature = "polars")]
|
||||||
use {crate::polars_arrow_convertors, polars::frame::ArrowChunk, polars::prelude::DataFrame};
|
use {crate::polars_arrow_convertors, polars::frame::ArrowChunk, polars::prelude::DataFrame};
|
||||||
@@ -161,6 +162,26 @@ impl IntoArrowStream for datafusion_physical_plan::SendableRecordBatchStream {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait LanceDbDatagenExt {
|
||||||
|
fn into_ldb_stream(
|
||||||
|
self,
|
||||||
|
batch_size: RowCount,
|
||||||
|
num_batches: BatchCount,
|
||||||
|
) -> SendableRecordBatchStream;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanceDbDatagenExt for BatchGeneratorBuilder {
|
||||||
|
fn into_ldb_stream(
|
||||||
|
self,
|
||||||
|
batch_size: RowCount,
|
||||||
|
num_batches: BatchCount,
|
||||||
|
) -> SendableRecordBatchStream {
|
||||||
|
let (stream, schema) = self.into_reader_stream(batch_size, num_batches);
|
||||||
|
let stream = stream.map_err(|err| Error::Arrow { source: err });
|
||||||
|
Box::pin(SimpleRecordBatchStream::new(stream, schema))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(feature = "polars")]
|
#[cfg(feature = "polars")]
|
||||||
/// An iterator of record batches formed from a Polars DataFrame.
|
/// An iterator of record batches formed from a Polars DataFrame.
|
||||||
pub struct PolarsDataFrameRecordBatchReader {
|
pub struct PolarsDataFrameRecordBatchReader {
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ use crate::database::listing::{
|
|||||||
use crate::database::{
|
use crate::database::{
|
||||||
CloneTableRequest, CreateNamespaceRequest, CreateTableData, CreateTableMode,
|
CloneTableRequest, CreateNamespaceRequest, CreateTableData, CreateTableMode,
|
||||||
CreateTableRequest, Database, DatabaseOptions, DropNamespaceRequest, ListNamespacesRequest,
|
CreateTableRequest, Database, DatabaseOptions, DropNamespaceRequest, ListNamespacesRequest,
|
||||||
OpenTableRequest, TableNamesRequest,
|
OpenTableRequest, ReadConsistency, TableNamesRequest,
|
||||||
};
|
};
|
||||||
use crate::embeddings::{
|
use crate::embeddings::{
|
||||||
EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry, MemoryRegistry, WithEmbeddings,
|
EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry, MemoryRegistry, WithEmbeddings,
|
||||||
@@ -152,6 +152,7 @@ impl CreateTableBuilder<true> {
|
|||||||
let request = self.into_request()?;
|
let request = self.into_request()?;
|
||||||
Ok(Table::new_with_embedding_registry(
|
Ok(Table::new_with_embedding_registry(
|
||||||
parent.create_table(request).await?,
|
parent.create_table(request).await?,
|
||||||
|
parent,
|
||||||
embedding_registry,
|
embedding_registry,
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
@@ -211,9 +212,9 @@ impl CreateTableBuilder<false> {
|
|||||||
|
|
||||||
/// Execute the create table operation
|
/// Execute the create table operation
|
||||||
pub async fn execute(self) -> Result<Table> {
|
pub async fn execute(self) -> Result<Table> {
|
||||||
Ok(Table::new(
|
let parent = self.parent.clone();
|
||||||
self.parent.clone().create_table(self.request).await?,
|
let table = parent.create_table(self.request).await?;
|
||||||
))
|
Ok(Table::new(table, parent))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -462,8 +463,10 @@ impl OpenTableBuilder {
|
|||||||
|
|
||||||
/// Open the table
|
/// Open the table
|
||||||
pub async fn execute(self) -> Result<Table> {
|
pub async fn execute(self) -> Result<Table> {
|
||||||
|
let table = self.parent.open_table(self.request).await?;
|
||||||
Ok(Table::new_with_embedding_registry(
|
Ok(Table::new_with_embedding_registry(
|
||||||
self.parent.clone().open_table(self.request).await?,
|
table,
|
||||||
|
self.parent,
|
||||||
self.embedding_registry,
|
self.embedding_registry,
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
@@ -519,16 +522,15 @@ impl CloneTableBuilder {
|
|||||||
|
|
||||||
/// Execute the clone operation
|
/// Execute the clone operation
|
||||||
pub async fn execute(self) -> Result<Table> {
|
pub async fn execute(self) -> Result<Table> {
|
||||||
Ok(Table::new(
|
let parent = self.parent.clone();
|
||||||
self.parent.clone().clone_table(self.request).await?,
|
let table = parent.clone_table(self.request).await?;
|
||||||
))
|
Ok(Table::new(table, parent))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A connection to LanceDB
|
/// A connection to LanceDB
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Connection {
|
pub struct Connection {
|
||||||
uri: String,
|
|
||||||
internal: Arc<dyn Database>,
|
internal: Arc<dyn Database>,
|
||||||
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
||||||
}
|
}
|
||||||
@@ -540,9 +542,19 @@ impl std::fmt::Display for Connection {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Connection {
|
impl Connection {
|
||||||
|
pub fn new(
|
||||||
|
internal: Arc<dyn Database>,
|
||||||
|
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
internal,
|
||||||
|
embedding_registry,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Get the URI of the connection
|
/// Get the URI of the connection
|
||||||
pub fn uri(&self) -> &str {
|
pub fn uri(&self) -> &str {
|
||||||
self.uri.as_str()
|
self.internal.uri()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get access to the underlying database
|
/// Get access to the underlying database
|
||||||
@@ -675,6 +687,11 @@ impl Connection {
|
|||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the read consistency of the connection
|
||||||
|
pub async fn read_consistency(&self) -> Result<ReadConsistency> {
|
||||||
|
self.internal.read_consistency().await
|
||||||
|
}
|
||||||
|
|
||||||
/// Drop a table in the database.
|
/// Drop a table in the database.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
@@ -973,7 +990,6 @@ impl ConnectBuilder {
|
|||||||
)?);
|
)?);
|
||||||
Ok(Connection {
|
Ok(Connection {
|
||||||
internal,
|
internal,
|
||||||
uri: self.request.uri,
|
|
||||||
embedding_registry: self
|
embedding_registry: self
|
||||||
.embedding_registry
|
.embedding_registry
|
||||||
.unwrap_or_else(|| Arc::new(MemoryRegistry::new())),
|
.unwrap_or_else(|| Arc::new(MemoryRegistry::new())),
|
||||||
@@ -996,7 +1012,6 @@ impl ConnectBuilder {
|
|||||||
let internal = Arc::new(ListingDatabase::connect_with_options(&self.request).await?);
|
let internal = Arc::new(ListingDatabase::connect_with_options(&self.request).await?);
|
||||||
Ok(Connection {
|
Ok(Connection {
|
||||||
internal,
|
internal,
|
||||||
uri: self.request.uri,
|
|
||||||
embedding_registry: self
|
embedding_registry: self
|
||||||
.embedding_registry
|
.embedding_registry
|
||||||
.unwrap_or_else(|| Arc::new(MemoryRegistry::new())),
|
.unwrap_or_else(|| Arc::new(MemoryRegistry::new())),
|
||||||
@@ -1104,7 +1119,6 @@ impl ConnectNamespaceBuilder {
|
|||||||
|
|
||||||
Ok(Connection {
|
Ok(Connection {
|
||||||
internal,
|
internal,
|
||||||
uri: format!("namespace://{}", self.ns_impl),
|
|
||||||
embedding_registry: self
|
embedding_registry: self
|
||||||
.embedding_registry
|
.embedding_registry
|
||||||
.unwrap_or_else(|| Arc::new(MemoryRegistry::new())),
|
.unwrap_or_else(|| Arc::new(MemoryRegistry::new())),
|
||||||
@@ -1139,7 +1153,6 @@ mod test_utils {
|
|||||||
let internal = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
|
let internal = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
|
||||||
Self {
|
Self {
|
||||||
internal,
|
internal,
|
||||||
uri: "db://test".to_string(),
|
|
||||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1156,7 +1169,6 @@ mod test_utils {
|
|||||||
));
|
));
|
||||||
Self {
|
Self {
|
||||||
internal,
|
internal,
|
||||||
uri: "db://test".to_string(),
|
|
||||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1170,13 +1182,13 @@ mod tests {
|
|||||||
use crate::database::listing::{ListingDatabaseOptions, NewTableConfig};
|
use crate::database::listing::{ListingDatabaseOptions, NewTableConfig};
|
||||||
use crate::query::QueryBase;
|
use crate::query::QueryBase;
|
||||||
use crate::query::{ExecutableQuery, QueryExecutionOptions};
|
use crate::query::{ExecutableQuery, QueryExecutionOptions};
|
||||||
use crate::test_connection::test_utils::new_test_connection;
|
use crate::test_utils::connection::new_test_connection;
|
||||||
use arrow::compute::concat_batches;
|
use arrow::compute::concat_batches;
|
||||||
use arrow_array::RecordBatchReader;
|
use arrow_array::RecordBatchReader;
|
||||||
use arrow_schema::{DataType, Field, Schema};
|
use arrow_schema::{DataType, Field, Schema};
|
||||||
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
|
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
|
||||||
use futures::{stream, TryStreamExt};
|
use futures::{stream, TryStreamExt};
|
||||||
use lance::error::{ArrowResult, DataFusionResult};
|
use lance_core::error::{ArrowResult, DataFusionResult};
|
||||||
use lance_testing::datagen::{BatchGenerator, IncrementingInt32};
|
use lance_testing::datagen::{BatchGenerator, IncrementingInt32};
|
||||||
use tempfile::tempdir;
|
use tempfile::tempdir;
|
||||||
|
|
||||||
@@ -1187,7 +1199,7 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_connect() {
|
async fn test_connect() {
|
||||||
let tc = new_test_connection().await.unwrap();
|
let tc = new_test_connection().await.unwrap();
|
||||||
assert_eq!(tc.connection.uri, tc.uri);
|
assert_eq!(tc.connection.uri(), tc.uri);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(windows))]
|
#[cfg(not(windows))]
|
||||||
@@ -1208,7 +1220,7 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(db.uri, relative_uri.to_str().unwrap().to_string());
|
assert_eq!(db.uri(), relative_uri.to_str().unwrap().to_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ use arrow_array::{
|
|||||||
use arrow_cast::{can_cast_types, cast};
|
use arrow_cast::{can_cast_types, cast};
|
||||||
use arrow_schema::{ArrowError, DataType, Field, Schema};
|
use arrow_schema::{ArrowError, DataType, Field, Schema};
|
||||||
use half::f16;
|
use half::f16;
|
||||||
use lance::arrow::{DataTypeExt, FixedSizeListArrayExt};
|
use lance_arrow::{DataTypeExt, FixedSizeListArrayExt};
|
||||||
use log::warn;
|
use log::warn;
|
||||||
use num_traits::cast::AsPrimitive;
|
use num_traits::cast::AsPrimitive;
|
||||||
|
|
||||||
@@ -189,7 +189,7 @@ mod tests {
|
|||||||
};
|
};
|
||||||
use arrow_schema::Field;
|
use arrow_schema::Field;
|
||||||
use half::f16;
|
use half::f16;
|
||||||
use lance::arrow::FixedSizeListArrayExt;
|
use lance_arrow::FixedSizeListArrayExt;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_coerce_list_to_fixed_size_list() {
|
fn test_coerce_list_to_fixed_size_list() {
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
use arrow_array::RecordBatchReader;
|
use arrow_array::RecordBatchReader;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
@@ -213,6 +214,20 @@ impl CloneTableRequest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// How long until a change is reflected from one Table instance to another
|
||||||
|
///
|
||||||
|
/// Tables are always internally consistent. If a write method is called on
|
||||||
|
/// a table instance it will be immediately visible in that same table instance.
|
||||||
|
pub enum ReadConsistency {
|
||||||
|
/// Changes will not be automatically propagated until the checkout_latest
|
||||||
|
/// method is called on the target table
|
||||||
|
Manual,
|
||||||
|
/// Changes will be propagated automatically within the given duration
|
||||||
|
Eventual(Duration),
|
||||||
|
/// Changes are immediately visible in target tables
|
||||||
|
Strong,
|
||||||
|
}
|
||||||
|
|
||||||
/// The `Database` trait defines the interface for database implementations.
|
/// The `Database` trait defines the interface for database implementations.
|
||||||
///
|
///
|
||||||
/// A database is responsible for managing tables and their metadata.
|
/// A database is responsible for managing tables and their metadata.
|
||||||
@@ -220,6 +235,10 @@ impl CloneTableRequest {
|
|||||||
pub trait Database:
|
pub trait Database:
|
||||||
Send + Sync + std::any::Any + std::fmt::Debug + std::fmt::Display + 'static
|
Send + Sync + std::any::Any + std::fmt::Debug + std::fmt::Display + 'static
|
||||||
{
|
{
|
||||||
|
/// Get the uri of the database
|
||||||
|
fn uri(&self) -> &str;
|
||||||
|
/// Get the read consistency of the database
|
||||||
|
async fn read_consistency(&self) -> Result<ReadConsistency>;
|
||||||
/// List immediate child namespace names in the given namespace
|
/// List immediate child namespace names in the given namespace
|
||||||
async fn list_namespaces(&self, request: ListNamespacesRequest) -> Result<Vec<String>>;
|
async fn list_namespaces(&self, request: ListNamespacesRequest) -> Result<Vec<String>>;
|
||||||
/// Create a new namespace
|
/// Create a new namespace
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ use object_store::local::LocalFileSystem;
|
|||||||
use snafu::ResultExt;
|
use snafu::ResultExt;
|
||||||
|
|
||||||
use crate::connection::ConnectRequest;
|
use crate::connection::ConnectRequest;
|
||||||
|
use crate::database::ReadConsistency;
|
||||||
use crate::error::{CreateDirSnafu, Error, Result};
|
use crate::error::{CreateDirSnafu, Error, Result};
|
||||||
use crate::io::object_store::MirroringObjectStoreWrapper;
|
use crate::io::object_store::MirroringObjectStoreWrapper;
|
||||||
use crate::table::NativeTable;
|
use crate::table::NativeTable;
|
||||||
@@ -454,6 +455,7 @@ impl ListingDatabase {
|
|||||||
// `remove_dir_all` may be used to remove something not be a dataset
|
// `remove_dir_all` may be used to remove something not be a dataset
|
||||||
lance::Error::NotFound { .. } => Error::TableNotFound {
|
lance::Error::NotFound { .. } => Error::TableNotFound {
|
||||||
name: name.to_owned(),
|
name: name.to_owned(),
|
||||||
|
source: Box::new(err),
|
||||||
},
|
},
|
||||||
_ => Error::from(err),
|
_ => Error::from(err),
|
||||||
})?;
|
})?;
|
||||||
@@ -598,6 +600,22 @@ impl Database for ListingDatabase {
|
|||||||
Ok(Vec::new())
|
Ok(Vec::new())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn uri(&self) -> &str {
|
||||||
|
&self.uri
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn read_consistency(&self) -> Result<ReadConsistency> {
|
||||||
|
if let Some(read_consistency_inverval) = self.read_consistency_interval {
|
||||||
|
if read_consistency_inverval.is_zero() {
|
||||||
|
Ok(ReadConsistency::Strong)
|
||||||
|
} else {
|
||||||
|
Ok(ReadConsistency::Eventual(read_consistency_inverval))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Ok(ReadConsistency::Manual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn create_namespace(&self, _request: CreateNamespaceRequest) -> Result<()> {
|
async fn create_namespace(&self, _request: CreateNamespaceRequest) -> Result<()> {
|
||||||
Err(Error::NotSupported {
|
Err(Error::NotSupported {
|
||||||
message: "Namespace operations are not supported for listing database".into(),
|
message: "Namespace operations are not supported for listing database".into(),
|
||||||
@@ -1249,7 +1267,8 @@ mod tests {
|
|||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let source_table_obj = Table::new(source_table.clone());
|
let db = Arc::new(db);
|
||||||
|
let source_table_obj = Table::new(source_table.clone(), db.clone());
|
||||||
source_table_obj
|
source_table_obj
|
||||||
.add(Box::new(arrow_array::RecordBatchIterator::new(
|
.add(Box::new(arrow_array::RecordBatchIterator::new(
|
||||||
vec![Ok(batch2)],
|
vec![Ok(batch2)],
|
||||||
@@ -1320,7 +1339,8 @@ mod tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
// Create a tag for the current version
|
// Create a tag for the current version
|
||||||
let source_table_obj = Table::new(source_table.clone());
|
let db = Arc::new(db);
|
||||||
|
let source_table_obj = Table::new(source_table.clone(), db.clone());
|
||||||
let mut tags = source_table_obj.tags().await.unwrap();
|
let mut tags = source_table_obj.tags().await.unwrap();
|
||||||
tags.create("v1.0", source_table.version().await.unwrap())
|
tags.create("v1.0", source_table.version().await.unwrap())
|
||||||
.await
|
.await
|
||||||
@@ -1336,7 +1356,7 @@ mod tests {
|
|||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let source_table_obj = Table::new(source_table.clone());
|
let source_table_obj = Table::new(source_table.clone(), db.clone());
|
||||||
source_table_obj
|
source_table_obj
|
||||||
.add(Box::new(arrow_array::RecordBatchIterator::new(
|
.add(Box::new(arrow_array::RecordBatchIterator::new(
|
||||||
vec![Ok(batch2)],
|
vec![Ok(batch2)],
|
||||||
@@ -1432,7 +1452,8 @@ mod tests {
|
|||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let cloned_table_obj = Table::new(cloned_table.clone());
|
let db = Arc::new(db);
|
||||||
|
let cloned_table_obj = Table::new(cloned_table.clone(), db.clone());
|
||||||
cloned_table_obj
|
cloned_table_obj
|
||||||
.add(Box::new(arrow_array::RecordBatchIterator::new(
|
.add(Box::new(arrow_array::RecordBatchIterator::new(
|
||||||
vec![Ok(batch_clone)],
|
vec![Ok(batch_clone)],
|
||||||
@@ -1452,7 +1473,7 @@ mod tests {
|
|||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let source_table_obj = Table::new(source_table.clone());
|
let source_table_obj = Table::new(source_table.clone(), db);
|
||||||
source_table_obj
|
source_table_obj
|
||||||
.add(Box::new(arrow_array::RecordBatchIterator::new(
|
.add(Box::new(arrow_array::RecordBatchIterator::new(
|
||||||
vec![Ok(batch_source)],
|
vec![Ok(batch_source)],
|
||||||
@@ -1495,6 +1516,7 @@ mod tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
// Add more data to create new versions
|
// Add more data to create new versions
|
||||||
|
let db = Arc::new(db);
|
||||||
for i in 0..3 {
|
for i in 0..3 {
|
||||||
let batch = RecordBatch::try_new(
|
let batch = RecordBatch::try_new(
|
||||||
schema.clone(),
|
schema.clone(),
|
||||||
@@ -1502,7 +1524,7 @@ mod tests {
|
|||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let source_table_obj = Table::new(source_table.clone());
|
let source_table_obj = Table::new(source_table.clone(), db.clone());
|
||||||
source_table_obj
|
source_table_obj
|
||||||
.add(Box::new(arrow_array::RecordBatchIterator::new(
|
.add(Box::new(arrow_array::RecordBatchIterator::new(
|
||||||
vec![Ok(batch)],
|
vec![Ok(batch)],
|
||||||
|
|||||||
@@ -8,17 +8,17 @@ use std::sync::Arc;
|
|||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use lance_namespace::{
|
use lance_namespace::{
|
||||||
connect as connect_namespace,
|
|
||||||
models::{
|
models::{
|
||||||
CreateEmptyTableRequest, CreateNamespaceRequest, DescribeTableRequest,
|
CreateEmptyTableRequest, CreateNamespaceRequest, DescribeTableRequest,
|
||||||
DropNamespaceRequest, DropTableRequest, ListNamespacesRequest, ListTablesRequest,
|
DropNamespaceRequest, DropTableRequest, ListNamespacesRequest, ListTablesRequest,
|
||||||
},
|
},
|
||||||
LanceNamespace,
|
LanceNamespace,
|
||||||
};
|
};
|
||||||
|
use lance_namespace_impls::ConnectBuilder;
|
||||||
|
|
||||||
use crate::connection::ConnectRequest;
|
|
||||||
use crate::database::listing::ListingDatabase;
|
use crate::database::listing::ListingDatabase;
|
||||||
use crate::error::{Error, Result};
|
use crate::error::{Error, Result};
|
||||||
|
use crate::{connection::ConnectRequest, database::ReadConsistency};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
BaseTable, CloneTableRequest, CreateNamespaceRequest as DbCreateNamespaceRequest,
|
BaseTable, CloneTableRequest, CreateNamespaceRequest as DbCreateNamespaceRequest,
|
||||||
@@ -36,6 +36,8 @@ pub struct LanceNamespaceDatabase {
|
|||||||
read_consistency_interval: Option<std::time::Duration>,
|
read_consistency_interval: Option<std::time::Duration>,
|
||||||
// Optional session for object stores and caching
|
// Optional session for object stores and caching
|
||||||
session: Option<Arc<lance::session::Session>>,
|
session: Option<Arc<lance::session::Session>>,
|
||||||
|
// database URI
|
||||||
|
uri: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LanceNamespaceDatabase {
|
impl LanceNamespaceDatabase {
|
||||||
@@ -46,17 +48,23 @@ impl LanceNamespaceDatabase {
|
|||||||
read_consistency_interval: Option<std::time::Duration>,
|
read_consistency_interval: Option<std::time::Duration>,
|
||||||
session: Option<Arc<lance::session::Session>>,
|
session: Option<Arc<lance::session::Session>>,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let namespace = connect_namespace(ns_impl, ns_properties.clone())
|
let mut builder = ConnectBuilder::new(ns_impl);
|
||||||
.await
|
for (key, value) in ns_properties.clone() {
|
||||||
.map_err(|e| Error::InvalidInput {
|
builder = builder.property(key, value);
|
||||||
message: format!("Failed to connect to namespace: {:?}", e),
|
}
|
||||||
})?;
|
if let Some(ref sess) = session {
|
||||||
|
builder = builder.session(sess.clone());
|
||||||
|
}
|
||||||
|
let namespace = builder.connect().await.map_err(|e| Error::InvalidInput {
|
||||||
|
message: format!("Failed to connect to namespace: {:?}", e),
|
||||||
|
})?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
namespace,
|
namespace,
|
||||||
storage_options,
|
storage_options,
|
||||||
read_consistency_interval,
|
read_consistency_interval,
|
||||||
session,
|
session,
|
||||||
|
uri: format!("namespace://{}", ns_impl),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -130,6 +138,22 @@ impl std::fmt::Display for LanceNamespaceDatabase {
|
|||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl Database for LanceNamespaceDatabase {
|
impl Database for LanceNamespaceDatabase {
|
||||||
|
fn uri(&self) -> &str {
|
||||||
|
&self.uri
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn read_consistency(&self) -> Result<ReadConsistency> {
|
||||||
|
if let Some(read_consistency_inverval) = self.read_consistency_interval {
|
||||||
|
if read_consistency_inverval.is_zero() {
|
||||||
|
Ok(ReadConsistency::Strong)
|
||||||
|
} else {
|
||||||
|
Ok(ReadConsistency::Eventual(read_consistency_inverval))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Ok(ReadConsistency::Manual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn list_namespaces(&self, request: DbListNamespacesRequest) -> Result<Vec<String>> {
|
async fn list_namespaces(&self, request: DbListNamespacesRequest) -> Result<Vec<String>> {
|
||||||
let ns_request = ListNamespacesRequest {
|
let ns_request = ListNamespacesRequest {
|
||||||
id: if request.namespace.is_empty() {
|
id: if request.namespace.is_empty() {
|
||||||
|
|||||||
4
rust/lancedb/src/dataloader.rs
Normal file
4
rust/lancedb/src/dataloader.rs
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
pub mod permutation;
|
||||||
18
rust/lancedb/src/dataloader/permutation.rs
Normal file
18
rust/lancedb/src/dataloader/permutation.rs
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
//! Contains the [PermutationBuilder] to create a permutation "view" of an existing table.
|
||||||
|
//!
|
||||||
|
//! A permutation view can apply a filter, divide the data into splits, and shuffle the data.
|
||||||
|
//! The permutation table only stores the split ids and row ids. It is not a materialized copy of
|
||||||
|
//! the underlying data and can be very lightweight.
|
||||||
|
//!
|
||||||
|
//! Building a permutation table should be fairly quick (it is an O(N) operation where N is
|
||||||
|
//! the number of rows in the base table) and memory efficient, even for billions or trillions
|
||||||
|
//! of rows.
|
||||||
|
|
||||||
|
pub mod builder;
|
||||||
|
pub mod reader;
|
||||||
|
pub mod shuffle;
|
||||||
|
pub mod split;
|
||||||
|
pub mod util;
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user