mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-24 13:59:58 +00:00
Compare commits
30 Commits
python-v0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e1f9b011f8 | ||
|
|
d664b8739f | ||
|
|
20bec61ecb | ||
|
|
45255be42c | ||
|
|
93c2cf2f59 | ||
|
|
9d29c83f81 | ||
|
|
2a6143b5bd | ||
|
|
b2242886e0 | ||
|
|
199904ab35 | ||
|
|
1fa888615f | ||
|
|
40967f3baa | ||
|
|
0bfc7de32c | ||
|
|
d43880a585 | ||
|
|
59a886958b | ||
|
|
c36f6746d1 | ||
|
|
25ce6d311f | ||
|
|
92a4e46f9f | ||
|
|
845641c480 | ||
|
|
d96404c635 | ||
|
|
02d31ee412 | ||
|
|
308623577d | ||
|
|
8ee3ae378f | ||
|
|
3372a2aae0 | ||
|
|
4cfcd95320 | ||
|
|
a70ff04bc9 | ||
|
|
a9daa18be9 | ||
|
|
3f2e3986e9 | ||
|
|
bf55feb9b6 | ||
|
|
8f8e06a2da | ||
|
|
03eab0f091 |
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.22.2"
|
||||
current_version = "0.22.3-beta.3"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
107
.github/workflows/codex-update-lance-dependency.yml
vendored
Normal file
107
.github/workflows/codex-update-lance-dependency.yml
vendored
Normal file
@@ -0,0 +1,107 @@
|
||||
name: Codex Update Lance Dependency
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
tag:
|
||||
description: "Tag name from Lance"
|
||||
required: true
|
||||
type: string
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: "Tag name from Lance"
|
||||
required: true
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
actions: read
|
||||
|
||||
jobs:
|
||||
update:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Show inputs
|
||||
run: |
|
||||
echo "tag = ${{ inputs.tag }}"
|
||||
|
||||
- name: Checkout Repo LanceDB
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: true
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 20
|
||||
|
||||
- name: Install Codex CLI
|
||||
run: npm install -g @openai/codex
|
||||
|
||||
- name: Install Rust toolchain
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
toolchain: stable
|
||||
components: clippy, rustfmt
|
||||
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y protobuf-compiler libssl-dev
|
||||
|
||||
- name: Install cargo-info
|
||||
run: cargo install cargo-info
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: python3 -m pip install --upgrade pip packaging
|
||||
|
||||
- name: Configure git user
|
||||
run: |
|
||||
git config user.name "lancedb automation"
|
||||
git config user.email "robot@lancedb.com"
|
||||
|
||||
- name: Configure Codex authentication
|
||||
env:
|
||||
CODEX_TOKEN_B64: ${{ secrets.CODEX_TOKEN }}
|
||||
run: |
|
||||
if [ -z "${CODEX_TOKEN_B64}" ]; then
|
||||
echo "Repository secret CODEX_TOKEN is not defined; skipping Codex execution."
|
||||
exit 1
|
||||
fi
|
||||
mkdir -p ~/.codex
|
||||
echo "${CODEX_TOKEN_B64}" | base64 --decode > ~/.codex/auth.json
|
||||
|
||||
- name: Run Codex to update Lance dependency
|
||||
env:
|
||||
TAG: ${{ inputs.tag }}
|
||||
GITHUB_TOKEN: ${{ secrets.ROBOT_TOKEN }}
|
||||
GH_TOKEN: ${{ secrets.ROBOT_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
VERSION="${TAG#refs/tags/}"
|
||||
VERSION="${VERSION#v}"
|
||||
BRANCH_NAME="codex/update-lance-${VERSION//[^a-zA-Z0-9]/-}"
|
||||
cat <<EOF >/tmp/codex-prompt.txt
|
||||
You are running inside the lancedb repository on a GitHub Actions runner. Update the Lance dependency to version ${VERSION} and prepare a pull request for maintainers to review.
|
||||
|
||||
Follow these steps exactly:
|
||||
1. Use script "ci/set_lance_version.py" to update Lance dependencies. The script already refreshes Cargo metadata, so allow it to finish even if it takes time.
|
||||
2. Run "cargo clippy --workspace --tests --all-features -- -D warnings". If diagnostics appear, fix them yourself and rerun clippy until it exits cleanly. Do not skip any warnings.
|
||||
3. After clippy succeeds, run "cargo fmt --all" to format the workspace.
|
||||
4. Ensure the repository is clean except for intentional changes. Inspect "git status --short" and "git diff" to confirm the dependency update and any required fixes.
|
||||
5. Create and switch to a new branch named "${BRANCH_NAME}" (replace any duplicated hyphens if necessary).
|
||||
6. Stage all relevant files with "git add -A". Commit using the message "chore: update lance dependency to v${VERSION}".
|
||||
7. Push the branch to origin. If the branch already exists, force-push your changes.
|
||||
8. env "GH_TOKEN" is available, use "gh" tools for github related operations like creating pull request.
|
||||
9. Create a pull request targeting "main" with title "chore: update lance dependency to v${VERSION}". In the body, summarize the dependency bump, clippy/fmt verification, and link the triggering tag (${TAG}).
|
||||
10. After creating the PR, display the PR URL, "git status --short", and a concise summary of the commands run and their results.
|
||||
|
||||
Constraints:
|
||||
- Use bash commands; avoid modifying GitHub workflow files other than through the scripted task above.
|
||||
- Do not merge the PR.
|
||||
- If any command fails, diagnose and fix the issue instead of aborting.
|
||||
EOF
|
||||
codex --config shell_environment_policy.ignore_default_excludes=true exec --dangerously-bypass-approvals-and-sandbox "$(cat /tmp/codex-prompt.txt)"
|
||||
101
AGENTS.md
Normal file
101
AGENTS.md
Normal file
@@ -0,0 +1,101 @@
|
||||
LanceDB is a database designed for retrieval, including vector, full-text, and hybrid search.
|
||||
It is a wrapper around Lance. There are two backends: local (in-process like SQLite) and
|
||||
remote (against LanceDB Cloud).
|
||||
|
||||
The core of LanceDB is written in Rust. There are bindings in Python, Typescript, and Java.
|
||||
|
||||
Project layout:
|
||||
|
||||
* `rust/lancedb`: The LanceDB core Rust implementation.
|
||||
* `python`: The Python bindings, using PyO3.
|
||||
* `nodejs`: The Typescript bindings, using napi-rs
|
||||
* `java`: The Java bindings
|
||||
|
||||
Common commands:
|
||||
|
||||
* Check for compiler errors: `cargo check --quiet --features remote --tests --examples`
|
||||
* Run tests: `cargo test --quiet --features remote --tests`
|
||||
* Run specific test: `cargo test --quiet --features remote -p <package_name> --test <test_name>`
|
||||
* Lint: `cargo clippy --quiet --features remote --tests --examples`
|
||||
* Format: `cargo fmt --all`
|
||||
|
||||
Before committing changes, run formatting.
|
||||
|
||||
## Coding tips
|
||||
|
||||
* When writing Rust doctests for things that require a connection or table reference,
|
||||
write them as a function instead of a fully executable test. This allows type checking
|
||||
to run but avoids needing a full test environment. For example:
|
||||
```rust
|
||||
/// ```
|
||||
/// use lance_index::scalar::FullTextSearchQuery;
|
||||
/// use lancedb::query::{QueryBase, ExecutableQuery};
|
||||
///
|
||||
/// # use lancedb::Table;
|
||||
/// # async fn query(table: &Table) -> Result<(), Box<dyn std::error::Error>> {
|
||||
/// let results = table.query()
|
||||
/// .full_text_search(FullTextSearchQuery::new("hello world".into()))
|
||||
/// .execute()
|
||||
/// .await?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
```
|
||||
|
||||
## Example plan: adding a new method on Table
|
||||
|
||||
Adding a new method involves first adding it to the Rust core, then exposing it
|
||||
in the Python and TypeScript bindings. There are both local and remote tables.
|
||||
Remote tables are implemented via a HTTP API and require the `remote` cargo
|
||||
feature flag to be enabled. Python has both sync and async methods.
|
||||
|
||||
Rust core changes:
|
||||
|
||||
1. Add method on `Table` struct in `rust/lancedb/src/table.rs` (calls `BaseTable` trait).
|
||||
2. Add method to `BaseTable` trait in `rust/lancedb/src/table.rs`.
|
||||
3. Implement new trait method on `NativeTable` in `rust/lancedb/src/table.rs`.
|
||||
* Test with unit test in `rust/lancedb/src/table.rs`.
|
||||
4. Implement new trait method on `RemoteTable` in `rust/lancedb/src/remote/table.rs`.
|
||||
* Test with unit test in `rust/lancedb/src/remote/table.rs` against mocked endpoint.
|
||||
|
||||
Python bindings changes:
|
||||
|
||||
1. Add PyO3 method binding in `python/src/table.rs`. Run `make develop` to compile bindings.
|
||||
2. Add types for PyO3 method in `python/python/lancedb/_lancedb.pyi`.
|
||||
3. Add method to `AsyncTable` class in `python/python/lancedb/table.py`.
|
||||
4. Add abstract method to `Table` abstract base class in `python/python/lancedb/table.py`.
|
||||
5. Add concrete sync method to `LanceTable` class in `python/python/lancedb/table.py`.
|
||||
* Should use `LOOP.run()` to call the corresponding `AsyncTable` method.
|
||||
6. Add concrete sync method to `RemoteTable` class in `python/python/lancedb/remote/table.py`.
|
||||
7. Add unit test in `python/tests/test_table.py`.
|
||||
|
||||
TypeScript bindings changes:
|
||||
|
||||
1. Add napi-rs method binding on `Table` in `nodejs/src/table.rs`.
|
||||
2. Run `npm run build` to generate TypeScript definitions.
|
||||
3. Add typescript method on abstract class `Table` in `nodejs/src/table.ts`.
|
||||
4. Add concrete method on `LocalTable` class in `nodejs/src/native_table.ts`.
|
||||
* Note: despite the name, this class is also used for remote tables.
|
||||
5. Add test in `nodejs/__test__/table.test.ts`.
|
||||
6. Run `npm run docs` to generate TypeScript documentation.
|
||||
|
||||
## Review Guidelines
|
||||
|
||||
Please consider the following when reviewing code contributions.
|
||||
|
||||
### Rust API design
|
||||
* Design public APIs so they can be evolved easily in the future without breaking
|
||||
changes. Often this means using builder patterns or options structs instead of
|
||||
long argument lists.
|
||||
* For public APIs, prefer inputs that use `Into<T>` or `AsRef<T>` traits to allow
|
||||
more flexible inputs. For example, use `name: Into<String>` instead of `name: String`,
|
||||
so we don't have to write `func("my_string".to_string())`.
|
||||
|
||||
### Testing
|
||||
* Ensure all new public APIs have documentation and examples.
|
||||
* Ensure that all bugfixes and features have corresponding tests. **We do not merge
|
||||
code without tests.**
|
||||
|
||||
### Documentation
|
||||
* New features must include updates to the rust documentation comments. Link to
|
||||
relevant structs and methods to increase the value of documentation.
|
||||
80
CLAUDE.md
80
CLAUDE.md
@@ -1,80 +0,0 @@
|
||||
LanceDB is a database designed for retrieval, including vector, full-text, and hybrid search.
|
||||
It is a wrapper around Lance. There are two backends: local (in-process like SQLite) and
|
||||
remote (against LanceDB Cloud).
|
||||
|
||||
The core of LanceDB is written in Rust. There are bindings in Python, Typescript, and Java.
|
||||
|
||||
Project layout:
|
||||
|
||||
* `rust/lancedb`: The LanceDB core Rust implementation.
|
||||
* `python`: The Python bindings, using PyO3.
|
||||
* `nodejs`: The Typescript bindings, using napi-rs
|
||||
* `java`: The Java bindings
|
||||
|
||||
Common commands:
|
||||
|
||||
* Check for compiler errors: `cargo check --quiet --features remote --tests --examples`
|
||||
* Run tests: `cargo test --quiet --features remote --tests`
|
||||
* Run specific test: `cargo test --quiet --features remote -p <package_name> --test <test_name>`
|
||||
* Lint: `cargo clippy --quiet --features remote --tests --examples`
|
||||
* Format: `cargo fmt --all`
|
||||
|
||||
Before committing changes, run formatting.
|
||||
|
||||
## Coding tips
|
||||
|
||||
* When writing Rust doctests for things that require a connection or table reference,
|
||||
write them as a function instead of a fully executable test. This allows type checking
|
||||
to run but avoids needing a full test environment. For example:
|
||||
```rust
|
||||
/// ```
|
||||
/// use lance_index::scalar::FullTextSearchQuery;
|
||||
/// use lancedb::query::{QueryBase, ExecutableQuery};
|
||||
///
|
||||
/// # use lancedb::Table;
|
||||
/// # async fn query(table: &Table) -> Result<(), Box<dyn std::error::Error>> {
|
||||
/// let results = table.query()
|
||||
/// .full_text_search(FullTextSearchQuery::new("hello world".into()))
|
||||
/// .execute()
|
||||
/// .await?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
```
|
||||
|
||||
## Example plan: adding a new method on Table
|
||||
|
||||
Adding a new method involves first adding it to the Rust core, then exposing it
|
||||
in the Python and TypeScript bindings. There are both local and remote tables.
|
||||
Remote tables are implemented via a HTTP API and require the `remote` cargo
|
||||
feature flag to be enabled. Python has both sync and async methods.
|
||||
|
||||
Rust core changes:
|
||||
|
||||
1. Add method on `Table` struct in `rust/lancedb/src/table.rs` (calls `BaseTable` trait).
|
||||
2. Add method to `BaseTable` trait in `rust/lancedb/src/table.rs`.
|
||||
3. Implement new trait method on `NativeTable` in `rust/lancedb/src/table.rs`.
|
||||
* Test with unit test in `rust/lancedb/src/table.rs`.
|
||||
4. Implement new trait method on `RemoteTable` in `rust/lancedb/src/remote/table.rs`.
|
||||
* Test with unit test in `rust/lancedb/src/remote/table.rs` against mocked endpoint.
|
||||
|
||||
Python bindings changes:
|
||||
|
||||
1. Add PyO3 method binding in `python/src/table.rs`. Run `make develop` to compile bindings.
|
||||
2. Add types for PyO3 method in `python/python/lancedb/_lancedb.pyi`.
|
||||
3. Add method to `AsyncTable` class in `python/python/lancedb/table.py`.
|
||||
4. Add abstract method to `Table` abstract base class in `python/python/lancedb/table.py`.
|
||||
5. Add concrete sync method to `LanceTable` class in `python/python/lancedb/table.py`.
|
||||
* Should use `LOOP.run()` to call the corresponding `AsyncTable` method.
|
||||
6. Add concrete sync method to `RemoteTable` class in `python/python/lancedb/remote/table.py`.
|
||||
7. Add unit test in `python/tests/test_table.py`.
|
||||
|
||||
TypeScript bindings changes:
|
||||
|
||||
1. Add napi-rs method binding on `Table` in `nodejs/src/table.rs`.
|
||||
2. Run `npm run build` to generate TypeScript definitions.
|
||||
3. Add typescript method on abstract class `Table` in `nodejs/src/table.ts`.
|
||||
4. Add concrete method on `LocalTable` class in `nodejs/src/native_table.ts`.
|
||||
* Note: despite the name, this class is also used for remote tables.
|
||||
5. Add test in `nodejs/__test__/table.test.ts`.
|
||||
6. Run `npm run docs` to generate TypeScript documentation.
|
||||
134
Cargo.lock
generated
134
Cargo.lock
generated
@@ -1139,7 +1139,7 @@ dependencies = [
|
||||
"bitflags 2.9.4",
|
||||
"cexpr",
|
||||
"clang-sys",
|
||||
"itertools 0.12.1",
|
||||
"itertools 0.11.0",
|
||||
"lazy_static",
|
||||
"lazycell",
|
||||
"log",
|
||||
@@ -2933,18 +2933,6 @@ version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f8eb564c5c7423d25c886fb561d1e4ee69f72354d16918afa32c08811f6b6a55"
|
||||
|
||||
[[package]]
|
||||
name = "fastbloom"
|
||||
version = "0.14.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "18c1ddb9231d8554c2d6bdf4cfaabf0c59251658c68b6c95cd52dd0c513a912a"
|
||||
dependencies = [
|
||||
"getrandom 0.3.3",
|
||||
"libm",
|
||||
"rand 0.9.2",
|
||||
"siphasher",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastdivide"
|
||||
version = "0.4.2"
|
||||
@@ -3044,8 +3032,8 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
|
||||
|
||||
[[package]]
|
||||
name = "fsst"
|
||||
version = "0.38.2"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
|
||||
version = "0.38.3"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3#afc0f9832cf11d0bf74381c2b63fd37de1c5f415"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"rand 0.9.2",
|
||||
@@ -4229,8 +4217,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance"
|
||||
version = "0.38.2"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
|
||||
version = "0.38.3"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3#afc0f9832cf11d0bf74381c2b63fd37de1c5f415"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-arith",
|
||||
@@ -4269,6 +4257,7 @@ dependencies = [
|
||||
"lance-index",
|
||||
"lance-io",
|
||||
"lance-linalg",
|
||||
"lance-namespace",
|
||||
"lance-table",
|
||||
"log",
|
||||
"moka",
|
||||
@@ -4292,8 +4281,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-arrow"
|
||||
version = "0.38.2"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
|
||||
version = "0.38.3"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3#afc0f9832cf11d0bf74381c2b63fd37de1c5f415"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4311,8 +4300,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-bitpacking"
|
||||
version = "0.38.2"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
|
||||
version = "0.38.3"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3#afc0f9832cf11d0bf74381c2b63fd37de1c5f415"
|
||||
dependencies = [
|
||||
"arrayref",
|
||||
"paste",
|
||||
@@ -4321,8 +4310,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-core"
|
||||
version = "0.38.2"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
|
||||
version = "0.38.3"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3#afc0f9832cf11d0bf74381c2b63fd37de1c5f415"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4358,8 +4347,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-datafusion"
|
||||
version = "0.38.2"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
|
||||
version = "0.38.3"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3#afc0f9832cf11d0bf74381c2b63fd37de1c5f415"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -4368,6 +4357,7 @@ dependencies = [
|
||||
"arrow-schema",
|
||||
"arrow-select",
|
||||
"async-trait",
|
||||
"chrono",
|
||||
"datafusion",
|
||||
"datafusion-common",
|
||||
"datafusion-functions",
|
||||
@@ -4387,8 +4377,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-datagen"
|
||||
version = "0.38.2"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
|
||||
version = "0.38.3"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3#afc0f9832cf11d0bf74381c2b63fd37de1c5f415"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -4405,8 +4395,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-encoding"
|
||||
version = "0.38.2"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
|
||||
version = "0.38.3"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3#afc0f9832cf11d0bf74381c2b63fd37de1c5f415"
|
||||
dependencies = [
|
||||
"arrow-arith",
|
||||
"arrow-array",
|
||||
@@ -4443,8 +4433,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-file"
|
||||
version = "0.38.2"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
|
||||
version = "0.38.3"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3#afc0f9832cf11d0bf74381c2b63fd37de1c5f415"
|
||||
dependencies = [
|
||||
"arrow-arith",
|
||||
"arrow-array",
|
||||
@@ -4469,7 +4459,6 @@ dependencies = [
|
||||
"prost",
|
||||
"prost-build",
|
||||
"prost-types",
|
||||
"roaring",
|
||||
"snafu",
|
||||
"tokio",
|
||||
"tracing",
|
||||
@@ -4477,8 +4466,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-index"
|
||||
version = "0.38.2"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
|
||||
version = "0.38.3"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3#afc0f9832cf11d0bf74381c2b63fd37de1c5f415"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-arith",
|
||||
@@ -4500,7 +4489,6 @@ dependencies = [
|
||||
"datafusion-sql",
|
||||
"deepsize",
|
||||
"dirs",
|
||||
"fastbloom",
|
||||
"fst",
|
||||
"futures",
|
||||
"half",
|
||||
@@ -4540,8 +4528,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-io"
|
||||
version = "0.38.2"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
|
||||
version = "0.38.3"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3#afc0f9832cf11d0bf74381c2b63fd37de1c5f415"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-arith",
|
||||
@@ -4562,6 +4550,7 @@ dependencies = [
|
||||
"futures",
|
||||
"lance-arrow",
|
||||
"lance-core",
|
||||
"lance-namespace",
|
||||
"log",
|
||||
"object_store",
|
||||
"object_store_opendal",
|
||||
@@ -4580,43 +4569,52 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-linalg"
|
||||
version = "0.38.2"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
|
||||
version = "0.38.3"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3#afc0f9832cf11d0bf74381c2b63fd37de1c5f415"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
"arrow-ord",
|
||||
"arrow-schema",
|
||||
"bitvec",
|
||||
"cc",
|
||||
"deepsize",
|
||||
"futures",
|
||||
"half",
|
||||
"lance-arrow",
|
||||
"lance-core",
|
||||
"log",
|
||||
"num-traits",
|
||||
"rand 0.9.2",
|
||||
"rayon",
|
||||
"tokio",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lance-namespace"
|
||||
version = "0.0.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7c0629165b5d85ff305f2de8833dcee507e899b36b098864c59f14f3b8b8e62d"
|
||||
version = "0.38.3"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3#afc0f9832cf11d0bf74381c2b63fd37de1c5f415"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"lance",
|
||||
"lance-core",
|
||||
"lance-namespace-reqwest-client",
|
||||
"opendal",
|
||||
"snafu",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lance-namespace-impls"
|
||||
version = "0.38.3"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3#afc0f9832cf11d0bf74381c2b63fd37de1c5f415"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-ipc",
|
||||
"arrow-schema",
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"lance",
|
||||
"lance-core",
|
||||
"lance-io",
|
||||
"lance-namespace",
|
||||
"object_store",
|
||||
"reqwest",
|
||||
"serde_json",
|
||||
"thiserror 1.0.69",
|
||||
"snafu",
|
||||
"url",
|
||||
]
|
||||
|
||||
@@ -4635,8 +4633,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-table"
|
||||
version = "0.38.2"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
|
||||
version = "0.38.3"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3#afc0f9832cf11d0bf74381c2b63fd37de1c5f415"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -4674,8 +4672,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-testing"
|
||||
version = "0.38.2"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
|
||||
version = "0.38.3"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3#afc0f9832cf11d0bf74381c2b63fd37de1c5f415"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-schema",
|
||||
@@ -4686,7 +4684,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb"
|
||||
version = "0.22.2"
|
||||
version = "0.22.3-beta.3"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"anyhow",
|
||||
@@ -4697,6 +4695,7 @@ dependencies = [
|
||||
"arrow-ipc",
|
||||
"arrow-ord",
|
||||
"arrow-schema",
|
||||
"arrow-select",
|
||||
"async-openai",
|
||||
"async-trait",
|
||||
"aws-config",
|
||||
@@ -4724,6 +4723,7 @@ dependencies = [
|
||||
"http 1.3.1",
|
||||
"http-body 1.0.1",
|
||||
"lance",
|
||||
"lance-arrow",
|
||||
"lance-core",
|
||||
"lance-datafusion",
|
||||
"lance-datagen",
|
||||
@@ -4733,6 +4733,7 @@ dependencies = [
|
||||
"lance-io",
|
||||
"lance-linalg",
|
||||
"lance-namespace",
|
||||
"lance-namespace-impls",
|
||||
"lance-table",
|
||||
"lance-testing",
|
||||
"lazy_static",
|
||||
@@ -4780,7 +4781,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb-nodejs"
|
||||
version = "0.22.2"
|
||||
version = "0.22.3-beta.3"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-ipc",
|
||||
@@ -4800,7 +4801,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb-python"
|
||||
version = "0.25.2"
|
||||
version = "0.25.3-beta.3"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"async-trait",
|
||||
@@ -5160,12 +5161,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "mock_instant"
|
||||
version = "0.3.2"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9366861eb2a2c436c20b12c8dbec5f798cea6b47ad99216be0282942e2c81ea0"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
]
|
||||
checksum = "dce6dd36094cac388f119d2e9dc82dc730ef91c32a6222170d630e5414b956e6"
|
||||
|
||||
[[package]]
|
||||
name = "moka"
|
||||
@@ -6395,8 +6393,8 @@ version = "0.13.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf"
|
||||
dependencies = [
|
||||
"heck 0.5.0",
|
||||
"itertools 0.14.0",
|
||||
"heck 0.4.1",
|
||||
"itertools 0.11.0",
|
||||
"log",
|
||||
"multimap",
|
||||
"once_cell",
|
||||
@@ -6416,7 +6414,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"itertools 0.14.0",
|
||||
"itertools 0.11.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.106",
|
||||
@@ -7724,7 +7722,7 @@ version = "0.8.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c1c97747dbf44bb1ca44a561ece23508e99cb592e862f22222dcf42f51d1e451"
|
||||
dependencies = [
|
||||
"heck 0.5.0",
|
||||
"heck 0.4.1",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.106",
|
||||
|
||||
39
Cargo.toml
39
Cargo.toml
@@ -15,18 +15,20 @@ categories = ["database-implementations"]
|
||||
rust-version = "1.78.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.38.2", default-features = false, "features" = ["dynamodb"], "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-core = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-datagen = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-file = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-io = { "version" = "=0.38.2", default-features = false, "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-index = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-linalg = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-table = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-testing = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-datafusion = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-encoding = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-namespace = "0.0.18"
|
||||
lance = { "version" = "=0.38.3", default-features = false, "tag" = "v0.38.3", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-core = { "version" = "=0.38.3", "tag" = "v0.38.3", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-datagen = { "version" = "=0.38.3", "tag" = "v0.38.3", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-file = { "version" = "=0.38.3", "tag" = "v0.38.3", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-io = { "version" = "=0.38.3", default-features = false, "tag" = "v0.38.3", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-index = { "version" = "=0.38.3", "tag" = "v0.38.3", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-linalg = { "version" = "=0.38.3", "tag" = "v0.38.3", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-namespace = { "version" = "=0.38.3", "tag" = "v0.38.3", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-namespace-impls = { "version" = "=0.38.3", "features" = ["dir-aws", "dir-gcp", "dir-azure", "dir-oss", "rest"], "tag" = "v0.38.3", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-table = { "version" = "=0.38.3", "tag" = "v0.38.3", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-testing = { "version" = "=0.38.3", "tag" = "v0.38.3", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-datafusion = { "version" = "=0.38.3", "tag" = "v0.38.3", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-encoding = { "version" = "=0.38.3", "tag" = "v0.38.3", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-arrow = { "version" = "=0.38.3", "tag" = "v0.38.3", "git" = "https://github.com/lancedb/lance.git" }
|
||||
ahash = "0.8"
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "56.2", optional = false }
|
||||
@@ -35,6 +37,7 @@ arrow-data = "56.2"
|
||||
arrow-ipc = "56.2"
|
||||
arrow-ord = "56.2"
|
||||
arrow-schema = "56.2"
|
||||
arrow-select = "56.2"
|
||||
arrow-cast = "56.2"
|
||||
async-trait = "0"
|
||||
datafusion = { version = "50.1", default-features = false }
|
||||
@@ -63,15 +66,3 @@ crunchy = "0.2.4"
|
||||
chrono = "0.4"
|
||||
# Workaround for: https://github.com/Lokathor/bytemuck/issues/306
|
||||
bytemuck_derive = ">=1.8.1, <1.9.0"
|
||||
|
||||
# This is only needed when we reference preview releases of lance
|
||||
# Force to use the same lance version as the rest of the project to avoid duplicate dependencies
|
||||
[patch.crates-io]
|
||||
lance = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-io = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-index = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-linalg = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-table = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-testing = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-datafusion = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-encoding = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
|
||||
@@ -117,7 +117,7 @@ def update_cargo_toml(line_updater):
|
||||
lance_line = ""
|
||||
is_parsing_lance_line = False
|
||||
for line in lines:
|
||||
if line.startswith("lance") and not line.startswith("lance-namespace"):
|
||||
if line.startswith("lance"):
|
||||
# Check if this is a single-line or multi-line entry
|
||||
# Single-line entries either:
|
||||
# 1. End with } (complete inline table)
|
||||
@@ -183,10 +183,8 @@ def set_preview_version(version: str):
|
||||
|
||||
def line_updater(line: str) -> str:
|
||||
package_name = line.split("=", maxsplit=1)[0].strip()
|
||||
base_version = version.split("-")[0] # Get the base version without beta suffix
|
||||
|
||||
# Build config in desired order: version, default-features, features, tag, git
|
||||
config = {"version": f"={base_version}"}
|
||||
config = {"version": f"={version}"}
|
||||
|
||||
if extract_default_features(line):
|
||||
config["default-features"] = False
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
# VoyageAI Embeddings : Multimodal
|
||||
|
||||
VoyageAI embeddings can also be used to embed both text and image data, only some of the models support image data and you can check the list
|
||||
under [https://docs.voyageai.com/docs/multimodal-embeddings](https://docs.voyageai.com/docs/multimodal-embeddings)
|
||||
|
||||
Supported parameters (to be passed in `create` method) are:
|
||||
|
||||
| Parameter | Type | Default Value | Description |
|
||||
|---|---|-------------------------|-------------------------------------------|
|
||||
| `name` | `str` | `"voyage-multimodal-3"` | The model ID of the VoyageAI model to use |
|
||||
|
||||
Usage Example:
|
||||
|
||||
```python
|
||||
import base64
|
||||
import os
|
||||
from io import BytesIO
|
||||
|
||||
import requests
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.embeddings import get_registry
|
||||
import pandas as pd
|
||||
|
||||
os.environ['VOYAGE_API_KEY'] = 'YOUR_VOYAGE_API_KEY'
|
||||
|
||||
db = lancedb.connect(".lancedb")
|
||||
func = get_registry().get("voyageai").create(name="voyage-multimodal-3")
|
||||
|
||||
|
||||
def image_to_base64(image_bytes: bytes):
|
||||
buffered = BytesIO(image_bytes)
|
||||
img_str = base64.b64encode(buffered.getvalue())
|
||||
return img_str.decode("utf-8")
|
||||
|
||||
|
||||
class Images(LanceModel):
|
||||
label: str
|
||||
image_uri: str = func.SourceField() # image uri as the source
|
||||
image_bytes: str = func.SourceField() # image bytes base64 encoded as the source
|
||||
vector: Vector(func.ndims()) = func.VectorField() # vector column
|
||||
vec_from_bytes: Vector(func.ndims()) = func.VectorField() # Another vector column
|
||||
|
||||
|
||||
if "images" in db.table_names():
|
||||
db.drop_table("images")
|
||||
table = db.create_table("images", schema=Images)
|
||||
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
|
||||
uris = [
|
||||
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
||||
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
||||
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
||||
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
||||
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
||||
]
|
||||
# get each uri as bytes
|
||||
images_bytes = [image_to_base64(requests.get(uri).content) for uri in uris]
|
||||
table.add(
|
||||
pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": images_bytes})
|
||||
)
|
||||
```
|
||||
Now we can search using text from both the default vector column and the custom vector column
|
||||
```python
|
||||
|
||||
# text search
|
||||
actual = table.search("man's best friend", "vec_from_bytes").limit(1).to_pydantic(Images)[0]
|
||||
print(actual.label) # prints "dog"
|
||||
|
||||
frombytes = (
|
||||
table.search("man's best friend", vector_column_name="vec_from_bytes")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
print(frombytes.label)
|
||||
|
||||
```
|
||||
|
||||
Because we're using a multi-modal embedding function, we can also search using images
|
||||
|
||||
```python
|
||||
# image search
|
||||
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
|
||||
image_bytes = requests.get(query_image_uri).content
|
||||
query_image = Image.open(BytesIO(image_bytes))
|
||||
actual = table.search(query_image, "vec_from_bytes").limit(1).to_pydantic(Images)[0]
|
||||
print(actual.label == "dog")
|
||||
|
||||
# image search using a custom vector column
|
||||
other = (
|
||||
table.search(query_image, vector_column_name="vec_from_bytes")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
print(actual.label)
|
||||
|
||||
```
|
||||
@@ -397,117 +397,6 @@ For **read-only access**, LanceDB will need a policy such as:
|
||||
}
|
||||
```
|
||||
|
||||
#### DynamoDB Commit Store for concurrent writes
|
||||
|
||||
By default, S3 does not support concurrent writes. Having two or more processes
|
||||
writing to the same table at the same time can lead to data corruption. This is
|
||||
because S3, unlike other object stores, does not have any atomic put or copy
|
||||
operation.
|
||||
|
||||
To enable concurrent writes, you can configure LanceDB to use a DynamoDB table
|
||||
as a commit store. This table will be used to coordinate writes between
|
||||
different processes. To enable this feature, you must modify your connection
|
||||
URI to use the `s3+ddb` scheme and add a query parameter `ddbTableName` with the
|
||||
name of the table to use.
|
||||
|
||||
=== "Python"
|
||||
|
||||
=== "Sync API"
|
||||
|
||||
```python
|
||||
import lancedb
|
||||
db = lancedb.connect(
|
||||
"s3+ddb://bucket/path?ddbTableName=my-dynamodb-table",
|
||||
)
|
||||
```
|
||||
=== "Async API"
|
||||
|
||||
```python
|
||||
import lancedb
|
||||
async_db = await lancedb.connect_async(
|
||||
"s3+ddb://bucket/path?ddbTableName=my-dynamodb-table",
|
||||
)
|
||||
```
|
||||
|
||||
=== "JavaScript"
|
||||
|
||||
```javascript
|
||||
const lancedb = require("lancedb");
|
||||
|
||||
const db = await lancedb.connect(
|
||||
"s3+ddb://bucket/path?ddbTableName=my-dynamodb-table",
|
||||
);
|
||||
```
|
||||
|
||||
The DynamoDB table must be created with the following schema:
|
||||
|
||||
- Hash key: `base_uri` (string)
|
||||
- Range key: `version` (number)
|
||||
|
||||
You can create this programmatically with:
|
||||
|
||||
=== "Python"
|
||||
|
||||
<!-- skip-test -->
|
||||
```python
|
||||
import boto3
|
||||
|
||||
dynamodb = boto3.client("dynamodb")
|
||||
table = dynamodb.create_table(
|
||||
TableName=table_name,
|
||||
KeySchema=[
|
||||
{"AttributeName": "base_uri", "KeyType": "HASH"},
|
||||
{"AttributeName": "version", "KeyType": "RANGE"},
|
||||
],
|
||||
AttributeDefinitions=[
|
||||
{"AttributeName": "base_uri", "AttributeType": "S"},
|
||||
{"AttributeName": "version", "AttributeType": "N"},
|
||||
],
|
||||
ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1},
|
||||
)
|
||||
```
|
||||
|
||||
=== "JavaScript"
|
||||
|
||||
<!-- skip-test -->
|
||||
```javascript
|
||||
import {
|
||||
CreateTableCommand,
|
||||
DynamoDBClient,
|
||||
} from "@aws-sdk/client-dynamodb";
|
||||
|
||||
const dynamodb = new DynamoDBClient({
|
||||
region: CONFIG.awsRegion,
|
||||
credentials: {
|
||||
accessKeyId: CONFIG.awsAccessKeyId,
|
||||
secretAccessKey: CONFIG.awsSecretAccessKey,
|
||||
},
|
||||
endpoint: CONFIG.awsEndpoint,
|
||||
});
|
||||
const command = new CreateTableCommand({
|
||||
TableName: table_name,
|
||||
AttributeDefinitions: [
|
||||
{
|
||||
AttributeName: "base_uri",
|
||||
AttributeType: "S",
|
||||
},
|
||||
{
|
||||
AttributeName: "version",
|
||||
AttributeType: "N",
|
||||
},
|
||||
],
|
||||
KeySchema: [
|
||||
{ AttributeName: "base_uri", KeyType: "HASH" },
|
||||
{ AttributeName: "version", KeyType: "RANGE" },
|
||||
],
|
||||
ProvisionedThroughput: {
|
||||
ReadCapacityUnits: 1,
|
||||
WriteCapacityUnits: 1,
|
||||
},
|
||||
});
|
||||
await client.send(command);
|
||||
```
|
||||
|
||||
|
||||
#### S3-compatible stores
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ AnalyzeExec verbose=true, metrics=[]
|
||||
### execute()
|
||||
|
||||
```ts
|
||||
protected execute(options?): RecordBatchIterator
|
||||
protected execute(options?): AsyncGenerator<RecordBatch<any>, void, unknown>
|
||||
```
|
||||
|
||||
Execute the query and return the results as an
|
||||
@@ -91,7 +91,7 @@ Execute the query and return the results as an
|
||||
|
||||
#### Returns
|
||||
|
||||
[`RecordBatchIterator`](RecordBatchIterator.md)
|
||||
`AsyncGenerator`<`RecordBatch`<`any`>, `void`, `unknown`>
|
||||
|
||||
#### See
|
||||
|
||||
@@ -343,6 +343,29 @@ This is useful for pagination.
|
||||
|
||||
***
|
||||
|
||||
### outputSchema()
|
||||
|
||||
```ts
|
||||
outputSchema(): Promise<Schema<any>>
|
||||
```
|
||||
|
||||
Returns the schema of the output that will be returned by this query.
|
||||
|
||||
This can be used to inspect the types and names of the columns that will be
|
||||
returned by the query before executing it.
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<`Schema`<`any`>>
|
||||
|
||||
An Arrow Schema describing the output columns.
|
||||
|
||||
#### Inherited from
|
||||
|
||||
`StandardQueryBase.outputSchema`
|
||||
|
||||
***
|
||||
|
||||
### select()
|
||||
|
||||
```ts
|
||||
|
||||
@@ -81,7 +81,7 @@ AnalyzeExec verbose=true, metrics=[]
|
||||
### execute()
|
||||
|
||||
```ts
|
||||
protected execute(options?): RecordBatchIterator
|
||||
protected execute(options?): AsyncGenerator<RecordBatch<any>, void, unknown>
|
||||
```
|
||||
|
||||
Execute the query and return the results as an
|
||||
@@ -92,7 +92,7 @@ Execute the query and return the results as an
|
||||
|
||||
#### Returns
|
||||
|
||||
[`RecordBatchIterator`](RecordBatchIterator.md)
|
||||
`AsyncGenerator`<`RecordBatch`<`any`>, `void`, `unknown`>
|
||||
|
||||
#### See
|
||||
|
||||
@@ -140,6 +140,25 @@ const plan = await table.query().nearestTo([0.5, 0.2]).explainPlan();
|
||||
|
||||
***
|
||||
|
||||
### outputSchema()
|
||||
|
||||
```ts
|
||||
outputSchema(): Promise<Schema<any>>
|
||||
```
|
||||
|
||||
Returns the schema of the output that will be returned by this query.
|
||||
|
||||
This can be used to inspect the types and names of the columns that will be
|
||||
returned by the query before executing it.
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<`Schema`<`any`>>
|
||||
|
||||
An Arrow Schema describing the output columns.
|
||||
|
||||
***
|
||||
|
||||
### select()
|
||||
|
||||
```ts
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / RecordBatchIterator
|
||||
|
||||
# Class: RecordBatchIterator
|
||||
|
||||
## Implements
|
||||
|
||||
- `AsyncIterator`<`RecordBatch`>
|
||||
|
||||
## Constructors
|
||||
|
||||
### new RecordBatchIterator()
|
||||
|
||||
```ts
|
||||
new RecordBatchIterator(promise?): RecordBatchIterator
|
||||
```
|
||||
|
||||
#### Parameters
|
||||
|
||||
* **promise?**: `Promise`<`RecordBatchIterator`>
|
||||
|
||||
#### Returns
|
||||
|
||||
[`RecordBatchIterator`](RecordBatchIterator.md)
|
||||
|
||||
## Methods
|
||||
|
||||
### next()
|
||||
|
||||
```ts
|
||||
next(): Promise<IteratorResult<RecordBatch<any>, any>>
|
||||
```
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<`IteratorResult`<`RecordBatch`<`any`>, `any`>>
|
||||
|
||||
#### Implementation of
|
||||
|
||||
`AsyncIterator.next`
|
||||
@@ -76,7 +76,7 @@ AnalyzeExec verbose=true, metrics=[]
|
||||
### execute()
|
||||
|
||||
```ts
|
||||
protected execute(options?): RecordBatchIterator
|
||||
protected execute(options?): AsyncGenerator<RecordBatch<any>, void, unknown>
|
||||
```
|
||||
|
||||
Execute the query and return the results as an
|
||||
@@ -87,7 +87,7 @@ Execute the query and return the results as an
|
||||
|
||||
#### Returns
|
||||
|
||||
[`RecordBatchIterator`](RecordBatchIterator.md)
|
||||
`AsyncGenerator`<`RecordBatch`<`any`>, `void`, `unknown`>
|
||||
|
||||
#### See
|
||||
|
||||
@@ -143,6 +143,29 @@ const plan = await table.query().nearestTo([0.5, 0.2]).explainPlan();
|
||||
|
||||
***
|
||||
|
||||
### outputSchema()
|
||||
|
||||
```ts
|
||||
outputSchema(): Promise<Schema<any>>
|
||||
```
|
||||
|
||||
Returns the schema of the output that will be returned by this query.
|
||||
|
||||
This can be used to inspect the types and names of the columns that will be
|
||||
returned by the query before executing it.
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<`Schema`<`any`>>
|
||||
|
||||
An Arrow Schema describing the output columns.
|
||||
|
||||
#### Inherited from
|
||||
|
||||
[`QueryBase`](QueryBase.md).[`outputSchema`](QueryBase.md#outputschema)
|
||||
|
||||
***
|
||||
|
||||
### select()
|
||||
|
||||
```ts
|
||||
|
||||
@@ -221,7 +221,7 @@ also increase the latency of your query. The default value is 1.5*limit.
|
||||
### execute()
|
||||
|
||||
```ts
|
||||
protected execute(options?): RecordBatchIterator
|
||||
protected execute(options?): AsyncGenerator<RecordBatch<any>, void, unknown>
|
||||
```
|
||||
|
||||
Execute the query and return the results as an
|
||||
@@ -232,7 +232,7 @@ Execute the query and return the results as an
|
||||
|
||||
#### Returns
|
||||
|
||||
[`RecordBatchIterator`](RecordBatchIterator.md)
|
||||
`AsyncGenerator`<`RecordBatch`<`any`>, `void`, `unknown`>
|
||||
|
||||
#### See
|
||||
|
||||
@@ -498,6 +498,29 @@ This is useful for pagination.
|
||||
|
||||
***
|
||||
|
||||
### outputSchema()
|
||||
|
||||
```ts
|
||||
outputSchema(): Promise<Schema<any>>
|
||||
```
|
||||
|
||||
Returns the schema of the output that will be returned by this query.
|
||||
|
||||
This can be used to inspect the types and names of the columns that will be
|
||||
returned by the query before executing it.
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<`Schema`<`any`>>
|
||||
|
||||
An Arrow Schema describing the output columns.
|
||||
|
||||
#### Inherited from
|
||||
|
||||
`StandardQueryBase.outputSchema`
|
||||
|
||||
***
|
||||
|
||||
### postfilter()
|
||||
|
||||
```ts
|
||||
|
||||
19
docs/src/js/functions/RecordBatchIterator.md
Normal file
19
docs/src/js/functions/RecordBatchIterator.md
Normal file
@@ -0,0 +1,19 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / RecordBatchIterator
|
||||
|
||||
# Function: RecordBatchIterator()
|
||||
|
||||
```ts
|
||||
function RecordBatchIterator(promisedInner): AsyncGenerator<RecordBatch<any>, void, unknown>
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
* **promisedInner**: `Promise`<`RecordBatchIterator`>
|
||||
|
||||
## Returns
|
||||
|
||||
`AsyncGenerator`<`RecordBatch`<`any`>, `void`, `unknown`>
|
||||
@@ -7,7 +7,7 @@
|
||||
# Function: permutationBuilder()
|
||||
|
||||
```ts
|
||||
function permutationBuilder(table, destTableName): PermutationBuilder
|
||||
function permutationBuilder(table): PermutationBuilder
|
||||
```
|
||||
|
||||
Create a permutation builder for the given table.
|
||||
@@ -17,9 +17,6 @@ Create a permutation builder for the given table.
|
||||
* **table**: [`Table`](../classes/Table.md)
|
||||
The source table to create a permutation from
|
||||
|
||||
* **destTableName**: `string`
|
||||
The name for the destination permutation table
|
||||
|
||||
## Returns
|
||||
|
||||
[`PermutationBuilder`](../classes/PermutationBuilder.md)
|
||||
|
||||
@@ -32,7 +32,6 @@
|
||||
- [PhraseQuery](classes/PhraseQuery.md)
|
||||
- [Query](classes/Query.md)
|
||||
- [QueryBase](classes/QueryBase.md)
|
||||
- [RecordBatchIterator](classes/RecordBatchIterator.md)
|
||||
- [Session](classes/Session.md)
|
||||
- [StaticHeaderProvider](classes/StaticHeaderProvider.md)
|
||||
- [Table](classes/Table.md)
|
||||
@@ -105,6 +104,7 @@
|
||||
|
||||
## Functions
|
||||
|
||||
- [RecordBatchIterator](functions/RecordBatchIterator.md)
|
||||
- [connect](functions/connect.md)
|
||||
- [makeArrowTable](functions/makeArrowTable.md)
|
||||
- [packBits](functions/packBits.md)
|
||||
|
||||
101
docs/src/js/interfaces/IvfRqOptions.md
Normal file
101
docs/src/js/interfaces/IvfRqOptions.md
Normal file
@@ -0,0 +1,101 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / IvfRqOptions
|
||||
|
||||
# Interface: IvfRqOptions
|
||||
|
||||
## Properties
|
||||
|
||||
### distanceType?
|
||||
|
||||
```ts
|
||||
optional distanceType: "l2" | "cosine" | "dot";
|
||||
```
|
||||
|
||||
Distance type to use to build the index.
|
||||
|
||||
Default value is "l2".
|
||||
|
||||
This is used when training the index to calculate the IVF partitions
|
||||
(vectors are grouped in partitions with similar vectors according to this
|
||||
distance type) and during quantization.
|
||||
|
||||
The distance type used to train an index MUST match the distance type used
|
||||
to search the index. Failure to do so will yield inaccurate results.
|
||||
|
||||
The following distance types are available:
|
||||
|
||||
"l2" - Euclidean distance.
|
||||
"cosine" - Cosine distance.
|
||||
"dot" - Dot product.
|
||||
|
||||
***
|
||||
|
||||
### maxIterations?
|
||||
|
||||
```ts
|
||||
optional maxIterations: number;
|
||||
```
|
||||
|
||||
Max iterations to train IVF kmeans.
|
||||
|
||||
When training an IVF index we use kmeans to calculate the partitions. This parameter
|
||||
controls how many iterations of kmeans to run.
|
||||
|
||||
The default value is 50.
|
||||
|
||||
***
|
||||
|
||||
### numBits?
|
||||
|
||||
```ts
|
||||
optional numBits: number;
|
||||
```
|
||||
|
||||
Number of bits per dimension for residual quantization.
|
||||
|
||||
This value controls how much each residual component is compressed. The more
|
||||
bits, the more accurate the index will be but the slower search. Typical values
|
||||
are small integers; the default is 1 bit per dimension.
|
||||
|
||||
***
|
||||
|
||||
### numPartitions?
|
||||
|
||||
```ts
|
||||
optional numPartitions: number;
|
||||
```
|
||||
|
||||
The number of IVF partitions to create.
|
||||
|
||||
This value should generally scale with the number of rows in the dataset.
|
||||
By default the number of partitions is the square root of the number of
|
||||
rows.
|
||||
|
||||
If this value is too large then the first part of the search (picking the
|
||||
right partition) will be slow. If this value is too small then the second
|
||||
part of the search (searching within a partition) will be slow.
|
||||
|
||||
***
|
||||
|
||||
### sampleRate?
|
||||
|
||||
```ts
|
||||
optional sampleRate: number;
|
||||
```
|
||||
|
||||
The number of vectors, per partition, to sample when training IVF kmeans.
|
||||
|
||||
When an IVF index is trained, we need to calculate partitions. These are groups
|
||||
of vectors that are similar to each other. To do this we use an algorithm called kmeans.
|
||||
|
||||
Running kmeans on a large dataset can be slow. To speed this up we run kmeans on a
|
||||
random sample of the data. This parameter controls the size of the sample. The total
|
||||
number of vectors used to train the index is `sample_rate * num_partitions`.
|
||||
|
||||
Increasing this value might improve the quality of the index but in most cases the
|
||||
default should be sufficient.
|
||||
|
||||
The default value is 256.
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.22.2-final.0</version>
|
||||
<version>0.22.3-beta.3</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.22.2-final.0</version>
|
||||
<version>0.22.3-beta.3</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.22.2-final.0</version>
|
||||
<version>0.22.3-beta.3</version>
|
||||
<packaging>pom</packaging>
|
||||
<name>${project.artifactId}</name>
|
||||
<description>LanceDB Java SDK Parent POM</description>
|
||||
|
||||
13
nodejs/AGENTS.md
Normal file
13
nodejs/AGENTS.md
Normal file
@@ -0,0 +1,13 @@
|
||||
These are the typescript bindings of LanceDB.
|
||||
The core Rust library is in the `../rust/lancedb` directory, the rust binding
|
||||
code is in the `src/` directory and the typescript bindings are in
|
||||
the `lancedb/` directory.
|
||||
|
||||
Whenever you change the Rust code, you will need to recompile: `npm run build`.
|
||||
|
||||
Common commands:
|
||||
* Build: `npm run build`
|
||||
* Lint: `npm run lint`
|
||||
* Fix lints: `npm run lint-fix`
|
||||
* Test: `npm test`
|
||||
* Run single test file: `npm test __test__/arrow.test.ts`
|
||||
@@ -1,13 +0,0 @@
|
||||
These are the typescript bindings of LanceDB.
|
||||
The core Rust library is in the `../rust/lancedb` directory, the rust binding
|
||||
code is in the `src/` directory and the typescript bindings are in
|
||||
the `lancedb/` directory.
|
||||
|
||||
Whenever you change the Rust code, you will need to recompile: `npm run build`.
|
||||
|
||||
Common commands:
|
||||
* Build: `npm run build`
|
||||
* Lint: `npm run lint`
|
||||
* Fix lints: `npm run lint-fix`
|
||||
* Test: `npm test`
|
||||
* Run single test file: `npm test __test__/arrow.test.ts`
|
||||
1
nodejs/CLAUDE.md
Symbolic link
1
nodejs/CLAUDE.md
Symbolic link
@@ -0,0 +1 @@
|
||||
AGENTS.md
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "lancedb-nodejs"
|
||||
edition.workspace = true
|
||||
version = "0.22.2"
|
||||
version = "0.22.3-beta.3"
|
||||
license.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
|
||||
@@ -38,23 +38,22 @@ describe("PermutationBuilder", () => {
|
||||
});
|
||||
|
||||
test("should create permutation builder", () => {
|
||||
const builder = permutationBuilder(table, "permutation_table");
|
||||
const builder = permutationBuilder(table);
|
||||
expect(builder).toBeDefined();
|
||||
});
|
||||
|
||||
test("should execute basic permutation", async () => {
|
||||
const builder = permutationBuilder(table, "permutation_table");
|
||||
const builder = permutationBuilder(table);
|
||||
const permutationTable = await builder.execute();
|
||||
|
||||
expect(permutationTable).toBeDefined();
|
||||
expect(permutationTable.name).toBe("permutation_table");
|
||||
|
||||
const rowCount = await permutationTable.countRows();
|
||||
expect(rowCount).toBe(10);
|
||||
});
|
||||
|
||||
test("should create permutation with random splits", async () => {
|
||||
const builder = permutationBuilder(table, "permutation_table").splitRandom({
|
||||
const builder = permutationBuilder(table).splitRandom({
|
||||
ratios: [1.0],
|
||||
seed: 42,
|
||||
});
|
||||
@@ -65,7 +64,7 @@ describe("PermutationBuilder", () => {
|
||||
});
|
||||
|
||||
test("should create permutation with percentage splits", async () => {
|
||||
const builder = permutationBuilder(table, "permutation_table").splitRandom({
|
||||
const builder = permutationBuilder(table).splitRandom({
|
||||
ratios: [0.3, 0.7],
|
||||
seed: 42,
|
||||
});
|
||||
@@ -84,7 +83,7 @@ describe("PermutationBuilder", () => {
|
||||
});
|
||||
|
||||
test("should create permutation with count splits", async () => {
|
||||
const builder = permutationBuilder(table, "permutation_table").splitRandom({
|
||||
const builder = permutationBuilder(table).splitRandom({
|
||||
counts: [3, 7],
|
||||
seed: 42,
|
||||
});
|
||||
@@ -102,7 +101,7 @@ describe("PermutationBuilder", () => {
|
||||
});
|
||||
|
||||
test("should create permutation with hash splits", async () => {
|
||||
const builder = permutationBuilder(table, "permutation_table").splitHash({
|
||||
const builder = permutationBuilder(table).splitHash({
|
||||
columns: ["id"],
|
||||
splitWeights: [50, 50],
|
||||
discardWeight: 0,
|
||||
@@ -122,10 +121,9 @@ describe("PermutationBuilder", () => {
|
||||
});
|
||||
|
||||
test("should create permutation with sequential splits", async () => {
|
||||
const builder = permutationBuilder(
|
||||
table,
|
||||
"permutation_table",
|
||||
).splitSequential({ ratios: [0.5, 0.5] });
|
||||
const builder = permutationBuilder(table).splitSequential({
|
||||
ratios: [0.5, 0.5],
|
||||
});
|
||||
|
||||
const permutationTable = await builder.execute();
|
||||
const rowCount = await permutationTable.countRows();
|
||||
@@ -140,10 +138,7 @@ describe("PermutationBuilder", () => {
|
||||
});
|
||||
|
||||
test("should create permutation with calculated splits", async () => {
|
||||
const builder = permutationBuilder(
|
||||
table,
|
||||
"permutation_table",
|
||||
).splitCalculated("id % 2");
|
||||
const builder = permutationBuilder(table).splitCalculated("id % 2");
|
||||
|
||||
const permutationTable = await builder.execute();
|
||||
const rowCount = await permutationTable.countRows();
|
||||
@@ -159,7 +154,7 @@ describe("PermutationBuilder", () => {
|
||||
});
|
||||
|
||||
test("should create permutation with shuffle", async () => {
|
||||
const builder = permutationBuilder(table, "permutation_table").shuffle({
|
||||
const builder = permutationBuilder(table).shuffle({
|
||||
seed: 42,
|
||||
});
|
||||
|
||||
@@ -169,7 +164,7 @@ describe("PermutationBuilder", () => {
|
||||
});
|
||||
|
||||
test("should create permutation with shuffle and clump size", async () => {
|
||||
const builder = permutationBuilder(table, "permutation_table").shuffle({
|
||||
const builder = permutationBuilder(table).shuffle({
|
||||
seed: 42,
|
||||
clumpSize: 2,
|
||||
});
|
||||
@@ -180,9 +175,7 @@ describe("PermutationBuilder", () => {
|
||||
});
|
||||
|
||||
test("should create permutation with filter", async () => {
|
||||
const builder = permutationBuilder(table, "permutation_table").filter(
|
||||
"value > 50",
|
||||
);
|
||||
const builder = permutationBuilder(table).filter("value > 50");
|
||||
|
||||
const permutationTable = await builder.execute();
|
||||
const rowCount = await permutationTable.countRows();
|
||||
@@ -190,7 +183,7 @@ describe("PermutationBuilder", () => {
|
||||
});
|
||||
|
||||
test("should chain multiple operations", async () => {
|
||||
const builder = permutationBuilder(table, "permutation_table")
|
||||
const builder = permutationBuilder(table)
|
||||
.filter("value <= 80")
|
||||
.splitRandom({ ratios: [0.5, 0.5], seed: 42 })
|
||||
.shuffle({ seed: 123 });
|
||||
@@ -209,7 +202,7 @@ describe("PermutationBuilder", () => {
|
||||
});
|
||||
|
||||
test("should throw error for invalid split arguments", () => {
|
||||
const builder = permutationBuilder(table, "permutation_table");
|
||||
const builder = permutationBuilder(table);
|
||||
|
||||
// Test no arguments provided
|
||||
expect(() => builder.splitRandom({})).toThrow(
|
||||
@@ -223,7 +216,7 @@ describe("PermutationBuilder", () => {
|
||||
});
|
||||
|
||||
test("should throw error when builder is consumed", async () => {
|
||||
const builder = permutationBuilder(table, "permutation_table");
|
||||
const builder = permutationBuilder(table);
|
||||
|
||||
// Execute once
|
||||
await builder.execute();
|
||||
|
||||
111
nodejs/__test__/query.test.ts
Normal file
111
nodejs/__test__/query.test.ts
Normal file
@@ -0,0 +1,111 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import * as tmp from "tmp";
|
||||
|
||||
import { type Table, connect } from "../lancedb";
|
||||
import {
|
||||
Field,
|
||||
FixedSizeList,
|
||||
Float32,
|
||||
Int64,
|
||||
Schema,
|
||||
Utf8,
|
||||
makeArrowTable,
|
||||
} from "../lancedb/arrow";
|
||||
import { Index } from "../lancedb/indices";
|
||||
|
||||
describe("Query outputSchema", () => {
|
||||
let tmpDir: tmp.DirResult;
|
||||
let table: Table;
|
||||
|
||||
beforeEach(async () => {
|
||||
tmpDir = tmp.dirSync({ unsafeCleanup: true });
|
||||
const db = await connect(tmpDir.name);
|
||||
|
||||
// Create table with explicit schema to ensure proper types
|
||||
const schema = new Schema([
|
||||
new Field("a", new Int64(), true),
|
||||
new Field("text", new Utf8(), true),
|
||||
new Field(
|
||||
"vec",
|
||||
new FixedSizeList(2, new Field("item", new Float32())),
|
||||
true,
|
||||
),
|
||||
]);
|
||||
|
||||
const data = makeArrowTable(
|
||||
[
|
||||
{ a: 1n, text: "foo", vec: [1, 2] },
|
||||
{ a: 2n, text: "bar", vec: [3, 4] },
|
||||
{ a: 3n, text: "baz", vec: [5, 6] },
|
||||
],
|
||||
{ schema },
|
||||
);
|
||||
table = await db.createTable("test", data);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
tmpDir.removeCallback();
|
||||
});
|
||||
|
||||
it("should return schema for plain query", async () => {
|
||||
const schema = await table.query().outputSchema();
|
||||
|
||||
expect(schema.fields.length).toBe(3);
|
||||
expect(schema.fields.map((f) => f.name)).toEqual(["a", "text", "vec"]);
|
||||
expect(schema.fields[0].type.toString()).toBe("Int64");
|
||||
expect(schema.fields[1].type.toString()).toBe("Utf8");
|
||||
});
|
||||
|
||||
it("should return schema with dynamic projection", async () => {
|
||||
const schema = await table.query().select({ bl: "a * 2" }).outputSchema();
|
||||
|
||||
expect(schema.fields.length).toBe(1);
|
||||
expect(schema.fields[0].name).toBe("bl");
|
||||
expect(schema.fields[0].type.toString()).toBe("Int64");
|
||||
});
|
||||
|
||||
it("should return schema for vector search with _distance column", async () => {
|
||||
const schema = await table
|
||||
.vectorSearch([1, 2])
|
||||
.select(["a"])
|
||||
.outputSchema();
|
||||
|
||||
expect(schema.fields.length).toBe(2);
|
||||
expect(schema.fields.map((f) => f.name)).toEqual(["a", "_distance"]);
|
||||
expect(schema.fields[0].type.toString()).toBe("Int64");
|
||||
expect(schema.fields[1].type.toString()).toBe("Float32");
|
||||
});
|
||||
|
||||
it("should return schema for FTS search", async () => {
|
||||
await table.createIndex("text", { config: Index.fts() });
|
||||
|
||||
const schema = await table
|
||||
.search("foo", "fts")
|
||||
.select(["a"])
|
||||
.outputSchema();
|
||||
|
||||
// FTS search includes _score column in addition to selected columns
|
||||
expect(schema.fields.length).toBe(2);
|
||||
expect(schema.fields.map((f) => f.name)).toContain("a");
|
||||
expect(schema.fields.map((f) => f.name)).toContain("_score");
|
||||
const aField = schema.fields.find((f) => f.name === "a");
|
||||
expect(aField?.type.toString()).toBe("Int64");
|
||||
});
|
||||
|
||||
it("should return schema for take query", async () => {
|
||||
const schema = await table.takeOffsets([0]).select(["text"]).outputSchema();
|
||||
|
||||
expect(schema.fields.length).toBe(1);
|
||||
expect(schema.fields[0].name).toBe("text");
|
||||
expect(schema.fields[0].type.toString()).toBe("Utf8");
|
||||
});
|
||||
|
||||
it("should return full schema when no select is specified", async () => {
|
||||
const schema = await table.query().outputSchema();
|
||||
|
||||
// Should return all columns
|
||||
expect(schema.fields.length).toBe(3);
|
||||
});
|
||||
});
|
||||
@@ -161,7 +161,6 @@ export class PermutationBuilder {
|
||||
* Create a permutation builder for the given table.
|
||||
*
|
||||
* @param table - The source table to create a permutation from
|
||||
* @param destTableName - The name for the destination permutation table
|
||||
* @returns A PermutationBuilder instance
|
||||
* @example
|
||||
* ```ts
|
||||
@@ -172,17 +171,13 @@ export class PermutationBuilder {
|
||||
* const trainingTable = await builder.execute();
|
||||
* ```
|
||||
*/
|
||||
export function permutationBuilder(
|
||||
table: Table,
|
||||
destTableName: string,
|
||||
): PermutationBuilder {
|
||||
export function permutationBuilder(table: Table): PermutationBuilder {
|
||||
// Extract the inner native table from the TypeScript wrapper
|
||||
const localTable = table as LocalTable;
|
||||
// Access inner through type assertion since it's private
|
||||
const nativeBuilder = nativePermutationBuilder(
|
||||
// biome-ignore lint/suspicious/noExplicitAny: need access to private variable
|
||||
(localTable as any).inner,
|
||||
destTableName,
|
||||
);
|
||||
return new PermutationBuilder(nativeBuilder);
|
||||
}
|
||||
|
||||
@@ -20,35 +20,25 @@ import {
|
||||
} from "./native";
|
||||
import { Reranker } from "./rerankers";
|
||||
|
||||
export class RecordBatchIterator implements AsyncIterator<RecordBatch> {
|
||||
private promisedInner?: Promise<NativeBatchIterator>;
|
||||
private inner?: NativeBatchIterator;
|
||||
export async function* RecordBatchIterator(
|
||||
promisedInner: Promise<NativeBatchIterator>,
|
||||
) {
|
||||
const inner = await promisedInner;
|
||||
|
||||
constructor(promise?: Promise<NativeBatchIterator>) {
|
||||
// TODO: check promise reliably so we dont need to pass two arguments.
|
||||
this.promisedInner = promise;
|
||||
if (inner === undefined) {
|
||||
throw new Error("Invalid iterator state");
|
||||
}
|
||||
|
||||
// biome-ignore lint/suspicious/noExplicitAny: skip
|
||||
async next(): Promise<IteratorResult<RecordBatch<any>>> {
|
||||
if (this.inner === undefined) {
|
||||
this.inner = await this.promisedInner;
|
||||
}
|
||||
if (this.inner === undefined) {
|
||||
throw new Error("Invalid iterator state state");
|
||||
}
|
||||
const n = await this.inner.next();
|
||||
if (n == null) {
|
||||
return Promise.resolve({ done: true, value: null });
|
||||
}
|
||||
const tbl = tableFromIPC(n);
|
||||
if (tbl.batches.length != 1) {
|
||||
for (let buffer = await inner.next(); buffer; buffer = await inner.next()) {
|
||||
const { batches } = tableFromIPC(buffer);
|
||||
|
||||
if (batches.length !== 1) {
|
||||
throw new Error("Expected only one batch");
|
||||
}
|
||||
return Promise.resolve({ done: false, value: tbl.batches[0] });
|
||||
|
||||
yield batches[0];
|
||||
}
|
||||
}
|
||||
/* eslint-enable */
|
||||
|
||||
class RecordBatchIterable<
|
||||
NativeQueryType extends NativeQuery | NativeVectorQuery | NativeTakeQuery,
|
||||
@@ -64,7 +54,7 @@ class RecordBatchIterable<
|
||||
|
||||
// biome-ignore lint/suspicious/noExplicitAny: skip
|
||||
[Symbol.asyncIterator](): AsyncIterator<RecordBatch<any>, any, undefined> {
|
||||
return new RecordBatchIterator(
|
||||
return RecordBatchIterator(
|
||||
this.inner.execute(this.options?.maxBatchLength, this.options?.timeoutMs),
|
||||
);
|
||||
}
|
||||
@@ -231,10 +221,8 @@ export class QueryBase<
|
||||
* single query)
|
||||
*
|
||||
*/
|
||||
protected execute(
|
||||
options?: Partial<QueryExecutionOptions>,
|
||||
): RecordBatchIterator {
|
||||
return new RecordBatchIterator(this.nativeExecute(options));
|
||||
protected execute(options?: Partial<QueryExecutionOptions>) {
|
||||
return RecordBatchIterator(this.nativeExecute(options));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -242,8 +230,7 @@ export class QueryBase<
|
||||
*/
|
||||
// biome-ignore lint/suspicious/noExplicitAny: skip
|
||||
[Symbol.asyncIterator](): AsyncIterator<RecordBatch<any>> {
|
||||
const promise = this.nativeExecute();
|
||||
return new RecordBatchIterator(promise);
|
||||
return RecordBatchIterator(this.nativeExecute());
|
||||
}
|
||||
|
||||
/** Collect the results as an Arrow @see {@link ArrowTable}. */
|
||||
@@ -326,6 +313,25 @@ export class QueryBase<
|
||||
return this.inner.analyzePlan();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the schema of the output that will be returned by this query.
|
||||
*
|
||||
* This can be used to inspect the types and names of the columns that will be
|
||||
* returned by the query before executing it.
|
||||
*
|
||||
* @returns An Arrow Schema describing the output columns.
|
||||
*/
|
||||
async outputSchema(): Promise<import("./arrow").Schema> {
|
||||
let schemaBuffer: Buffer;
|
||||
if (this.inner instanceof Promise) {
|
||||
schemaBuffer = await this.inner.then((inner) => inner.outputSchema());
|
||||
} else {
|
||||
schemaBuffer = await this.inner.outputSchema();
|
||||
}
|
||||
const schema = tableFromIPC(schemaBuffer).schema;
|
||||
return schema;
|
||||
}
|
||||
}
|
||||
|
||||
export class StandardQueryBase<
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-arm64",
|
||||
"version": "0.22.2",
|
||||
"version": "0.22.3-beta.3",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.darwin-arm64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-x64",
|
||||
"version": "0.22.2",
|
||||
"version": "0.22.3-beta.3",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.darwin-x64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||
"version": "0.22.2",
|
||||
"version": "0.22.3-beta.3",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||
"version": "0.22.2",
|
||||
"version": "0.22.3-beta.3",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||
"version": "0.22.2",
|
||||
"version": "0.22.3-beta.3",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||
"version": "0.22.2",
|
||||
"version": "0.22.3-beta.3",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||
"version": "0.22.2",
|
||||
"version": "0.22.3-beta.3",
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||
"version": "0.22.2",
|
||||
"version": "0.22.3-beta.3",
|
||||
"os": ["win32"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.win32-x64-msvc.node",
|
||||
|
||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.22.2",
|
||||
"version": "0.22.3-beta.3",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.22.2",
|
||||
"version": "0.22.3-beta.3",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
"ann"
|
||||
],
|
||||
"private": false,
|
||||
"version": "0.22.2",
|
||||
"version": "0.22.3-beta.3",
|
||||
"main": "dist/index.js",
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
|
||||
@@ -5,8 +5,8 @@ use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::{error::NapiErrorExt, table::Table};
|
||||
use lancedb::dataloader::{
|
||||
permutation::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
|
||||
split::{SplitSizes, SplitStrategy},
|
||||
permutation::builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
|
||||
permutation::split::{SplitSizes, SplitStrategy},
|
||||
};
|
||||
use napi_derive::napi;
|
||||
|
||||
@@ -40,7 +40,6 @@ pub struct ShuffleOptions {
|
||||
|
||||
pub struct PermutationBuilderState {
|
||||
pub builder: Option<LancePermutationBuilder>,
|
||||
pub dest_table_name: String,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
@@ -49,11 +48,10 @@ pub struct PermutationBuilder {
|
||||
}
|
||||
|
||||
impl PermutationBuilder {
|
||||
pub fn new(builder: LancePermutationBuilder, dest_table_name: String) -> Self {
|
||||
pub fn new(builder: LancePermutationBuilder) -> Self {
|
||||
Self {
|
||||
state: Arc::new(Mutex::new(PermutationBuilderState {
|
||||
builder: Some(builder),
|
||||
dest_table_name,
|
||||
})),
|
||||
}
|
||||
}
|
||||
@@ -191,32 +189,26 @@ impl PermutationBuilder {
|
||||
/// Execute the permutation builder and create the table
|
||||
#[napi]
|
||||
pub async fn execute(&self) -> napi::Result<Table> {
|
||||
let (builder, dest_table_name) = {
|
||||
let builder = {
|
||||
let mut state = self.state.lock().unwrap();
|
||||
let builder = state
|
||||
state
|
||||
.builder
|
||||
.take()
|
||||
.ok_or_else(|| napi::Error::from_reason("Builder already consumed"))?;
|
||||
|
||||
let dest_table_name = std::mem::take(&mut state.dest_table_name);
|
||||
(builder, dest_table_name)
|
||||
.ok_or_else(|| napi::Error::from_reason("Builder already consumed"))?
|
||||
};
|
||||
|
||||
let table = builder.build(&dest_table_name).await.default_error()?;
|
||||
let table = builder.build().await.default_error()?;
|
||||
Ok(Table::new(table))
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a permutation builder for the given table
|
||||
#[napi]
|
||||
pub fn permutation_builder(
|
||||
table: &crate::table::Table,
|
||||
dest_table_name: String,
|
||||
) -> napi::Result<PermutationBuilder> {
|
||||
use lancedb::dataloader::permutation::PermutationBuilder as LancePermutationBuilder;
|
||||
pub fn permutation_builder(table: &crate::table::Table) -> napi::Result<PermutationBuilder> {
|
||||
use lancedb::dataloader::permutation::builder::PermutationBuilder as LancePermutationBuilder;
|
||||
|
||||
let inner_table = table.inner_ref()?.clone();
|
||||
let inner_builder = LancePermutationBuilder::new(inner_table);
|
||||
|
||||
Ok(PermutationBuilder::new(inner_builder, dest_table_name))
|
||||
Ok(PermutationBuilder::new(inner_builder))
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ use crate::error::NapiErrorExt;
|
||||
use crate::iterator::RecordBatchIterator;
|
||||
use crate::rerankers::Reranker;
|
||||
use crate::rerankers::RerankerCallbacks;
|
||||
use crate::util::parse_distance_type;
|
||||
use crate::util::{parse_distance_type, schema_to_buffer};
|
||||
|
||||
#[napi]
|
||||
pub struct Query {
|
||||
@@ -88,6 +88,12 @@ impl Query {
|
||||
self.inner = self.inner.clone().with_row_id();
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn output_schema(&self) -> napi::Result<Buffer> {
|
||||
let schema = self.inner.output_schema().await.default_error()?;
|
||||
schema_to_buffer(&schema)
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn execute(
|
||||
&self,
|
||||
@@ -273,6 +279,12 @@ impl VectorQuery {
|
||||
.rerank(Arc::new(Reranker::new(callbacks)));
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn output_schema(&self) -> napi::Result<Buffer> {
|
||||
let schema = self.inner.output_schema().await.default_error()?;
|
||||
schema_to_buffer(&schema)
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn execute(
|
||||
&self,
|
||||
@@ -346,6 +358,12 @@ impl TakeQuery {
|
||||
self.inner = self.inner.clone().with_row_id();
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn output_schema(&self) -> napi::Result<Buffer> {
|
||||
let schema = self.inner.output_schema().await.default_error()?;
|
||||
schema_to_buffer(&schema)
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn execute(
|
||||
&self,
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use arrow_ipc::writer::FileWriter;
|
||||
use lancedb::ipc::ipc_file_to_batches;
|
||||
use lancedb::table::{
|
||||
AddDataMode, ColumnAlteration as LanceColumnAlteration, Duration, NewColumnTransform,
|
||||
@@ -16,6 +15,7 @@ use crate::error::NapiErrorExt;
|
||||
use crate::index::Index;
|
||||
use crate::merge::NativeMergeInsertBuilder;
|
||||
use crate::query::{Query, TakeQuery, VectorQuery};
|
||||
use crate::util::schema_to_buffer;
|
||||
|
||||
#[napi]
|
||||
pub struct Table {
|
||||
@@ -64,14 +64,7 @@ impl Table {
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn schema(&self) -> napi::Result<Buffer> {
|
||||
let schema = self.inner_ref()?.schema().await.default_error()?;
|
||||
let mut writer = FileWriter::try_new(vec![], &schema)
|
||||
.map_err(|e| napi::Error::from_reason(format!("Failed to create IPC file: {}", e)))?;
|
||||
writer
|
||||
.finish()
|
||||
.map_err(|e| napi::Error::from_reason(format!("Failed to finish IPC file: {}", e)))?;
|
||||
Ok(Buffer::from(writer.into_inner().map_err(|e| {
|
||||
napi::Error::from_reason(format!("Failed to get IPC file: {}", e))
|
||||
})?))
|
||||
schema_to_buffer(&schema)
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use arrow_ipc::writer::FileWriter;
|
||||
use arrow_schema::Schema;
|
||||
use lancedb::DistanceType;
|
||||
use napi::bindgen_prelude::Buffer;
|
||||
|
||||
pub fn parse_distance_type(distance_type: impl AsRef<str>) -> napi::Result<DistanceType> {
|
||||
match distance_type.as_ref().to_lowercase().as_str() {
|
||||
@@ -15,3 +18,15 @@ pub fn parse_distance_type(distance_type: impl AsRef<str>) -> napi::Result<Dista
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert an Arrow Schema to an Arrow IPC file buffer
|
||||
pub fn schema_to_buffer(schema: &Schema) -> napi::Result<Buffer> {
|
||||
let mut writer = FileWriter::try_new(vec![], schema)
|
||||
.map_err(|e| napi::Error::from_reason(format!("Failed to create IPC file: {}", e)))?;
|
||||
writer
|
||||
.finish()
|
||||
.map_err(|e| napi::Error::from_reason(format!("Failed to finish IPC file: {}", e)))?;
|
||||
Ok(Buffer::from(writer.into_inner().map_err(|e| {
|
||||
napi::Error::from_reason(format!("Failed to get IPC file: {}", e))
|
||||
})?))
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.25.3-beta.0"
|
||||
current_version = "0.25.3-beta.4"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
19
python/AGENTS.md
Normal file
19
python/AGENTS.md
Normal file
@@ -0,0 +1,19 @@
|
||||
These are the Python bindings of LanceDB.
|
||||
The core Rust library is in the `../rust/lancedb` directory, the rust binding
|
||||
code is in the `src/` directory and the Python bindings are in the `lancedb/` directory.
|
||||
|
||||
Common commands:
|
||||
|
||||
* Build: `make develop`
|
||||
* Format: `make format`
|
||||
* Lint: `make check`
|
||||
* Fix lints: `make fix`
|
||||
* Test: `make test`
|
||||
* Doc test: `make doctest`
|
||||
|
||||
Before committing changes, run lints and then formatting.
|
||||
|
||||
When you change the Rust code, you will need to recompile the Python bindings: `make develop`.
|
||||
|
||||
When you export new types from Rust to Python, you must manually update `python/lancedb/_lancedb.pyi`
|
||||
with the corresponding type hints. You can run `pyright` to check for type errors in the Python code.
|
||||
@@ -1,19 +0,0 @@
|
||||
These are the Python bindings of LanceDB.
|
||||
The core Rust library is in the `../rust/lancedb` directory, the rust binding
|
||||
code is in the `src/` directory and the Python bindings are in the `lancedb/` directory.
|
||||
|
||||
Common commands:
|
||||
|
||||
* Build: `make develop`
|
||||
* Format: `make format`
|
||||
* Lint: `make check`
|
||||
* Fix lints: `make fix`
|
||||
* Test: `make test`
|
||||
* Doc test: `make doctest`
|
||||
|
||||
Before committing changes, run lints and then formatting.
|
||||
|
||||
When you change the Rust code, you will need to recompile the Python bindings: `make develop`.
|
||||
|
||||
When you export new types from Rust to Python, you must manually update `python/lancedb/_lancedb.pyi`
|
||||
with the corresponding type hints. You can run `pyright` to check for type errors in the Python code.
|
||||
1
python/CLAUDE.md
Symbolic link
1
python/CLAUDE.md
Symbolic link
@@ -0,0 +1 @@
|
||||
AGENTS.md
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.25.3-beta.0"
|
||||
version = "0.25.3-beta.4"
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
license.workspace = true
|
||||
|
||||
@@ -123,6 +123,8 @@ class Table:
|
||||
@property
|
||||
def tags(self) -> Tags: ...
|
||||
def query(self) -> Query: ...
|
||||
def take_offsets(self, offsets: list[int]) -> TakeQuery: ...
|
||||
def take_row_ids(self, row_ids: list[int]) -> TakeQuery: ...
|
||||
def vector_search(self) -> VectorQuery: ...
|
||||
|
||||
class Tags:
|
||||
@@ -165,6 +167,7 @@ class Query:
|
||||
def postfilter(self): ...
|
||||
def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ...
|
||||
def nearest_to_text(self, query: dict) -> FTSQuery: ...
|
||||
async def output_schema(self) -> pa.Schema: ...
|
||||
async def execute(
|
||||
self, max_batch_length: Optional[int], timeout: Optional[timedelta]
|
||||
) -> RecordBatchStream: ...
|
||||
@@ -172,6 +175,13 @@ class Query:
|
||||
async def analyze_plan(self) -> str: ...
|
||||
def to_query_request(self) -> PyQueryRequest: ...
|
||||
|
||||
class TakeQuery:
|
||||
def select(self, columns: List[str]): ...
|
||||
def with_row_id(self): ...
|
||||
async def output_schema(self) -> pa.Schema: ...
|
||||
async def execute(self) -> RecordBatchStream: ...
|
||||
def to_query_request(self) -> PyQueryRequest: ...
|
||||
|
||||
class FTSQuery:
|
||||
def where(self, filter: str): ...
|
||||
def select(self, columns: List[str]): ...
|
||||
@@ -183,12 +193,14 @@ class FTSQuery:
|
||||
def get_query(self) -> str: ...
|
||||
def add_query_vector(self, query_vec: pa.Array) -> None: ...
|
||||
def nearest_to(self, query_vec: pa.Array) -> HybridQuery: ...
|
||||
async def output_schema(self) -> pa.Schema: ...
|
||||
async def execute(
|
||||
self, max_batch_length: Optional[int], timeout: Optional[timedelta]
|
||||
) -> RecordBatchStream: ...
|
||||
def to_query_request(self) -> PyQueryRequest: ...
|
||||
|
||||
class VectorQuery:
|
||||
async def output_schema(self) -> pa.Schema: ...
|
||||
async def execute(self) -> RecordBatchStream: ...
|
||||
def where(self, filter: str): ...
|
||||
def select(self, columns: List[str]): ...
|
||||
|
||||
@@ -3,9 +3,11 @@
|
||||
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import List, Union, Optional, Any
|
||||
from logging import warning
|
||||
from typing import List, Union, Optional, Any, Callable
|
||||
import numpy as np
|
||||
import io
|
||||
import warnings
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import EmbeddingFunction
|
||||
@@ -19,35 +21,52 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
An embedding function that uses the ColPali engine for
|
||||
multimodal multi-vector embeddings.
|
||||
|
||||
This embedding function supports ColQwen2.5 models, producing multivector outputs
|
||||
for both text and image inputs. The output embeddings are lists of vectors, each
|
||||
vector being 128-dimensional by default, represented as List[List[float]].
|
||||
This embedding function supports ColPali models, producing multivector outputs
|
||||
for both text and image inputs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name : str
|
||||
The name of the model to use (e.g., "Metric-AI/ColQwen2.5-3b-multilingual-v1.0")
|
||||
Supports models based on these engines:
|
||||
- ColPali: "vidore/colpali-v1.3" and others
|
||||
- ColQwen2.5: "Metric-AI/ColQwen2.5-3b-multilingual-v1.0" and others
|
||||
- ColQwen2: "vidore/colqwen2-v1.0" and others
|
||||
- ColSmol: "vidore/colSmol-256M" and others
|
||||
|
||||
device : str
|
||||
The device for inference (default "cuda:0").
|
||||
The device for inference (default "auto").
|
||||
dtype : str
|
||||
Data type for model weights (default "bfloat16").
|
||||
use_token_pooling : bool
|
||||
Whether to use token pooling to reduce embedding size (default True).
|
||||
DEPRECATED. Whether to use token pooling. Use `pooling_strategy` instead.
|
||||
pooling_strategy : str, optional
|
||||
The token pooling strategy to use, by default "hierarchical".
|
||||
- "hierarchical": Progressively pools tokens to reduce sequence length.
|
||||
- "lambda": A simpler pooling that uses a custom `pooling_func`.
|
||||
pooling_func: typing.Callable, optional
|
||||
A function to use for pooling when `pooling_strategy` is "lambda".
|
||||
pool_factor : int
|
||||
Factor to reduce sequence length if token pooling is enabled (default 2).
|
||||
quantization_config : Optional[BitsAndBytesConfig]
|
||||
Quantization configuration for the model. (default None, bitsandbytes needed)
|
||||
batch_size : int
|
||||
Batch size for processing inputs (default 2).
|
||||
offload_folder: str, optional
|
||||
Folder to offload model weights if using CPU offloading (default None). This is
|
||||
useful for large models that do not fit in memory.
|
||||
"""
|
||||
|
||||
model_name: str = "Metric-AI/ColQwen2.5-3b-multilingual-v1.0"
|
||||
device: str = "auto"
|
||||
dtype: str = "bfloat16"
|
||||
use_token_pooling: bool = True
|
||||
pooling_strategy: Optional[str] = "hierarchical"
|
||||
pooling_func: Optional[Any] = None
|
||||
pool_factor: int = 2
|
||||
quantization_config: Optional[Any] = None
|
||||
batch_size: int = 2
|
||||
offload_folder: Optional[str] = None
|
||||
|
||||
_model = None
|
||||
_processor = None
|
||||
@@ -56,15 +75,43 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
|
||||
if not self.use_token_pooling:
|
||||
warnings.warn(
|
||||
"use_token_pooling is deprecated, use pooling_strategy=None instead",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.pooling_strategy = None
|
||||
|
||||
if self.pooling_strategy == "lambda" and self.pooling_func is None:
|
||||
raise ValueError(
|
||||
"pooling_func must be provided when pooling_strategy is 'lambda'"
|
||||
)
|
||||
|
||||
device = self.device
|
||||
if device == "auto":
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
dtype = self.dtype
|
||||
if device == "mps" and dtype == "bfloat16":
|
||||
dtype = "float32" # Avoid NaNs on MPS
|
||||
|
||||
(
|
||||
self._model,
|
||||
self._processor,
|
||||
self._token_pooler,
|
||||
) = self._load_model(
|
||||
self.model_name,
|
||||
self.dtype,
|
||||
self.device,
|
||||
self.use_token_pooling,
|
||||
dtype,
|
||||
device,
|
||||
self.pooling_strategy,
|
||||
self.pooling_func,
|
||||
self.quantization_config,
|
||||
)
|
||||
|
||||
@@ -74,16 +121,26 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
model_name: str,
|
||||
dtype: str,
|
||||
device: str,
|
||||
use_token_pooling: bool,
|
||||
pooling_strategy: Optional[str],
|
||||
pooling_func: Optional[Callable],
|
||||
quantization_config: Optional[Any],
|
||||
):
|
||||
"""
|
||||
Initialize and cache the ColPali model, processor, and token pooler.
|
||||
"""
|
||||
if device.startswith("mps"):
|
||||
# warn some torch ops in late interaction architecture result in nans on mps
|
||||
warning(
|
||||
"MPS device detected. Some operations may result in NaNs. "
|
||||
"If you encounter issues, consider using 'cpu' or 'cuda' devices."
|
||||
)
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
transformers = attempt_import_or_raise("transformers", "transformers")
|
||||
colpali_engine = attempt_import_or_raise("colpali_engine", "colpali_engine")
|
||||
from colpali_engine.compression.token_pooling import HierarchicalTokenPooler
|
||||
from colpali_engine.compression.token_pooling import (
|
||||
HierarchicalTokenPooler,
|
||||
LambdaTokenPooler,
|
||||
)
|
||||
|
||||
if quantization_config is not None:
|
||||
if not isinstance(quantization_config, transformers.BitsAndBytesConfig):
|
||||
@@ -98,21 +155,45 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
else:
|
||||
torch_dtype = torch.float32
|
||||
|
||||
model = colpali_engine.models.ColQwen2_5.from_pretrained(
|
||||
model_class, processor_class = None, None
|
||||
model_name_lower = model_name.lower()
|
||||
if "colqwen2.5" in model_name_lower:
|
||||
model_class = colpali_engine.models.ColQwen2_5
|
||||
processor_class = colpali_engine.models.ColQwen2_5_Processor
|
||||
elif "colsmol" in model_name_lower or "colidefics3" in model_name_lower:
|
||||
model_class = colpali_engine.models.ColIdefics3
|
||||
processor_class = colpali_engine.models.ColIdefics3Processor
|
||||
elif "colqwen" in model_name_lower:
|
||||
model_class = colpali_engine.models.ColQwen2
|
||||
processor_class = colpali_engine.models.ColQwen2Processor
|
||||
elif "colpali" in model_name_lower:
|
||||
model_class = colpali_engine.models.ColPali
|
||||
processor_class = colpali_engine.models.ColPaliProcessor
|
||||
|
||||
if model_class is None:
|
||||
raise ValueError(f"Unsupported model: {model_name}")
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=device,
|
||||
quantization_config=quantization_config
|
||||
if quantization_config is not None
|
||||
else None,
|
||||
attn_implementation="flash_attention_2"
|
||||
if is_flash_attn_2_available()
|
||||
else None,
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
processor = colpali_engine.models.ColQwen2_5_Processor.from_pretrained(
|
||||
model_name
|
||||
)
|
||||
token_pooler = HierarchicalTokenPooler() if use_token_pooling else None
|
||||
model = model.to(device)
|
||||
model = model.to(torch_dtype) # Force cast after moving to device
|
||||
processor = processor_class.from_pretrained(model_name)
|
||||
|
||||
token_pooler = None
|
||||
if pooling_strategy == "hierarchical":
|
||||
token_pooler = HierarchicalTokenPooler()
|
||||
elif pooling_strategy == "lambda":
|
||||
token_pooler = LambdaTokenPooler(pool_func=pooling_func)
|
||||
|
||||
return model, processor, token_pooler
|
||||
|
||||
def ndims(self):
|
||||
@@ -128,7 +209,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
with torch.no_grad():
|
||||
query_embeddings = self._model(**batch_queries)
|
||||
|
||||
if self.use_token_pooling and self._token_pooler is not None:
|
||||
if self.pooling_strategy and self._token_pooler is not None:
|
||||
query_embeddings = self._token_pooler.pool_embeddings(
|
||||
query_embeddings,
|
||||
pool_factor=self.pool_factor,
|
||||
@@ -145,13 +226,20 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
Use token pooling if enabled.
|
||||
"""
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
if self.use_token_pooling and self._token_pooler is not None:
|
||||
embeddings = self._token_pooler.pool_embeddings(
|
||||
embeddings,
|
||||
pool_factor=self.pool_factor,
|
||||
padding=True,
|
||||
padding_side=self._processor.tokenizer.padding_side,
|
||||
)
|
||||
if self.pooling_strategy and self._token_pooler is not None:
|
||||
if self.pooling_strategy == "hierarchical":
|
||||
embeddings = self._token_pooler.pool_embeddings(
|
||||
embeddings,
|
||||
pool_factor=self.pool_factor,
|
||||
padding=True,
|
||||
padding_side=self._processor.tokenizer.padding_side,
|
||||
)
|
||||
elif self.pooling_strategy == "lambda":
|
||||
embeddings = self._token_pooler.pool_embeddings(
|
||||
embeddings,
|
||||
padding=True,
|
||||
padding_side=self._processor.tokenizer.padding_side,
|
||||
)
|
||||
|
||||
if isinstance(embeddings, torch.Tensor):
|
||||
tensors = embeddings.detach().cpu()
|
||||
@@ -179,6 +267,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
)
|
||||
with torch.no_grad():
|
||||
query_embeddings = self._model(**batch_queries)
|
||||
query_embeddings = torch.nan_to_num(query_embeddings)
|
||||
all_embeddings.extend(self._process_embeddings(query_embeddings))
|
||||
return all_embeddings
|
||||
|
||||
@@ -225,6 +314,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
)
|
||||
with torch.no_grad():
|
||||
image_embeddings = self._model(**batch_images)
|
||||
image_embeddings = torch.nan_to_num(image_embeddings)
|
||||
all_embeddings.extend(self._process_embeddings(image_embeddings))
|
||||
return all_embeddings
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
import base64
|
||||
import os
|
||||
from typing import ClassVar, TYPE_CHECKING, List, Union, Any
|
||||
from typing import ClassVar, TYPE_CHECKING, List, Union, Any, Generator
|
||||
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
@@ -19,6 +19,23 @@ from .utils import api_key_not_found_help, IMAGES, TEXT
|
||||
if TYPE_CHECKING:
|
||||
import PIL
|
||||
|
||||
# Token limits for different VoyageAI models
|
||||
VOYAGE_TOTAL_TOKEN_LIMITS = {
|
||||
"voyage-context-3": 32_000,
|
||||
"voyage-3.5-lite": 1_000_000,
|
||||
"voyage-3.5": 320_000,
|
||||
"voyage-3-lite": 120_000,
|
||||
"voyage-3": 120_000,
|
||||
"voyage-multimodal-3": 120_000,
|
||||
"voyage-finance-2": 120_000,
|
||||
"voyage-multilingual-2": 120_000,
|
||||
"voyage-law-2": 120_000,
|
||||
"voyage-code-2": 120_000,
|
||||
}
|
||||
|
||||
# Batch size for embedding requests (max number of items per batch)
|
||||
BATCH_SIZE = 1000
|
||||
|
||||
|
||||
def is_valid_url(text):
|
||||
try:
|
||||
@@ -120,6 +137,9 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
name: str
|
||||
The name of the model to use. List of acceptable models:
|
||||
|
||||
* voyage-context-3
|
||||
* voyage-3.5
|
||||
* voyage-3.5-lite
|
||||
* voyage-3
|
||||
* voyage-3-lite
|
||||
* voyage-multimodal-3
|
||||
@@ -157,25 +177,35 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
name: str
|
||||
client: ClassVar = None
|
||||
text_embedding_models: list = [
|
||||
"voyage-3.5",
|
||||
"voyage-3.5-lite",
|
||||
"voyage-3",
|
||||
"voyage-3-lite",
|
||||
"voyage-finance-2",
|
||||
"voyage-multilingual-2",
|
||||
"voyage-law-2",
|
||||
"voyage-code-2",
|
||||
]
|
||||
multimodal_embedding_models: list = ["voyage-multimodal-3"]
|
||||
contextual_embedding_models: list = ["voyage-context-3"]
|
||||
|
||||
def _is_multimodal_model(self, model_name: str):
|
||||
return (
|
||||
model_name in self.multimodal_embedding_models or "multimodal" in model_name
|
||||
)
|
||||
|
||||
def _is_contextual_model(self, model_name: str):
|
||||
return model_name in self.contextual_embedding_models or "context" in model_name
|
||||
|
||||
def ndims(self):
|
||||
if self.name == "voyage-3-lite":
|
||||
return 512
|
||||
elif self.name == "voyage-code-2":
|
||||
return 1536
|
||||
elif self.name in [
|
||||
"voyage-context-3",
|
||||
"voyage-3.5",
|
||||
"voyage-3.5-lite",
|
||||
"voyage-3",
|
||||
"voyage-multimodal-3",
|
||||
"voyage-finance-2",
|
||||
@@ -207,6 +237,11 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
result = client.multimodal_embed(
|
||||
inputs=[[query]], model=self.name, input_type="query", **kwargs
|
||||
)
|
||||
elif self._is_contextual_model(self.name):
|
||||
result = client.contextualized_embed(
|
||||
inputs=[[query]], model=self.name, input_type="query", **kwargs
|
||||
)
|
||||
result = result.results[0]
|
||||
else:
|
||||
result = client.embed(
|
||||
texts=[query], model=self.name, input_type="query", **kwargs
|
||||
@@ -231,18 +266,164 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
List[np.array]: the list of embeddings
|
||||
"""
|
||||
client = VoyageAIEmbeddingFunction._get_client()
|
||||
|
||||
# For multimodal models, check if inputs contain images
|
||||
if self._is_multimodal_model(self.name):
|
||||
inputs = sanitize_multimodal_input(inputs)
|
||||
result = client.multimodal_embed(
|
||||
inputs=inputs, model=self.name, input_type="document", **kwargs
|
||||
sanitized = sanitize_multimodal_input(inputs)
|
||||
has_images = any(
|
||||
inp["content"][0].get("type") != "text" for inp in sanitized
|
||||
)
|
||||
if has_images:
|
||||
# Use non-batched API for images
|
||||
result = client.multimodal_embed(
|
||||
inputs=sanitized, model=self.name, input_type="document", **kwargs
|
||||
)
|
||||
return result.embeddings
|
||||
# Extract texts for batching
|
||||
inputs = [inp["content"][0]["text"] for inp in sanitized]
|
||||
else:
|
||||
inputs = sanitize_text_input(inputs)
|
||||
result = client.embed(
|
||||
texts=inputs, model=self.name, input_type="document", **kwargs
|
||||
)
|
||||
|
||||
return result.embeddings
|
||||
# Use batching for all text inputs
|
||||
return self._embed_with_batching(
|
||||
client, inputs, input_type="document", **kwargs
|
||||
)
|
||||
|
||||
def _build_batches(
|
||||
self, client, texts: List[str]
|
||||
) -> Generator[List[str], None, None]:
|
||||
"""
|
||||
Generate batches of texts based on token limits using a generator.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
client : voyageai.Client
|
||||
The VoyageAI client instance.
|
||||
texts : List[str]
|
||||
List of texts to batch.
|
||||
|
||||
Yields
|
||||
------
|
||||
List[str]: Batches of texts.
|
||||
"""
|
||||
if not texts:
|
||||
return
|
||||
|
||||
max_tokens_per_batch = VOYAGE_TOTAL_TOKEN_LIMITS.get(self.name, 120_000)
|
||||
current_batch: List[str] = []
|
||||
current_batch_tokens = 0
|
||||
|
||||
# Tokenize all texts in one API call
|
||||
token_lists = client.tokenize(texts, model=self.name)
|
||||
token_counts = [len(token_list) for token_list in token_lists]
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
n_tokens = token_counts[i]
|
||||
|
||||
# Check if adding this text would exceed limits
|
||||
if current_batch and (
|
||||
len(current_batch) >= BATCH_SIZE
|
||||
or (current_batch_tokens + n_tokens > max_tokens_per_batch)
|
||||
):
|
||||
# Yield the current batch and start a new one
|
||||
yield current_batch
|
||||
current_batch = []
|
||||
current_batch_tokens = 0
|
||||
|
||||
current_batch.append(text)
|
||||
current_batch_tokens += n_tokens
|
||||
|
||||
# Yield the last batch (always has at least one text)
|
||||
if current_batch:
|
||||
yield current_batch
|
||||
|
||||
def _get_embed_function(
|
||||
self, client, input_type: str = "document", **kwargs
|
||||
) -> callable:
|
||||
"""
|
||||
Get the appropriate embedding function based on model type.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
client : voyageai.Client
|
||||
The VoyageAI client instance.
|
||||
input_type : str
|
||||
Either "query" or "document"
|
||||
**kwargs
|
||||
Additional arguments to pass to the embedding API
|
||||
|
||||
Returns
|
||||
-------
|
||||
callable: A function that takes a batch of texts and returns embeddings.
|
||||
"""
|
||||
if self._is_multimodal_model(self.name):
|
||||
|
||||
def embed_batch(batch: List[str]) -> List[np.array]:
|
||||
batch_inputs = sanitize_multimodal_input(batch)
|
||||
result = client.multimodal_embed(
|
||||
inputs=batch_inputs,
|
||||
model=self.name,
|
||||
input_type=input_type,
|
||||
**kwargs,
|
||||
)
|
||||
return result.embeddings
|
||||
|
||||
return embed_batch
|
||||
|
||||
elif self._is_contextual_model(self.name):
|
||||
|
||||
def embed_batch(batch: List[str]) -> List[np.array]:
|
||||
result = client.contextualized_embed(
|
||||
inputs=[batch], model=self.name, input_type=input_type, **kwargs
|
||||
)
|
||||
return result.results[0].embeddings
|
||||
|
||||
return embed_batch
|
||||
|
||||
else:
|
||||
|
||||
def embed_batch(batch: List[str]) -> List[np.array]:
|
||||
result = client.embed(
|
||||
texts=batch, model=self.name, input_type=input_type, **kwargs
|
||||
)
|
||||
return result.embeddings
|
||||
|
||||
return embed_batch
|
||||
|
||||
def _embed_with_batching(
|
||||
self, client, texts: List[str], input_type: str = "document", **kwargs
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Embed texts with automatic batching based on token limits.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
client : voyageai.Client
|
||||
The VoyageAI client instance.
|
||||
texts : List[str]
|
||||
List of texts to embed.
|
||||
input_type : str
|
||||
Either "query" or "document"
|
||||
**kwargs
|
||||
Additional arguments to pass to the embedding API
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[np.array]: List of embeddings.
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# Get the appropriate embedding function for this model type
|
||||
embed_fn = self._get_embed_function(client, input_type=input_type, **kwargs)
|
||||
|
||||
# Process each batch
|
||||
all_embeddings = []
|
||||
for batch in self._build_batches(client, texts):
|
||||
batch_embeddings = embed_fn(batch)
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
return all_embeddings
|
||||
|
||||
@staticmethod
|
||||
def _get_client():
|
||||
|
||||
@@ -8,8 +8,8 @@ from typing import Optional
|
||||
|
||||
|
||||
class PermutationBuilder:
|
||||
def __init__(self, table: LanceTable, dest_table_name: str):
|
||||
self._async = async_permutation_builder(table, dest_table_name)
|
||||
def __init__(self, table: LanceTable):
|
||||
self._async = async_permutation_builder(table)
|
||||
|
||||
def select(self, projections: dict[str, str]) -> "PermutationBuilder":
|
||||
self._async.select(projections)
|
||||
@@ -68,5 +68,5 @@ class PermutationBuilder:
|
||||
return LOOP.run(do_execute())
|
||||
|
||||
|
||||
def permutation_builder(table: LanceTable, dest_table_name: str) -> PermutationBuilder:
|
||||
return PermutationBuilder(table, dest_table_name)
|
||||
def permutation_builder(table: LanceTable) -> PermutationBuilder:
|
||||
return PermutationBuilder(table)
|
||||
|
||||
@@ -1237,6 +1237,14 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
self._refine_factor = refine_factor
|
||||
return self
|
||||
|
||||
def output_schema(self) -> pa.Schema:
|
||||
"""
|
||||
Return the output schema for the query
|
||||
|
||||
This does not execute the query.
|
||||
"""
|
||||
return self._table._output_schema(self.to_query_object())
|
||||
|
||||
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
|
||||
"""
|
||||
Execute the query and return the results as an
|
||||
@@ -1452,6 +1460,14 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
offset=self._offset,
|
||||
)
|
||||
|
||||
def output_schema(self) -> pa.Schema:
|
||||
"""
|
||||
Return the output schema for the query
|
||||
|
||||
This does not execute the query.
|
||||
"""
|
||||
return self._table._output_schema(self.to_query_object())
|
||||
|
||||
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
|
||||
path, fs, exist = self._table._get_fts_index_path()
|
||||
if exist:
|
||||
@@ -1595,6 +1611,10 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
||||
offset=self._offset,
|
||||
)
|
||||
|
||||
def output_schema(self) -> pa.Schema:
|
||||
query = self.to_query_object()
|
||||
return self._table._output_schema(query)
|
||||
|
||||
def to_batches(
|
||||
self, /, batch_size: Optional[int] = None, timeout: Optional[timedelta] = None
|
||||
) -> pa.RecordBatchReader:
|
||||
@@ -2238,6 +2258,14 @@ class AsyncQueryBase(object):
|
||||
)
|
||||
)
|
||||
|
||||
async def output_schema(self) -> pa.Schema:
|
||||
"""
|
||||
Return the output schema for the query
|
||||
|
||||
This does not execute the query.
|
||||
"""
|
||||
return await self._inner.output_schema()
|
||||
|
||||
async def to_arrow(self, timeout: Optional[timedelta] = None) -> pa.Table:
|
||||
"""
|
||||
Execute the query and collect the results into an Apache Arrow Table.
|
||||
@@ -3193,6 +3221,14 @@ class BaseQueryBuilder(object):
|
||||
self._inner.with_row_id()
|
||||
return self
|
||||
|
||||
def output_schema(self) -> pa.Schema:
|
||||
"""
|
||||
Return the output schema for the query
|
||||
|
||||
This does not execute the query.
|
||||
"""
|
||||
return LOOP.run(self._inner.output_schema())
|
||||
|
||||
def to_batches(
|
||||
self,
|
||||
*,
|
||||
|
||||
@@ -436,6 +436,9 @@ class RemoteTable(Table):
|
||||
def _analyze_plan(self, query: Query) -> str:
|
||||
return LOOP.run(self._table._analyze_plan(query))
|
||||
|
||||
def _output_schema(self, query: Query) -> pa.Schema:
|
||||
return LOOP.run(self._table._output_schema(query))
|
||||
|
||||
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
||||
"""Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder]
|
||||
that can be used to create a "merge insert" operation.
|
||||
|
||||
@@ -21,6 +21,8 @@ class VoyageAIReranker(Reranker):
|
||||
----------
|
||||
model_name : str, default "rerank-english-v2.0"
|
||||
The name of the cross encoder model to use. Available voyageai models are:
|
||||
- rerank-2.5
|
||||
- rerank-2.5-lite
|
||||
- rerank-2
|
||||
- rerank-2-lite
|
||||
column : str, default "text"
|
||||
|
||||
@@ -1248,6 +1248,9 @@ class Table(ABC):
|
||||
@abstractmethod
|
||||
def _analyze_plan(self, query: Query) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def _output_schema(self, query: Query) -> pa.Schema: ...
|
||||
|
||||
@abstractmethod
|
||||
def _do_merge(
|
||||
self,
|
||||
@@ -2761,6 +2764,9 @@ class LanceTable(Table):
|
||||
def _analyze_plan(self, query: Query) -> str:
|
||||
return LOOP.run(self._table._analyze_plan(query))
|
||||
|
||||
def _output_schema(self, query: Query) -> pa.Schema:
|
||||
return LOOP.run(self._table._output_schema(query))
|
||||
|
||||
def _do_merge(
|
||||
self,
|
||||
merge: LanceMergeInsertBuilder,
|
||||
@@ -3918,6 +3924,10 @@ class AsyncTable:
|
||||
async_query = self._sync_query_to_async(query)
|
||||
return await async_query.analyze_plan()
|
||||
|
||||
async def _output_schema(self, query: Query) -> pa.Schema:
|
||||
async_query = self._sync_query_to_async(query)
|
||||
return await async_query.output_schema()
|
||||
|
||||
async def _do_merge(
|
||||
self,
|
||||
merge: LanceMergeInsertBuilder,
|
||||
|
||||
@@ -532,6 +532,27 @@ def test_voyageai_embedding_function():
|
||||
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||
)
|
||||
def test_voyageai_embedding_function_contextual_model():
|
||||
voyageai = (
|
||||
get_registry().get("voyageai").create(name="voyage-context-3", max_retries=0)
|
||||
)
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = voyageai.SourceField()
|
||||
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
||||
|
||||
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||
db = lancedb.connect("~/lancedb")
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||
@@ -656,6 +677,106 @@ def test_colpali(tmp_path):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
importlib.util.find_spec("colpali_engine") is None,
|
||||
reason="colpali_engine not installed",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[
|
||||
"vidore/colSmol-256M",
|
||||
"vidore/colqwen2.5-v0.2",
|
||||
"vidore/colpali-v1.3",
|
||||
"vidore/colqwen2-v1.0",
|
||||
],
|
||||
)
|
||||
def test_colpali_models(tmp_path, model_name):
|
||||
import requests
|
||||
from lancedb.pydantic import LanceModel
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = get_registry()
|
||||
func = registry.get("colpali").create(model_name=model_name)
|
||||
|
||||
class MediaItems(LanceModel):
|
||||
text: str
|
||||
image_uri: str = func.SourceField()
|
||||
image_bytes: bytes = func.SourceField()
|
||||
image_vectors: MultiVector(func.ndims()) = func.VectorField()
|
||||
|
||||
table = db.create_table(f"media_{model_name.replace('/', '_')}", schema=MediaItems)
|
||||
|
||||
texts = [
|
||||
"a cute cat playing with yarn",
|
||||
]
|
||||
|
||||
uris = [
|
||||
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||
]
|
||||
|
||||
image_bytes = [requests.get(uri).content for uri in uris]
|
||||
|
||||
table.add(
|
||||
pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes})
|
||||
)
|
||||
|
||||
image_results = (
|
||||
table.search("fluffy companion", vector_column_name="image_vectors")
|
||||
.limit(1)
|
||||
.to_pydantic(MediaItems)[0]
|
||||
)
|
||||
assert "cat" in image_results.text.lower() or "puppy" in image_results.text.lower()
|
||||
|
||||
first_row = table.to_arrow().to_pylist()[0]
|
||||
assert len(first_row["image_vectors"]) > 1, "Should have multiple image vectors"
|
||||
assert len(first_row["image_vectors"][0]) == func.ndims(), (
|
||||
"Vector dimension mismatch"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
importlib.util.find_spec("colpali_engine") is None,
|
||||
reason="colpali_engine not installed",
|
||||
)
|
||||
def test_colpali_pooling(tmp_path):
|
||||
registry = get_registry()
|
||||
model_name = "vidore/colSmol-256M"
|
||||
test_sentence = "a test sentence for pooling"
|
||||
|
||||
# 1. Get embeddings with no pooling
|
||||
func_no_pool = registry.get("colpali").create(
|
||||
model_name=model_name, pooling_strategy=None
|
||||
)
|
||||
unpooled_embeddings = func_no_pool.generate_text_embeddings([test_sentence])[0]
|
||||
original_length = len(unpooled_embeddings)
|
||||
assert original_length > 1
|
||||
|
||||
# 2. Test hierarchical pooling
|
||||
func_hierarchical = registry.get("colpali").create(
|
||||
model_name=model_name, pooling_strategy="hierarchical", pool_factor=2
|
||||
)
|
||||
hierarchical_embeddings = func_hierarchical.generate_text_embeddings(
|
||||
[test_sentence]
|
||||
)[0]
|
||||
expected_hierarchical_length = (original_length + 1) // 2
|
||||
assert len(hierarchical_embeddings) == expected_hierarchical_length
|
||||
|
||||
# 3. Test lambda pooling
|
||||
def simple_pool_func(tensor):
|
||||
return tensor[::2]
|
||||
|
||||
func_lambda = registry.get("colpali").create(
|
||||
model_name=model_name,
|
||||
pooling_strategy="lambda",
|
||||
pooling_func=simple_pool_func,
|
||||
)
|
||||
lambda_embeddings = func_lambda.generate_text_embeddings([test_sentence])[0]
|
||||
expected_lambda_length = (original_length + 1) // 2
|
||||
assert len(lambda_embeddings) == expected_lambda_length
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_siglip(tmp_path, test_images, query_image_bytes):
|
||||
from PIL import Image
|
||||
|
||||
@@ -59,6 +59,14 @@ class TempNamespace(LanceNamespace):
|
||||
root
|
||||
] # Reference to shared namespaces
|
||||
|
||||
def namespace_id(self) -> str:
|
||||
"""Return a human-readable unique identifier for this namespace instance.
|
||||
|
||||
Returns:
|
||||
A unique identifier string based on the root directory
|
||||
"""
|
||||
return f"TempNamespace {{ root: '{self.config.root}' }}"
|
||||
|
||||
def list_tables(self, request: ListTablesRequest) -> ListTablesResponse:
|
||||
"""List all tables in the namespace."""
|
||||
if not request.id:
|
||||
|
||||
@@ -12,11 +12,7 @@ def test_split_random_ratios(mem_db):
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"x": range(100), "y": range(100)})
|
||||
)
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
.split_random(ratios=[0.3, 0.7])
|
||||
.execute()
|
||||
)
|
||||
permutation_tbl = permutation_builder(tbl).split_random(ratios=[0.3, 0.7]).execute()
|
||||
|
||||
# Check that the table was created and has data
|
||||
assert permutation_tbl.count_rows() == 100
|
||||
@@ -38,11 +34,7 @@ def test_split_random_counts(mem_db):
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"x": range(100), "y": range(100)})
|
||||
)
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
.split_random(counts=[20, 30])
|
||||
.execute()
|
||||
)
|
||||
permutation_tbl = permutation_builder(tbl).split_random(counts=[20, 30]).execute()
|
||||
|
||||
# Check that we have exactly the requested counts
|
||||
assert permutation_tbl.count_rows() == 50
|
||||
@@ -58,9 +50,7 @@ def test_split_random_fixed(mem_db):
|
||||
tbl = mem_db.create_table(
|
||||
"test_table", pa.table({"x": range(100), "y": range(100)})
|
||||
)
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation").split_random(fixed=4).execute()
|
||||
)
|
||||
permutation_tbl = permutation_builder(tbl).split_random(fixed=4).execute()
|
||||
|
||||
# Check that we have 4 splits with 25 rows each
|
||||
assert permutation_tbl.count_rows() == 100
|
||||
@@ -78,17 +68,9 @@ def test_split_random_with_seed(mem_db):
|
||||
tbl = mem_db.create_table("test_table", pa.table({"x": range(50), "y": range(50)}))
|
||||
|
||||
# Create two identical permutations with same seed
|
||||
perm1 = (
|
||||
permutation_builder(tbl, "perm1")
|
||||
.split_random(ratios=[0.6, 0.4], seed=42)
|
||||
.execute()
|
||||
)
|
||||
perm1 = permutation_builder(tbl).split_random(ratios=[0.6, 0.4], seed=42).execute()
|
||||
|
||||
perm2 = (
|
||||
permutation_builder(tbl, "perm2")
|
||||
.split_random(ratios=[0.6, 0.4], seed=42)
|
||||
.execute()
|
||||
)
|
||||
perm2 = permutation_builder(tbl).split_random(ratios=[0.6, 0.4], seed=42).execute()
|
||||
|
||||
# Results should be identical
|
||||
data1 = perm1.search(None).to_arrow().to_pydict()
|
||||
@@ -112,7 +94,7 @@ def test_split_hash(mem_db):
|
||||
)
|
||||
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
permutation_builder(tbl)
|
||||
.split_hash(["category"], [1, 1], discard_weight=0)
|
||||
.execute()
|
||||
)
|
||||
@@ -133,7 +115,7 @@ def test_split_hash(mem_db):
|
||||
# Hash splits should be deterministic - same category should go to same split
|
||||
# Let's verify by creating another permutation and checking consistency
|
||||
perm2 = (
|
||||
permutation_builder(tbl, "test_permutation2")
|
||||
permutation_builder(tbl)
|
||||
.split_hash(["category"], [1, 1], discard_weight=0)
|
||||
.execute()
|
||||
)
|
||||
@@ -150,7 +132,7 @@ def test_split_hash_with_discard(mem_db):
|
||||
)
|
||||
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
permutation_builder(tbl)
|
||||
.split_hash(["category"], [1, 1], discard_weight=2) # Should discard ~50%
|
||||
.execute()
|
||||
)
|
||||
@@ -168,9 +150,7 @@ def test_split_sequential(mem_db):
|
||||
)
|
||||
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
.split_sequential(counts=[30, 40])
|
||||
.execute()
|
||||
permutation_builder(tbl).split_sequential(counts=[30, 40]).execute()
|
||||
)
|
||||
|
||||
assert permutation_tbl.count_rows() == 70
|
||||
@@ -194,7 +174,7 @@ def test_split_calculated(mem_db):
|
||||
)
|
||||
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
permutation_builder(tbl)
|
||||
.split_calculated("id % 3") # Split based on id modulo 3
|
||||
.execute()
|
||||
)
|
||||
@@ -216,23 +196,21 @@ def test_split_error_cases(mem_db):
|
||||
|
||||
# Test split_random with no parameters
|
||||
with pytest.raises(Exception):
|
||||
permutation_builder(tbl, "error1").split_random().execute()
|
||||
permutation_builder(tbl).split_random().execute()
|
||||
|
||||
# Test split_random with multiple parameters
|
||||
with pytest.raises(Exception):
|
||||
permutation_builder(tbl, "error2").split_random(
|
||||
permutation_builder(tbl).split_random(
|
||||
ratios=[0.5, 0.5], counts=[5, 5]
|
||||
).execute()
|
||||
|
||||
# Test split_sequential with no parameters
|
||||
with pytest.raises(Exception):
|
||||
permutation_builder(tbl, "error3").split_sequential().execute()
|
||||
permutation_builder(tbl).split_sequential().execute()
|
||||
|
||||
# Test split_sequential with multiple parameters
|
||||
with pytest.raises(Exception):
|
||||
permutation_builder(tbl, "error4").split_sequential(
|
||||
ratios=[0.5, 0.5], fixed=2
|
||||
).execute()
|
||||
permutation_builder(tbl).split_sequential(ratios=[0.5, 0.5], fixed=2).execute()
|
||||
|
||||
|
||||
def test_shuffle_no_seed(mem_db):
|
||||
@@ -242,7 +220,7 @@ def test_shuffle_no_seed(mem_db):
|
||||
)
|
||||
|
||||
# Create a permutation with shuffling (no seed)
|
||||
permutation_tbl = permutation_builder(tbl, "test_permutation").shuffle().execute()
|
||||
permutation_tbl = permutation_builder(tbl).shuffle().execute()
|
||||
|
||||
assert permutation_tbl.count_rows() == 100
|
||||
|
||||
@@ -262,9 +240,9 @@ def test_shuffle_with_seed(mem_db):
|
||||
)
|
||||
|
||||
# Create two identical permutations with same shuffle seed
|
||||
perm1 = permutation_builder(tbl, "perm1").shuffle(seed=42).execute()
|
||||
perm1 = permutation_builder(tbl).shuffle(seed=42).execute()
|
||||
|
||||
perm2 = permutation_builder(tbl, "perm2").shuffle(seed=42).execute()
|
||||
perm2 = permutation_builder(tbl).shuffle(seed=42).execute()
|
||||
|
||||
# Results should be identical due to same seed
|
||||
data1 = perm1.search(None).to_arrow().to_pydict()
|
||||
@@ -282,7 +260,7 @@ def test_shuffle_with_clump_size(mem_db):
|
||||
|
||||
# Create a permutation with shuffling using clumps
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
permutation_builder(tbl)
|
||||
.shuffle(clump_size=10) # 10-row clumps
|
||||
.execute()
|
||||
)
|
||||
@@ -304,19 +282,9 @@ def test_shuffle_different_seeds(mem_db):
|
||||
)
|
||||
|
||||
# Create two permutations with different shuffle seeds
|
||||
perm1 = (
|
||||
permutation_builder(tbl, "perm1")
|
||||
.split_random(fixed=2)
|
||||
.shuffle(seed=42)
|
||||
.execute()
|
||||
)
|
||||
perm1 = permutation_builder(tbl).split_random(fixed=2).shuffle(seed=42).execute()
|
||||
|
||||
perm2 = (
|
||||
permutation_builder(tbl, "perm2")
|
||||
.split_random(fixed=2)
|
||||
.shuffle(seed=123)
|
||||
.execute()
|
||||
)
|
||||
perm2 = permutation_builder(tbl).split_random(fixed=2).shuffle(seed=123).execute()
|
||||
|
||||
# Results should be different due to different seeds
|
||||
data1 = perm1.search(None).to_arrow().to_pydict()
|
||||
@@ -341,7 +309,7 @@ def test_shuffle_combined_with_splits(mem_db):
|
||||
|
||||
# Test shuffle with random splits
|
||||
perm_random = (
|
||||
permutation_builder(tbl, "perm_random")
|
||||
permutation_builder(tbl)
|
||||
.split_random(ratios=[0.6, 0.4], seed=42)
|
||||
.shuffle(seed=123, clump_size=None)
|
||||
.execute()
|
||||
@@ -349,7 +317,7 @@ def test_shuffle_combined_with_splits(mem_db):
|
||||
|
||||
# Test shuffle with hash splits
|
||||
perm_hash = (
|
||||
permutation_builder(tbl, "perm_hash")
|
||||
permutation_builder(tbl)
|
||||
.split_hash(["category"], [1, 1], discard_weight=0)
|
||||
.shuffle(seed=456, clump_size=5)
|
||||
.execute()
|
||||
@@ -357,7 +325,7 @@ def test_shuffle_combined_with_splits(mem_db):
|
||||
|
||||
# Test shuffle with sequential splits
|
||||
perm_sequential = (
|
||||
permutation_builder(tbl, "perm_sequential")
|
||||
permutation_builder(tbl)
|
||||
.split_sequential(counts=[40, 35])
|
||||
.shuffle(seed=789, clump_size=None)
|
||||
.execute()
|
||||
@@ -384,7 +352,7 @@ def test_no_shuffle_maintains_order(mem_db):
|
||||
|
||||
# Create permutation without shuffle (should maintain some order)
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
permutation_builder(tbl)
|
||||
.split_sequential(counts=[25, 25]) # Sequential maintains order
|
||||
.execute()
|
||||
)
|
||||
@@ -405,9 +373,7 @@ def test_filter_basic(mem_db):
|
||||
)
|
||||
|
||||
# Filter to only include rows where id < 50
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation").filter("id < 50").execute()
|
||||
)
|
||||
permutation_tbl = permutation_builder(tbl).filter("id < 50").execute()
|
||||
|
||||
assert permutation_tbl.count_rows() == 50
|
||||
|
||||
@@ -433,7 +399,7 @@ def test_filter_with_splits(mem_db):
|
||||
|
||||
# Filter to only category A and B, then split
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
permutation_builder(tbl)
|
||||
.filter("category IN ('A', 'B')")
|
||||
.split_random(ratios=[0.5, 0.5])
|
||||
.execute()
|
||||
@@ -465,7 +431,7 @@ def test_filter_with_shuffle(mem_db):
|
||||
|
||||
# Filter and shuffle
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
permutation_builder(tbl)
|
||||
.filter("category IN ('A', 'C')")
|
||||
.shuffle(seed=42)
|
||||
.execute()
|
||||
@@ -488,7 +454,7 @@ def test_filter_empty_result(mem_db):
|
||||
|
||||
# Filter that matches nothing
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl, "test_permutation")
|
||||
permutation_builder(tbl)
|
||||
.filter("value > 100") # No values > 100 in our data
|
||||
.execute()
|
||||
)
|
||||
|
||||
@@ -1298,6 +1298,79 @@ async def test_query_serialization_async(table_async: AsyncTable):
|
||||
)
|
||||
|
||||
|
||||
def test_query_schema(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
tbl = db.create_table(
|
||||
"test",
|
||||
pa.table(
|
||||
{
|
||||
"a": [1, 2, 3],
|
||||
"text": ["a", "b", "c"],
|
||||
"vec": pa.array(
|
||||
[[1, 2], [3, 4], [5, 6]], pa.list_(pa.float32(), list_size=2)
|
||||
),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
assert tbl.search(None).output_schema() == pa.schema(
|
||||
{
|
||||
"a": pa.int64(),
|
||||
"text": pa.string(),
|
||||
"vec": pa.list_(pa.float32(), list_size=2),
|
||||
}
|
||||
)
|
||||
assert tbl.search(None).select({"bl": "a * 2"}).output_schema() == pa.schema(
|
||||
{"bl": pa.int64()}
|
||||
)
|
||||
assert tbl.search([1, 2]).select(["a"]).output_schema() == pa.schema(
|
||||
{"a": pa.int64(), "_distance": pa.float32()}
|
||||
)
|
||||
assert tbl.search("blah").select(["a"]).output_schema() == pa.schema(
|
||||
{"a": pa.int64()}
|
||||
)
|
||||
assert tbl.take_offsets([0]).select(["text"]).output_schema() == pa.schema(
|
||||
{"text": pa.string()}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_schema_async(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
tbl = await db.create_table(
|
||||
"test",
|
||||
pa.table(
|
||||
{
|
||||
"a": [1, 2, 3],
|
||||
"text": ["a", "b", "c"],
|
||||
"vec": pa.array(
|
||||
[[1, 2], [3, 4], [5, 6]], pa.list_(pa.float32(), list_size=2)
|
||||
),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
assert await tbl.query().output_schema() == pa.schema(
|
||||
{
|
||||
"a": pa.int64(),
|
||||
"text": pa.string(),
|
||||
"vec": pa.list_(pa.float32(), list_size=2),
|
||||
}
|
||||
)
|
||||
assert await tbl.query().select({"bl": "a * 2"}).output_schema() == pa.schema(
|
||||
{"bl": pa.int64()}
|
||||
)
|
||||
assert await tbl.vector_search([1, 2]).select(["a"]).output_schema() == pa.schema(
|
||||
{"a": pa.int64(), "_distance": pa.float32()}
|
||||
)
|
||||
assert await (await tbl.search("blah")).select(["a"]).output_schema() == pa.schema(
|
||||
{"a": pa.int64()}
|
||||
)
|
||||
assert await tbl.take_offsets([0]).select(["text"]).output_schema() == pa.schema(
|
||||
{"text": pa.string()}
|
||||
)
|
||||
|
||||
|
||||
def test_query_timeout(tmp_path):
|
||||
# Use local directory instead of memory:// to add a bit of latency to
|
||||
# operations so a timeout of zero will trigger exceptions.
|
||||
|
||||
@@ -484,7 +484,7 @@ def test_jina_reranker(tmp_path, use_tantivy):
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_voyageai_reranker(tmp_path, use_tantivy):
|
||||
pytest.importorskip("voyageai")
|
||||
reranker = VoyageAIReranker(model_name="rerank-2")
|
||||
reranker = VoyageAIReranker(model_name="rerank-2.5")
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||
|
||||
|
||||
@@ -5,8 +5,8 @@ use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::{error::PythonErrorExt, table::Table};
|
||||
use lancedb::dataloader::{
|
||||
permutation::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
|
||||
split::{SplitSizes, SplitStrategy},
|
||||
permutation::builder::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
|
||||
permutation::split::{SplitSizes, SplitStrategy},
|
||||
};
|
||||
use pyo3::{
|
||||
exceptions::PyRuntimeError, pyclass, pymethods, types::PyAnyMethods, Bound, PyAny, PyRefMut,
|
||||
@@ -16,10 +16,7 @@ use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
/// Create a permutation builder for the given table
|
||||
#[pyo3::pyfunction]
|
||||
pub fn async_permutation_builder(
|
||||
table: Bound<'_, PyAny>,
|
||||
dest_table_name: String,
|
||||
) -> PyResult<PyAsyncPermutationBuilder> {
|
||||
pub fn async_permutation_builder(table: Bound<'_, PyAny>) -> PyResult<PyAsyncPermutationBuilder> {
|
||||
let table = table.getattr("_inner")?.downcast_into::<Table>()?;
|
||||
let inner_table = table.borrow().inner_ref()?.clone();
|
||||
let inner_builder = LancePermutationBuilder::new(inner_table);
|
||||
@@ -27,14 +24,12 @@ pub fn async_permutation_builder(
|
||||
Ok(PyAsyncPermutationBuilder {
|
||||
state: Arc::new(Mutex::new(PyAsyncPermutationBuilderState {
|
||||
builder: Some(inner_builder),
|
||||
dest_table_name,
|
||||
})),
|
||||
})
|
||||
}
|
||||
|
||||
struct PyAsyncPermutationBuilderState {
|
||||
builder: Option<LancePermutationBuilder>,
|
||||
dest_table_name: String,
|
||||
}
|
||||
|
||||
#[pyclass(name = "AsyncPermutationBuilder")]
|
||||
@@ -167,10 +162,8 @@ impl PyAsyncPermutationBuilder {
|
||||
.take()
|
||||
.ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?;
|
||||
|
||||
let dest_table_name = std::mem::take(&mut state.dest_table_name);
|
||||
|
||||
future_into_py(slf.py(), async move {
|
||||
let table = builder.build(&dest_table_name).await.infer_error()?;
|
||||
let table = builder.build().await.infer_error()?;
|
||||
Ok(Table::new(table))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ use arrow::array::Array;
|
||||
use arrow::array::ArrayData;
|
||||
use arrow::pyarrow::FromPyArrow;
|
||||
use arrow::pyarrow::IntoPyArrow;
|
||||
use arrow::pyarrow::ToPyArrow;
|
||||
use lancedb::index::scalar::{
|
||||
BooleanQuery, BoostQuery, FtsQuery, FullTextSearchQuery, MatchQuery, MultiMatchQuery, Occur,
|
||||
Operator, PhraseQuery,
|
||||
@@ -30,6 +31,7 @@ use pyo3::IntoPyObject;
|
||||
use pyo3::PyAny;
|
||||
use pyo3::PyRef;
|
||||
use pyo3::PyResult;
|
||||
use pyo3::Python;
|
||||
use pyo3::{exceptions::PyRuntimeError, FromPyObject};
|
||||
use pyo3::{
|
||||
exceptions::{PyNotImplementedError, PyValueError},
|
||||
@@ -445,6 +447,15 @@ impl Query {
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = ())]
|
||||
pub fn output_schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let schema = inner.output_schema().await.infer_error()?;
|
||||
Python::with_gil(|py| schema.to_pyarrow(py))
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (max_batch_length=None, timeout=None))]
|
||||
pub fn execute(
|
||||
self_: PyRef<'_, Self>,
|
||||
@@ -515,6 +526,15 @@ impl TakeQuery {
|
||||
self.inner = self.inner.clone().with_row_id();
|
||||
}
|
||||
|
||||
#[pyo3(signature = ())]
|
||||
pub fn output_schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let schema = inner.output_schema().await.infer_error()?;
|
||||
Python::with_gil(|py| schema.to_pyarrow(py))
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (max_batch_length=None, timeout=None))]
|
||||
pub fn execute(
|
||||
self_: PyRef<'_, Self>,
|
||||
@@ -601,6 +621,15 @@ impl FTSQuery {
|
||||
self.inner = self.inner.clone().postfilter();
|
||||
}
|
||||
|
||||
#[pyo3(signature = ())]
|
||||
pub fn output_schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let schema = inner.output_schema().await.infer_error()?;
|
||||
Python::with_gil(|py| schema.to_pyarrow(py))
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (max_batch_length=None, timeout=None))]
|
||||
pub fn execute(
|
||||
self_: PyRef<'_, Self>,
|
||||
@@ -771,6 +800,15 @@ impl VectorQuery {
|
||||
self.inner = self.inner.clone().bypass_vector_index()
|
||||
}
|
||||
|
||||
#[pyo3(signature = ())]
|
||||
pub fn output_schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let schema = inner.output_schema().await.infer_error()?;
|
||||
Python::with_gil(|py| schema.to_pyarrow(py))
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (max_batch_length=None, timeout=None))]
|
||||
pub fn execute(
|
||||
self_: PyRef<'_, Self>,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb"
|
||||
version = "0.22.2"
|
||||
version = "0.22.3-beta.3"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
@@ -16,6 +16,7 @@ arrow = { workspace = true }
|
||||
arrow-array = { workspace = true }
|
||||
arrow-data = { workspace = true }
|
||||
arrow-schema = { workspace = true }
|
||||
arrow-select = { workspace = true }
|
||||
arrow-ord = { workspace = true }
|
||||
arrow-cast = { workspace = true }
|
||||
arrow-ipc.workspace = true
|
||||
@@ -41,7 +42,9 @@ lance-table = { workspace = true }
|
||||
lance-linalg = { workspace = true }
|
||||
lance-testing = { workspace = true }
|
||||
lance-encoding = { workspace = true }
|
||||
lance-arrow = { workspace = true }
|
||||
lance-namespace = { workspace = true }
|
||||
lance-namespace-impls = { workspace = true }
|
||||
moka = { workspace = true }
|
||||
pin-project = { workspace = true }
|
||||
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
||||
|
||||
@@ -1182,13 +1182,13 @@ mod tests {
|
||||
use crate::database::listing::{ListingDatabaseOptions, NewTableConfig};
|
||||
use crate::query::QueryBase;
|
||||
use crate::query::{ExecutableQuery, QueryExecutionOptions};
|
||||
use crate::test_connection::test_utils::new_test_connection;
|
||||
use crate::test_utils::connection::new_test_connection;
|
||||
use arrow::compute::concat_batches;
|
||||
use arrow_array::RecordBatchReader;
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
|
||||
use futures::{stream, TryStreamExt};
|
||||
use lance::error::{ArrowResult, DataFusionResult};
|
||||
use lance_core::error::{ArrowResult, DataFusionResult};
|
||||
use lance_testing::datagen::{BatchGenerator, IncrementingInt32};
|
||||
use tempfile::tempdir;
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ use arrow_array::{
|
||||
use arrow_cast::{can_cast_types, cast};
|
||||
use arrow_schema::{ArrowError, DataType, Field, Schema};
|
||||
use half::f16;
|
||||
use lance::arrow::{DataTypeExt, FixedSizeListArrayExt};
|
||||
use lance_arrow::{DataTypeExt, FixedSizeListArrayExt};
|
||||
use log::warn;
|
||||
use num_traits::cast::AsPrimitive;
|
||||
|
||||
@@ -189,7 +189,7 @@ mod tests {
|
||||
};
|
||||
use arrow_schema::Field;
|
||||
use half::f16;
|
||||
use lance::arrow::FixedSizeListArrayExt;
|
||||
use lance_arrow::FixedSizeListArrayExt;
|
||||
|
||||
#[test]
|
||||
fn test_coerce_list_to_fixed_size_list() {
|
||||
|
||||
@@ -8,13 +8,13 @@ use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use lance_namespace::{
|
||||
connect as connect_namespace,
|
||||
models::{
|
||||
CreateEmptyTableRequest, CreateNamespaceRequest, DescribeTableRequest,
|
||||
DropNamespaceRequest, DropTableRequest, ListNamespacesRequest, ListTablesRequest,
|
||||
},
|
||||
LanceNamespace,
|
||||
};
|
||||
use lance_namespace_impls::ConnectBuilder;
|
||||
|
||||
use crate::database::listing::ListingDatabase;
|
||||
use crate::error::{Error, Result};
|
||||
@@ -48,11 +48,16 @@ impl LanceNamespaceDatabase {
|
||||
read_consistency_interval: Option<std::time::Duration>,
|
||||
session: Option<Arc<lance::session::Session>>,
|
||||
) -> Result<Self> {
|
||||
let namespace = connect_namespace(ns_impl, ns_properties.clone())
|
||||
.await
|
||||
.map_err(|e| Error::InvalidInput {
|
||||
message: format!("Failed to connect to namespace: {:?}", e),
|
||||
})?;
|
||||
let mut builder = ConnectBuilder::new(ns_impl);
|
||||
for (key, value) in ns_properties.clone() {
|
||||
builder = builder.property(key, value);
|
||||
}
|
||||
if let Some(ref sess) = session {
|
||||
builder = builder.session(sess.clone());
|
||||
}
|
||||
let namespace = builder.connect().await.map_err(|e| Error::InvalidInput {
|
||||
message: format!("Failed to connect to namespace: {:?}", e),
|
||||
})?;
|
||||
|
||||
Ok(Self {
|
||||
namespace,
|
||||
|
||||
@@ -2,6 +2,3 @@
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
pub mod permutation;
|
||||
pub mod shuffle;
|
||||
pub mod split;
|
||||
pub mod util;
|
||||
|
||||
@@ -7,288 +7,12 @@
|
||||
//! The permutation table only stores the split ids and row ids. It is not a materialized copy of
|
||||
//! the underlying data and can be very lightweight.
|
||||
//!
|
||||
//! Building a permutation table should be fairly quick and memory efficient, even for billions or
|
||||
//! trillions of rows.
|
||||
//! Building a permutation table should be fairly quick (it is an O(N) operation where N is
|
||||
//! the number of rows in the base table) and memory efficient, even for billions or trillions
|
||||
//! of rows.
|
||||
|
||||
use datafusion::prelude::{SessionConfig, SessionContext};
|
||||
use datafusion_execution::{disk_manager::DiskManagerBuilder, runtime_env::RuntimeEnvBuilder};
|
||||
use datafusion_expr::col;
|
||||
use futures::TryStreamExt;
|
||||
use lance_datafusion::exec::SessionContextExt;
|
||||
|
||||
use crate::{
|
||||
arrow::{SendableRecordBatchStream, SendableRecordBatchStreamExt, SimpleRecordBatchStream},
|
||||
dataloader::{
|
||||
shuffle::{Shuffler, ShufflerConfig},
|
||||
split::{SplitStrategy, Splitter, SPLIT_ID_COLUMN},
|
||||
util::{rename_column, TemporaryDirectory},
|
||||
},
|
||||
query::{ExecutableQuery, QueryBase},
|
||||
Connection, Error, Result, Table,
|
||||
};
|
||||
|
||||
/// Configuration for creating a permutation table
|
||||
#[derive(Debug, Default)]
|
||||
pub struct PermutationConfig {
|
||||
/// Splitting configuration
|
||||
pub split_strategy: SplitStrategy,
|
||||
/// Shuffle strategy
|
||||
pub shuffle_strategy: ShuffleStrategy,
|
||||
/// Optional filter to apply to the base table
|
||||
pub filter: Option<String>,
|
||||
/// Directory to use for temporary files
|
||||
pub temp_dir: TemporaryDirectory,
|
||||
}
|
||||
|
||||
/// Strategy for shuffling the data.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ShuffleStrategy {
|
||||
/// The data is randomly shuffled
|
||||
///
|
||||
/// A seed can be provided to make the shuffle deterministic.
|
||||
///
|
||||
/// If a clump size is provided, then data will be shuffled in small blocks of contiguous rows.
|
||||
/// This decreases the overall randomization but can improve I/O performance when reading from
|
||||
/// cloud storage.
|
||||
///
|
||||
/// For example, a clump size of 16 will means we will shuffle blocks of 16 contiguous rows. This
|
||||
/// will mean 16x fewer IOPS but these 16 rows will always be close together and this can influence
|
||||
/// the performance of the model. Note: shuffling within clumps can still be done at read time but
|
||||
/// this will only provide a local shuffle and not a global shuffle.
|
||||
Random {
|
||||
seed: Option<u64>,
|
||||
clump_size: Option<u64>,
|
||||
},
|
||||
/// The data is not shuffled
|
||||
///
|
||||
/// This is useful for debugging and testing.
|
||||
None,
|
||||
}
|
||||
|
||||
impl Default for ShuffleStrategy {
|
||||
fn default() -> Self {
|
||||
Self::None
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for creating a permutation table.
|
||||
///
|
||||
/// A permutation table is a table that stores split assignments and a shuffled order of rows. This
|
||||
/// can be used to create a
|
||||
pub struct PermutationBuilder {
|
||||
config: PermutationConfig,
|
||||
base_table: Table,
|
||||
}
|
||||
|
||||
impl PermutationBuilder {
|
||||
pub fn new(base_table: Table) -> Self {
|
||||
Self {
|
||||
config: PermutationConfig::default(),
|
||||
base_table,
|
||||
}
|
||||
}
|
||||
|
||||
/// Configures the strategy for assigning rows to splits.
|
||||
///
|
||||
/// For example, it is common to create a test/train split of the data. Splits can also be used
|
||||
/// to limit the number of rows. For example, to only use 10% of the data in a permutation you can
|
||||
/// create a single split with 10% of the data.
|
||||
///
|
||||
/// Splits are _not_ required for parallel processing. A single split can be loaded in parallel across
|
||||
/// multiple processes and multiple nodes.
|
||||
///
|
||||
/// The default is a single split that contains all rows.
|
||||
pub fn with_split_strategy(mut self, split_strategy: SplitStrategy) -> Self {
|
||||
self.config.split_strategy = split_strategy;
|
||||
self
|
||||
}
|
||||
|
||||
/// Configures the strategy for shuffling the data.
|
||||
///
|
||||
/// The default is to shuffle the data randomly at row-level granularity (no shard size) and
|
||||
/// with a random seed.
|
||||
pub fn with_shuffle_strategy(mut self, shuffle_strategy: ShuffleStrategy) -> Self {
|
||||
self.config.shuffle_strategy = shuffle_strategy;
|
||||
self
|
||||
}
|
||||
|
||||
/// Configures a filter to apply to the base table.
|
||||
///
|
||||
/// Only rows matching the filter will be included in the permutation.
|
||||
pub fn with_filter(mut self, filter: String) -> Self {
|
||||
self.config.filter = Some(filter);
|
||||
self
|
||||
}
|
||||
|
||||
/// Configures the directory to use for temporary files.
|
||||
///
|
||||
/// The default is to use the operating system's default temporary directory.
|
||||
pub fn with_temp_dir(mut self, temp_dir: TemporaryDirectory) -> Self {
|
||||
self.config.temp_dir = temp_dir;
|
||||
self
|
||||
}
|
||||
|
||||
async fn sort_by_split_id(
|
||||
&self,
|
||||
data: SendableRecordBatchStream,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let ctx = SessionContext::new_with_config_rt(
|
||||
SessionConfig::default(),
|
||||
RuntimeEnvBuilder::new()
|
||||
.with_memory_limit(100 * 1024 * 1024, 1.0)
|
||||
.with_disk_manager_builder(
|
||||
DiskManagerBuilder::default()
|
||||
.with_mode(self.config.temp_dir.to_disk_manager_mode()),
|
||||
)
|
||||
.build_arc()
|
||||
.unwrap(),
|
||||
);
|
||||
let df = ctx
|
||||
.read_one_shot(data.into_df_stream())
|
||||
.map_err(|e| Error::Other {
|
||||
message: format!("Failed to setup sort by split id: {}", e),
|
||||
source: Some(e.into()),
|
||||
})?;
|
||||
let df_stream = df
|
||||
.sort_by(vec![col(SPLIT_ID_COLUMN)])
|
||||
.map_err(|e| Error::Other {
|
||||
message: format!("Failed to plan sort by split id: {}", e),
|
||||
source: Some(e.into()),
|
||||
})?
|
||||
.execute_stream()
|
||||
.await
|
||||
.map_err(|e| Error::Other {
|
||||
message: format!("Failed to sort by split id: {}", e),
|
||||
source: Some(e.into()),
|
||||
})?;
|
||||
|
||||
let schema = df_stream.schema();
|
||||
let stream = df_stream.map_err(|e| Error::Other {
|
||||
message: format!("Failed to execute sort by split id: {}", e),
|
||||
source: Some(e.into()),
|
||||
});
|
||||
Ok(Box::pin(SimpleRecordBatchStream { schema, stream }))
|
||||
}
|
||||
|
||||
/// Builds the permutation table and stores it in the given database.
|
||||
pub async fn build(self, dest_table_name: &str) -> Result<Table> {
|
||||
// First pass, apply filter and load row ids
|
||||
let mut rows = self.base_table.query().with_row_id();
|
||||
|
||||
if let Some(filter) = &self.config.filter {
|
||||
rows = rows.only_if(filter);
|
||||
}
|
||||
|
||||
let splitter = Splitter::new(
|
||||
self.config.temp_dir.clone(),
|
||||
self.config.split_strategy.clone(),
|
||||
);
|
||||
|
||||
let mut needs_sort = !splitter.orders_by_split_id();
|
||||
|
||||
// Might need to load additional columns to calculate splits (e.g. hash columns or calculated
|
||||
// split id)
|
||||
rows = splitter.project(rows);
|
||||
|
||||
let num_rows = self
|
||||
.base_table
|
||||
.count_rows(self.config.filter.clone())
|
||||
.await? as u64;
|
||||
|
||||
// Apply splits
|
||||
let rows = rows.execute().await?;
|
||||
let split_data = splitter.apply(rows, num_rows).await?;
|
||||
|
||||
// Shuffle data if requested
|
||||
let shuffled = match self.config.shuffle_strategy {
|
||||
ShuffleStrategy::None => split_data,
|
||||
ShuffleStrategy::Random { seed, clump_size } => {
|
||||
let shuffler = Shuffler::new(ShufflerConfig {
|
||||
seed,
|
||||
clump_size,
|
||||
temp_dir: self.config.temp_dir.clone(),
|
||||
max_rows_per_file: 10 * 1024 * 1024,
|
||||
});
|
||||
shuffler.shuffle(split_data, num_rows).await?
|
||||
}
|
||||
};
|
||||
|
||||
// We want the final permutation to be sorted by the split id. If we shuffled or if
|
||||
// the split was not assigned sequentially then we need to sort the data.
|
||||
needs_sort |= !matches!(self.config.shuffle_strategy, ShuffleStrategy::None);
|
||||
|
||||
let sorted = if needs_sort {
|
||||
self.sort_by_split_id(shuffled).await?
|
||||
} else {
|
||||
shuffled
|
||||
};
|
||||
|
||||
// Rename _rowid to row_id
|
||||
let renamed = rename_column(sorted, "_rowid", "row_id")?;
|
||||
|
||||
// Create permutation table
|
||||
let conn = Connection::new(
|
||||
self.base_table.database().clone(),
|
||||
self.base_table.embedding_registry().clone(),
|
||||
);
|
||||
conn.create_table_streaming(dest_table_name, renamed)
|
||||
.execute()
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use arrow::datatypes::Int32Type;
|
||||
use lance_datagen::{BatchCount, RowCount};
|
||||
|
||||
use crate::{arrow::LanceDbDatagenExt, connect, dataloader::split::SplitSizes};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_permutation_builder() {
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
|
||||
let db = connect(temp_dir.path().to_str().unwrap())
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let initial_data = lance_datagen::gen_batch()
|
||||
.col("some_value", lance_datagen::array::step::<Int32Type>())
|
||||
.into_ldb_stream(RowCount::from(100), BatchCount::from(10));
|
||||
let data_table = db
|
||||
.create_table_streaming("mytbl", initial_data)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let permutation_table = PermutationBuilder::new(data_table)
|
||||
.with_filter("some_value > 57".to_string())
|
||||
.with_split_strategy(SplitStrategy::Random {
|
||||
seed: Some(42),
|
||||
sizes: SplitSizes::Percentages(vec![0.05, 0.30]),
|
||||
})
|
||||
.build("permutation")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Potentially brittle seed-dependent values below
|
||||
assert_eq!(permutation_table.count_rows(None).await.unwrap(), 330);
|
||||
assert_eq!(
|
||||
permutation_table
|
||||
.count_rows(Some("split_id = 0".to_string()))
|
||||
.await
|
||||
.unwrap(),
|
||||
47
|
||||
);
|
||||
assert_eq!(
|
||||
permutation_table
|
||||
.count_rows(Some("split_id = 1".to_string()))
|
||||
.await
|
||||
.unwrap(),
|
||||
283
|
||||
);
|
||||
}
|
||||
}
|
||||
pub mod builder;
|
||||
pub mod reader;
|
||||
pub mod shuffle;
|
||||
pub mod split;
|
||||
pub mod util;
|
||||
|
||||
326
rust/lancedb/src/dataloader/permutation/builder.rs
Normal file
326
rust/lancedb/src/dataloader/permutation/builder.rs
Normal file
@@ -0,0 +1,326 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion::prelude::{SessionConfig, SessionContext};
|
||||
use datafusion_execution::{disk_manager::DiskManagerBuilder, runtime_env::RuntimeEnvBuilder};
|
||||
use datafusion_expr::col;
|
||||
use futures::TryStreamExt;
|
||||
use lance_core::ROW_ID;
|
||||
use lance_datafusion::exec::SessionContextExt;
|
||||
|
||||
use crate::{
|
||||
arrow::{SendableRecordBatchStream, SendableRecordBatchStreamExt, SimpleRecordBatchStream},
|
||||
connect,
|
||||
database::{CreateTableData, CreateTableRequest, Database},
|
||||
dataloader::permutation::{
|
||||
shuffle::{Shuffler, ShufflerConfig},
|
||||
split::{SplitStrategy, Splitter, SPLIT_ID_COLUMN},
|
||||
util::{rename_column, TemporaryDirectory},
|
||||
},
|
||||
query::{ExecutableQuery, QueryBase},
|
||||
Error, Result, Table,
|
||||
};
|
||||
|
||||
pub const SRC_ROW_ID_COL: &str = "row_id";
|
||||
|
||||
/// Where to store the permutation table
|
||||
#[derive(Debug, Clone, Default)]
|
||||
enum PermutationDestination {
|
||||
/// The permutation table is a temporary table in memory
|
||||
#[default]
|
||||
Temporary,
|
||||
/// The permutation table is a permanent table in a database
|
||||
Permanent(Arc<dyn Database>, String),
|
||||
}
|
||||
|
||||
/// Configuration for creating a permutation table
|
||||
#[derive(Debug, Default)]
|
||||
pub struct PermutationConfig {
|
||||
/// Splitting configuration
|
||||
split_strategy: SplitStrategy,
|
||||
/// Shuffle strategy
|
||||
shuffle_strategy: ShuffleStrategy,
|
||||
/// Optional filter to apply to the base table
|
||||
filter: Option<String>,
|
||||
/// Directory to use for temporary files
|
||||
temp_dir: TemporaryDirectory,
|
||||
/// Destination
|
||||
destination: PermutationDestination,
|
||||
}
|
||||
|
||||
/// Strategy for shuffling the data.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ShuffleStrategy {
|
||||
/// The data is randomly shuffled
|
||||
///
|
||||
/// A seed can be provided to make the shuffle deterministic.
|
||||
///
|
||||
/// If a clump size is provided, then data will be shuffled in small blocks of contiguous rows.
|
||||
/// This decreases the overall randomization but can improve I/O performance when reading from
|
||||
/// cloud storage.
|
||||
///
|
||||
/// For example, a clump size of 16 will means we will shuffle blocks of 16 contiguous rows. This
|
||||
/// will mean 16x fewer IOPS but these 16 rows will always be close together and this can influence
|
||||
/// the performance of the model. Note: shuffling within clumps can still be done at read time but
|
||||
/// this will only provide a local shuffle and not a global shuffle.
|
||||
Random {
|
||||
seed: Option<u64>,
|
||||
clump_size: Option<u64>,
|
||||
},
|
||||
/// The data is not shuffled
|
||||
///
|
||||
/// This is useful for debugging and testing.
|
||||
None,
|
||||
}
|
||||
|
||||
impl Default for ShuffleStrategy {
|
||||
fn default() -> Self {
|
||||
Self::None
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for creating a permutation table.
|
||||
///
|
||||
/// A permutation table is a table that stores split assignments and a shuffled order of rows. This
|
||||
/// can be used to create a permutation reader that reads rows in the order defined by the permutation.
|
||||
///
|
||||
/// The permutation table is not a materialized copy of the underlying data and can be very lightweight.
|
||||
/// It is not a view of the underlying data and is not a copy of the data. It is a separate table that
|
||||
/// stores just row id and split id.
|
||||
pub struct PermutationBuilder {
|
||||
config: PermutationConfig,
|
||||
base_table: Table,
|
||||
}
|
||||
|
||||
impl PermutationBuilder {
|
||||
pub fn new(base_table: Table) -> Self {
|
||||
Self {
|
||||
config: PermutationConfig::default(),
|
||||
base_table,
|
||||
}
|
||||
}
|
||||
|
||||
/// Configures the strategy for assigning rows to splits.
|
||||
///
|
||||
/// For example, it is common to create a test/train split of the data. Splits can also be used
|
||||
/// to limit the number of rows. For example, to only use 10% of the data in a permutation you can
|
||||
/// create a single split with 10% of the data.
|
||||
///
|
||||
/// Splits are _not_ required for parallel processing. A single split can be loaded in parallel across
|
||||
/// multiple processes and multiple nodes.
|
||||
///
|
||||
/// The default is a single split that contains all rows.
|
||||
pub fn with_split_strategy(mut self, split_strategy: SplitStrategy) -> Self {
|
||||
self.config.split_strategy = split_strategy;
|
||||
self
|
||||
}
|
||||
|
||||
/// Configures the strategy for shuffling the data.
|
||||
///
|
||||
/// The default is to shuffle the data randomly at row-level granularity (no clump size) and
|
||||
/// with a random seed.
|
||||
pub fn with_shuffle_strategy(mut self, shuffle_strategy: ShuffleStrategy) -> Self {
|
||||
self.config.shuffle_strategy = shuffle_strategy;
|
||||
self
|
||||
}
|
||||
|
||||
/// Configures a filter to apply to the base table.
|
||||
///
|
||||
/// Only rows matching the filter will be included in the permutation.
|
||||
pub fn with_filter(mut self, filter: String) -> Self {
|
||||
self.config.filter = Some(filter);
|
||||
self
|
||||
}
|
||||
|
||||
/// Configures the directory to use for temporary files.
|
||||
///
|
||||
/// The default is to use the operating system's default temporary directory.
|
||||
pub fn with_temp_dir(mut self, temp_dir: TemporaryDirectory) -> Self {
|
||||
self.config.temp_dir = temp_dir;
|
||||
self
|
||||
}
|
||||
|
||||
/// Stores the permutation as a table in a database
|
||||
///
|
||||
/// By default, the permutation is stored in memory. If this method is called then
|
||||
/// the permutation will be stored as a table in the given database.
|
||||
pub fn persist(mut self, database: Arc<dyn Database>, table_name: String) -> Self {
|
||||
self.config.destination = PermutationDestination::Permanent(database, table_name);
|
||||
self
|
||||
}
|
||||
|
||||
async fn sort_by_split_id(
|
||||
&self,
|
||||
data: SendableRecordBatchStream,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let ctx = SessionContext::new_with_config_rt(
|
||||
SessionConfig::default(),
|
||||
RuntimeEnvBuilder::new()
|
||||
.with_memory_limit(100 * 1024 * 1024, 1.0)
|
||||
.with_disk_manager_builder(
|
||||
DiskManagerBuilder::default()
|
||||
.with_mode(self.config.temp_dir.to_disk_manager_mode()),
|
||||
)
|
||||
.build_arc()
|
||||
.unwrap(),
|
||||
);
|
||||
let df = ctx
|
||||
.read_one_shot(data.into_df_stream())
|
||||
.map_err(|e| Error::Other {
|
||||
message: format!("Failed to setup sort by split id: {}", e),
|
||||
source: Some(e.into()),
|
||||
})?;
|
||||
let df_stream = df
|
||||
.sort_by(vec![col(SPLIT_ID_COLUMN)])
|
||||
.map_err(|e| Error::Other {
|
||||
message: format!("Failed to plan sort by split id: {}", e),
|
||||
source: Some(e.into()),
|
||||
})?
|
||||
.execute_stream()
|
||||
.await
|
||||
.map_err(|e| Error::Other {
|
||||
message: format!("Failed to sort by split id: {}", e),
|
||||
source: Some(e.into()),
|
||||
})?;
|
||||
|
||||
let schema = df_stream.schema();
|
||||
let stream = df_stream.map_err(|e| Error::Other {
|
||||
message: format!("Failed to execute sort by split id: {}", e),
|
||||
source: Some(e.into()),
|
||||
});
|
||||
Ok(Box::pin(SimpleRecordBatchStream { schema, stream }))
|
||||
}
|
||||
|
||||
/// Builds the permutation table and stores it in the given database.
|
||||
pub async fn build(self) -> Result<Table> {
|
||||
// First pass, apply filter and load row ids
|
||||
let mut rows = self.base_table.query().with_row_id();
|
||||
|
||||
if let Some(filter) = &self.config.filter {
|
||||
rows = rows.only_if(filter);
|
||||
}
|
||||
|
||||
let splitter = Splitter::new(
|
||||
self.config.temp_dir.clone(),
|
||||
self.config.split_strategy.clone(),
|
||||
);
|
||||
|
||||
let mut needs_sort = !splitter.orders_by_split_id();
|
||||
|
||||
// Might need to load additional columns to calculate splits (e.g. hash columns or calculated
|
||||
// split id)
|
||||
rows = splitter.project(rows);
|
||||
|
||||
let num_rows = self
|
||||
.base_table
|
||||
.count_rows(self.config.filter.clone())
|
||||
.await? as u64;
|
||||
|
||||
// Apply splits
|
||||
let rows = rows.execute().await?;
|
||||
let split_data = splitter.apply(rows, num_rows).await?;
|
||||
|
||||
// Shuffle data if requested
|
||||
let shuffled = match self.config.shuffle_strategy {
|
||||
ShuffleStrategy::None => split_data,
|
||||
ShuffleStrategy::Random { seed, clump_size } => {
|
||||
let shuffler = Shuffler::new(ShufflerConfig {
|
||||
seed,
|
||||
clump_size,
|
||||
temp_dir: self.config.temp_dir.clone(),
|
||||
max_rows_per_file: 10 * 1024 * 1024,
|
||||
});
|
||||
shuffler.shuffle(split_data, num_rows).await?
|
||||
}
|
||||
};
|
||||
|
||||
// We want the final permutation to be sorted by the split id. If we shuffled or if
|
||||
// the split was not assigned sequentially then we need to sort the data.
|
||||
needs_sort |= !matches!(self.config.shuffle_strategy, ShuffleStrategy::None);
|
||||
|
||||
let sorted = if needs_sort {
|
||||
self.sort_by_split_id(shuffled).await?
|
||||
} else {
|
||||
shuffled
|
||||
};
|
||||
|
||||
// Rename _rowid to row_id
|
||||
let renamed = rename_column(sorted, ROW_ID, SRC_ROW_ID_COL)?;
|
||||
|
||||
let (name, database) = match &self.config.destination {
|
||||
PermutationDestination::Permanent(database, table_name) => {
|
||||
(table_name.as_str(), database.clone())
|
||||
}
|
||||
PermutationDestination::Temporary => {
|
||||
let conn = connect("memory:///").execute().await?;
|
||||
("permutation", conn.database().clone())
|
||||
}
|
||||
};
|
||||
|
||||
let create_table_request =
|
||||
CreateTableRequest::new(name.to_string(), CreateTableData::StreamingData(renamed));
|
||||
|
||||
let table = database.create_table(create_table_request).await?;
|
||||
Ok(Table::new(table, database))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use arrow::datatypes::Int32Type;
|
||||
use lance_datagen::{BatchCount, RowCount};
|
||||
|
||||
use crate::{arrow::LanceDbDatagenExt, connect, dataloader::permutation::split::SplitSizes};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_permutation_builder() {
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
|
||||
let db = connect(temp_dir.path().to_str().unwrap())
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let initial_data = lance_datagen::gen_batch()
|
||||
.col("some_value", lance_datagen::array::step::<Int32Type>())
|
||||
.into_ldb_stream(RowCount::from(100), BatchCount::from(10));
|
||||
let data_table = db
|
||||
.create_table_streaming("mytbl", initial_data)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let permutation_table = PermutationBuilder::new(data_table.clone())
|
||||
.with_filter("some_value > 57".to_string())
|
||||
.with_split_strategy(SplitStrategy::Random {
|
||||
seed: Some(42),
|
||||
sizes: SplitSizes::Percentages(vec![0.05, 0.30]),
|
||||
})
|
||||
.build()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
println!("permutation_table: {:?}", permutation_table);
|
||||
|
||||
// Potentially brittle seed-dependent values below
|
||||
assert_eq!(permutation_table.count_rows(None).await.unwrap(), 330);
|
||||
assert_eq!(
|
||||
permutation_table
|
||||
.count_rows(Some("split_id = 0".to_string()))
|
||||
.await
|
||||
.unwrap(),
|
||||
47
|
||||
);
|
||||
assert_eq!(
|
||||
permutation_table
|
||||
.count_rows(Some("split_id = 1".to_string()))
|
||||
.await
|
||||
.unwrap(),
|
||||
283
|
||||
);
|
||||
}
|
||||
}
|
||||
384
rust/lancedb/src/dataloader/permutation/reader.rs
Normal file
384
rust/lancedb/src/dataloader/permutation/reader.rs
Normal file
@@ -0,0 +1,384 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! Row ID-based views for LanceDB tables
|
||||
//!
|
||||
//! This module provides functionality for creating views that are based on specific row IDs.
|
||||
//! The `IdView` allows you to create a virtual table that contains only
|
||||
//! the rows from a source table that correspond to row IDs stored in a separate table.
|
||||
|
||||
use crate::arrow::{SendableRecordBatchStream, SimpleRecordBatchStream};
|
||||
use crate::dataloader::permutation::builder::SRC_ROW_ID_COL;
|
||||
use crate::dataloader::permutation::split::SPLIT_ID_COLUMN;
|
||||
use crate::error::Error;
|
||||
use crate::query::{QueryExecutionOptions, QueryFilter, QueryRequest, Select};
|
||||
use crate::table::{AnyQuery, BaseTable};
|
||||
use crate::Result;
|
||||
use arrow::array::AsArray;
|
||||
use arrow::datatypes::UInt64Type;
|
||||
use arrow_array::{RecordBatch, UInt64Array};
|
||||
use futures::{StreamExt, TryStreamExt};
|
||||
use lance::dataset::scanner::DatasetRecordBatchStream;
|
||||
use lance_arrow::RecordBatchExt;
|
||||
use lance_core::error::LanceOptionExt;
|
||||
use lance_core::ROW_ID;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Reads a permutation of a source table based on row IDs stored in a separate table
|
||||
pub struct PermutationReader {
|
||||
base_table: Arc<dyn BaseTable>,
|
||||
permutation_table: Arc<dyn BaseTable>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for PermutationReader {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"PermutationReader(base={}, permutation={})",
|
||||
self.base_table.name(),
|
||||
self.permutation_table.name(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl PermutationReader {
|
||||
/// Create a new PermutationReader
|
||||
pub async fn try_new(
|
||||
base_table: Arc<dyn BaseTable>,
|
||||
permutation_table: Arc<dyn BaseTable>,
|
||||
) -> Result<Self> {
|
||||
let schema = permutation_table.schema().await?;
|
||||
if schema.column_with_name(SRC_ROW_ID_COL).is_none() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Permutation table must contain a column named row_id".to_string(),
|
||||
});
|
||||
}
|
||||
if schema.column_with_name(SPLIT_ID_COLUMN).is_none() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Permutation table must contain a column named split_id".to_string(),
|
||||
});
|
||||
}
|
||||
Ok(Self {
|
||||
base_table,
|
||||
permutation_table,
|
||||
})
|
||||
}
|
||||
|
||||
fn is_sorted_already<'a, T: Iterator<Item = &'a u64>>(iter: T) -> bool {
|
||||
for (expected, idx) in iter.enumerate() {
|
||||
if *idx != expected as u64 {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
async fn load_batch(
|
||||
base_table: &Arc<dyn BaseTable>,
|
||||
row_ids: RecordBatch,
|
||||
selection: Select,
|
||||
has_row_id: bool,
|
||||
) -> Result<RecordBatch> {
|
||||
let num_rows = row_ids.num_rows();
|
||||
let row_ids = row_ids
|
||||
.column(0)
|
||||
.as_primitive_opt::<UInt64Type>()
|
||||
.expect_ok()?
|
||||
.values();
|
||||
|
||||
let filter = format!(
|
||||
"_rowid in ({})",
|
||||
row_ids
|
||||
.iter()
|
||||
.map(|o| o.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
);
|
||||
|
||||
let base_query = QueryRequest {
|
||||
filter: Some(QueryFilter::Sql(filter)),
|
||||
select: selection,
|
||||
with_row_id: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut data = base_table
|
||||
.query(
|
||||
&AnyQuery::Query(base_query),
|
||||
QueryExecutionOptions {
|
||||
max_batch_length: num_rows as u32,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let Some(batch) = data.try_next().await? else {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Base table returned no batches".to_string(),
|
||||
});
|
||||
};
|
||||
if data.try_next().await?.is_some() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Base table returned more than one batch".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if batch.num_rows() != num_rows {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Base table returned different number of rows than the number of row IDs"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// There is no guarantee the result order will match the order provided
|
||||
// so may need to restore order
|
||||
let actual_row_ids = batch
|
||||
.column_by_name(ROW_ID)
|
||||
.expect_ok()?
|
||||
.as_primitive_opt::<UInt64Type>()
|
||||
.expect_ok()?
|
||||
.values();
|
||||
|
||||
// Map from row id to order in batch, used to restore original ordering
|
||||
let ordering = actual_row_ids
|
||||
.iter()
|
||||
.copied()
|
||||
.enumerate()
|
||||
.map(|(i, o)| (o, i as u64))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let desired_idx_order = row_ids
|
||||
.iter()
|
||||
.map(|o| ordering.get(o).copied().expect_ok().map_err(Error::from))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let ordered_batch = if Self::is_sorted_already(desired_idx_order.iter()) {
|
||||
// Fast path if already sorted, important as data may be large and
|
||||
// re-ordering could be expensive
|
||||
batch
|
||||
} else {
|
||||
let desired_idx_order = UInt64Array::from(desired_idx_order);
|
||||
|
||||
arrow_select::take::take_record_batch(&batch, &desired_idx_order)?
|
||||
};
|
||||
|
||||
if has_row_id {
|
||||
Ok(ordered_batch)
|
||||
} else {
|
||||
// The user didn't ask for row id, we needed it for ordering the data, but now we drop it
|
||||
Ok(ordered_batch.drop_column(ROW_ID)?)
|
||||
}
|
||||
}
|
||||
|
||||
async fn row_ids_to_batches(
|
||||
base_table: Arc<dyn BaseTable>,
|
||||
row_ids: DatasetRecordBatchStream,
|
||||
selection: Select,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let has_row_id = Self::has_row_id(&selection)?;
|
||||
let mut stream = row_ids
|
||||
.map_err(Error::from)
|
||||
.try_filter_map(move |batch| {
|
||||
let selection = selection.clone();
|
||||
let base_table = base_table.clone();
|
||||
async move {
|
||||
Self::load_batch(&base_table, batch, selection, has_row_id)
|
||||
.await
|
||||
.map(Some)
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
|
||||
// Need to read out first batch to get schema
|
||||
let Some(first_batch) = stream.try_next().await? else {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Permutation was empty".to_string(),
|
||||
});
|
||||
};
|
||||
let schema = first_batch.schema();
|
||||
|
||||
let stream = futures::stream::once(std::future::ready(Ok(first_batch))).chain(stream);
|
||||
|
||||
Ok(Box::pin(SimpleRecordBatchStream::new(stream, schema)))
|
||||
}
|
||||
|
||||
fn has_row_id(selection: &Select) -> Result<bool> {
|
||||
match selection {
|
||||
Select::All => {
|
||||
// _rowid is a system column and is not included in Select::All
|
||||
Ok(false)
|
||||
}
|
||||
Select::Columns(columns) => Ok(columns.contains(&ROW_ID.to_string())),
|
||||
Select::Dynamic(columns) => {
|
||||
for column in columns {
|
||||
if column.0 == ROW_ID {
|
||||
if column.1 == ROW_ID {
|
||||
return Ok(true);
|
||||
} else {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"Dynamic column {} cannot be used to select _rowid",
|
||||
column.1
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read_split(
|
||||
&self,
|
||||
split: u64,
|
||||
selection: Select,
|
||||
execution_options: QueryExecutionOptions,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let row_ids = self
|
||||
.permutation_table
|
||||
.query(
|
||||
&AnyQuery::Query(QueryRequest {
|
||||
select: Select::Columns(vec![SRC_ROW_ID_COL.to_string()]),
|
||||
filter: Some(QueryFilter::Sql(format!("{} = {}", SPLIT_ID_COLUMN, split))),
|
||||
..Default::default()
|
||||
}),
|
||||
execution_options,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Self::row_ids_to_batches(self.base_table.clone(), row_ids, selection).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use arrow::datatypes::Int32Type;
|
||||
use arrow_array::{ArrowPrimitiveType, RecordBatch, UInt64Array};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use lance_datagen::{BatchCount, RowCount};
|
||||
use rand::seq::SliceRandom;
|
||||
|
||||
use crate::{
|
||||
arrow::SendableRecordBatchStream,
|
||||
query::{ExecutableQuery, QueryBase},
|
||||
test_utils::datagen::{virtual_table, LanceDbDatagenExt},
|
||||
Table,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
async fn collect_from_stream<T: ArrowPrimitiveType>(
|
||||
mut stream: SendableRecordBatchStream,
|
||||
column: &str,
|
||||
) -> Vec<T::Native> {
|
||||
let mut row_ids = Vec::new();
|
||||
while let Some(batch) = stream.try_next().await.unwrap() {
|
||||
let col_idx = batch.schema().index_of(column).unwrap();
|
||||
row_ids.extend(batch.column(col_idx).as_primitive::<T>().values().to_vec());
|
||||
}
|
||||
row_ids
|
||||
}
|
||||
|
||||
async fn collect_column<T: ArrowPrimitiveType>(table: &Table, column: &str) -> Vec<T::Native> {
|
||||
collect_from_stream::<T>(
|
||||
table
|
||||
.query()
|
||||
.select(Select::Columns(vec![column.to_string()]))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap(),
|
||||
column,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_permutation_reader() {
|
||||
let base_table = lance_datagen::gen_batch()
|
||||
.col("idx", lance_datagen::array::step::<Int32Type>())
|
||||
.col("other_col", lance_datagen::array::step::<UInt64Type>())
|
||||
.into_mem_table("tbl", RowCount::from(9), BatchCount::from(1))
|
||||
.await;
|
||||
|
||||
let mut row_ids = collect_column::<UInt64Type>(&base_table, "_rowid").await;
|
||||
row_ids.shuffle(&mut rand::rng());
|
||||
// Put the last two rows in split 1
|
||||
let split_ids = UInt64Array::from_iter_values(
|
||||
std::iter::repeat_n(0, row_ids.len() - 2).chain(std::iter::repeat_n(1, 2)),
|
||||
);
|
||||
let permutation_batch = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![
|
||||
Field::new("row_id", DataType::UInt64, false),
|
||||
Field::new(SPLIT_ID_COLUMN, DataType::UInt64, false),
|
||||
])),
|
||||
vec![
|
||||
Arc::new(UInt64Array::from(row_ids.clone())),
|
||||
Arc::new(split_ids),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let row_ids_table = virtual_table("row_ids", &permutation_batch).await;
|
||||
|
||||
let reader = PermutationReader::try_new(
|
||||
base_table.base_table().clone(),
|
||||
row_ids_table.base_table().clone(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Read split 0
|
||||
let mut stream = reader
|
||||
.read_split(
|
||||
0,
|
||||
Select::All,
|
||||
QueryExecutionOptions {
|
||||
max_batch_length: 3,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(stream.schema(), base_table.schema().await.unwrap());
|
||||
|
||||
let check_batch = async |stream: &mut SendableRecordBatchStream,
|
||||
expected_values: &[u64]| {
|
||||
let batch = stream.try_next().await.unwrap().unwrap();
|
||||
assert_eq!(batch.num_rows(), expected_values.len());
|
||||
assert_eq!(
|
||||
batch.column(0).as_primitive::<Int32Type>().values(),
|
||||
&expected_values
|
||||
.iter()
|
||||
.map(|o| *o as i32)
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
assert_eq!(
|
||||
batch.column(1).as_primitive::<UInt64Type>().values(),
|
||||
&expected_values
|
||||
);
|
||||
};
|
||||
|
||||
check_batch(&mut stream, &row_ids[0..3]).await;
|
||||
check_batch(&mut stream, &row_ids[3..6]).await;
|
||||
check_batch(&mut stream, &row_ids[6..7]).await;
|
||||
assert!(stream.try_next().await.unwrap().is_none());
|
||||
|
||||
// Read split 1
|
||||
let mut stream = reader
|
||||
.read_split(
|
||||
1,
|
||||
Select::All,
|
||||
QueryExecutionOptions {
|
||||
max_batch_length: 3,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
check_batch(&mut stream, &row_ids[7..9]).await;
|
||||
assert!(stream.try_next().await.unwrap().is_none());
|
||||
}
|
||||
}
|
||||
@@ -22,7 +22,7 @@ use rand::{seq::SliceRandom, Rng, RngCore};
|
||||
|
||||
use crate::{
|
||||
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
|
||||
dataloader::util::{non_crypto_rng, TemporaryDirectory},
|
||||
dataloader::permutation::util::{non_crypto_rng, TemporaryDirectory},
|
||||
Error, Result,
|
||||
};
|
||||
|
||||
@@ -13,13 +13,13 @@ use arrow_array::{Array, BooleanArray, RecordBatch, UInt64Array};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use datafusion_common::hash_utils::create_hashes;
|
||||
use futures::{StreamExt, TryStreamExt};
|
||||
use lance::arrow::SchemaExt;
|
||||
use lance_arrow::SchemaExt;
|
||||
|
||||
use crate::{
|
||||
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
|
||||
dataloader::{
|
||||
shuffle::{Shuffler, ShufflerConfig},
|
||||
util::TemporaryDirectory,
|
||||
permutation::shuffle::{Shuffler, ShufflerConfig},
|
||||
permutation::util::TemporaryDirectory,
|
||||
},
|
||||
query::{Query, QueryBase, Select},
|
||||
Error, Result,
|
||||
@@ -10,7 +10,7 @@ pub mod sentence_transformers;
|
||||
#[cfg(feature = "bedrock")]
|
||||
pub mod bedrock;
|
||||
|
||||
use lance::arrow::RecordBatchExt;
|
||||
use lance_arrow::RecordBatchExt;
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
collections::{HashMap, HashSet},
|
||||
|
||||
@@ -207,7 +207,8 @@ pub mod query;
|
||||
pub mod remote;
|
||||
pub mod rerankers;
|
||||
pub mod table;
|
||||
pub mod test_connection;
|
||||
#[cfg(test)]
|
||||
pub mod test_utils;
|
||||
pub mod utils;
|
||||
|
||||
use std::fmt::Display;
|
||||
|
||||
@@ -6,15 +6,13 @@ use std::{future::Future, time::Duration};
|
||||
|
||||
use arrow::compute::concat_batches;
|
||||
use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array};
|
||||
use arrow_schema::DataType;
|
||||
use arrow_schema::{DataType, SchemaRef};
|
||||
use datafusion_expr::Expr;
|
||||
use datafusion_physical_plan::ExecutionPlan;
|
||||
use futures::{stream, try_join, FutureExt, TryStreamExt};
|
||||
use futures::{stream, try_join, FutureExt, TryFutureExt, TryStreamExt};
|
||||
use half::f16;
|
||||
use lance::{
|
||||
arrow::RecordBatchExt,
|
||||
dataset::{scanner::DatasetRecordBatchStream, ROW_ID},
|
||||
};
|
||||
use lance::dataset::{scanner::DatasetRecordBatchStream, ROW_ID};
|
||||
use lance_arrow::RecordBatchExt;
|
||||
use lance_datafusion::exec::execute_plan;
|
||||
use lance_index::scalar::inverted::SCORE_COL;
|
||||
use lance_index::scalar::FullTextSearchQuery;
|
||||
@@ -582,16 +580,40 @@ pub trait ExecutableQuery {
|
||||
options: QueryExecutionOptions,
|
||||
) -> impl Future<Output = Result<SendableRecordBatchStream>> + Send;
|
||||
|
||||
/// Explain the plan for a query
|
||||
///
|
||||
/// This will create a string representation of the plan that will be used to
|
||||
/// execute the query. This will not execute the query.
|
||||
///
|
||||
/// This function can be used to get an understanding of what work will be done by the query
|
||||
/// and is useful for debugging query performance.
|
||||
fn explain_plan(&self, verbose: bool) -> impl Future<Output = Result<String>> + Send;
|
||||
|
||||
/// Execute the query and display the runtime metrics
|
||||
///
|
||||
/// This shows the same plan as [`ExecutableQuery::explain_plan`] but includes runtime metrics.
|
||||
///
|
||||
/// This function will actually execute the query in order to get the runtime metrics.
|
||||
fn analyze_plan(&self) -> impl Future<Output = Result<String>> + Send {
|
||||
self.analyze_plan_with_options(QueryExecutionOptions::default())
|
||||
}
|
||||
|
||||
/// Execute the query and display the runtime metrics
|
||||
///
|
||||
/// This is the same as [`ExecutableQuery::analyze_plan`] but allows for specifying the execution options.
|
||||
fn analyze_plan_with_options(
|
||||
&self,
|
||||
options: QueryExecutionOptions,
|
||||
) -> impl Future<Output = Result<String>> + Send;
|
||||
|
||||
/// Return the output schema for data returned by the query without actually executing the query
|
||||
///
|
||||
/// This can be useful when the selection for a query is built dynamically as it is not always
|
||||
/// obvious what the output schema will be.
|
||||
fn output_schema(&self) -> impl Future<Output = Result<SchemaRef>> + Send {
|
||||
self.create_plan(QueryExecutionOptions::default())
|
||||
.and_then(|plan| std::future::ready(Ok(plan.schema())))
|
||||
}
|
||||
}
|
||||
|
||||
/// A query filter that can be applied to a query
|
||||
@@ -1505,6 +1527,16 @@ mod tests {
|
||||
.query()
|
||||
.limit(10)
|
||||
.select(Select::dynamic(&[("id2", "id * 2"), ("id", "id")]));
|
||||
|
||||
let schema = query.output_schema().await.unwrap();
|
||||
assert_eq!(
|
||||
schema,
|
||||
Arc::new(ArrowSchema::new(vec![
|
||||
ArrowField::new("id2", DataType::Int32, true),
|
||||
ArrowField::new("id", DataType::Int32, true),
|
||||
]))
|
||||
);
|
||||
|
||||
let result = query.execute().await;
|
||||
let mut batches = result
|
||||
.expect("should have result")
|
||||
|
||||
@@ -1427,6 +1427,10 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
"NOT_SUPPORTED"
|
||||
}
|
||||
|
||||
async fn storage_options(&self) -> Option<HashMap<String, String>> {
|
||||
None
|
||||
}
|
||||
|
||||
async fn stats(&self) -> Result<TableStatistics> {
|
||||
let request = self
|
||||
.client
|
||||
|
||||
@@ -511,6 +511,9 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
|
||||
/// Get the namespace of the table.
|
||||
fn namespace(&self) -> &[String];
|
||||
/// Get the id of the table
|
||||
///
|
||||
/// This is the namespace of the table concatenated with the name
|
||||
/// separated by a dot (".")
|
||||
fn id(&self) -> &str;
|
||||
/// Get the arrow [Schema] of the table.
|
||||
async fn schema(&self) -> Result<SchemaRef>;
|
||||
@@ -598,6 +601,8 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
|
||||
async fn table_definition(&self) -> Result<TableDefinition>;
|
||||
/// Get the table URI
|
||||
fn dataset_uri(&self) -> &str;
|
||||
/// Get the storage options used when opening this table, if any.
|
||||
async fn storage_options(&self) -> Option<HashMap<String, String>>;
|
||||
/// Poll until the columns are fully indexed. Will return Error::Timeout if the columns
|
||||
/// are not fully indexed within the timeout.
|
||||
async fn wait_for_index(
|
||||
@@ -1290,6 +1295,13 @@ impl Table {
|
||||
self.inner.dataset_uri()
|
||||
}
|
||||
|
||||
/// Get the storage options used when opening this table, if any.
|
||||
///
|
||||
/// Warning: This is an internal API and the return value is subject to change.
|
||||
pub async fn storage_options(&self) -> Option<HashMap<String, String>> {
|
||||
self.inner.storage_options().await
|
||||
}
|
||||
|
||||
/// Get statistics about an index.
|
||||
/// Returns None if the index does not exist.
|
||||
pub async fn index_stats(
|
||||
@@ -2614,6 +2626,14 @@ impl BaseTable for NativeTable {
|
||||
self.uri.as_str()
|
||||
}
|
||||
|
||||
async fn storage_options(&self) -> Option<HashMap<String, String>> {
|
||||
self.dataset
|
||||
.get()
|
||||
.await
|
||||
.ok()
|
||||
.and_then(|dataset| dataset.storage_options().cloned())
|
||||
}
|
||||
|
||||
async fn index_stats(&self, index_name: &str) -> Result<Option<IndexStatistics>> {
|
||||
let stats = match self
|
||||
.dataset
|
||||
@@ -2623,7 +2643,7 @@ impl BaseTable for NativeTable {
|
||||
.await
|
||||
{
|
||||
Ok(stats) => stats,
|
||||
Err(lance::error::Error::IndexNotFound { .. }) => return Ok(None),
|
||||
Err(lance_core::Error::IndexNotFound { .. }) => return Ok(None),
|
||||
Err(e) => return Err(Error::from(e)),
|
||||
};
|
||||
|
||||
|
||||
@@ -1,126 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! Functions for testing connections.
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test_utils {
|
||||
use regex::Regex;
|
||||
use std::env;
|
||||
use std::io::{BufRead, BufReader};
|
||||
use std::process::{Child, ChildStdout, Command, Stdio};
|
||||
|
||||
use crate::{connect, Connection};
|
||||
use anyhow::{bail, Result};
|
||||
use tempfile::{tempdir, TempDir};
|
||||
|
||||
pub struct TestConnection {
|
||||
pub uri: String,
|
||||
pub connection: Connection,
|
||||
_temp_dir: Option<TempDir>,
|
||||
_process: Option<TestProcess>,
|
||||
}
|
||||
|
||||
struct TestProcess {
|
||||
child: Child,
|
||||
}
|
||||
|
||||
impl Drop for TestProcess {
|
||||
#[allow(unused_must_use)]
|
||||
fn drop(&mut self) {
|
||||
self.child.kill();
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn new_test_connection() -> Result<TestConnection> {
|
||||
match env::var("CREATE_LANCEDB_TEST_CONNECTION_SCRIPT") {
|
||||
Ok(script_path) => new_remote_connection(&script_path).await,
|
||||
Err(_e) => new_local_connection().await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn new_remote_connection(script_path: &str) -> Result<TestConnection> {
|
||||
let temp_dir = tempdir()?;
|
||||
let data_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let child_result = Command::new(script_path)
|
||||
.stdin(Stdio::null())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.arg(data_path.clone())
|
||||
.spawn();
|
||||
if child_result.is_err() {
|
||||
bail!(format!(
|
||||
"Unable to run {}: {:?}",
|
||||
script_path,
|
||||
child_result.err()
|
||||
));
|
||||
}
|
||||
let mut process = TestProcess {
|
||||
child: child_result.unwrap(),
|
||||
};
|
||||
let stdout = BufReader::new(process.child.stdout.take().unwrap());
|
||||
let port = read_process_port(stdout)?;
|
||||
let uri = "db://test";
|
||||
let host_override = format!("http://localhost:{}", port);
|
||||
let connection = create_new_connection(uri, &host_override).await?;
|
||||
Ok(TestConnection {
|
||||
uri: uri.to_string(),
|
||||
connection,
|
||||
_temp_dir: Some(temp_dir),
|
||||
_process: Some(process),
|
||||
})
|
||||
}
|
||||
|
||||
fn read_process_port(mut stdout: BufReader<ChildStdout>) -> Result<String> {
|
||||
let mut line = String::new();
|
||||
let re = Regex::new(r"Query node now listening on 0.0.0.0:(.*)").unwrap();
|
||||
loop {
|
||||
let result = stdout.read_line(&mut line);
|
||||
if let Err(err) = result {
|
||||
bail!(format!(
|
||||
"read_process_port: error while reading from process output: {}",
|
||||
err
|
||||
));
|
||||
} else if result.unwrap() == 0 {
|
||||
bail!("read_process_port: hit EOF before reading port from process output.");
|
||||
}
|
||||
if re.is_match(&line) {
|
||||
let caps = re.captures(&line).unwrap();
|
||||
return Ok(caps[1].to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
async fn create_new_connection(
|
||||
uri: &str,
|
||||
host_override: &str,
|
||||
) -> crate::error::Result<Connection> {
|
||||
connect(uri)
|
||||
.region("us-east-1")
|
||||
.api_key("sk_localtest")
|
||||
.host_override(host_override)
|
||||
.execute()
|
||||
.await
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "remote"))]
|
||||
async fn create_new_connection(
|
||||
_uri: &str,
|
||||
_host_override: &str,
|
||||
) -> crate::error::Result<Connection> {
|
||||
panic!("remote feature not supported");
|
||||
}
|
||||
|
||||
async fn new_local_connection() -> Result<TestConnection> {
|
||||
let temp_dir = tempdir()?;
|
||||
let uri = temp_dir.path().to_str().unwrap();
|
||||
let connection = connect(uri).execute().await?;
|
||||
Ok(TestConnection {
|
||||
uri: uri.to_string(),
|
||||
connection,
|
||||
_temp_dir: Some(temp_dir),
|
||||
_process: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
5
rust/lancedb/src/test_utils.rs
Normal file
5
rust/lancedb/src/test_utils.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
pub mod connection;
|
||||
pub mod datagen;
|
||||
120
rust/lancedb/src/test_utils/connection.rs
Normal file
120
rust/lancedb/src/test_utils/connection.rs
Normal file
@@ -0,0 +1,120 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! Functions for testing connections.
|
||||
|
||||
use regex::Regex;
|
||||
use std::env;
|
||||
use std::io::{BufRead, BufReader};
|
||||
use std::process::{Child, ChildStdout, Command, Stdio};
|
||||
|
||||
use crate::{connect, Connection};
|
||||
use anyhow::{bail, Result};
|
||||
use tempfile::{tempdir, TempDir};
|
||||
|
||||
pub struct TestConnection {
|
||||
pub uri: String,
|
||||
pub connection: Connection,
|
||||
_temp_dir: Option<TempDir>,
|
||||
_process: Option<TestProcess>,
|
||||
}
|
||||
|
||||
struct TestProcess {
|
||||
child: Child,
|
||||
}
|
||||
|
||||
impl Drop for TestProcess {
|
||||
#[allow(unused_must_use)]
|
||||
fn drop(&mut self) {
|
||||
self.child.kill();
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn new_test_connection() -> Result<TestConnection> {
|
||||
match env::var("CREATE_LANCEDB_TEST_CONNECTION_SCRIPT") {
|
||||
Ok(script_path) => new_remote_connection(&script_path).await,
|
||||
Err(_e) => new_local_connection().await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn new_remote_connection(script_path: &str) -> Result<TestConnection> {
|
||||
let temp_dir = tempdir()?;
|
||||
let data_path = temp_dir.path().to_str().unwrap().to_string();
|
||||
let child_result = Command::new(script_path)
|
||||
.stdin(Stdio::null())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.arg(data_path.clone())
|
||||
.spawn();
|
||||
if child_result.is_err() {
|
||||
bail!(format!(
|
||||
"Unable to run {}: {:?}",
|
||||
script_path,
|
||||
child_result.err()
|
||||
));
|
||||
}
|
||||
let mut process = TestProcess {
|
||||
child: child_result.unwrap(),
|
||||
};
|
||||
let stdout = BufReader::new(process.child.stdout.take().unwrap());
|
||||
let port = read_process_port(stdout)?;
|
||||
let uri = "db://test";
|
||||
let host_override = format!("http://localhost:{}", port);
|
||||
let connection = create_new_connection(uri, &host_override).await?;
|
||||
Ok(TestConnection {
|
||||
uri: uri.to_string(),
|
||||
connection,
|
||||
_temp_dir: Some(temp_dir),
|
||||
_process: Some(process),
|
||||
})
|
||||
}
|
||||
|
||||
fn read_process_port(mut stdout: BufReader<ChildStdout>) -> Result<String> {
|
||||
let mut line = String::new();
|
||||
let re = Regex::new(r"Query node now listening on 0.0.0.0:(.*)").unwrap();
|
||||
loop {
|
||||
let result = stdout.read_line(&mut line);
|
||||
if let Err(err) = result {
|
||||
bail!(format!(
|
||||
"read_process_port: error while reading from process output: {}",
|
||||
err
|
||||
));
|
||||
} else if result.unwrap() == 0 {
|
||||
bail!("read_process_port: hit EOF before reading port from process output.");
|
||||
}
|
||||
if re.is_match(&line) {
|
||||
let caps = re.captures(&line).unwrap();
|
||||
return Ok(caps[1].to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
async fn create_new_connection(uri: &str, host_override: &str) -> crate::error::Result<Connection> {
|
||||
connect(uri)
|
||||
.region("us-east-1")
|
||||
.api_key("sk_localtest")
|
||||
.host_override(host_override)
|
||||
.execute()
|
||||
.await
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "remote"))]
|
||||
async fn create_new_connection(
|
||||
_uri: &str,
|
||||
_host_override: &str,
|
||||
) -> crate::error::Result<Connection> {
|
||||
panic!("remote feature not supported");
|
||||
}
|
||||
|
||||
async fn new_local_connection() -> Result<TestConnection> {
|
||||
let temp_dir = tempdir()?;
|
||||
let uri = temp_dir.path().to_str().unwrap();
|
||||
let connection = connect(uri).execute().await?;
|
||||
Ok(TestConnection {
|
||||
uri: uri.to_string(),
|
||||
connection,
|
||||
_temp_dir: Some(temp_dir),
|
||||
_process: None,
|
||||
})
|
||||
}
|
||||
55
rust/lancedb/src/test_utils/datagen.rs
Normal file
55
rust/lancedb/src/test_utils/datagen.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use arrow_array::RecordBatch;
|
||||
use futures::TryStreamExt;
|
||||
use lance_datagen::{BatchCount, BatchGeneratorBuilder, RowCount};
|
||||
|
||||
use crate::{
|
||||
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
|
||||
connect, Error, Table,
|
||||
};
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait LanceDbDatagenExt {
|
||||
async fn into_mem_table(
|
||||
self,
|
||||
table_name: &str,
|
||||
rows_per_batch: RowCount,
|
||||
num_batches: BatchCount,
|
||||
) -> Table;
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl LanceDbDatagenExt for BatchGeneratorBuilder {
|
||||
async fn into_mem_table(
|
||||
self,
|
||||
table_name: &str,
|
||||
rows_per_batch: RowCount,
|
||||
num_batches: BatchCount,
|
||||
) -> Table {
|
||||
let (stream, schema) = self.into_reader_stream(rows_per_batch, num_batches);
|
||||
let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream::new(
|
||||
stream.map_err(Error::from),
|
||||
schema,
|
||||
));
|
||||
let db = connect("memory:///").execute().await.unwrap();
|
||||
db.create_table_streaming(table_name, stream)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn virtual_table(name: &str, values: &RecordBatch) -> Table {
|
||||
let schema = values.schema();
|
||||
let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream::new(
|
||||
futures::stream::once(std::future::ready(Ok(values.clone()))),
|
||||
schema,
|
||||
));
|
||||
let db = connect("memory:///").execute().await.unwrap();
|
||||
db.create_table_streaming(name, stream)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
Reference in New Issue
Block a user