Compare commits

...

35 Commits

Author SHA1 Message Date
Lance Release
972c682857 Bump version: 0.27.1 → 0.28.0-beta.0 2026-02-03 04:47:20 +00:00
LuQQiu
4f8ee82730 chore: update lance core java version to 1.0.4 (#2971) 2026-02-02 20:43:36 -08:00
Will Jones
131024839f fix: include _rowid in hash and calculated split projections (#2965)
## Summary

- PR #2957 changed the permutation builder to only select `_rowid` from
the base table, but `Splitter::project()` for hash and calculated splits
replaced the selection entirely, dropping `_rowid`.
- Include `_rowid` in the column selections for hash and calculated
split projections.
- Fix a Python test that queried the permutation table for base table
columns no longer materialized.

Fixes the `test_split_hash`, `test_split_hash_with_discard`,
`test_split_calculated`, `test_shuffle_combined_with_splits`, and
`test_filter_with_splits` failures in `test_permutation.py`.

## Test plan

- [x] `cargo test -p lancedb -- permutation` (22 passed)
- [x] `pytest python/tests/test_permutation.py` (46 passed)
- [x] `npm test __test__/permutation.test.ts` (20 passed)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-02 16:27:58 -08:00
ChinmayGowda71
3c7ddf4d0c refactor: modularize table.rs and extract delete logic (#2952)
References #2949 Moved DeleteResult and delete() implementation to
src/table/delete.rs. No functional changes. Added a test delete which
works. Will work on refactoring update next.
2026-02-02 11:54:49 -08:00
Siyuan Huang
461176f9f2 docs: update REST API link in README.md (#2906)
Fix broken REST API docs link in README.md by replacing
https://docs.lancedb.com/api-reference/introduction (404) with
https://docs.lancedb.com/api-reference/rest
2026-01-30 15:49:41 -08:00
Aman Harsh
3b8996bb69 fix(python): cancel remote queries on sync API interruption (#2913)
Fixes #2898 

Problem:
Sync API cancellations didn’t stop remote query coroutines, so requests
could continue after interrupt.

Changes:
- Cancel run_coroutine_threadsafe futures on any BaseException in the
sync background loop
- Update cancellation test to avoid starting a real background thread
and cover GeneratorExit
2026-01-30 15:47:18 -08:00
Mesut-Doner
3755064e93 fix(rust): support embeddings in create_empty_table (#2961)
Fixes the Rust SDK's `create_empty_table` to properly support embedding
column definitions, bringing it to parity with the Python SDK.

## Problem

The Rust SDK's `Connection::create_empty_table` did not support setting
embedding columns. When using `.add_embedding()` on the builder, the
embedding column definitions were lost because
`TableDefinition::new_from_schema(schema)` marks all columns as physical
only, without embedding metadata.

The Python SDK worked around this by creating an empty record batch with
proper schema metadata rather than using `create_empty_table` directly.

## Solution
Modified `CreateTableBuilder<false>` to handle embeddings

Closes #2759
2026-01-30 15:44:18 -08:00
Xin Sun
8773b865a9 fix(python): uses PIL incorrectly and may raise AttributeError (#2954)
Importing `PIL` alone does not guarantee that the `Image` submodule is
loaded. In a clean environment where no other code has imported
`PIL.Image` before, `PIL.Image` does not exist on the `PIL` package,
which leads to the AttributeError.
2026-01-30 15:33:10 -08:00
fzowl
1ee29675b3 feat(python): adding VoyageAI v4 models (#2959)
Adding VoyageAI v4 models
 - with these, i added unit tests
 - added example code (tested!)
2026-01-30 15:16:03 -08:00
Weston Pace
9be28448f5 fix: don't store all columns in the permutation table (#2957)
The permutation table was always intended to be a small table of row id
pointers (and split id). However, it was accidentally doing a full
materialization of the base table 🤦

This PR changes the permutation builder to only store row id and split
id.
2026-01-29 16:06:36 -08:00
Lei Xu
357197bacc chore!: change support python version from 3.10 to 3.13 (#2955)
Python 3.9 is EOL since Oct 2025. and last two pyarrow builts were
against python3.10-3.13.

* This PR is contributed by codex-gpt5.2
2026-01-30 01:47:50 +08:00
Lei Xu
ad51e2dd1f fix: support pydantic list of structs or optional struct (#2953)
Closes #2950

*This code is generated by codex-gpt5.2*
2026-01-28 21:08:18 -08:00
Weston Pace
e9e904783c feat: allow the permutation builder memory limit to be configured by env var (#2946)
Running into issues with DF sorting again. This will at least allow the
memory limit to be set large to bypass problems.
2026-01-28 09:02:59 +05:30
Lance Release
8500b16eca Bump version: 0.24.1-beta.0 → 0.24.1 2026-01-26 23:39:18 +00:00
Lance Release
57e7282342 Bump version: 0.24.0 → 0.24.1-beta.0 2026-01-26 23:38:50 +00:00
Lance Release
cc5f8070d7 Bump version: 0.27.1-beta.0 → 0.27.1 2026-01-26 23:38:24 +00:00
Lance Release
dc0fb01f6b Bump version: 0.27.0 → 0.27.1-beta.0 2026-01-26 23:38:23 +00:00
LanceDB Robot
94b7781551 feat: update lance dependency to v1.0.4 (#2944)
## Summary
- bump Lance dependencies to v1.0.4
- run `cargo clippy --workspace --tests --all-features -- -D warnings`
- run `cargo fmt --all`

## Testing
- `cargo clippy --workspace --tests --all-features -- -D warnings`

## Reference
- https://github.com/lance-format/lance/releases/tag/v1.0.4
2026-01-26 15:37:28 -08:00
Jack Ye
7bf020b3d5 chore: fix clippy when remote flag is not set (#2943)
Also add a step in CI to ensure this does not happen in the future
2026-01-26 13:59:31 -08:00
LanceDB Robot
12a98479dc chore: update lance dependency to v1.0.4-rc.1 (#2942)
## Summary
- bump Lance dependencies to v1.0.4-rc.1
- verified `cargo clippy --workspace --tests --all-features -- -D
warnings`
- ran `cargo fmt --all`

## References
- https://github.com/lance-format/lance/releases/tag/v1.0.4-rc.1
2026-01-26 12:17:22 -08:00
Jack Ye
e4552e577a chore(revert): revert update lance dependency to v2.0.0-rc.1 (#2936) (#2941)
This reverts commit bd84bba14d, so that we
can bump version to 1.0.4-rc.1
2026-01-26 11:13:59 -08:00
Will Jones
f979a902ad ci(rust): fix MSRV check (#2940)
Realized our MSRV check was inert because `rust-toolchain.toml` was
overriding the Rust version. We set the `RUSTUP_TOOLCHAIN` environment
variable, which overrides that.

Also needed to update to MSRV 1.88 (due to dependencies like Lance and
DataFusion) and fix some clippy warnings.
2026-01-23 15:57:09 -08:00
Colin Patrick McCabe
5a7a8da567 feat: check AZURE_STORAGE_ACCOUNT_NAME in remote conns (#2918)
Unlike in Amazon S3, in Azure bucket names are not globally unique.
Instead, the combination of (storage_account_name, bucket_name) is
unique.

Therefore, when using Azure blob store, we always need a way to
configure the storage account name. One way is to use the
storage_options hash map and set azure_storage_account_name. Another way
is to set an environment variable, AZURE_STORAGE_ACCOUNT_NAME.

Prior to this PR, the second way (environment variable) did not work
with remote connections. This is because the existing code that checks
for these environment variables happens inside the Azure object store
implementation itself, which does not run locally when using remote
connections.

This PR addresses that situation by adding a check of the environment
variable. This functions as a default if the relevant storage option is
not set in the storage_options hash map.
2026-01-22 13:36:05 -08:00
Jack Ye
0db8176445 test: fix failing remote doctest reference to aws feature (#2935)
Closes https://github.com/lancedb/lancedb/issues/2933
2026-01-22 13:17:03 -08:00
LanceDB Robot
bd84bba14d chore: update lance dependency to v2.0.0-rc.1 (#2936)
## Summary
- bump Lance dependencies to v2.0.0-rc.1 (git tag)
- align Arrow/DataFusion/PyO3 versions for the new Lance release
- update Python bindings for PyO3 0.26 (attach API + Py<PyAny>)

## Verification
- `cargo clippy --workspace --tests --all-features -- -D warnings`
- `cargo fmt --all`

## Reference
- https://github.com/lance-format/lance/releases/tag/v2.0.0-rc.1

---------

Co-authored-by: Jack Ye <yezhaoqin@gmail.com>
Co-authored-by: Will Jones <willjones127@gmail.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: BubbleCal <bubble_cal@outlook.com>
2026-01-22 13:14:38 -08:00
Lance Release
ac07f8068c Bump version: 0.24.0-beta.1 → 0.24.0 2026-01-22 01:10:15 +00:00
Lance Release
bba362d372 Bump version: 0.24.0-beta.0 → 0.24.0-beta.1 2026-01-22 01:09:53 +00:00
Lance Release
042bc22468 Bump version: 0.27.0-beta.1 → 0.27.0 2026-01-22 01:09:32 +00:00
Lance Release
68569906c6 Bump version: 0.27.0-beta.0 → 0.27.0-beta.1 2026-01-22 01:09:31 +00:00
LanceDB Robot
c71c1fc822 feat: update lance dependency to v1.0.3 (#2932)
## Summary
- bump Lance dependency to v1.0.3
- refresh Cargo metadata and lockfile

## Verification
- cargo clippy --workspace --tests --all-features -- -D warnings
- cargo fmt --all

## Release
- https://github.com/lance-format/lance/releases/tag/v1.0.3
2026-01-21 17:08:24 -08:00
Jack Ye
4a6a0c856e ci: fix codex version bump title and summary (#2931)
1. use feat for releases, chore for prereleases
2. do not have literal `\n` in summary
2026-01-21 15:45:28 -08:00
Jack Ye
f124c9d8d2 test: string type conversion in pandas 3.0+ (#2928)
Pandas 3.0+ string now converts to Arrow large_utf8. This PR mainly
makes sure our test accounts for the difference across the pandas
versions when constructing schema.
2026-01-21 13:40:48 -08:00
Jack Ye
4e65748abf chore: update lance dependency to v1.0.3-rc.1 (#2927)
Supercedes https://github.com/lancedb/lancedb/pull/2925

We accidentally upgraded lance to 2.0.0-beta.8. This PR reverts that
first and then bump to 1.0.3-rc.1

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-21 11:52:07 -08:00
Colin Patrick McCabe
e897f3edab test: assert remote behavior of drop_table (#2926)
Add support for testing remote connections in drop_table in
`rust/lancedb/src/connection.rs`.
2026-01-21 08:42:40 -08:00
Lance Release
790ba7115b Bump version: 0.23.1 → 0.24.0-beta.0 2026-01-21 12:21:53 +00:00
70 changed files with 1951 additions and 920 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion] [tool.bumpversion]
current_version = "0.23.1" current_version = "0.24.1"
parse = """(?x) parse = """(?x)
(?P<major>0|[1-9]\\d*)\\. (?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\. (?P<minor>0|[1-9]\\d*)\\.

View File

@@ -3,7 +3,7 @@ name: build-linux-wheel
description: "Build a manylinux wheel for lance" description: "Build a manylinux wheel for lance"
inputs: inputs:
python-minor-version: python-minor-version:
description: "8, 9, 10, 11, 12" description: "10, 11, 12, 13"
required: true required: true
args: args:
description: "--release" description: "--release"

View File

@@ -3,7 +3,7 @@ name: build_wheel
description: "Build a lance wheel" description: "Build a lance wheel"
inputs: inputs:
python-minor-version: python-minor-version:
description: "8, 9, 10, 11" description: "10, 11, 12, 13"
required: true required: true
args: args:
description: "--release" description: "--release"

View File

@@ -3,7 +3,7 @@ name: build_wheel
description: "Build a lance wheel" description: "Build a lance wheel"
inputs: inputs:
python-minor-version: python-minor-version:
description: "8, 9, 10, 11" description: "10, 11, 12, 13, 14"
required: true required: true
args: args:
description: "--release" description: "--release"

View File

@@ -75,6 +75,13 @@ jobs:
VERSION="${VERSION#v}" VERSION="${VERSION#v}"
BRANCH_NAME="codex/update-lance-${VERSION//[^a-zA-Z0-9]/-}" BRANCH_NAME="codex/update-lance-${VERSION//[^a-zA-Z0-9]/-}"
# Use "chore" for beta/rc versions, "feat" for stable releases
if [[ "${VERSION}" == *beta* ]] || [[ "${VERSION}" == *rc* ]]; then
COMMIT_TYPE="chore"
else
COMMIT_TYPE="feat"
fi
cat <<EOF >/tmp/codex-prompt.txt cat <<EOF >/tmp/codex-prompt.txt
You are running inside the lancedb repository on a GitHub Actions runner. Update the Lance dependency to version ${VERSION} and prepare a pull request for maintainers to review. You are running inside the lancedb repository on a GitHub Actions runner. Update the Lance dependency to version ${VERSION} and prepare a pull request for maintainers to review.
@@ -84,10 +91,10 @@ jobs:
3. After clippy succeeds, run "cargo fmt --all" to format the workspace. 3. After clippy succeeds, run "cargo fmt --all" to format the workspace.
4. Ensure the repository is clean except for intentional changes. Inspect "git status --short" and "git diff" to confirm the dependency update and any required fixes. 4. Ensure the repository is clean except for intentional changes. Inspect "git status --short" and "git diff" to confirm the dependency update and any required fixes.
5. Create and switch to a new branch named "${BRANCH_NAME}" (replace any duplicated hyphens if necessary). 5. Create and switch to a new branch named "${BRANCH_NAME}" (replace any duplicated hyphens if necessary).
6. Stage all relevant files with "git add -A". Commit using the message "chore: update lance dependency to v${VERSION}". 6. Stage all relevant files with "git add -A". Commit using the message "${COMMIT_TYPE}: update lance dependency to v${VERSION}".
7. Push the branch to origin. If the branch already exists, force-push your changes. 7. Push the branch to origin. If the branch already exists, force-push your changes.
8. env "GH_TOKEN" is available, use "gh" tools for github related operations like creating pull request. 8. env "GH_TOKEN" is available, use "gh" tools for github related operations like creating pull request.
9. Create a pull request targeting "main" with title "chore: update lance dependency to v${VERSION}". In the body, summarize the dependency bump, clippy/fmt verification, and link the triggering tag (${TAG}). 9. Create a pull request targeting "main" with title "${COMMIT_TYPE}: update lance dependency to v${VERSION}". First, write the PR body to /tmp/pr-body.md using a heredoc (cat <<'EOF' > /tmp/pr-body.md). The body should summarize the dependency bump, clippy/fmt verification, and link the triggering tag (${TAG}). Then run "gh pr create --body-file /tmp/pr-body.md".
10. After creating the PR, display the PR URL, "git status --short", and a concise summary of the commands run and their results. 10. After creating the PR, display the PR URL, "git status --short", and a concise summary of the commands run and their results.
Constraints: Constraints:

View File

@@ -41,7 +41,7 @@ jobs:
sudo apt install -y protobuf-compiler libssl-dev sudo apt install -y protobuf-compiler libssl-dev
rustup update && rustup default rustup update && rustup default
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 uses: actions/setup-python@v6
with: with:
python-version: "3.10" python-version: "3.10"
cache: "pip" cache: "pip"

View File

@@ -44,12 +44,12 @@ jobs:
fetch-depth: 0 fetch-depth: 0
lfs: true lfs: true
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v4 uses: actions/setup-python@v6
with: with:
python-version: 3.8 python-version: "3.10"
- uses: ./.github/workflows/build_linux_wheel - uses: ./.github/workflows/build_linux_wheel
with: with:
python-minor-version: 8 python-minor-version: 10
args: "--release --strip ${{ matrix.config.extra_args }}" args: "--release --strip ${{ matrix.config.extra_args }}"
arm-build: ${{ matrix.config.platform == 'aarch64' }} arm-build: ${{ matrix.config.platform == 'aarch64' }}
manylinux: ${{ matrix.config.manylinux }} manylinux: ${{ matrix.config.manylinux }}
@@ -74,12 +74,12 @@ jobs:
fetch-depth: 0 fetch-depth: 0
lfs: true lfs: true
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v4 uses: actions/setup-python@v6
with: with:
python-version: 3.12 python-version: "3.13"
- uses: ./.github/workflows/build_mac_wheel - uses: ./.github/workflows/build_mac_wheel
with: with:
python-minor-version: 8 python-minor-version: 10
args: "--release --strip --target ${{ matrix.config.target }} --features fp16kernels" args: "--release --strip --target ${{ matrix.config.target }} --features fp16kernels"
- uses: ./.github/workflows/upload_wheel - uses: ./.github/workflows/upload_wheel
if: startsWith(github.ref, 'refs/tags/python-v') if: startsWith(github.ref, 'refs/tags/python-v')
@@ -95,12 +95,12 @@ jobs:
fetch-depth: 0 fetch-depth: 0
lfs: true lfs: true
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v4 uses: actions/setup-python@v6
with: with:
python-version: 3.12 python-version: "3.13"
- uses: ./.github/workflows/build_windows_wheel - uses: ./.github/workflows/build_windows_wheel
with: with:
python-minor-version: 8 python-minor-version: 10
args: "--release --strip" args: "--release --strip"
vcpkg_token: ${{ secrets.VCPKG_GITHUB_PACKAGES }} vcpkg_token: ${{ secrets.VCPKG_GITHUB_PACKAGES }}
- uses: ./.github/workflows/upload_wheel - uses: ./.github/workflows/upload_wheel

View File

@@ -36,9 +36,9 @@ jobs:
fetch-depth: 0 fetch-depth: 0
lfs: true lfs: true
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 uses: actions/setup-python@v6
with: with:
python-version: "3.12" python-version: "3.13"
- name: Install ruff - name: Install ruff
run: | run: |
pip install ruff==0.9.9 pip install ruff==0.9.9
@@ -61,9 +61,9 @@ jobs:
fetch-depth: 0 fetch-depth: 0
lfs: true lfs: true
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 uses: actions/setup-python@v6
with: with:
python-version: "3.12" python-version: "3.13"
- name: Install protobuf compiler - name: Install protobuf compiler
run: | run: |
sudo apt update sudo apt update
@@ -90,9 +90,9 @@ jobs:
fetch-depth: 0 fetch-depth: 0
lfs: true lfs: true
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 uses: actions/setup-python@v6
with: with:
python-version: "3.12" python-version: "3.13"
cache: "pip" cache: "pip"
- name: Install protobuf - name: Install protobuf
run: | run: |
@@ -110,7 +110,7 @@ jobs:
timeout-minutes: 30 timeout-minutes: 30
strategy: strategy:
matrix: matrix:
python-minor-version: ["9", "12"] python-minor-version: ["10", "13"]
runs-on: "ubuntu-24.04" runs-on: "ubuntu-24.04"
defaults: defaults:
run: run:
@@ -126,7 +126,7 @@ jobs:
sudo apt update sudo apt update
sudo apt install -y protobuf-compiler sudo apt install -y protobuf-compiler
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 uses: actions/setup-python@v6
with: with:
python-version: 3.${{ matrix.python-minor-version }} python-version: 3.${{ matrix.python-minor-version }}
- uses: ./.github/workflows/build_linux_wheel - uses: ./.github/workflows/build_linux_wheel
@@ -156,9 +156,9 @@ jobs:
fetch-depth: 0 fetch-depth: 0
lfs: true lfs: true
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 uses: actions/setup-python@v6
with: with:
python-version: "3.12" python-version: "3.13"
- uses: ./.github/workflows/build_mac_wheel - uses: ./.github/workflows/build_mac_wheel
with: with:
args: --profile ci args: --profile ci
@@ -185,9 +185,9 @@ jobs:
fetch-depth: 0 fetch-depth: 0
lfs: true lfs: true
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 uses: actions/setup-python@v6
with: with:
python-version: "3.12" python-version: "3.13"
- uses: ./.github/workflows/build_windows_wheel - uses: ./.github/workflows/build_windows_wheel
with: with:
args: --profile ci args: --profile ci
@@ -212,9 +212,9 @@ jobs:
sudo apt update sudo apt update
sudo apt install -y protobuf-compiler sudo apt install -y protobuf-compiler
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 uses: actions/setup-python@v6
with: with:
python-version: 3.9 python-version: "3.10"
- name: Install lancedb - name: Install lancedb
run: | run: |
pip install "pydantic<2" pip install "pydantic<2"

View File

@@ -48,6 +48,8 @@ jobs:
run: cargo fmt --all -- --check run: cargo fmt --all -- --check
- name: Run clippy - name: Run clippy
run: cargo clippy --profile ci --workspace --tests --all-features -- -D warnings run: cargo clippy --profile ci --workspace --tests --all-features -- -D warnings
- name: Run clippy (without remote feature)
run: cargo clippy --profile ci --workspace --tests -- -D warnings
build-no-lock: build-no-lock:
runs-on: ubuntu-24.04 runs-on: ubuntu-24.04
@@ -181,7 +183,7 @@ jobs:
runs-on: ubuntu-24.04 runs-on: ubuntu-24.04
strategy: strategy:
matrix: matrix:
msrv: ["1.78.0"] # This should match up with rust-version in Cargo.toml msrv: ["1.88.0"] # This should match up with rust-version in Cargo.toml
env: env:
# Need up-to-date compilers for kernels # Need up-to-date compilers for kernels
CC: clang-18 CC: clang-18
@@ -212,4 +214,6 @@ jobs:
cargo update -p aws-sdk-sts --precise 1.51.0 cargo update -p aws-sdk-sts --precise 1.51.0
cargo update -p home --precise 0.5.9 cargo update -p home --precise 0.5.9
- name: cargo +${{ matrix.msrv }} check - name: cargo +${{ matrix.msrv }} check
env:
RUSTUP_TOOLCHAIN: ${{ matrix.msrv }}
run: cargo check --profile ci --workspace --tests --benches --all-features run: cargo check --profile ci --workspace --tests --benches --all-features

837
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -12,42 +12,42 @@ repository = "https://github.com/lancedb/lancedb"
description = "Serverless, low-latency vector database for AI applications" description = "Serverless, low-latency vector database for AI applications"
keywords = ["lancedb", "lance", "database", "vector", "search"] keywords = ["lancedb", "lance", "database", "vector", "search"]
categories = ["database-implementations"] categories = ["database-implementations"]
rust-version = "1.78.0" rust-version = "1.88.0"
[workspace.dependencies] [workspace.dependencies]
lance = { "version" = "=2.0.0-beta.8", default-features = false, "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" } lance = { "version" = "=1.0.4", default-features = false, "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-core = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" } lance-core = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-datagen = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" } lance-datagen = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-file = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" } lance-file = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-io = { "version" = "=2.0.0-beta.8", default-features = false, "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" } lance-io = { "version" = "=1.0.4", default-features = false, "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-index = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" } lance-index = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-linalg = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" } lance-linalg = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" } lance-namespace = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace-impls = { "version" = "=2.0.0-beta.8", default-features = false, "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" } lance-namespace-impls = { "version" = "=1.0.4", default-features = false, "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-table = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" } lance-table = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-testing = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" } lance-testing = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-datafusion = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" } lance-datafusion = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-encoding = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" } lance-encoding = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-arrow = { "version" = "=2.0.0-beta.8", "tag" = "v2.0.0-beta.8", "git" = "https://github.com/lance-format/lance.git" } lance-arrow = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
ahash = "0.8" ahash = "0.8"
# Note that this one does not include pyarrow # Note that this one does not include pyarrow
arrow = { version = "57.2", optional = false } arrow = { version = "56.2", optional = false }
arrow-array = "57.2" arrow-array = "56.2"
arrow-data = "57.2" arrow-data = "56.2"
arrow-ipc = "57.2" arrow-ipc = "56.2"
arrow-ord = "57.2" arrow-ord = "56.2"
arrow-schema = "57.2" arrow-schema = "56.2"
arrow-select = "57.2" arrow-select = "56.2"
arrow-cast = "57.2" arrow-cast = "56.2"
async-trait = "0" async-trait = "0"
datafusion = { version = "51.0", default-features = false } datafusion = { version = "50.1", default-features = false }
datafusion-catalog = "51.0" datafusion-catalog = "50.1"
datafusion-common = { version = "51.0", default-features = false } datafusion-common = { version = "50.1", default-features = false }
datafusion-execution = "51.0" datafusion-execution = "50.1"
datafusion-expr = "51.0" datafusion-expr = "50.1"
datafusion-physical-plan = "51.0" datafusion-physical-plan = "50.1"
env_logger = "0.11" env_logger = "0.11"
half = { "version" = "2.7.1", default-features = false, features = [ half = { "version" = "2.6.0", default-features = false, features = [
"num-traits", "num-traits",
] } ] }
futures = "0" futures = "0"
@@ -59,7 +59,7 @@ rand = "0.9"
snafu = "0.8" snafu = "0.8"
url = "2" url = "2"
num-traits = "0.2" num-traits = "0.2"
regex = "1.12" regex = "1.10"
lazy_static = "1" lazy_static = "1"
semver = "1.0.25" semver = "1.0.25"
chrono = "0.4" chrono = "0.4"

View File

@@ -66,7 +66,7 @@ Follow the [Quickstart](https://lancedb.com/docs/quickstart/) doc to set up Lanc
| Python SDK | https://lancedb.github.io/lancedb/python/python/ | | Python SDK | https://lancedb.github.io/lancedb/python/python/ |
| Typescript SDK | https://lancedb.github.io/lancedb/js/globals/ | | Typescript SDK | https://lancedb.github.io/lancedb/js/globals/ |
| Rust SDK | https://docs.rs/lancedb/latest/lancedb/index.html | | Rust SDK | https://docs.rs/lancedb/latest/lancedb/index.html |
| REST API | https://docs.lancedb.com/api-reference/introduction | | REST API | https://docs.lancedb.com/api-reference/rest |
## **Join Us and Contribute** ## **Join Us and Contribute**

View File

@@ -0,0 +1,62 @@
# VoyageAI Embeddings
Voyage AI provides cutting-edge embedding and rerankers.
Using voyageai API requires voyageai package, which can be installed using `pip install voyageai`. Voyage AI embeddings are used to generate embeddings for text data. The embeddings can be used for various tasks like semantic search, clustering, and classification.
You also need to set the `VOYAGE_API_KEY` environment variable to use the VoyageAI API.
Supported models are:
**Voyage-4 Series (Latest)**
- voyage-4 (1024 dims, general-purpose and multilingual retrieval, 320K batch tokens)
- voyage-4-lite (1024 dims, optimized for latency and cost, 1M batch tokens)
- voyage-4-large (1024 dims, best retrieval quality, 120K batch tokens)
**Voyage-3 Series**
- voyage-3
- voyage-3-lite
**Domain-Specific Models**
- voyage-finance-2
- voyage-multilingual-2
- voyage-law-2
- voyage-code-2
Supported parameters (to be passed in `create` method) are:
| Parameter | Type | Default Value | Description |
|---|---|--------|---------|
| `name` | `str` | `None` | The model ID of the model to use. Supported base models for Text Embeddings: voyage-4, voyage-4-lite, voyage-4-large, voyage-3, voyage-3-lite, voyage-finance-2, voyage-multilingual-2, voyage-law-2, voyage-code-2 |
| `input_type` | `str` | `None` | Type of the input text. Default to None. Other options: query, document. |
| `truncation` | `bool` | `True` | Whether to truncate the input texts to fit within the context length. |
Usage Example:
```python
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import EmbeddingFunctionRegistry
voyageai = EmbeddingFunctionRegistry
.get_instance()
.get("voyageai")
.create(name="voyage-3")
class TextModel(LanceModel):
text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
data = [ { "text": "hello world" },
{ "text": "goodbye world" }]
db = lancedb.connect("~/.lancedb")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(data)
```

View File

@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
<dependency> <dependency>
<groupId>com.lancedb</groupId> <groupId>com.lancedb</groupId>
<artifactId>lancedb-core</artifactId> <artifactId>lancedb-core</artifactId>
<version>0.23.1</version> <version>0.24.1</version>
</dependency> </dependency>
``` ```

View File

@@ -8,7 +8,7 @@
<parent> <parent>
<groupId>com.lancedb</groupId> <groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId> <artifactId>lancedb-parent</artifactId>
<version>0.23.1-final.0</version> <version>0.24.1-final.0</version>
<relativePath>../pom.xml</relativePath> <relativePath>../pom.xml</relativePath>
</parent> </parent>

View File

@@ -6,7 +6,7 @@
<groupId>com.lancedb</groupId> <groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId> <artifactId>lancedb-parent</artifactId>
<version>0.23.1-final.0</version> <version>0.24.1-final.0</version>
<packaging>pom</packaging> <packaging>pom</packaging>
<name>${project.artifactId}</name> <name>${project.artifactId}</name>
<description>LanceDB Java SDK Parent POM</description> <description>LanceDB Java SDK Parent POM</description>
@@ -28,7 +28,7 @@
<properties> <properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<arrow.version>15.0.0</arrow.version> <arrow.version>15.0.0</arrow.version>
<lance-core.version>1.0.0-rc.2</lance-core.version> <lance-core.version>1.0.4</lance-core.version>
<spotless.skip>false</spotless.skip> <spotless.skip>false</spotless.skip>
<spotless.version>2.30.0</spotless.version> <spotless.version>2.30.0</spotless.version>
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version> <spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>

View File

@@ -1,7 +1,7 @@
[package] [package]
name = "lancedb-nodejs" name = "lancedb-nodejs"
edition.workspace = true edition.workspace = true
version = "0.23.1" version = "0.24.1"
license.workspace = true license.workspace = true
description.workspace = true description.workspace = true
repository.workspace = true repository.workspace = true

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-darwin-arm64", "name": "@lancedb/lancedb-darwin-arm64",
"version": "0.23.1", "version": "0.24.1",
"os": ["darwin"], "os": ["darwin"],
"cpu": ["arm64"], "cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node", "main": "lancedb.darwin-arm64.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-darwin-x64", "name": "@lancedb/lancedb-darwin-x64",
"version": "0.23.1", "version": "0.24.1",
"os": ["darwin"], "os": ["darwin"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.darwin-x64.node", "main": "lancedb.darwin-x64.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-arm64-gnu", "name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.23.1", "version": "0.24.1",
"os": ["linux"], "os": ["linux"],
"cpu": ["arm64"], "cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node", "main": "lancedb.linux-arm64-gnu.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-arm64-musl", "name": "@lancedb/lancedb-linux-arm64-musl",
"version": "0.23.1", "version": "0.24.1",
"os": ["linux"], "os": ["linux"],
"cpu": ["arm64"], "cpu": ["arm64"],
"main": "lancedb.linux-arm64-musl.node", "main": "lancedb.linux-arm64-musl.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-x64-gnu", "name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.23.1", "version": "0.24.1",
"os": ["linux"], "os": ["linux"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node", "main": "lancedb.linux-x64-gnu.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-x64-musl", "name": "@lancedb/lancedb-linux-x64-musl",
"version": "0.23.1", "version": "0.24.1",
"os": ["linux"], "os": ["linux"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.linux-x64-musl.node", "main": "lancedb.linux-x64-musl.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-win32-arm64-msvc", "name": "@lancedb/lancedb-win32-arm64-msvc",
"version": "0.23.1", "version": "0.24.1",
"os": [ "os": [
"win32" "win32"
], ],

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-win32-x64-msvc", "name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.23.1", "version": "0.24.1",
"os": ["win32"], "os": ["win32"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node", "main": "lancedb.win32-x64-msvc.node",

View File

@@ -1,12 +1,12 @@
{ {
"name": "@lancedb/lancedb", "name": "@lancedb/lancedb",
"version": "0.23.1", "version": "0.24.1",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "@lancedb/lancedb", "name": "@lancedb/lancedb",
"version": "0.23.1", "version": "0.24.1",
"cpu": [ "cpu": [
"x64", "x64",
"arm64" "arm64"

View File

@@ -11,7 +11,7 @@
"ann" "ann"
], ],
"private": false, "private": false,
"version": "0.23.1", "version": "0.24.1",
"main": "dist/index.js", "main": "dist/index.js",
"exports": { "exports": {
".": "./dist/index.js", ".": "./dist/index.js",

View File

@@ -1,5 +1,5 @@
[tool.bumpversion] [tool.bumpversion]
current_version = "0.27.0-beta.0" current_version = "0.28.0-beta.0"
parse = """(?x) parse = """(?x)
(?P<major>0|[1-9]\\d*)\\. (?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\. (?P<minor>0|[1-9]\\d*)\\.

View File

@@ -16,7 +16,7 @@ The Python package is a wrapper around the Rust library, `lancedb`. We use
To set up your development environment, you will need to install the following: To set up your development environment, you will need to install the following:
1. Python 3.9 or later 1. Python 3.10 or later
2. Cargo (Rust's package manager). Use [rustup](https://rustup.rs/) to install. 2. Cargo (Rust's package manager). Use [rustup](https://rustup.rs/) to install.
3. [protoc](https://grpc.io/docs/protoc-installation/) (Protocol Buffers compiler) 3. [protoc](https://grpc.io/docs/protoc-installation/) (Protocol Buffers compiler)

View File

@@ -1,28 +1,28 @@
[package] [package]
name = "lancedb-python" name = "lancedb-python"
version = "0.27.0-beta.0" version = "0.28.0-beta.0"
edition.workspace = true edition.workspace = true
description = "Python bindings for LanceDB" description = "Python bindings for LanceDB"
license.workspace = true license.workspace = true
repository.workspace = true repository.workspace = true
keywords.workspace = true keywords.workspace = true
categories.workspace = true categories.workspace = true
rust-version = "1.75.0" rust-version = "1.88.0"
[lib] [lib]
name = "_lancedb" name = "_lancedb"
crate-type = ["cdylib"] crate-type = ["cdylib"]
[dependencies] [dependencies]
arrow = { version = "57.2", features = ["pyarrow"] } arrow = { version = "56.2", features = ["pyarrow"] }
async-trait = "0.1" async-trait = "0.1"
lancedb = { path = "../rust/lancedb", default-features = false } lancedb = { path = "../rust/lancedb", default-features = false }
lance-core.workspace = true lance-core.workspace = true
lance-namespace.workspace = true lance-namespace.workspace = true
lance-io.workspace = true lance-io.workspace = true
env_logger.workspace = true env_logger.workspace = true
pyo3 = { version = "0.26", features = ["extension-module", "abi3-py39"] } pyo3 = { version = "0.25", features = ["extension-module", "abi3-py310"] }
pyo3-async-runtimes = { version = "0.26", features = [ pyo3-async-runtimes = { version = "0.25", features = [
"attributes", "attributes",
"tokio-runtime", "tokio-runtime",
] } ] }
@@ -32,9 +32,9 @@ snafu.workspace = true
tokio = { version = "1.40", features = ["sync"] } tokio = { version = "1.40", features = ["sync"] }
[build-dependencies] [build-dependencies]
pyo3-build-config = { version = "0.26", features = [ pyo3-build-config = { version = "0.25", features = [
"extension-module", "extension-module",
"abi3-py39", "abi3-py310",
] } ] }
[features] [features]

View File

@@ -16,7 +16,7 @@ description = "lancedb"
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }] authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]
license = { file = "LICENSE" } license = { file = "LICENSE" }
readme = "README.md" readme = "README.md"
requires-python = ">=3.9" requires-python = ">=3.10"
keywords = [ keywords = [
"data-format", "data-format",
"data-science", "data-science",
@@ -33,10 +33,10 @@ classifiers = [
"Programming Language :: Python", "Programming Language :: Python",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering",
] ]
@@ -137,4 +137,4 @@ include = [
"python/lancedb/_lancedb.pyi", "python/lancedb/_lancedb.pyi",
] ]
exclude = ["python/tests/"] exclude = ["python/tests/"]
pythonVersion = "3.12" pythonVersion = "3.13"

View File

@@ -22,7 +22,12 @@ class BackgroundEventLoop:
self.thread.start() self.thread.start()
def run(self, future): def run(self, future):
return asyncio.run_coroutine_threadsafe(future, self.loop).result() concurrent_future = asyncio.run_coroutine_threadsafe(future, self.loop)
try:
return concurrent_future.result()
except BaseException:
concurrent_future.cancel()
raise
LOOP = BackgroundEventLoop() LOOP = BackgroundEventLoop()

View File

@@ -275,7 +275,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
""" """
Convert image inputs to PIL Images. Convert image inputs to PIL Images.
""" """
PIL = attempt_import_or_raise("PIL", "pillow") PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
requests = attempt_import_or_raise("requests", "requests") requests = attempt_import_or_raise("requests", "requests")
images = self.sanitize_input(images) images = self.sanitize_input(images)
pil_images = [] pil_images = []
@@ -285,12 +285,12 @@ class ColPaliEmbeddings(EmbeddingFunction):
if image.startswith(("http://", "https://")): if image.startswith(("http://", "https://")):
response = requests.get(image, timeout=10) response = requests.get(image, timeout=10)
response.raise_for_status() response.raise_for_status()
pil_images.append(PIL.Image.open(io.BytesIO(response.content))) pil_images.append(PIL_Image.open(io.BytesIO(response.content)))
else: else:
with PIL.Image.open(image) as im: with PIL_Image.open(image) as im:
pil_images.append(im.copy()) pil_images.append(im.copy())
elif isinstance(image, bytes): elif isinstance(image, bytes):
pil_images.append(PIL.Image.open(io.BytesIO(image))) pil_images.append(PIL_Image.open(io.BytesIO(image)))
else: else:
# Assume it's a PIL Image; will raise if invalid # Assume it's a PIL Image; will raise if invalid
pil_images.append(image) pil_images.append(image)

View File

@@ -77,8 +77,8 @@ class JinaEmbeddings(EmbeddingFunction):
if isinstance(inputs, list): if isinstance(inputs, list):
inputs = inputs inputs = inputs
else: else:
PIL = attempt_import_or_raise("PIL", "pillow") PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(inputs, PIL.Image.Image): if isinstance(inputs, PIL_Image.Image):
inputs = [inputs] inputs = [inputs]
return inputs return inputs
@@ -89,13 +89,13 @@ class JinaEmbeddings(EmbeddingFunction):
elif isinstance(image, (str, Path)): elif isinstance(image, (str, Path)):
parsed = urlparse.urlparse(image) parsed = urlparse.urlparse(image)
# TODO handle drive letter on windows. # TODO handle drive letter on windows.
PIL = attempt_import_or_raise("PIL", "pillow") PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if parsed.scheme == "file": if parsed.scheme == "file":
pil_image = PIL.Image.open(parsed.path) pil_image = PIL_Image.open(parsed.path)
elif parsed.scheme == "": elif parsed.scheme == "":
pil_image = PIL.Image.open(image if os.name == "nt" else parsed.path) pil_image = PIL_Image.open(image if os.name == "nt" else parsed.path)
elif parsed.scheme.startswith("http"): elif parsed.scheme.startswith("http"):
pil_image = PIL.Image.open(io.BytesIO(url_retrieve(image))) pil_image = PIL_Image.open(io.BytesIO(url_retrieve(image)))
else: else:
raise NotImplementedError("Only local and http(s) urls are supported") raise NotImplementedError("Only local and http(s) urls are supported")
buffered = io.BytesIO() buffered = io.BytesIO()
@@ -103,9 +103,9 @@ class JinaEmbeddings(EmbeddingFunction):
image_bytes = buffered.getvalue() image_bytes = buffered.getvalue()
image_dict = {"image": base64.b64encode(image_bytes).decode("utf-8")} image_dict = {"image": base64.b64encode(image_bytes).decode("utf-8")}
else: else:
PIL = attempt_import_or_raise("PIL", "pillow") PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(image, PIL.Image.Image): if isinstance(image, PIL_Image.Image):
buffered = io.BytesIO() buffered = io.BytesIO()
image.save(buffered, format="PNG") image.save(buffered, format="PNG")
image_bytes = buffered.getvalue() image_bytes = buffered.getvalue()
@@ -136,9 +136,9 @@ class JinaEmbeddings(EmbeddingFunction):
elif isinstance(query, (Path, bytes)): elif isinstance(query, (Path, bytes)):
return [self.generate_image_embedding(query)] return [self.generate_image_embedding(query)]
else: else:
PIL = attempt_import_or_raise("PIL", "pillow") PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(query, PIL.Image.Image): if isinstance(query, PIL_Image.Image):
return [self.generate_image_embedding(query)] return [self.generate_image_embedding(query)]
else: else:
raise TypeError( raise TypeError(

View File

@@ -71,8 +71,8 @@ class OpenClipEmbeddings(EmbeddingFunction):
if isinstance(query, str): if isinstance(query, str):
return [self.generate_text_embeddings(query)] return [self.generate_text_embeddings(query)]
else: else:
PIL = attempt_import_or_raise("PIL", "pillow") PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(query, PIL.Image.Image): if isinstance(query, PIL_Image.Image):
return [self.generate_image_embedding(query)] return [self.generate_image_embedding(query)]
else: else:
raise TypeError("OpenClip supports str or PIL Image as query") raise TypeError("OpenClip supports str or PIL Image as query")
@@ -145,20 +145,20 @@ class OpenClipEmbeddings(EmbeddingFunction):
return self._encode_and_normalize_image(image) return self._encode_and_normalize_image(image)
def _to_pil(self, image: Union[str, bytes]): def _to_pil(self, image: Union[str, bytes]):
PIL = attempt_import_or_raise("PIL", "pillow") PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(image, bytes): if isinstance(image, bytes):
return PIL.Image.open(io.BytesIO(image)) return PIL_Image.open(io.BytesIO(image))
if isinstance(image, PIL.Image.Image): if isinstance(image, PIL_Image.Image):
return image return image
elif isinstance(image, str): elif isinstance(image, str):
parsed = urlparse.urlparse(image) parsed = urlparse.urlparse(image)
# TODO handle drive letter on windows. # TODO handle drive letter on windows.
if parsed.scheme == "file": if parsed.scheme == "file":
return PIL.Image.open(parsed.path) return PIL_Image.open(parsed.path)
elif parsed.scheme == "": elif parsed.scheme == "":
return PIL.Image.open(image if os.name == "nt" else parsed.path) return PIL_Image.open(image if os.name == "nt" else parsed.path)
elif parsed.scheme.startswith("http"): elif parsed.scheme.startswith("http"):
return PIL.Image.open(io.BytesIO(url_retrieve(image))) return PIL_Image.open(io.BytesIO(url_retrieve(image)))
else: else:
raise NotImplementedError("Only local and http(s) urls are supported") raise NotImplementedError("Only local and http(s) urls are supported")

View File

@@ -56,8 +56,8 @@ class SigLipEmbeddings(EmbeddingFunction):
if isinstance(query, str): if isinstance(query, str):
return [self.generate_text_embeddings(query)] return [self.generate_text_embeddings(query)]
else: else:
PIL = attempt_import_or_raise("PIL", "pillow") PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(query, PIL.Image.Image): if isinstance(query, PIL_Image.Image):
return [self.generate_image_embedding(query)] return [self.generate_image_embedding(query)]
else: else:
raise TypeError("SigLIP supports str or PIL Image as query") raise TypeError("SigLIP supports str or PIL Image as query")
@@ -127,21 +127,21 @@ class SigLipEmbeddings(EmbeddingFunction):
return image_features.cpu().detach().numpy().squeeze() return image_features.cpu().detach().numpy().squeeze()
def _to_pil(self, image: Union[str, bytes, "PIL.Image.Image"]): def _to_pil(self, image: Union[str, bytes, "PIL.Image.Image"]):
PIL = attempt_import_or_raise("PIL", "pillow") PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(image, PIL.Image.Image): if isinstance(image, PIL_Image.Image):
return image.convert("RGB") if image.mode != "RGB" else image return image.convert("RGB") if image.mode != "RGB" else image
elif isinstance(image, bytes): elif isinstance(image, bytes):
return PIL.Image.open(io.BytesIO(image)).convert("RGB") return PIL_Image.open(io.BytesIO(image)).convert("RGB")
elif isinstance(image, str): elif isinstance(image, str):
parsed = urlparse.urlparse(image) parsed = urlparse.urlparse(image)
if parsed.scheme == "file": if parsed.scheme == "file":
return PIL.Image.open(parsed.path).convert("RGB") return PIL_Image.open(parsed.path).convert("RGB")
elif parsed.scheme == "": elif parsed.scheme == "":
path = image if os.name == "nt" else parsed.path path = image if os.name == "nt" else parsed.path
return PIL.Image.open(path).convert("RGB") return PIL_Image.open(path).convert("RGB")
elif parsed.scheme.startswith("http"): elif parsed.scheme.startswith("http"):
image_bytes = url_retrieve(image) image_bytes = url_retrieve(image)
return PIL.Image.open(io.BytesIO(image_bytes)).convert("RGB") return PIL_Image.open(io.BytesIO(image_bytes)).convert("RGB")
else: else:
raise NotImplementedError("Only local and http(s) urls are supported") raise NotImplementedError("Only local and http(s) urls are supported")
else: else:

View File

@@ -21,6 +21,9 @@ if TYPE_CHECKING:
# Token limits for different VoyageAI models # Token limits for different VoyageAI models
VOYAGE_TOTAL_TOKEN_LIMITS = { VOYAGE_TOTAL_TOKEN_LIMITS = {
"voyage-4": 320_000,
"voyage-4-lite": 1_000_000,
"voyage-4-large": 120_000,
"voyage-context-3": 32_000, "voyage-context-3": 32_000,
"voyage-3.5-lite": 1_000_000, "voyage-3.5-lite": 1_000_000,
"voyage-3.5": 320_000, "voyage-3.5": 320_000,
@@ -61,7 +64,7 @@ def is_video_path(path: Path) -> bool:
def transform_input(input_data: Union[str, bytes, Path]): def transform_input(input_data: Union[str, bytes, Path]):
PIL = attempt_import_or_raise("PIL", "pillow") PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(input_data, str): if isinstance(input_data, str):
if is_valid_url(input_data): if is_valid_url(input_data):
if is_video_url(input_data): if is_video_url(input_data):
@@ -70,7 +73,7 @@ def transform_input(input_data: Union[str, bytes, Path]):
content = {"type": "image_url", "image_url": input_data} content = {"type": "image_url", "image_url": input_data}
else: else:
content = {"type": "text", "text": input_data} content = {"type": "text", "text": input_data}
elif isinstance(input_data, PIL.Image.Image): elif isinstance(input_data, PIL_Image.Image):
buffered = BytesIO() buffered = BytesIO()
input_data.save(buffered, format="JPEG") input_data.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
@@ -79,7 +82,7 @@ def transform_input(input_data: Union[str, bytes, Path]):
"image_base64": "data:image/jpeg;base64," + img_str, "image_base64": "data:image/jpeg;base64," + img_str,
} }
elif isinstance(input_data, bytes): elif isinstance(input_data, bytes):
img = PIL.Image.open(BytesIO(input_data)) img = PIL_Image.open(BytesIO(input_data))
buffered = BytesIO() buffered = BytesIO()
img.save(buffered, format="JPEG") img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
@@ -98,7 +101,7 @@ def transform_input(input_data: Union[str, bytes, Path]):
"video_base64": video_str, "video_base64": video_str,
} }
else: else:
img = PIL.Image.open(input_data) img = PIL_Image.open(input_data)
buffered = BytesIO() buffered = BytesIO()
img.save(buffered, format="JPEG") img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
@@ -116,8 +119,8 @@ def sanitize_multimodal_input(inputs: Union[TEXT, IMAGES]) -> List[Any]:
""" """
Sanitize the input to the embedding function. Sanitize the input to the embedding function.
""" """
PIL = attempt_import_or_raise("PIL", "pillow") PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(inputs, (str, bytes, Path, PIL.Image.Image)): if isinstance(inputs, (str, bytes, Path, PIL_Image.Image)):
inputs = [inputs] inputs = [inputs]
elif isinstance(inputs, list): elif isinstance(inputs, list):
pass # Already a list, use as-is pass # Already a list, use as-is
@@ -130,7 +133,7 @@ def sanitize_multimodal_input(inputs: Union[TEXT, IMAGES]) -> List[Any]:
f"Input type {type(inputs)} not allowed with multimodal model." f"Input type {type(inputs)} not allowed with multimodal model."
) )
if not all(isinstance(x, (str, bytes, Path, PIL.Image.Image)) for x in inputs): if not all(isinstance(x, (str, bytes, Path, PIL_Image.Image)) for x in inputs):
raise ValueError("Each input should be either str, bytes, Path or Image.") raise ValueError("Each input should be either str, bytes, Path or Image.")
return [transform_input(i) for i in inputs] return [transform_input(i) for i in inputs]
@@ -167,6 +170,9 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
name: str name: str
The name of the model to use. List of acceptable models: The name of the model to use. List of acceptable models:
* voyage-4 (1024 dims, general-purpose and multilingual retrieval)
* voyage-4-lite (1024 dims, optimized for latency and cost)
* voyage-4-large (1024 dims, best retrieval quality)
* voyage-context-3 * voyage-context-3
* voyage-3.5 * voyage-3.5
* voyage-3.5-lite * voyage-3.5-lite
@@ -215,6 +221,9 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
_FLEXIBLE_DIM_MODELS: ClassVar[list] = ["voyage-multimodal-3.5"] _FLEXIBLE_DIM_MODELS: ClassVar[list] = ["voyage-multimodal-3.5"]
_VALID_DIMENSIONS: ClassVar[list] = [256, 512, 1024, 2048] _VALID_DIMENSIONS: ClassVar[list] = [256, 512, 1024, 2048]
text_embedding_models: list = [ text_embedding_models: list = [
"voyage-4",
"voyage-4-lite",
"voyage-4-large",
"voyage-3.5", "voyage-3.5",
"voyage-3.5-lite", "voyage-3.5-lite",
"voyage-3", "voyage-3",
@@ -252,6 +261,9 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
elif self.name == "voyage-code-2": elif self.name == "voyage-code-2":
return 1536 return 1536
elif self.name in [ elif self.name in [
"voyage-4",
"voyage-4-lite",
"voyage-4-large",
"voyage-context-3", "voyage-context-3",
"voyage-3.5", "voyage-3.5",
"voyage-3.5-lite", "voyage-3.5-lite",

View File

@@ -275,7 +275,7 @@ def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
return pa.timestamp("us", tz=tz) return pa.timestamp("us", tz=tz)
elif getattr(py_type, "__origin__", None) in (list, tuple): elif getattr(py_type, "__origin__", None) in (list, tuple):
child = py_type.__args__[0] child = py_type.__args__[0]
return pa.list_(_py_type_to_arrow_type(child, field)) return _pydantic_list_child_to_arrow(child, field)
raise TypeError( raise TypeError(
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}." f"Converting Pydantic type to Arrow Type: unsupported type {py_type}."
) )
@@ -298,12 +298,18 @@ else:
def _pydantic_type_to_arrow_type(tp: Any, field: FieldInfo) -> pa.DataType: def _pydantic_type_to_arrow_type(tp: Any, field: FieldInfo) -> pa.DataType:
def _safe_issubclass(candidate: Any, base: type) -> bool:
try:
return issubclass(candidate, base)
except TypeError:
return False
if inspect.isclass(tp): if inspect.isclass(tp):
if issubclass(tp, pydantic.BaseModel): if _safe_issubclass(tp, pydantic.BaseModel):
# Struct # Struct
fields = _pydantic_model_to_fields(tp) fields = _pydantic_model_to_fields(tp)
return pa.struct(fields) return pa.struct(fields)
if issubclass(tp, FixedSizeListMixin): if _safe_issubclass(tp, FixedSizeListMixin):
if getattr(tp, "is_multi_vector", lambda: False)(): if getattr(tp, "is_multi_vector", lambda: False)():
return pa.list_(pa.list_(tp.value_arrow_type(), tp.dim())) return pa.list_(pa.list_(tp.value_arrow_type(), tp.dim()))
# For regular Vector # For regular Vector
@@ -311,45 +317,67 @@ def _pydantic_type_to_arrow_type(tp: Any, field: FieldInfo) -> pa.DataType:
return _py_type_to_arrow_type(tp, field) return _py_type_to_arrow_type(tp, field)
def _pydantic_list_child_to_arrow(child: Any, field: FieldInfo) -> pa.DataType:
unwrapped = _unwrap_optional_annotation(child)
if unwrapped is not None:
return pa.list_(
pa.field("item", _pydantic_type_to_arrow_type(unwrapped, field), True)
)
return pa.list_(_pydantic_type_to_arrow_type(child, field))
def _unwrap_optional_annotation(annotation: Any) -> Any | None:
if isinstance(annotation, (_GenericAlias, GenericAlias)):
origin = annotation.__origin__
args = annotation.__args__
if origin == Union:
non_none = [arg for arg in args if arg is not type(None)]
if len(non_none) == 1 and len(non_none) != len(args):
return non_none[0]
elif sys.version_info >= (3, 10) and isinstance(annotation, types.UnionType):
args = annotation.__args__
non_none = [arg for arg in args if arg is not type(None)]
if len(non_none) == 1 and len(non_none) != len(args):
return non_none[0]
return None
def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType: def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
"""Convert a Pydantic FieldInfo to Arrow DataType""" """Convert a Pydantic FieldInfo to Arrow DataType"""
unwrapped = _unwrap_optional_annotation(field.annotation)
if unwrapped is not None:
return _pydantic_type_to_arrow_type(unwrapped, field)
if isinstance(field.annotation, (_GenericAlias, GenericAlias)): if isinstance(field.annotation, (_GenericAlias, GenericAlias)):
origin = field.annotation.__origin__ origin = field.annotation.__origin__
args = field.annotation.__args__ args = field.annotation.__args__
if origin is list: if origin is list:
child = args[0] child = args[0]
return pa.list_(_py_type_to_arrow_type(child, field)) return _pydantic_list_child_to_arrow(child, field)
elif origin == Union:
if len(args) == 2 and args[1] is type(None):
return _pydantic_type_to_arrow_type(args[0], field)
elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType):
args = field.annotation.__args__
if len(args) == 2:
for typ in args:
if typ is type(None):
continue
return _py_type_to_arrow_type(typ, field)
return _pydantic_type_to_arrow_type(field.annotation, field) return _pydantic_type_to_arrow_type(field.annotation, field)
def is_nullable(field: FieldInfo) -> bool: def is_nullable(field: FieldInfo) -> bool:
"""Check if a Pydantic FieldInfo is nullable.""" """Check if a Pydantic FieldInfo is nullable."""
if _unwrap_optional_annotation(field.annotation) is not None:
return True
if isinstance(field.annotation, (_GenericAlias, GenericAlias)): if isinstance(field.annotation, (_GenericAlias, GenericAlias)):
origin = field.annotation.__origin__ origin = field.annotation.__origin__
args = field.annotation.__args__ args = field.annotation.__args__
if origin == Union: if origin == Union:
if len(args) == 2 and args[1] is type(None): if any(typ is type(None) for typ in args):
return True return True
elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType): elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType):
args = field.annotation.__args__ args = field.annotation.__args__
for typ in args: for typ in args:
if typ is type(None): if typ is type(None):
return True return True
elif inspect.isclass(field.annotation) and issubclass( elif inspect.isclass(field.annotation):
field.annotation, FixedSizeListMixin try:
): if issubclass(field.annotation, FixedSizeListMixin):
return field.annotation.nullable() return field.annotation.nullable()
except TypeError:
return False
return False return False

View File

@@ -961,27 +961,22 @@ class LanceQueryBuilder(ABC):
>>> query = [100, 100] >>> query = [100, 100]
>>> plan = table.search(query).analyze_plan() >>> plan = table.search(query).analyze_plan()
>>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE >>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
AnalyzeExec verbose=true, elapsed=..., metrics=... AnalyzeExec verbose=true, metrics=[], cumulative_cpu=...
TracedExec, elapsed=..., metrics=... TracedExec, metrics=[], cumulative_cpu=...
ProjectionExec: elapsed=..., expr=[...], ProjectionExec: expr=[...], metrics=[...], cumulative_cpu=...
metrics=[output_rows=..., elapsed_compute=..., output_bytes=...] GlobalLimitExec: skip=0, fetch=10, metrics=[...], cumulative_cpu=...
GlobalLimitExec: elapsed=..., skip=0, fetch=10, FilterExec: _distance@2 IS NOT NULL,
metrics=[output_rows=..., elapsed_compute=..., output_bytes=...] metrics=[output_rows=..., elapsed_compute=...], cumulative_cpu=...
FilterExec: elapsed=..., _distance@2 IS NOT NULL, metrics=[...] SortExec: TopK(fetch=10), expr=[...],
SortExec: elapsed=..., TopK(fetch=10), expr=[...],
preserve_partitioning=[...], preserve_partitioning=[...],
metrics=[output_rows=..., elapsed_compute=..., metrics=[output_rows=..., elapsed_compute=..., row_replacements=...],
output_bytes=..., row_replacements=...] cumulative_cpu=...
KNNVectorDistance: elapsed=..., metric=l2, KNNVectorDistance: metric=l2,
metrics=[output_rows=..., elapsed_compute=..., metrics=[output_rows=..., elapsed_compute=..., output_batches=...],
output_bytes=..., output_batches=...] cumulative_cpu=...
LanceRead: elapsed=..., uri=..., projection=[vector], LanceRead: uri=..., projection=[vector], ...
num_fragments=..., range_before=None, range_after=None, metrics=[output_rows=..., elapsed_compute=...,
row_id=true, row_addr=false, bytes_read=..., iops=..., requests=...], cumulative_cpu=...
full_filter=--, refine_filter=--,
metrics=[output_rows=..., elapsed_compute=..., output_bytes=...,
fragments_scanned=..., ranges_scanned=1, rows_scanned=1,
bytes_read=..., iops=..., requests=..., task_wait_time=...]
Returns Returns
------- -------

View File

@@ -2,12 +2,27 @@
# SPDX-FileCopyrightText: Copyright The LanceDB Authors # SPDX-FileCopyrightText: Copyright The LanceDB Authors
from datetime import timedelta from datetime import timedelta
from lancedb.db import AsyncConnection, DBConnection from lancedb.db import AsyncConnection, DBConnection
import lancedb import lancedb
import pytest import pytest
import pytest_asyncio import pytest_asyncio
def pandas_string_type():
"""Return the PyArrow string type that pandas uses for string columns.
pandas 3.0+ uses large_string for string columns, pandas 2.x uses string.
"""
import pandas as pd
import pyarrow as pa
version = tuple(int(x) for x in pd.__version__.split(".")[:2])
if version >= (3, 0):
return pa.large_utf8()
return pa.utf8()
# Use an in-memory database for most tests. # Use an in-memory database for most tests.
@pytest.fixture @pytest.fixture
def mem_db() -> DBConnection: def mem_db() -> DBConnection:

View File

@@ -268,6 +268,8 @@ async def test_create_table_from_iterator_async(mem_db_async: lancedb.AsyncConne
def test_create_exist_ok(tmp_db: lancedb.DBConnection): def test_create_exist_ok(tmp_db: lancedb.DBConnection):
from conftest import pandas_string_type
data = pd.DataFrame( data = pd.DataFrame(
{ {
"vector": [[3.1, 4.1], [5.9, 26.5]], "vector": [[3.1, 4.1], [5.9, 26.5]],
@@ -286,10 +288,11 @@ def test_create_exist_ok(tmp_db: lancedb.DBConnection):
assert tbl.schema == tbl2.schema assert tbl.schema == tbl2.schema
assert len(tbl) == len(tbl2) assert len(tbl) == len(tbl2)
# pandas 3.0+ uses large_string, pandas 2.x uses string
schema = pa.schema( schema = pa.schema(
[ [
pa.field("vector", pa.list_(pa.float32(), list_size=2)), pa.field("vector", pa.list_(pa.float32(), list_size=2)),
pa.field("item", pa.utf8()), pa.field("item", pandas_string_type()),
pa.field("price", pa.float64()), pa.field("price", pa.float64()),
] ]
) )
@@ -299,7 +302,7 @@ def test_create_exist_ok(tmp_db: lancedb.DBConnection):
bad_schema = pa.schema( bad_schema = pa.schema(
[ [
pa.field("vector", pa.list_(pa.float32(), list_size=2)), pa.field("vector", pa.list_(pa.float32(), list_size=2)),
pa.field("item", pa.utf8()), pa.field("item", pandas_string_type()),
pa.field("price", pa.float64()), pa.field("price", pa.float64()),
pa.field("extra", pa.float32()), pa.field("extra", pa.float32()),
] ]
@@ -365,6 +368,8 @@ async def test_create_mode_async(tmp_db_async: lancedb.AsyncConnection):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_exist_ok_async(tmp_db_async: lancedb.AsyncConnection): async def test_create_exist_ok_async(tmp_db_async: lancedb.AsyncConnection):
from conftest import pandas_string_type
data = pd.DataFrame( data = pd.DataFrame(
{ {
"vector": [[3.1, 4.1], [5.9, 26.5]], "vector": [[3.1, 4.1], [5.9, 26.5]],
@@ -382,10 +387,11 @@ async def test_create_exist_ok_async(tmp_db_async: lancedb.AsyncConnection):
assert tbl.name == tbl2.name assert tbl.name == tbl2.name
assert await tbl.schema() == await tbl2.schema() assert await tbl.schema() == await tbl2.schema()
# pandas 3.0+ uses large_string, pandas 2.x uses string
schema = pa.schema( schema = pa.schema(
[ [
pa.field("vector", pa.list_(pa.float32(), list_size=2)), pa.field("vector", pa.list_(pa.float32(), list_size=2)),
pa.field("item", pa.utf8()), pa.field("item", pandas_string_type()),
pa.field("price", pa.float64()), pa.field("price", pa.float64()),
] ]
) )
@@ -595,6 +601,8 @@ def test_open_table_sync(tmp_db: lancedb.DBConnection):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_open_table(tmp_path): async def test_open_table(tmp_path):
from conftest import pandas_string_type
db = await lancedb.connect_async(tmp_path) db = await lancedb.connect_async(tmp_path)
data = pd.DataFrame( data = pd.DataFrame(
{ {
@@ -614,10 +622,11 @@ async def test_open_table(tmp_path):
) )
is not None is not None
) )
# pandas 3.0+ uses large_string, pandas 2.x uses string
assert await tbl.schema() == pa.schema( assert await tbl.schema() == pa.schema(
{ {
"vector": pa.list_(pa.float32(), list_size=2), "vector": pa.list_(pa.float32(), list_size=2),
"item": pa.utf8(), "item": pandas_string_type(),
"price": pa.float64(), "price": pa.float64(),
} }
) )

View File

@@ -517,19 +517,36 @@ def test_ollama_embedding(tmp_path):
@pytest.mark.skipif( @pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set" os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
) )
def test_voyageai_embedding_function(): @pytest.mark.parametrize(
voyageai = get_registry().get("voyageai").create(name="voyage-3", max_retries=0) "model_name,expected_dims",
[
("voyage-3", 1024),
("voyage-4", 1024),
("voyage-4-lite", 1024),
("voyage-4-large", 1024),
],
)
def test_voyageai_embedding_function(model_name, expected_dims, tmp_path):
"""Integration test for VoyageAI text embedding models with real API calls."""
voyageai = get_registry().get("voyageai").create(name=model_name, max_retries=0)
class TextModel(LanceModel): class TextModel(LanceModel):
text: str = voyageai.SourceField() text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField() vector: Vector(voyageai.ndims()) = voyageai.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]}) df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect("~/lancedb") db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TextModel, mode="overwrite") tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df) tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims() assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
assert voyageai.ndims() == expected_dims, (
f"{model_name} should have {expected_dims} dimensions"
)
# Test search functionality
result = tbl.search("hello").limit(1).to_pandas()
assert result["text"][0] == "hello world"
@pytest.mark.slow @pytest.mark.slow

View File

@@ -26,6 +26,8 @@ import pytest
from lance_namespace import ( from lance_namespace import (
CreateEmptyTableRequest, CreateEmptyTableRequest,
CreateEmptyTableResponse, CreateEmptyTableResponse,
DeclareTableRequest,
DeclareTableResponse,
DescribeTableRequest, DescribeTableRequest,
DescribeTableResponse, DescribeTableResponse,
LanceNamespace, LanceNamespace,
@@ -160,6 +162,19 @@ class TrackingNamespace(LanceNamespace):
return modified return modified
def declare_table(self, request: DeclareTableRequest) -> DeclareTableResponse:
"""Track declare_table calls and inject rotating credentials."""
with self.lock:
self.create_call_count += 1
count = self.create_call_count
response = self.inner.declare_table(request)
response.storage_options = self._modify_storage_options(
response.storage_options, count
)
return response
def create_empty_table( def create_empty_table(
self, request: CreateEmptyTableRequest self, request: CreateEmptyTableRequest
) -> CreateEmptyTableResponse: ) -> CreateEmptyTableResponse:

View File

@@ -438,11 +438,15 @@ def test_filter_with_splits(mem_db):
row_count = permutation_tbl.count_rows() row_count = permutation_tbl.count_rows()
assert row_count == 67 assert row_count == 67
data = permutation_tbl.search(None).to_arrow().to_pydict() # Verify the permutation table only contains row_id and split_id
assert set(permutation_tbl.schema.names) == {"row_id", "split_id"}
row_ids = permutation_tbl.search(None).to_arrow().to_pydict()["row_id"]
data = tbl.take_row_ids(row_ids).to_arrow().to_pydict()
categories = data["category"] categories = data["category"]
# All categories should be A or B # All categories should be A or B
assert all(cat in ["A", "B"] for cat in categories) assert all(cat in ("A", "B") for cat in categories)
def test_filter_with_shuffle(mem_db): def test_filter_with_shuffle(mem_db):

View File

@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright The LanceDB Authors # SPDX-FileCopyrightText: Copyright The LanceDB Authors
import json import json
import sys
from datetime import date, datetime from datetime import date, datetime
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
@@ -20,10 +19,6 @@ from pydantic import BaseModel
from pydantic import Field from pydantic import Field
@pytest.mark.skipif(
sys.version_info < (3, 9),
reason="using native type alias requires python3.9 or higher",
)
def test_pydantic_to_arrow(): def test_pydantic_to_arrow():
class StructModel(pydantic.BaseModel): class StructModel(pydantic.BaseModel):
a: str a: str
@@ -83,10 +78,6 @@ def test_pydantic_to_arrow():
assert schema == expect_schema assert schema == expect_schema
@pytest.mark.skipif(
sys.version_info < (3, 10),
reason="using | type syntax requires python3.10 or higher",
)
def test_optional_types_py310(): def test_optional_types_py310():
class TestModel(pydantic.BaseModel): class TestModel(pydantic.BaseModel):
a: str | None a: str | None
@@ -105,10 +96,233 @@ def test_optional_types_py310():
assert schema == expect_schema assert schema == expect_schema
@pytest.mark.skipif( def test_optional_structs():
sys.version_info > (3, 8), class SplitInfo(pydantic.BaseModel):
reason="using native type alias requires python3.9 or higher", start_frame: int
) end_frame: int
class TestModel(pydantic.BaseModel):
id: str
split: SplitInfo | None = None
schema = pydantic_to_schema(TestModel)
expect_schema = pa.schema(
[
pa.field("id", pa.utf8(), False),
pa.field(
"split",
pa.struct(
[
pa.field("start_frame", pa.int64(), False),
pa.field("end_frame", pa.int64(), False),
]
),
True,
),
]
)
assert schema == expect_schema
def test_optional_struct_list_py310():
class SplitInfo(pydantic.BaseModel):
start_frame: int
end_frame: int
class TestModel(pydantic.BaseModel):
id: str
splits: list[SplitInfo] | None = None
schema = pydantic_to_schema(TestModel)
expect_schema = pa.schema(
[
pa.field("id", pa.utf8(), False),
pa.field(
"splits",
pa.list_(
pa.struct(
[
pa.field("start_frame", pa.int64(), False),
pa.field("end_frame", pa.int64(), False),
]
)
),
True,
),
]
)
assert schema == expect_schema
def test_nested_struct_list():
class SplitInfo(pydantic.BaseModel):
start_frame: int
end_frame: int
class TestModel(pydantic.BaseModel):
id: str
splits: list[SplitInfo]
schema = pydantic_to_schema(TestModel)
expect_schema = pa.schema(
[
pa.field("id", pa.utf8(), False),
pa.field(
"splits",
pa.list_(
pa.struct(
[
pa.field("start_frame", pa.int64(), False),
pa.field("end_frame", pa.int64(), False),
]
)
),
False,
),
]
)
assert schema == expect_schema
def test_nested_struct_list_optional():
class SplitInfo(pydantic.BaseModel):
start_frame: int
end_frame: int
class TestModel(pydantic.BaseModel):
id: str
splits: Optional[list[SplitInfo]] = None
schema = pydantic_to_schema(TestModel)
expect_schema = pa.schema(
[
pa.field("id", pa.utf8(), False),
pa.field(
"splits",
pa.list_(
pa.struct(
[
pa.field("start_frame", pa.int64(), False),
pa.field("end_frame", pa.int64(), False),
]
)
),
True,
),
]
)
assert schema == expect_schema
def test_nested_struct_list_optional_items():
class SplitInfo(pydantic.BaseModel):
start_frame: int
end_frame: int
class TestModel(pydantic.BaseModel):
id: str
splits: list[Optional[SplitInfo]]
schema = pydantic_to_schema(TestModel)
expect_schema = pa.schema(
[
pa.field("id", pa.utf8(), False),
pa.field(
"splits",
pa.list_(
pa.field(
"item",
pa.struct(
[
pa.field("start_frame", pa.int64(), False),
pa.field("end_frame", pa.int64(), False),
]
),
True,
)
),
False,
),
]
)
assert schema == expect_schema
def test_nested_struct_list_optional_container_and_items():
class SplitInfo(pydantic.BaseModel):
start_frame: int
end_frame: int
class TestModel(pydantic.BaseModel):
id: str
splits: Optional[list[Optional[SplitInfo]]] = None
schema = pydantic_to_schema(TestModel)
expect_schema = pa.schema(
[
pa.field("id", pa.utf8(), False),
pa.field(
"splits",
pa.list_(
pa.field(
"item",
pa.struct(
[
pa.field("start_frame", pa.int64(), False),
pa.field("end_frame", pa.int64(), False),
]
),
True,
)
),
True,
),
]
)
assert schema == expect_schema
def test_nested_struct_list_optional_items_pep604():
class SplitInfo(pydantic.BaseModel):
start_frame: int
end_frame: int
class TestModel(pydantic.BaseModel):
id: str
splits: list[SplitInfo | None]
schema = pydantic_to_schema(TestModel)
expect_schema = pa.schema(
[
pa.field("id", pa.utf8(), False),
pa.field(
"splits",
pa.list_(
pa.field(
"item",
pa.struct(
[
pa.field("start_frame", pa.int64(), False),
pa.field("end_frame", pa.int64(), False),
]
),
True,
)
),
False,
),
]
)
assert schema == expect_schema
def test_pydantic_to_arrow_py38(): def test_pydantic_to_arrow_py38():
class StructModel(pydantic.BaseModel): class StructModel(pydantic.BaseModel):
a: str a: str

View File

@@ -8,7 +8,7 @@ import http.server
import json import json
import threading import threading
import time import time
from unittest.mock import MagicMock from unittest.mock import MagicMock, patch
import uuid import uuid
from packaging.version import Version from packaging.version import Version
@@ -1203,3 +1203,22 @@ async def test_header_provider_overrides_static_headers():
extra_headers={"X-API-Key": "static-key", "X-Extra": "extra-value"}, extra_headers={"X-API-Key": "static-key", "X-Extra": "extra-value"},
) as db: ) as db:
await db.table_names() await db.table_names()
@pytest.mark.parametrize("exception", [KeyboardInterrupt, SystemExit, GeneratorExit])
def test_background_loop_cancellation(exception):
"""Test that BackgroundEventLoop.run() cancels the future on interrupt."""
from lancedb.background_loop import BackgroundEventLoop
mock_future = MagicMock()
mock_future.result.side_effect = exception()
with (
patch.object(BackgroundEventLoop, "__init__", return_value=None),
patch("asyncio.run_coroutine_threadsafe", return_value=mock_future),
):
loop = BackgroundEventLoop()
loop.loop = MagicMock()
with pytest.raises(exception):
loop.run(None)
mock_future.cancel.assert_called_once()

View File

@@ -528,12 +528,19 @@ def test_sanitize_data(
else: else:
expected_schema = schema expected_schema = schema
else: else:
from conftest import pandas_string_type
# polars uses large_string, pandas 3.0+ uses large_string, others use string
if isinstance(data, pl.DataFrame):
text_type = pa.large_utf8()
elif isinstance(data, pd.DataFrame):
text_type = pandas_string_type()
else:
text_type = pa.string()
expected_schema = pa.schema( expected_schema = pa.schema(
{ {
"id": pa.int64(), "id": pa.int64(),
"text": pa.large_utf8() "text": text_type,
if isinstance(data, pl.DataFrame)
else pa.string(),
"vector": pa.list_(pa.float32(), 10), "vector": pa.list_(pa.float32(), 10),
} }
) )

View File

@@ -0,0 +1,108 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
"""Unit tests for VoyageAI embedding function.
These tests verify model registration and configuration without requiring API calls.
"""
import pytest
from unittest.mock import MagicMock, patch
from lancedb.embeddings import get_registry
@pytest.fixture(autouse=True)
def reset_voyageai_client():
"""Reset VoyageAI client before and after each test to avoid state pollution."""
from lancedb.embeddings.voyageai import VoyageAIEmbeddingFunction
VoyageAIEmbeddingFunction.client = None
yield
VoyageAIEmbeddingFunction.client = None
class TestVoyageAIModelRegistration:
"""Tests for VoyageAI model registration and configuration."""
@pytest.fixture
def mock_voyageai_client(self):
"""Mock VoyageAI client to avoid API calls."""
with patch.dict("os.environ", {"VOYAGE_API_KEY": "test-key"}):
with patch("lancedb.embeddings.voyageai.attempt_import_or_raise") as mock:
mock_client = MagicMock()
mock_voyageai = MagicMock()
mock_voyageai.Client.return_value = mock_client
mock.return_value = mock_voyageai
yield mock_client
def test_voyageai_registered(self):
"""Test that VoyageAI is registered in the embedding function registry."""
registry = get_registry()
assert registry.get("voyageai") is not None
@pytest.mark.parametrize(
"model_name,expected_dims",
[
# Voyage-4 series (all 1024 dims)
("voyage-4", 1024),
("voyage-4-lite", 1024),
("voyage-4-large", 1024),
# Voyage-3 series
("voyage-3", 1024),
("voyage-3-lite", 512),
# Domain-specific models
("voyage-finance-2", 1024),
("voyage-multilingual-2", 1024),
("voyage-law-2", 1024),
("voyage-code-2", 1536),
# Multimodal
("voyage-multimodal-3", 1024),
],
)
def test_model_dimensions(self, model_name, expected_dims, mock_voyageai_client):
"""Test that each model returns the correct dimensions."""
registry = get_registry()
func = registry.get("voyageai").create(name=model_name)
assert func.ndims() == expected_dims, (
f"Model {model_name} should have {expected_dims} dimensions"
)
def test_unsupported_model_raises_error(self, mock_voyageai_client):
"""Test that unsupported models raise ValueError."""
registry = get_registry()
func = registry.get("voyageai").create(name="unsupported-model")
with pytest.raises(ValueError, match="not supported"):
func.ndims()
@pytest.mark.parametrize(
"model_name",
[
"voyage-4",
"voyage-4-lite",
"voyage-4-large",
],
)
def test_voyage4_models_are_text_models(self, model_name, mock_voyageai_client):
"""Test that voyage-4 models are classified as text models (not multimodal)."""
registry = get_registry()
func = registry.get("voyageai").create(name=model_name)
assert not func._is_multimodal_model(model_name), (
f"{model_name} should be a text model, not multimodal"
)
def test_voyage4_models_in_text_embedding_list(self, mock_voyageai_client):
"""Test that voyage-4 models are in the text_embedding_models list."""
registry = get_registry()
func = registry.get("voyageai").create(name="voyage-4")
assert "voyage-4" in func.text_embedding_models
assert "voyage-4-lite" in func.text_embedding_models
assert "voyage-4-large" in func.text_embedding_models
def test_voyage4_models_not_in_multimodal_list(self, mock_voyageai_client):
"""Test that voyage-4 models are NOT in the multimodal_embedding_models list."""
registry = get_registry()
func = registry.get("voyageai").create(name="voyage-4")
assert "voyage-4" not in func.multimodal_embedding_models
assert "voyage-4-lite" not in func.multimodal_embedding_models
assert "voyage-4-large" not in func.multimodal_embedding_models

View File

@@ -10,7 +10,8 @@ use arrow::{
use futures::stream::StreamExt; use futures::stream::StreamExt;
use lancedb::arrow::SendableRecordBatchStream; use lancedb::arrow::SendableRecordBatchStream;
use pyo3::{ use pyo3::{
exceptions::PyStopAsyncIteration, pyclass, pymethods, Bound, Py, PyAny, PyRef, PyResult, Python, exceptions::PyStopAsyncIteration, pyclass, pymethods, Bound, PyAny, PyObject, PyRef, PyResult,
Python,
}; };
use pyo3_async_runtimes::tokio::future_into_py; use pyo3_async_runtimes::tokio::future_into_py;
@@ -35,11 +36,8 @@ impl RecordBatchStream {
#[pymethods] #[pymethods]
impl RecordBatchStream { impl RecordBatchStream {
#[getter] #[getter]
pub fn schema(&self, py: Python) -> PyResult<Py<PyAny>> { pub fn schema(&self, py: Python) -> PyResult<PyObject> {
(*self.schema) (*self.schema).clone().into_pyarrow(py)
.clone()
.into_pyarrow(py)
.map(|obj| obj.unbind())
} }
pub fn __aiter__(self_: PyRef<'_, Self>) -> PyRef<'_, Self> { pub fn __aiter__(self_: PyRef<'_, Self>) -> PyRef<'_, Self> {
@@ -55,12 +53,7 @@ impl RecordBatchStream {
.next() .next()
.await .await
.ok_or_else(|| PyStopAsyncIteration::new_err(""))?; .ok_or_else(|| PyStopAsyncIteration::new_err(""))?;
#[allow(deprecated)] Python::with_gil(|py| inner_next.infer_error()?.to_pyarrow(py))
let py_obj: Py<PyAny> = Python::with_gil(|py| -> PyResult<Py<PyAny>> {
let bound = inner_next.infer_error()?.to_pyarrow(py)?;
Ok(bound.unbind())
})?;
Ok(py_obj)
}) })
} }
} }

View File

@@ -12,7 +12,7 @@ use pyo3::{
exceptions::{PyRuntimeError, PyValueError}, exceptions::{PyRuntimeError, PyValueError},
pyclass, pyfunction, pymethods, pyclass, pyfunction, pymethods,
types::{PyDict, PyDictMethods}, types::{PyDict, PyDictMethods},
Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python, Bound, FromPyObject, Py, PyAny, PyObject, PyRef, PyResult, Python,
}; };
use pyo3_async_runtimes::tokio::future_into_py; use pyo3_async_runtimes::tokio::future_into_py;
@@ -114,7 +114,7 @@ impl Connection {
data: Bound<'_, PyAny>, data: Bound<'_, PyAny>,
namespace: Vec<String>, namespace: Vec<String>,
storage_options: Option<HashMap<String, String>>, storage_options: Option<HashMap<String, String>>,
storage_options_provider: Option<Py<PyAny>>, storage_options_provider: Option<PyObject>,
location: Option<String>, location: Option<String>,
) -> PyResult<Bound<'a, PyAny>> { ) -> PyResult<Bound<'a, PyAny>> {
let inner = self_.get_inner()?.clone(); let inner = self_.get_inner()?.clone();
@@ -152,7 +152,7 @@ impl Connection {
schema: Bound<'_, PyAny>, schema: Bound<'_, PyAny>,
namespace: Vec<String>, namespace: Vec<String>,
storage_options: Option<HashMap<String, String>>, storage_options: Option<HashMap<String, String>>,
storage_options_provider: Option<Py<PyAny>>, storage_options_provider: Option<PyObject>,
location: Option<String>, location: Option<String>,
) -> PyResult<Bound<'a, PyAny>> { ) -> PyResult<Bound<'a, PyAny>> {
let inner = self_.get_inner()?.clone(); let inner = self_.get_inner()?.clone();
@@ -187,7 +187,7 @@ impl Connection {
name: String, name: String,
namespace: Vec<String>, namespace: Vec<String>,
storage_options: Option<HashMap<String, String>>, storage_options: Option<HashMap<String, String>>,
storage_options_provider: Option<Py<PyAny>>, storage_options_provider: Option<PyObject>,
index_cache_size: Option<u32>, index_cache_size: Option<u32>,
location: Option<String>, location: Option<String>,
) -> PyResult<Bound<'_, PyAny>> { ) -> PyResult<Bound<'_, PyAny>> {
@@ -307,7 +307,6 @@ impl Connection {
..Default::default() ..Default::default()
}; };
let response = inner.list_namespaces(request).await.infer_error()?; let response = inner.list_namespaces(request).await.infer_error()?;
#[allow(deprecated)]
Python::with_gil(|py| -> PyResult<Py<PyDict>> { Python::with_gil(|py| -> PyResult<Py<PyDict>> {
let dict = PyDict::new(py); let dict = PyDict::new(py);
dict.set_item("namespaces", response.namespaces)?; dict.set_item("namespaces", response.namespaces)?;
@@ -328,7 +327,8 @@ impl Connection {
let py = self_.py(); let py = self_.py();
future_into_py(py, async move { future_into_py(py, async move {
use lance_namespace::models::CreateNamespaceRequest; use lance_namespace::models::CreateNamespaceRequest;
let mode_enum = mode.and_then(|m| match m.to_lowercase().as_str() { // Mode is now a string field
let mode_str = mode.and_then(|m| match m.to_lowercase().as_str() {
"create" => Some("Create".to_string()), "create" => Some("Create".to_string()),
"exist_ok" => Some("ExistOk".to_string()), "exist_ok" => Some("ExistOk".to_string()),
"overwrite" => Some("Overwrite".to_string()), "overwrite" => Some("Overwrite".to_string()),
@@ -340,12 +340,11 @@ impl Connection {
} else { } else {
Some(namespace) Some(namespace)
}, },
mode: mode_enum, mode: mode_str,
properties, properties,
..Default::default() ..Default::default()
}; };
let response = inner.create_namespace(request).await.infer_error()?; let response = inner.create_namespace(request).await.infer_error()?;
#[allow(deprecated)]
Python::with_gil(|py| -> PyResult<Py<PyDict>> { Python::with_gil(|py| -> PyResult<Py<PyDict>> {
let dict = PyDict::new(py); let dict = PyDict::new(py);
dict.set_item("properties", response.properties)?; dict.set_item("properties", response.properties)?;
@@ -365,12 +364,13 @@ impl Connection {
let py = self_.py(); let py = self_.py();
future_into_py(py, async move { future_into_py(py, async move {
use lance_namespace::models::DropNamespaceRequest; use lance_namespace::models::DropNamespaceRequest;
let mode_enum = mode.and_then(|m| match m.to_uppercase().as_str() { // Mode and Behavior are now string fields
let mode_str = mode.and_then(|m| match m.to_uppercase().as_str() {
"SKIP" => Some("Skip".to_string()), "SKIP" => Some("Skip".to_string()),
"FAIL" => Some("Fail".to_string()), "FAIL" => Some("Fail".to_string()),
_ => None, _ => None,
}); });
let behavior_enum = behavior.and_then(|b| match b.to_uppercase().as_str() { let behavior_str = behavior.and_then(|b| match b.to_uppercase().as_str() {
"RESTRICT" => Some("Restrict".to_string()), "RESTRICT" => Some("Restrict".to_string()),
"CASCADE" => Some("Cascade".to_string()), "CASCADE" => Some("Cascade".to_string()),
_ => None, _ => None,
@@ -381,12 +381,11 @@ impl Connection {
} else { } else {
Some(namespace) Some(namespace)
}, },
mode: mode_enum, mode: mode_str,
behavior: behavior_enum, behavior: behavior_str,
..Default::default() ..Default::default()
}; };
let response = inner.drop_namespace(request).await.infer_error()?; let response = inner.drop_namespace(request).await.infer_error()?;
#[allow(deprecated)]
Python::with_gil(|py| -> PyResult<Py<PyDict>> { Python::with_gil(|py| -> PyResult<Py<PyDict>> {
let dict = PyDict::new(py); let dict = PyDict::new(py);
dict.set_item("properties", response.properties)?; dict.set_item("properties", response.properties)?;
@@ -414,7 +413,6 @@ impl Connection {
..Default::default() ..Default::default()
}; };
let response = inner.describe_namespace(request).await.infer_error()?; let response = inner.describe_namespace(request).await.infer_error()?;
#[allow(deprecated)]
Python::with_gil(|py| -> PyResult<Py<PyDict>> { Python::with_gil(|py| -> PyResult<Py<PyDict>> {
let dict = PyDict::new(py); let dict = PyDict::new(py);
dict.set_item("properties", response.properties)?; dict.set_item("properties", response.properties)?;
@@ -445,7 +443,6 @@ impl Connection {
..Default::default() ..Default::default()
}; };
let response = inner.list_tables(request).await.infer_error()?; let response = inner.list_tables(request).await.infer_error()?;
#[allow(deprecated)]
Python::with_gil(|py| -> PyResult<Py<PyDict>> { Python::with_gil(|py| -> PyResult<Py<PyDict>> {
let dict = PyDict::new(py); let dict = PyDict::new(py);
dict.set_item("tables", response.tables)?; dict.set_item("tables", response.tables)?;

View File

@@ -40,34 +40,31 @@ impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
request_id, request_id,
source, source,
status_code, status_code,
} => { } => Python::with_gil(|py| {
#[allow(deprecated)] let message = err.to_string();
Python::with_gil(|py| { let http_err_cls = py
let message = err.to_string(); .import(intern!(py, "lancedb.remote.errors"))?
let http_err_cls = py .getattr(intern!(py, "HttpError"))?;
.import(intern!(py, "lancedb.remote.errors"))? let err = http_err_cls.call1((
.getattr(intern!(py, "HttpError"))?; message,
let err = http_err_cls.call1(( request_id,
message, status_code.map(|s| s.as_u16()),
))?;
if let Some(cause) = source.source() {
// The HTTP error already includes the first cause. But
// we can add the rest of the chain if there is any more.
let cause_err = http_from_rust_error(
py,
cause,
request_id, request_id,
status_code.map(|s| s.as_u16()), status_code.map(|s| s.as_u16()),
))?; )?;
err.setattr(intern!(py, "__cause__"), cause_err)?;
}
if let Some(cause) = source.source() { Err(PyErr::from_value(err))
// The HTTP error already includes the first cause. But }),
// we can add the rest of the chain if there is any more.
let cause_err = http_from_rust_error(
py,
cause,
request_id,
status_code.map(|s| s.as_u16()),
)?;
err.setattr(intern!(py, "__cause__"), cause_err)?;
}
Err(PyErr::from_value(err))
})
}
LanceError::Retry { LanceError::Retry {
request_id, request_id,
request_failures, request_failures,
@@ -78,37 +75,33 @@ impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
max_read_failures, max_read_failures,
source, source,
status_code, status_code,
} => } => Python::with_gil(|py| {
{ let cause_err = http_from_rust_error(
#[allow(deprecated)] py,
Python::with_gil(|py| { source.as_ref(),
let cause_err = http_from_rust_error( request_id,
py, status_code.map(|s| s.as_u16()),
source.as_ref(), )?;
request_id,
status_code.map(|s| s.as_u16()),
)?;
let message = err.to_string(); let message = err.to_string();
let retry_error_cls = py let retry_error_cls = py
.import(intern!(py, "lancedb.remote.errors"))? .import(intern!(py, "lancedb.remote.errors"))?
.getattr("RetryError")?; .getattr("RetryError")?;
let err = retry_error_cls.call1(( let err = retry_error_cls.call1((
message, message,
request_id, request_id,
*request_failures, *request_failures,
*connect_failures, *connect_failures,
*read_failures, *read_failures,
*max_request_failures, *max_request_failures,
*max_connect_failures, *max_connect_failures,
*max_read_failures, *max_read_failures,
status_code.map(|s| s.as_u16()), status_code.map(|s| s.as_u16()),
))?; ))?;
err.setattr(intern!(py, "__cause__"), cause_err)?; err.setattr(intern!(py, "__cause__"), cause_err)?;
Err(PyErr::from_value(err)) Err(PyErr::from_value(err))
}) }),
}
_ => self.runtime_error(), _ => self.runtime_error(),
}, },
} }

View File

@@ -12,7 +12,6 @@ pub struct PyHeaderProvider {
impl Clone for PyHeaderProvider { impl Clone for PyHeaderProvider {
fn clone(&self) -> Self { fn clone(&self) -> Self {
#[allow(deprecated)]
Python::with_gil(|py| Self { Python::with_gil(|py| Self {
provider: self.provider.clone_ref(py), provider: self.provider.clone_ref(py),
}) })
@@ -26,7 +25,6 @@ impl PyHeaderProvider {
/// Get headers from the Python provider (internal implementation) /// Get headers from the Python provider (internal implementation)
fn get_headers_internal(&self) -> Result<HashMap<String, String>, String> { fn get_headers_internal(&self) -> Result<HashMap<String, String>, String> {
#[allow(deprecated)]
Python::with_gil(|py| { Python::with_gil(|py| {
// Call the get_headers method // Call the get_headers method
let result = self.provider.call_method0(py, "get_headers"); let result = self.provider.call_method0(py, "get_headers");

View File

@@ -19,7 +19,7 @@ use pyo3::{
exceptions::PyRuntimeError, exceptions::PyRuntimeError,
pyclass, pymethods, pyclass, pymethods,
types::{PyAnyMethods, PyDict, PyDictMethods, PyType}, types::{PyAnyMethods, PyDict, PyDictMethods, PyType},
Bound, Py, PyAny, PyRef, PyRefMut, PyResult, Python, Bound, PyAny, PyRef, PyRefMut, PyResult, Python,
}; };
use pyo3_async_runtimes::tokio::future_into_py; use pyo3_async_runtimes::tokio::future_into_py;
@@ -281,12 +281,7 @@ impl PyPermutationReader {
let reader = slf.reader.clone(); let reader = slf.reader.clone();
future_into_py(slf.py(), async move { future_into_py(slf.py(), async move {
let schema = reader.output_schema(selection).await.infer_error()?; let schema = reader.output_schema(selection).await.infer_error()?;
#[allow(deprecated)] Python::with_gil(|py| schema.to_pyarrow(py))
let py_obj: Py<PyAny> = Python::with_gil(|py| -> PyResult<Py<PyAny>> {
let bound = schema.to_pyarrow(py)?;
Ok(bound.unbind())
})?;
Ok(py_obj)
}) })
} }

View File

@@ -29,7 +29,6 @@ use pyo3::types::PyList;
use pyo3::types::{PyDict, PyString}; use pyo3::types::{PyDict, PyString};
use pyo3::Bound; use pyo3::Bound;
use pyo3::IntoPyObject; use pyo3::IntoPyObject;
use pyo3::Py;
use pyo3::PyAny; use pyo3::PyAny;
use pyo3::PyRef; use pyo3::PyRef;
use pyo3::PyResult; use pyo3::PyResult;
@@ -454,12 +453,7 @@ impl Query {
let inner = self_.inner.clone(); let inner = self_.inner.clone();
future_into_py(self_.py(), async move { future_into_py(self_.py(), async move {
let schema = inner.output_schema().await.infer_error()?; let schema = inner.output_schema().await.infer_error()?;
#[allow(deprecated)] Python::with_gil(|py| schema.to_pyarrow(py))
let py_obj: Py<PyAny> = Python::with_gil(|py| -> PyResult<Py<PyAny>> {
let bound = schema.to_pyarrow(py)?;
Ok(bound.unbind())
})?;
Ok(py_obj)
}) })
} }
@@ -538,12 +532,7 @@ impl TakeQuery {
let inner = self_.inner.clone(); let inner = self_.inner.clone();
future_into_py(self_.py(), async move { future_into_py(self_.py(), async move {
let schema = inner.output_schema().await.infer_error()?; let schema = inner.output_schema().await.infer_error()?;
#[allow(deprecated)] Python::with_gil(|py| schema.to_pyarrow(py))
let py_obj: Py<PyAny> = Python::with_gil(|py| -> PyResult<Py<PyAny>> {
let bound = schema.to_pyarrow(py)?;
Ok(bound.unbind())
})?;
Ok(py_obj)
}) })
} }
@@ -638,12 +627,7 @@ impl FTSQuery {
let inner = self_.inner.clone(); let inner = self_.inner.clone();
future_into_py(self_.py(), async move { future_into_py(self_.py(), async move {
let schema = inner.output_schema().await.infer_error()?; let schema = inner.output_schema().await.infer_error()?;
#[allow(deprecated)] Python::with_gil(|py| schema.to_pyarrow(py))
let py_obj: Py<PyAny> = Python::with_gil(|py| -> PyResult<Py<PyAny>> {
let bound = schema.to_pyarrow(py)?;
Ok(bound.unbind())
})?;
Ok(py_obj)
}) })
} }
@@ -822,12 +806,7 @@ impl VectorQuery {
let inner = self_.inner.clone(); let inner = self_.inner.clone();
future_into_py(self_.py(), async move { future_into_py(self_.py(), async move {
let schema = inner.output_schema().await.infer_error()?; let schema = inner.output_schema().await.infer_error()?;
#[allow(deprecated)] Python::with_gil(|py| schema.to_pyarrow(py))
let py_obj: Py<PyAny> = Python::with_gil(|py| -> PyResult<Py<PyAny>> {
let bound = schema.to_pyarrow(py)?;
Ok(bound.unbind())
})?;
Ok(py_obj)
}) })
} }

View File

@@ -17,12 +17,11 @@ use pyo3::types::PyDict;
/// Internal wrapper around a Python object implementing StorageOptionsProvider /// Internal wrapper around a Python object implementing StorageOptionsProvider
pub struct PyStorageOptionsProvider { pub struct PyStorageOptionsProvider {
/// The Python object implementing fetch_storage_options() /// The Python object implementing fetch_storage_options()
inner: Py<PyAny>, inner: PyObject,
} }
impl Clone for PyStorageOptionsProvider { impl Clone for PyStorageOptionsProvider {
fn clone(&self) -> Self { fn clone(&self) -> Self {
#[allow(deprecated)]
Python::with_gil(|py| Self { Python::with_gil(|py| Self {
inner: self.inner.clone_ref(py), inner: self.inner.clone_ref(py),
}) })
@@ -30,8 +29,7 @@ impl Clone for PyStorageOptionsProvider {
} }
impl PyStorageOptionsProvider { impl PyStorageOptionsProvider {
pub fn new(obj: Py<PyAny>) -> PyResult<Self> { pub fn new(obj: PyObject) -> PyResult<Self> {
#[allow(deprecated)]
Python::with_gil(|py| { Python::with_gil(|py| {
// Verify the object has a fetch_storage_options method // Verify the object has a fetch_storage_options method
if !obj.bind(py).hasattr("fetch_storage_options")? { if !obj.bind(py).hasattr("fetch_storage_options")? {
@@ -39,9 +37,7 @@ impl PyStorageOptionsProvider {
"StorageOptionsProvider must implement fetch_storage_options() method", "StorageOptionsProvider must implement fetch_storage_options() method",
)); ));
} }
Ok(Self { Ok(Self { inner: obj })
inner: obj.clone_ref(py),
})
}) })
} }
} }
@@ -64,7 +60,6 @@ impl StorageOptionsProvider for PyStorageOptionsProviderWrapper {
let py_provider = self.py_provider.clone(); let py_provider = self.py_provider.clone();
tokio::task::spawn_blocking(move || { tokio::task::spawn_blocking(move || {
#[allow(deprecated)]
Python::with_gil(|py| { Python::with_gil(|py| {
// Call the Python fetch_storage_options method // Call the Python fetch_storage_options method
let result = py_provider let result = py_provider
@@ -124,7 +119,6 @@ impl StorageOptionsProvider for PyStorageOptionsProviderWrapper {
} }
fn provider_id(&self) -> String { fn provider_id(&self) -> String {
#[allow(deprecated)]
Python::with_gil(|py| { Python::with_gil(|py| {
// Call provider_id() method on the Python object // Call provider_id() method on the Python object
let obj = self.py_provider.inner.bind(py); let obj = self.py_provider.inner.bind(py);
@@ -149,7 +143,7 @@ impl std::fmt::Debug for PyStorageOptionsProviderWrapper {
/// This is the main entry point for converting Python StorageOptionsProvider objects /// This is the main entry point for converting Python StorageOptionsProvider objects
/// to Rust trait objects that can be used by the Lance ecosystem. /// to Rust trait objects that can be used by the Lance ecosystem.
pub fn py_object_to_storage_options_provider( pub fn py_object_to_storage_options_provider(
py_obj: Py<PyAny>, py_obj: PyObject,
) -> PyResult<Arc<dyn StorageOptionsProvider>> { ) -> PyResult<Arc<dyn StorageOptionsProvider>> {
let py_provider = PyStorageOptionsProvider::new(py_obj)?; let py_provider = PyStorageOptionsProvider::new(py_obj)?;
Ok(Arc::new(PyStorageOptionsProviderWrapper::new(py_provider))) Ok(Arc::new(PyStorageOptionsProviderWrapper::new(py_provider)))

View File

@@ -21,7 +21,7 @@ use pyo3::{
exceptions::{PyKeyError, PyRuntimeError, PyValueError}, exceptions::{PyKeyError, PyRuntimeError, PyValueError},
pyclass, pymethods, pyclass, pymethods,
types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods}, types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods},
Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python, Bound, FromPyObject, PyAny, PyRef, PyResult, Python,
}; };
use pyo3_async_runtimes::tokio::future_into_py; use pyo3_async_runtimes::tokio::future_into_py;
@@ -287,12 +287,7 @@ impl Table {
let inner = self_.inner_ref()?.clone(); let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move { future_into_py(self_.py(), async move {
let schema = inner.schema().await.infer_error()?; let schema = inner.schema().await.infer_error()?;
#[allow(deprecated)] Python::with_gil(|py| schema.to_pyarrow(py))
let py_obj: Py<PyAny> = Python::with_gil(|py| -> PyResult<Py<PyAny>> {
let bound = schema.to_pyarrow(py)?;
Ok(bound.unbind())
})?;
Ok(py_obj)
}) })
} }
@@ -442,7 +437,6 @@ impl Table {
future_into_py(self_.py(), async move { future_into_py(self_.py(), async move {
let stats = inner.index_stats(&index_name).await.infer_error()?; let stats = inner.index_stats(&index_name).await.infer_error()?;
if let Some(stats) = stats { if let Some(stats) = stats {
#[allow(deprecated)]
Python::with_gil(|py| { Python::with_gil(|py| {
let dict = PyDict::new(py); let dict = PyDict::new(py);
dict.set_item("num_indexed_rows", stats.num_indexed_rows)?; dict.set_item("num_indexed_rows", stats.num_indexed_rows)?;
@@ -473,7 +467,6 @@ impl Table {
let inner = self_.inner_ref()?.clone(); let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move { future_into_py(self_.py(), async move {
let stats = inner.stats().await.infer_error()?; let stats = inner.stats().await.infer_error()?;
#[allow(deprecated)]
Python::with_gil(|py| { Python::with_gil(|py| {
let dict = PyDict::new(py); let dict = PyDict::new(py);
dict.set_item("total_bytes", stats.total_bytes)?; dict.set_item("total_bytes", stats.total_bytes)?;
@@ -528,7 +521,6 @@ impl Table {
let inner = self_.inner_ref()?.clone(); let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move { future_into_py(self_.py(), async move {
let versions = inner.list_versions().await.infer_error()?; let versions = inner.list_versions().await.infer_error()?;
#[allow(deprecated)]
let versions_as_dict = Python::with_gil(|py| { let versions_as_dict = Python::with_gil(|py| {
versions versions
.iter() .iter()
@@ -880,7 +872,6 @@ impl Tags {
let tags = inner.tags().await.infer_error()?; let tags = inner.tags().await.infer_error()?;
let res = tags.list().await.infer_error()?; let res = tags.list().await.infer_error()?;
#[allow(deprecated)]
Python::with_gil(|py| { Python::with_gil(|py| {
let py_dict = PyDict::new(py); let py_dict = PyDict::new(py);
for (key, contents) in res { for (key, contents) in res {

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "lancedb" name = "lancedb"
version = "0.23.1" version = "0.24.1"
edition.workspace = true edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications" description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true license.workspace = true

View File

@@ -36,10 +36,42 @@ use crate::remote::{
}; };
use crate::table::{TableDefinition, WriteOptions}; use crate::table::{TableDefinition, WriteOptions};
use crate::Table; use crate::Table;
use lance::io::ObjectStoreParams;
pub use lance_encoding::version::LanceFileVersion; pub use lance_encoding::version::LanceFileVersion;
#[cfg(feature = "remote")] #[cfg(feature = "remote")]
use lance_io::object_store::StorageOptions; use lance_io::object_store::StorageOptions;
use lance_io::object_store::StorageOptionsProvider; use lance_io::object_store::{StorageOptionsAccessor, StorageOptionsProvider};
fn merge_storage_options(
store_params: &mut ObjectStoreParams,
pairs: impl IntoIterator<Item = (String, String)>,
) {
let mut options = store_params.storage_options().cloned().unwrap_or_default();
for (key, value) in pairs {
options.insert(key, value);
}
let provider = store_params
.storage_options_accessor
.as_ref()
.and_then(|accessor| accessor.provider().cloned());
let accessor = if let Some(provider) = provider {
StorageOptionsAccessor::with_initial_and_provider(options, provider)
} else {
StorageOptionsAccessor::with_static_options(options)
};
store_params.storage_options_accessor = Some(Arc::new(accessor));
}
fn set_storage_options_provider(
store_params: &mut ObjectStoreParams,
provider: Arc<dyn StorageOptionsProvider>,
) {
let accessor = match store_params.storage_options().cloned() {
Some(options) => StorageOptionsAccessor::with_initial_and_provider(options, provider),
None => StorageOptionsAccessor::with_provider(provider),
};
store_params.storage_options_accessor = Some(Arc::new(accessor));
}
/// A builder for configuring a [`Connection::table_names`] operation /// A builder for configuring a [`Connection::table_names`] operation
pub struct TableNamesBuilder { pub struct TableNamesBuilder {
@@ -219,8 +251,36 @@ impl CreateTableBuilder<false> {
/// Execute the create table operation /// Execute the create table operation
pub async fn execute(self) -> Result<Table> { pub async fn execute(self) -> Result<Table> {
let parent = self.parent.clone(); let parent = self.parent.clone();
let table = parent.create_table(self.request).await?; let embedding_registry = self.embedding_registry.clone();
Ok(Table::new(table, parent)) let request = self.into_request()?;
Ok(Table::new_with_embedding_registry(
parent.create_table(request).await?,
parent,
embedding_registry,
))
}
fn into_request(self) -> Result<CreateTableRequest> {
if self.embeddings.is_empty() {
return Ok(self.request);
}
let CreateTableData::Empty(table_def) = self.request.data else {
unreachable!("CreateTableBuilder<false> should always have Empty data")
};
let schema = table_def.schema.clone();
let empty_batch = arrow_array::RecordBatch::new_empty(schema.clone());
let reader = Box::new(std::iter::once(Ok(empty_batch)).collect::<Vec<_>>());
let reader = arrow_array::RecordBatchIterator::new(reader.into_iter(), schema);
let with_embeddings = WithEmbeddings::new(reader, self.embeddings);
let table_definition = with_embeddings.table_definition()?;
Ok(CreateTableRequest {
data: CreateTableData::Empty(table_definition),
..self.request
})
} }
} }
@@ -246,16 +306,14 @@ impl<const HAS_DATA: bool> CreateTableBuilder<HAS_DATA> {
/// ///
/// See available options at <https://lancedb.com/docs/storage/> /// See available options at <https://lancedb.com/docs/storage/>
pub fn storage_option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self { pub fn storage_option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
let store_options = self let store_params = self
.request .request
.write_options .write_options
.lance_write_params .lance_write_params
.get_or_insert(Default::default()) .get_or_insert(Default::default())
.store_params .store_params
.get_or_insert(Default::default())
.storage_options
.get_or_insert(Default::default()); .get_or_insert(Default::default());
store_options.insert(key.into(), value.into()); merge_storage_options(store_params, [(key.into(), value.into())]);
self self
} }
@@ -269,19 +327,17 @@ impl<const HAS_DATA: bool> CreateTableBuilder<HAS_DATA> {
mut self, mut self,
pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>, pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
) -> Self { ) -> Self {
let store_options = self let store_params = self
.request .request
.write_options .write_options
.lance_write_params .lance_write_params
.get_or_insert(Default::default()) .get_or_insert(Default::default())
.store_params .store_params
.get_or_insert(Default::default())
.storage_options
.get_or_insert(Default::default()); .get_or_insert(Default::default());
let updates = pairs
for (key, value) in pairs { .into_iter()
store_options.insert(key.into(), value.into()); .map(|(key, value)| (key.into(), value.into()));
} merge_storage_options(store_params, updates);
self self
} }
@@ -318,23 +374,21 @@ impl<const HAS_DATA: bool> CreateTableBuilder<HAS_DATA> {
/// This has no effect in LanceDB Cloud. /// This has no effect in LanceDB Cloud.
#[deprecated(since = "0.15.1", note = "Use `database_options` instead")] #[deprecated(since = "0.15.1", note = "Use `database_options` instead")]
pub fn enable_v2_manifest_paths(mut self, use_v2_manifest_paths: bool) -> Self { pub fn enable_v2_manifest_paths(mut self, use_v2_manifest_paths: bool) -> Self {
let storage_options = self let store_params = self
.request .request
.write_options .write_options
.lance_write_params .lance_write_params
.get_or_insert_with(Default::default) .get_or_insert_with(Default::default)
.store_params .store_params
.get_or_insert_with(Default::default)
.storage_options
.get_or_insert_with(Default::default); .get_or_insert_with(Default::default);
let value = if use_v2_manifest_paths {
storage_options.insert( "true".to_string()
OPT_NEW_TABLE_V2_MANIFEST_PATHS.to_string(), } else {
if use_v2_manifest_paths { "false".to_string()
"true".to_string() };
} else { merge_storage_options(
"false".to_string() store_params,
}, [(OPT_NEW_TABLE_V2_MANIFEST_PATHS.to_string(), value)],
); );
self self
} }
@@ -344,19 +398,19 @@ impl<const HAS_DATA: bool> CreateTableBuilder<HAS_DATA> {
/// The default is `LanceFileVersion::Stable`. /// The default is `LanceFileVersion::Stable`.
#[deprecated(since = "0.15.1", note = "Use `database_options` instead")] #[deprecated(since = "0.15.1", note = "Use `database_options` instead")]
pub fn data_storage_version(mut self, data_storage_version: LanceFileVersion) -> Self { pub fn data_storage_version(mut self, data_storage_version: LanceFileVersion) -> Self {
let storage_options = self let store_params = self
.request .request
.write_options .write_options
.lance_write_params .lance_write_params
.get_or_insert_with(Default::default) .get_or_insert_with(Default::default)
.store_params .store_params
.get_or_insert_with(Default::default)
.storage_options
.get_or_insert_with(Default::default); .get_or_insert_with(Default::default);
merge_storage_options(
storage_options.insert( store_params,
OPT_NEW_TABLE_STORAGE_VERSION.to_string(), [(
data_storage_version.to_string(), OPT_NEW_TABLE_STORAGE_VERSION.to_string(),
data_storage_version.to_string(),
)],
); );
self self
} }
@@ -381,13 +435,14 @@ impl<const HAS_DATA: bool> CreateTableBuilder<HAS_DATA> {
/// This allows tables to automatically refresh cloud storage credentials /// This allows tables to automatically refresh cloud storage credentials
/// when they expire, enabling long-running operations on remote storage. /// when they expire, enabling long-running operations on remote storage.
pub fn storage_options_provider(mut self, provider: Arc<dyn StorageOptionsProvider>) -> Self { pub fn storage_options_provider(mut self, provider: Arc<dyn StorageOptionsProvider>) -> Self {
self.request let store_params = self
.request
.write_options .write_options
.lance_write_params .lance_write_params
.get_or_insert(Default::default()) .get_or_insert(Default::default())
.store_params .store_params
.get_or_insert(Default::default()) .get_or_insert(Default::default());
.storage_options_provider = Some(provider); set_storage_options_provider(store_params, provider);
self self
} }
} }
@@ -450,15 +505,13 @@ impl OpenTableBuilder {
/// ///
/// See available options at <https://lancedb.com/docs/storage/> /// See available options at <https://lancedb.com/docs/storage/>
pub fn storage_option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self { pub fn storage_option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
let storage_options = self let store_params = self
.request .request
.lance_read_params .lance_read_params
.get_or_insert(Default::default()) .get_or_insert(Default::default())
.store_options .store_options
.get_or_insert(Default::default())
.storage_options
.get_or_insert(Default::default()); .get_or_insert(Default::default());
storage_options.insert(key.into(), value.into()); merge_storage_options(store_params, [(key.into(), value.into())]);
self self
} }
@@ -472,18 +525,16 @@ impl OpenTableBuilder {
mut self, mut self,
pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>, pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
) -> Self { ) -> Self {
let storage_options = self let store_params = self
.request .request
.lance_read_params .lance_read_params
.get_or_insert(Default::default()) .get_or_insert(Default::default())
.store_options .store_options
.get_or_insert(Default::default())
.storage_options
.get_or_insert(Default::default()); .get_or_insert(Default::default());
let updates = pairs
for (key, value) in pairs { .into_iter()
storage_options.insert(key.into(), value.into()); .map(|(key, value)| (key.into(), value.into()));
} merge_storage_options(store_params, updates);
self self
} }
@@ -507,12 +558,13 @@ impl OpenTableBuilder {
/// This allows tables to automatically refresh cloud storage credentials /// This allows tables to automatically refresh cloud storage credentials
/// when they expire, enabling long-running operations on remote storage. /// when they expire, enabling long-running operations on remote storage.
pub fn storage_options_provider(mut self, provider: Arc<dyn StorageOptionsProvider>) -> Self { pub fn storage_options_provider(mut self, provider: Arc<dyn StorageOptionsProvider>) -> Self {
self.request let store_params = self
.request
.lance_read_params .lance_read_params
.get_or_insert(Default::default()) .get_or_insert(Default::default())
.store_options .store_options
.get_or_insert(Default::default()) .get_or_insert(Default::default());
.storage_options_provider = Some(provider); set_storage_options_provider(store_params, provider);
self self
} }
@@ -868,6 +920,10 @@ pub struct ConnectBuilder {
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>, embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
} }
#[cfg(feature = "remote")]
const ENV_VARS_TO_STORAGE_OPTS: [(&str, &str); 1] =
[("AZURE_STORAGE_ACCOUNT_NAME", "azure_storage_account_name")];
impl ConnectBuilder { impl ConnectBuilder {
/// Create a new [`ConnectOptions`] with the given database URI. /// Create a new [`ConnectOptions`] with the given database URI.
pub fn new(uri: &str) -> Self { pub fn new(uri: &str) -> Self {
@@ -1051,11 +1107,27 @@ impl ConnectBuilder {
self self
} }
#[cfg(feature = "remote")]
fn apply_env_defaults(
env_var_to_remote_storage_option: &[(&str, &str)],
options: &mut HashMap<String, String>,
) {
for (env_key, opt_key) in env_var_to_remote_storage_option {
if let Ok(env_value) = std::env::var(env_key) {
if !options.contains_key(*opt_key) {
options.insert((*opt_key).to_string(), env_value);
}
}
}
}
#[cfg(feature = "remote")] #[cfg(feature = "remote")]
fn execute_remote(self) -> Result<Connection> { fn execute_remote(self) -> Result<Connection> {
use crate::remote::db::RemoteDatabaseOptions; use crate::remote::db::RemoteDatabaseOptions;
let options = RemoteDatabaseOptions::parse_from_map(&self.request.options)?; let mut merged_options = self.request.options.clone();
Self::apply_env_defaults(&ENV_VARS_TO_STORAGE_OPTS, &mut merged_options);
let options = RemoteDatabaseOptions::parse_from_map(&merged_options)?;
let region = options.region.ok_or_else(|| Error::InvalidInput { let region = options.region.ok_or_else(|| Error::InvalidInput {
message: "A region is required when connecting to LanceDb Cloud".to_string(), message: "A region is required when connecting to LanceDb Cloud".to_string(),
@@ -1277,8 +1349,6 @@ mod test_utils {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::fs::create_dir_all;
use crate::database::listing::{ListingDatabaseOptions, NewTableConfig}; use crate::database::listing::{ListingDatabaseOptions, NewTableConfig};
use crate::query::QueryBase; use crate::query::QueryBase;
use crate::query::{ExecutableQuery, QueryExecutionOptions}; use crate::query::{ExecutableQuery, QueryExecutionOptions};
@@ -1302,6 +1372,23 @@ mod tests {
assert_eq!(tc.connection.uri(), tc.uri); assert_eq!(tc.connection.uri(), tc.uri);
} }
#[cfg(feature = "remote")]
#[test]
fn test_apply_env_defaults() {
let env_key = "TEST_APPLY_ENV_DEFAULTS_ENVIRONMENT_VARIABLE_ENV_KEY";
let env_val = "TEST_APPLY_ENV_DEFAULTS_ENVIRONMENT_VARIABLE_ENV_VAL";
let opts_key = "test_apply_env_defaults_environment_variable_opts_key";
std::env::set_var(env_key, env_val);
let mut options = HashMap::new();
ConnectBuilder::apply_env_defaults(&[(env_key, opts_key)], &mut options);
assert_eq!(Some(&env_val.to_string()), options.get(opts_key));
options.insert(opts_key.to_string(), "EXPLICIT-VALUE".to_string());
ConnectBuilder::apply_env_defaults(&[(env_key, opts_key)], &mut options);
assert_eq!(Some(&"EXPLICIT-VALUE".to_string()), options.get(opts_key));
}
#[cfg(not(windows))] #[cfg(not(windows))]
#[tokio::test] #[tokio::test]
async fn test_connect_relative() { async fn test_connect_relative() {
@@ -1526,18 +1613,27 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn drop_table() { async fn drop_table() {
let tmp_dir = tempdir().unwrap(); let tc = new_test_connection().await.unwrap();
let db = tc.connection;
let uri = tmp_dir.path().to_str().unwrap(); if tc.is_remote {
let db = connect(uri).execute().await.unwrap(); // All the typical endpoints such as s3:///, file-object-store:///, etc. treat drop_table
// as idempotent.
assert!(db.drop_table("invalid_table", &[]).await.is_ok());
} else {
// The behavior of drop_table when using a file:/// endpoint differs from all other
// object providers, in that it returns an error when deleting a non-existent table.
assert!(matches!(
db.drop_table("invalid_table", &[]).await,
Err(crate::Error::TableNotFound { .. }),
));
}
// drop non-exist table let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)]));
assert!(matches!( db.create_empty_table("table1", schema.clone())
db.drop_table("invalid_table", &[]).await, .execute()
Err(crate::Error::TableNotFound { .. }), .await
)); .unwrap();
create_dir_all(tmp_dir.path().join("table1.lance")).unwrap();
db.drop_table("table1", &[]).await.unwrap(); db.drop_table("table1", &[]).await.unwrap();
let tables = db.table_names().execute().await.unwrap(); let tables = db.table_names().execute().await.unwrap();
@@ -1624,4 +1720,128 @@ mod tests {
let cloned_count = cloned_table.count_rows(None).await.unwrap(); let cloned_count = cloned_table.count_rows(None).await.unwrap();
assert_eq!(source_count, cloned_count); assert_eq!(source_count, cloned_count);
} }
#[tokio::test]
async fn test_create_empty_table_with_embeddings() {
use crate::embeddings::{EmbeddingDefinition, EmbeddingFunction};
use arrow_array::{
Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, StringArray,
};
use std::borrow::Cow;
#[derive(Debug, Clone)]
struct MockEmbedding {
dim: usize,
}
impl EmbeddingFunction for MockEmbedding {
fn name(&self) -> &str {
"test_embedding"
}
fn source_type(&self) -> Result<Cow<'_, DataType>> {
Ok(Cow::Owned(DataType::Utf8))
}
fn dest_type(&self) -> Result<Cow<'_, DataType>> {
Ok(Cow::Owned(DataType::new_fixed_size_list(
DataType::Float32,
self.dim as i32,
true,
)))
}
fn compute_source_embeddings(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
let len = source.len();
let values = vec![1.0f32; len * self.dim];
let values = Arc::new(Float32Array::from(values));
let field = Arc::new(Field::new("item", DataType::Float32, true));
Ok(Arc::new(FixedSizeListArray::new(
field,
self.dim as i32,
values,
None,
)))
}
fn compute_query_embeddings(&self, _input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
unimplemented!()
}
}
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = connect(uri).execute().await.unwrap();
let embed_func = Arc::new(MockEmbedding { dim: 128 });
db.embedding_registry()
.register("test_embedding", embed_func.clone())
.unwrap();
let schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)]));
let ed = EmbeddingDefinition {
source_column: "name".to_owned(),
dest_column: Some("name_embedding".to_owned()),
embedding_name: "test_embedding".to_owned(),
};
let table = db
.create_empty_table("test", schema)
.mode(CreateTableMode::Overwrite)
.add_embedding(ed)
.unwrap()
.execute()
.await
.unwrap();
let table_schema = table.schema().await.unwrap();
assert!(table_schema.column_with_name("name").is_some());
assert!(table_schema.column_with_name("name_embedding").is_some());
let embedding_field = table_schema.field_with_name("name_embedding").unwrap();
assert_eq!(
embedding_field.data_type(),
&DataType::new_fixed_size_list(DataType::Float32, 128, true)
);
let input_schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)]));
let input_batch = RecordBatch::try_new(
input_schema.clone(),
vec![Arc::new(StringArray::from(vec![
Some("Alice"),
Some("Bob"),
Some("Charlie"),
]))],
)
.unwrap();
let input_reader = Box::new(RecordBatchIterator::new(
vec![Ok(input_batch)].into_iter(),
input_schema,
));
table.add(input_reader).execute().await.unwrap();
let results = table
.query()
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(results.len(), 1);
let batch = &results[0];
assert_eq!(batch.num_rows(), 3);
assert!(batch.column_by_name("name_embedding").is_some());
let embedding_col = batch
.column_by_name("name_embedding")
.unwrap()
.as_any()
.downcast_ref::<FixedSizeListArray>()
.unwrap();
assert_eq!(embedding_col.len(), 3);
}
} }

View File

@@ -12,7 +12,7 @@ use lance::dataset::{builder::DatasetBuilder, ReadParams, WriteMode};
use lance::io::{ObjectStore, ObjectStoreParams, WrappingObjectStore}; use lance::io::{ObjectStore, ObjectStoreParams, WrappingObjectStore};
use lance_datafusion::utils::StreamingWriteSource; use lance_datafusion::utils::StreamingWriteSource;
use lance_encoding::version::LanceFileVersion; use lance_encoding::version::LanceFileVersion;
use lance_io::object_store::StorageOptionsProvider; use lance_io::object_store::{StorageOptionsAccessor, StorageOptionsProvider};
use lance_table::io::commit::commit_handler_from_url; use lance_table::io::commit::commit_handler_from_url;
use object_store::local::LocalFileSystem; use object_store::local::LocalFileSystem;
use snafu::ResultExt; use snafu::ResultExt;
@@ -356,7 +356,13 @@ impl ListingDatabase {
.clone() .clone()
.unwrap_or_else(|| Arc::new(lance::session::Session::default())); .unwrap_or_else(|| Arc::new(lance::session::Session::default()));
let os_params = ObjectStoreParams { let os_params = ObjectStoreParams {
storage_options: Some(options.storage_options.clone()), storage_options_accessor: if options.storage_options.is_empty() {
None
} else {
Some(Arc::new(StorageOptionsAccessor::with_static_options(
options.storage_options.clone(),
)))
},
..Default::default() ..Default::default()
}; };
let (object_store, base_path) = ObjectStore::from_uri_and_params( let (object_store, base_path) = ObjectStore::from_uri_and_params(
@@ -492,7 +498,13 @@ impl ListingDatabase {
async fn drop_tables(&self, names: Vec<String>) -> Result<()> { async fn drop_tables(&self, names: Vec<String>) -> Result<()> {
let object_store_params = ObjectStoreParams { let object_store_params = ObjectStoreParams {
storage_options: Some(self.storage_options.clone()), storage_options_accessor: if self.storage_options.is_empty() {
None
} else {
Some(Arc::new(StorageOptionsAccessor::with_static_options(
self.storage_options.clone(),
)))
},
..Default::default() ..Default::default()
}; };
let mut uri = self.uri.clone(); let mut uri = self.uri.clone();
@@ -541,7 +553,7 @@ impl ListingDatabase {
.lance_write_params .lance_write_params
.as_ref() .as_ref()
.and_then(|p| p.store_params.as_ref()) .and_then(|p| p.store_params.as_ref())
.and_then(|sp| sp.storage_options.as_ref()); .and_then(|sp| sp.storage_options());
let storage_version_override = storage_options let storage_version_override = storage_options
.and_then(|opts| opts.get(OPT_NEW_TABLE_STORAGE_VERSION)) .and_then(|opts| opts.get(OPT_NEW_TABLE_STORAGE_VERSION))
@@ -592,21 +604,20 @@ impl ListingDatabase {
// will cause a new connection to be created, and that connection will // will cause a new connection to be created, and that connection will
// be dropped from the cache when python GCs the table object, which // be dropped from the cache when python GCs the table object, which
// confounds reuse across tables. // confounds reuse across tables.
if !self.storage_options.is_empty() { if !self.storage_options.is_empty() || self.storage_options_provider.is_some() {
let storage_options = write_params let store_params = write_params
.store_params .store_params
.get_or_insert_with(Default::default)
.storage_options
.get_or_insert_with(Default::default); .get_or_insert_with(Default::default);
self.inherit_storage_options(storage_options); let mut storage_options = store_params.storage_options().cloned().unwrap_or_default();
} if !self.storage_options.is_empty() {
self.inherit_storage_options(&mut storage_options);
// Set storage options provider if available }
if self.storage_options_provider.is_some() { let accessor = if let Some(ref provider) = self.storage_options_provider {
write_params StorageOptionsAccessor::with_initial_and_provider(storage_options, provider.clone())
.store_params } else {
.get_or_insert_with(Default::default) StorageOptionsAccessor::with_static_options(storage_options)
.storage_options_provider = self.storage_options_provider.clone(); };
store_params.storage_options_accessor = Some(Arc::new(accessor));
} }
write_params.data_storage_version = self write_params.data_storage_version = self
@@ -892,7 +903,13 @@ impl Database for ListingDatabase {
validate_table_name(&request.target_table_name)?; validate_table_name(&request.target_table_name)?;
let storage_params = ObjectStoreParams { let storage_params = ObjectStoreParams {
storage_options: Some(self.storage_options.clone()), storage_options_accessor: if self.storage_options.is_empty() {
None
} else {
Some(Arc::new(StorageOptionsAccessor::with_static_options(
self.storage_options.clone(),
)))
},
..Default::default() ..Default::default()
}; };
let read_params = ReadParams { let read_params = ReadParams {
@@ -956,25 +973,28 @@ impl Database for ListingDatabase {
// will cause a new connection to be created, and that connection will // will cause a new connection to be created, and that connection will
// be dropped from the cache when python GCs the table object, which // be dropped from the cache when python GCs the table object, which
// confounds reuse across tables. // confounds reuse across tables.
if !self.storage_options.is_empty() { if !self.storage_options.is_empty() || self.storage_options_provider.is_some() {
let storage_options = request let store_params = request
.lance_read_params .lance_read_params
.get_or_insert_with(Default::default) .get_or_insert_with(Default::default)
.store_options .store_options
.get_or_insert_with(Default::default)
.storage_options
.get_or_insert_with(Default::default); .get_or_insert_with(Default::default);
self.inherit_storage_options(storage_options); let mut storage_options = store_params.storage_options().cloned().unwrap_or_default();
} if !self.storage_options.is_empty() {
self.inherit_storage_options(&mut storage_options);
// Set storage options provider if available }
if self.storage_options_provider.is_some() { // Preserve request-level provider if no connection-level provider exists
request let request_provider = store_params
.lance_read_params .storage_options_accessor
.get_or_insert_with(Default::default) .as_ref()
.store_options .and_then(|a| a.provider().cloned());
.get_or_insert_with(Default::default) let provider = self.storage_options_provider.clone().or(request_provider);
.storage_options_provider = self.storage_options_provider.clone(); let accessor = if let Some(provider) = provider {
StorageOptionsAccessor::with_initial_and_provider(storage_options, provider)
} else {
StorageOptionsAccessor::with_static_options(storage_options)
};
store_params.storage_options_accessor = Some(Arc::new(accessor));
} }
// Some ReadParams are exposed in the OpenTableBuilder, but we also // Some ReadParams are exposed in the OpenTableBuilder, but we also
@@ -1881,7 +1901,9 @@ mod tests {
let write_options = WriteOptions { let write_options = WriteOptions {
lance_write_params: Some(lance::dataset::WriteParams { lance_write_params: Some(lance::dataset::WriteParams {
store_params: Some(lance::io::ObjectStoreParams { store_params: Some(lance::io::ObjectStoreParams {
storage_options: Some(storage_options), storage_options_accessor: Some(Arc::new(
StorageOptionsAccessor::with_static_options(storage_options),
)),
..Default::default() ..Default::default()
}), }),
..Default::default() ..Default::default()
@@ -1955,7 +1977,9 @@ mod tests {
let write_options = WriteOptions { let write_options = WriteOptions {
lance_write_params: Some(lance::dataset::WriteParams { lance_write_params: Some(lance::dataset::WriteParams {
store_params: Some(lance::io::ObjectStoreParams { store_params: Some(lance::io::ObjectStoreParams {
storage_options: Some(storage_options), storage_options_accessor: Some(Arc::new(
StorageOptionsAccessor::with_static_options(storage_options),
)),
..Default::default() ..Default::default()
}), }),
..Default::default() ..Default::default()

View File

@@ -9,14 +9,15 @@ use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use lance_namespace::{ use lance_namespace::{
models::{ models::{
CreateNamespaceRequest, CreateNamespaceResponse, DeclareTableRequest, CreateEmptyTableRequest, CreateNamespaceRequest, CreateNamespaceResponse,
DescribeNamespaceRequest, DescribeNamespaceResponse, DescribeTableRequest, DeclareTableRequest, DescribeNamespaceRequest, DescribeNamespaceResponse,
DropNamespaceRequest, DropNamespaceResponse, DropTableRequest, ListNamespacesRequest, DescribeTableRequest, DropNamespaceRequest, DropNamespaceResponse, DropTableRequest,
ListNamespacesResponse, ListTablesRequest, ListTablesResponse, ListNamespacesRequest, ListNamespacesResponse, ListTablesRequest, ListTablesResponse,
}, },
LanceNamespace, LanceNamespace,
}; };
use lance_namespace_impls::ConnectBuilder; use lance_namespace_impls::ConnectBuilder;
use log::warn;
use crate::database::ReadConsistency; use crate::database::ReadConsistency;
use crate::error::{Error, Result}; use crate::error::{Error, Result};
@@ -154,7 +155,6 @@ impl Database for LanceNamespaceDatabase {
table_id.push(request.name.clone()); table_id.push(request.name.clone());
let describe_request = DescribeTableRequest { let describe_request = DescribeTableRequest {
id: Some(table_id.clone()), id: Some(table_id.clone()),
version: None,
..Default::default() ..Default::default()
}; };
@@ -205,26 +205,53 @@ impl Database for LanceNamespaceDatabase {
let mut table_id = request.namespace.clone(); let mut table_id = request.namespace.clone();
table_id.push(request.name.clone()); table_id.push(request.name.clone());
let create_empty_request = DeclareTableRequest { // Try declare_table first, falling back to create_empty_table for backwards
// compatibility with older namespace clients that don't support declare_table
let declare_request = DeclareTableRequest {
id: Some(table_id.clone()), id: Some(table_id.clone()),
location: None,
vend_credentials: None,
..Default::default() ..Default::default()
}; };
let create_empty_response = self let location = match self.namespace.declare_table(declare_request).await {
.namespace Ok(response) => response.location.ok_or_else(|| Error::Runtime {
.declare_table(create_empty_request) message: "Table location is missing from declare_table response".to_string(),
.await })?,
.map_err(|e| Error::Runtime { Err(e) => {
message: format!("Failed to declare table: {}", e), // Check if the error is "not supported" and try create_empty_table as fallback
})?; let err_str = e.to_string().to_lowercase();
if err_str.contains("not supported") || err_str.contains("not implemented") {
warn!(
"declare_table is not supported by the namespace client, \
falling back to deprecated create_empty_table. \
create_empty_table is deprecated and will be removed in Lance 3.0.0. \
Please upgrade your namespace client to support declare_table."
);
#[allow(deprecated)]
let create_empty_request = CreateEmptyTableRequest {
id: Some(table_id.clone()),
..Default::default()
};
let location = create_empty_response #[allow(deprecated)]
.location let create_response = self
.ok_or_else(|| Error::Runtime { .namespace
message: "Table location is missing from create_empty_table response".to_string(), .create_empty_table(create_empty_request)
})?; .await
.map_err(|e| Error::Runtime {
message: format!("Failed to create empty table: {}", e),
})?;
create_response.location.ok_or_else(|| Error::Runtime {
message: "Table location is missing from create_empty_table response"
.to_string(),
})?
} else {
return Err(Error::Runtime {
message: format!("Failed to declare table: {}", e),
});
}
}
};
let native_table = NativeTable::create_from_namespace( let native_table = NativeTable::create_from_namespace(
self.namespace.clone(), self.namespace.clone(),
@@ -439,8 +466,6 @@ mod tests {
// Create a child namespace first // Create a child namespace first
conn.create_namespace(CreateNamespaceRequest { conn.create_namespace(CreateNamespaceRequest {
id: Some(vec!["test_ns".into()]), id: Some(vec!["test_ns".into()]),
mode: None,
properties: None,
..Default::default() ..Default::default()
}) })
.await .await
@@ -501,8 +526,6 @@ mod tests {
// Create a child namespace first // Create a child namespace first
conn.create_namespace(CreateNamespaceRequest { conn.create_namespace(CreateNamespaceRequest {
id: Some(vec!["test_ns".into()]), id: Some(vec!["test_ns".into()]),
mode: None,
properties: None,
..Default::default() ..Default::default()
}) })
.await .await
@@ -566,8 +589,6 @@ mod tests {
// Create a child namespace first // Create a child namespace first
conn.create_namespace(CreateNamespaceRequest { conn.create_namespace(CreateNamespaceRequest {
id: Some(vec!["test_ns".into()]), id: Some(vec!["test_ns".into()]),
mode: None,
properties: None,
..Default::default() ..Default::default()
}) })
.await .await
@@ -651,8 +672,6 @@ mod tests {
// Create a child namespace first // Create a child namespace first
conn.create_namespace(CreateNamespaceRequest { conn.create_namespace(CreateNamespaceRequest {
id: Some(vec!["test_ns".into()]), id: Some(vec!["test_ns".into()]),
mode: None,
properties: None,
..Default::default() ..Default::default()
}) })
.await .await
@@ -708,8 +727,6 @@ mod tests {
// Create a child namespace first // Create a child namespace first
conn.create_namespace(CreateNamespaceRequest { conn.create_namespace(CreateNamespaceRequest {
id: Some(vec!["test_ns".into()]), id: Some(vec!["test_ns".into()]),
mode: None,
properties: None,
..Default::default() ..Default::default()
}) })
.await .await
@@ -790,8 +807,6 @@ mod tests {
// Create a child namespace first // Create a child namespace first
conn.create_namespace(CreateNamespaceRequest { conn.create_namespace(CreateNamespaceRequest {
id: Some(vec!["test_ns".into()]), id: Some(vec!["test_ns".into()]),
mode: None,
properties: None,
..Default::default() ..Default::default()
}) })
.await .await
@@ -825,8 +840,6 @@ mod tests {
// Create a child namespace first // Create a child namespace first
conn.create_namespace(CreateNamespaceRequest { conn.create_namespace(CreateNamespaceRequest {
id: Some(vec!["test_ns".into()]), id: Some(vec!["test_ns".into()]),
mode: None,
properties: None,
..Default::default() ..Default::default()
}) })
.await .await

View File

@@ -19,7 +19,7 @@ use crate::{
split::{SplitStrategy, Splitter, SPLIT_ID_COLUMN}, split::{SplitStrategy, Splitter, SPLIT_ID_COLUMN},
util::{rename_column, TemporaryDirectory}, util::{rename_column, TemporaryDirectory},
}, },
query::{ExecutableQuery, QueryBase}, query::{ExecutableQuery, QueryBase, Select},
Error, Result, Table, Error, Result, Table,
}; };
@@ -27,6 +27,8 @@ pub const SRC_ROW_ID_COL: &str = "row_id";
pub const SPLIT_NAMES_CONFIG_KEY: &str = "split_names"; pub const SPLIT_NAMES_CONFIG_KEY: &str = "split_names";
pub const DEFAULT_MEMORY_LIMIT: usize = 100 * 1024 * 1024;
/// Where to store the permutation table /// Where to store the permutation table
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
enum PermutationDestination { enum PermutationDestination {
@@ -167,10 +169,20 @@ impl PermutationBuilder {
&self, &self,
data: SendableRecordBatchStream, data: SendableRecordBatchStream,
) -> Result<SendableRecordBatchStream> { ) -> Result<SendableRecordBatchStream> {
let memory_limit = std::env::var("LANCEDB_PERM_BUILDER_MEMORY_LIMIT")
.unwrap_or_else(|_| DEFAULT_MEMORY_LIMIT.to_string())
.parse::<usize>()
.unwrap_or_else(|_| {
log::error!(
"Failed to parse LANCEDB_PERM_BUILDER_MEMORY_LIMIT, using default: {}",
DEFAULT_MEMORY_LIMIT
);
DEFAULT_MEMORY_LIMIT
});
let ctx = SessionContext::new_with_config_rt( let ctx = SessionContext::new_with_config_rt(
SessionConfig::default(), SessionConfig::default(),
RuntimeEnvBuilder::new() RuntimeEnvBuilder::new()
.with_memory_limit(100 * 1024 * 1024, 1.0) .with_memory_limit(memory_limit, 1.0)
.with_disk_manager_builder( .with_disk_manager_builder(
DiskManagerBuilder::default() DiskManagerBuilder::default()
.with_mode(self.config.temp_dir.to_disk_manager_mode()), .with_mode(self.config.temp_dir.to_disk_manager_mode()),
@@ -232,7 +244,7 @@ impl PermutationBuilder {
/// Builds the permutation table and stores it in the given database. /// Builds the permutation table and stores it in the given database.
pub async fn build(self) -> Result<Table> { pub async fn build(self) -> Result<Table> {
// First pass, apply filter and load row ids // First pass, apply filter and load row ids
let mut rows = self.base_table.query().with_row_id(); let mut rows = self.base_table.query().select(Select::columns(&[ROW_ID]));
if let Some(filter) = &self.config.filter { if let Some(filter) = &self.config.filter {
rows = rows.only_if(filter); rows = rows.only_if(filter);
@@ -321,6 +333,47 @@ mod tests {
use super::*; use super::*;
#[tokio::test]
async fn test_permutation_table_only_stores_row_id_and_split_id() {
let temp_dir = tempfile::tempdir().unwrap();
let db = connect(temp_dir.path().to_str().unwrap())
.execute()
.await
.unwrap();
let initial_data = lance_datagen::gen_batch()
.col("col_a", lance_datagen::array::step::<Int32Type>())
.col("col_b", lance_datagen::array::step::<Int32Type>())
.into_ldb_stream(RowCount::from(100), BatchCount::from(10));
let data_table = db
.create_table_streaming("base_tbl", initial_data)
.execute()
.await
.unwrap();
let permutation_table = PermutationBuilder::new(data_table.clone())
.with_split_strategy(
SplitStrategy::Sequential {
sizes: SplitSizes::Percentages(vec![0.5, 0.5]),
},
None,
)
.with_filter("col_a > 57".to_string())
.build()
.await
.unwrap();
let schema = permutation_table.schema().await.unwrap();
let field_names: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect();
assert_eq!(
field_names,
vec!["row_id", "split_id"],
"Permutation table should only contain row_id and split_id columns, but found: {:?}",
field_names,
);
}
#[tokio::test] #[tokio::test]
async fn test_permutation_builder() { async fn test_permutation_builder() {
let temp_dir = tempfile::tempdir().unwrap(); let temp_dir = tempfile::tempdir().unwrap();
@@ -352,8 +405,6 @@ mod tests {
.await .await
.unwrap(); .unwrap();
println!("permutation_table: {:?}", permutation_table);
// Potentially brittle seed-dependent values below // Potentially brittle seed-dependent values below
assert_eq!(permutation_table.count_rows(None).await.unwrap(), 330); assert_eq!(permutation_table.count_rows(None).await.unwrap(), 330);
assert_eq!( assert_eq!(

View File

@@ -171,7 +171,7 @@ impl Shuffler {
// This is kind of an annoying limitation but if we allow runt clumps from batches then // This is kind of an annoying limitation but if we allow runt clumps from batches then
// clumps will get unaligned and we will mess up the clumps when we do the in-memory // clumps will get unaligned and we will mess up the clumps when we do the in-memory
// shuffle step. If this is a problem we can probably figure out a better way to do this. // shuffle step. If this is a problem we can probably figure out a better way to do this.
if !is_last && batch.num_rows() as u64 % clump_size != 0 { if !is_last && !(batch.num_rows() as u64).is_multiple_of(clump_size) {
return Err(Error::Runtime { return Err(Error::Runtime {
message: format!( message: format!(
"Expected batch size ({}) to be divisible by clump size ({})", "Expected batch size ({}) to be divisible by clump size ({})",

View File

@@ -1,12 +1,9 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors // SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::{ use std::sync::{
iter, atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
sync::{ Arc,
atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
Arc,
},
}; };
use arrow_array::{Array, BooleanArray, RecordBatch, UInt64Array}; use arrow_array::{Array, BooleanArray, RecordBatch, UInt64Array};
@@ -15,6 +12,8 @@ use datafusion_common::hash_utils::create_hashes;
use futures::{StreamExt, TryStreamExt}; use futures::{StreamExt, TryStreamExt};
use lance_arrow::SchemaExt; use lance_arrow::SchemaExt;
use lance_core::ROW_ID;
use crate::{ use crate::{
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream}, arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
dataloader::{ dataloader::{
@@ -158,7 +157,7 @@ impl Splitter {
remaining_in_split remaining_in_split
}; };
split_ids.extend(iter::repeat(split_id as u64).take(rows_to_add as usize)); split_ids.extend(std::iter::repeat_n(split_id as u64, rows_to_add as usize));
if done { if done {
// Quit early if we've run out of splits // Quit early if we've run out of splits
break; break;
@@ -363,11 +362,15 @@ impl Splitter {
pub fn project(&self, query: Query) -> Query { pub fn project(&self, query: Query) -> Query {
match &self.strategy { match &self.strategy {
SplitStrategy::Calculated { calculation } => query.select(Select::Dynamic(vec![( SplitStrategy::Calculated { calculation } => query.select(Select::Dynamic(vec![
SPLIT_ID_COLUMN.to_string(), (SPLIT_ID_COLUMN.to_string(), calculation.clone()),
calculation.clone(), (ROW_ID.to_string(), ROW_ID.to_string()),
)])), ])),
SplitStrategy::Hash { columns, .. } => query.select(Select::Columns(columns.clone())), SplitStrategy::Hash { columns, .. } => {
let mut cols = columns.clone();
cols.push(ROW_ID.to_string());
query.select(Select::Columns(cols))
}
_ => query, _ => query,
} }
} }
@@ -662,7 +665,7 @@ mod tests {
assert_eq!(split_batch.num_rows(), total_split_sizes as usize); assert_eq!(split_batch.num_rows(), total_split_sizes as usize);
let mut expected = Vec::with_capacity(total_split_sizes as usize); let mut expected = Vec::with_capacity(total_split_sizes as usize);
for (i, size) in expected_split_sizes.iter().enumerate() { for (i, size) in expected_split_sizes.iter().enumerate() {
expected.extend(iter::repeat(i as u64).take(*size as usize)); expected.extend(std::iter::repeat_n(i as u64, *size as usize));
} }
let expected = Arc::new(UInt64Array::from(expected)) as Arc<dyn Array>; let expected = Arc::new(UInt64Array::from(expected)) as Arc<dyn Array>;

View File

@@ -297,10 +297,10 @@ impl IvfPqIndexBuilder {
} }
pub(crate) fn suggested_num_sub_vectors(dim: u32) -> u32 { pub(crate) fn suggested_num_sub_vectors(dim: u32) -> u32 {
if dim % 16 == 0 { if dim.is_multiple_of(16) {
// Should be more aggressive than this default. // Should be more aggressive than this default.
dim / 16 dim / 16
} else if dim % 8 == 0 { } else if dim.is_multiple_of(8) {
dim / 8 dim / 8
} else { } else {
log::warn!( log::warn!(

View File

@@ -51,24 +51,19 @@
//! - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud object store //! - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud object store
//! - `db://dbname` - Lance Cloud //! - `db://dbname` - Lance Cloud
//! //!
//! You can also use [`ConnectOptions`] to configure the connection to the database. //! You can also use [`ConnectBuilder`] to configure the connection to the database.
//! //!
//! ```rust //! ```rust
//! # #[cfg(feature = "aws")]
//! # {
//! use object_store::aws::AwsCredential;
//! # tokio::runtime::Runtime::new().unwrap().block_on(async { //! # tokio::runtime::Runtime::new().unwrap().block_on(async {
//! let db = lancedb::connect("data/sample-lancedb") //! let db = lancedb::connect("data/sample-lancedb")
//! .aws_creds(AwsCredential { //! .storage_options([
//! key_id: "some_key".to_string(), //! ("aws_access_key_id", "some_key"),
//! secret_key: "some_secret".to_string(), //! ("aws_secret_access_key", "some_secret"),
//! token: None, //! ])
//! })
//! .execute() //! .execute()
//! .await //! .await
//! .unwrap(); //! .unwrap();
//! # }); //! # });
//! # }
//! ``` //! ```
//! //!
//! LanceDB uses [arrow-rs](https://github.com/apache/arrow-rs) to define schema, data types and array itself. //! LanceDB uses [arrow-rs](https://github.com/apache/arrow-rs) to define schema, data types and array itself.

View File

@@ -1718,8 +1718,6 @@ mod tests {
let namespace = vec!["test_ns".to_string()]; let namespace = vec!["test_ns".to_string()];
conn.create_namespace(CreateNamespaceRequest { conn.create_namespace(CreateNamespaceRequest {
id: Some(namespace.clone()), id: Some(namespace.clone()),
mode: None,
properties: None,
..Default::default() ..Default::default()
}) })
.await .await
@@ -1745,8 +1743,6 @@ mod tests {
let list_response = conn let list_response = conn
.list_tables(ListTablesRequest { .list_tables(ListTablesRequest {
id: Some(namespace.clone()), id: Some(namespace.clone()),
page_token: None,
limit: None,
..Default::default() ..Default::default()
}) })
.await .await
@@ -1758,8 +1754,6 @@ mod tests {
let list_response = namespace_client let list_response = namespace_client
.list_tables(ListTablesRequest { .list_tables(ListTablesRequest {
id: Some(namespace.clone()), id: Some(namespace.clone()),
page_token: None,
limit: None,
..Default::default() ..Default::default()
}) })
.await .await
@@ -1800,8 +1794,6 @@ mod tests {
let namespace = vec!["multi_table_ns".to_string()]; let namespace = vec!["multi_table_ns".to_string()];
conn.create_namespace(CreateNamespaceRequest { conn.create_namespace(CreateNamespaceRequest {
id: Some(namespace.clone()), id: Some(namespace.clone()),
mode: None,
properties: None,
..Default::default() ..Default::default()
}) })
.await .await
@@ -1827,8 +1819,6 @@ mod tests {
let list_response = conn let list_response = conn
.list_tables(ListTablesRequest { .list_tables(ListTablesRequest {
id: Some(namespace.clone()), id: Some(namespace.clone()),
page_token: None,
limit: None,
..Default::default() ..Default::default()
}) })
.await .await

View File

@@ -40,7 +40,7 @@ use lance_index::vector::pq::PQBuildParams;
use lance_index::vector::sq::builder::SQBuildParams; use lance_index::vector::sq::builder::SQBuildParams;
use lance_index::DatasetIndexExt; use lance_index::DatasetIndexExt;
use lance_index::IndexType; use lance_index::IndexType;
use lance_io::object_store::LanceNamespaceStorageOptionsProvider; use lance_io::object_store::{LanceNamespaceStorageOptionsProvider, StorageOptionsAccessor};
use lance_namespace::models::{ use lance_namespace::models::{
QueryTableRequest as NsQueryTableRequest, QueryTableRequestColumns, QueryTableRequest as NsQueryTableRequest, QueryTableRequestColumns,
QueryTableRequestFullTextQuery, QueryTableRequestVector, StringFtsQuery, QueryTableRequestFullTextQuery, QueryTableRequestVector, StringFtsQuery,
@@ -79,10 +79,11 @@ use self::merge::MergeInsertBuilder;
pub mod datafusion; pub mod datafusion;
pub(crate) mod dataset; pub(crate) mod dataset;
pub mod delete;
pub mod merge; pub mod merge;
use crate::index::waiter::wait_for_index; use crate::index::waiter::wait_for_index;
pub use chrono::Duration; pub use chrono::Duration;
pub use delete::DeleteResult;
use futures::future::{join_all, Either}; use futures::future::{join_all, Either};
pub use lance::dataset::optimize::CompactionOptions; pub use lance::dataset::optimize::CompactionOptions;
pub use lance::dataset::refs::{TagContents, Tags as LanceTags}; pub use lance::dataset::refs::{TagContents, Tags as LanceTags};
@@ -446,15 +447,6 @@ pub struct AddResult {
pub version: u64, pub version: u64,
} }
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct DeleteResult {
// The commit version associated with the operation.
// A version of `0` indicates compatibility with legacy servers that do not return
/// a commit version.
#[serde(default)]
pub version: u64,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct MergeResult { pub struct MergeResult {
// The commit version associated with the operation. // The commit version associated with the operation.
@@ -1425,9 +1417,7 @@ impl Table {
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let unioned = UnionExec::try_new(projected_plans).map_err(|e| Error::Runtime { let unioned = Arc::new(UnionExec::new(projected_plans));
message: format!("Failed to build union plan: {e}"),
})?;
// We require 1 partition in the final output // We require 1 partition in the final output
let repartitioned = RepartitionExec::try_new( let repartitioned = RepartitionExec::try_new(
unioned, unioned,
@@ -1668,18 +1658,14 @@ impl NativeTable {
// Use DatasetBuilder::from_namespace which automatically fetches location // Use DatasetBuilder::from_namespace which automatically fetches location
// and storage options from the namespace // and storage options from the namespace
let builder = DatasetBuilder::from_namespace( let builder = DatasetBuilder::from_namespace(namespace_client.clone(), table_id)
namespace_client.clone(), .await
table_id, .map_err(|e| match e {
false, // Don't ignore namespace storage options lance::Error::Namespace { source, .. } => Error::Runtime {
) message: format!("Failed to get table info from namespace: {:?}", source),
.await },
.map_err(|e| match e { source => Error::Lance { source },
lance::Error::Namespace { source, .. } => Error::Runtime { })?;
message: format!("Failed to get table info from namespace: {:?}", source),
},
source => Error::Lance { source },
})?;
let dataset = builder let dataset = builder
.with_read_params(params) .with_read_params(params)
@@ -1883,7 +1869,13 @@ impl NativeTable {
let store_params = params let store_params = params
.store_params .store_params
.get_or_insert_with(ObjectStoreParams::default); .get_or_insert_with(ObjectStoreParams::default);
store_params.storage_options_provider = Some(storage_options_provider); let accessor = match store_params.storage_options().cloned() {
Some(options) => {
StorageOptionsAccessor::with_initial_and_provider(options, storage_options_provider)
}
None => StorageOptionsAccessor::with_provider(storage_options_provider),
};
store_params.storage_options_accessor = Some(Arc::new(accessor));
// Patch the params if we have a write store wrapper // Patch the params if we have a write store wrapper
let params = match write_store_wrapper.clone() { let params = match write_store_wrapper.clone() {
@@ -2059,7 +2051,7 @@ impl NativeTable {
return provided; return provided;
} }
let suggested = suggested_num_sub_vectors(dim); let suggested = suggested_num_sub_vectors(dim);
if num_bits.is_some_and(|num_bits| num_bits == 4) && suggested % 2 != 0 { if num_bits.is_some_and(|num_bits| num_bits == 4) && !suggested.is_multiple_of(2) {
// num_sub_vectors must be even when 4 bits are used // num_sub_vectors must be even when 4 bits are used
suggested + 1 suggested + 1
} else { } else {
@@ -2349,7 +2341,7 @@ impl NativeTable {
}; };
// Convert select to columns list // Convert select to columns list
let columns: Option<Box<QueryTableRequestColumns>> = match &vq.base.select { let columns = match &vq.base.select {
Select::All => None, Select::All => None,
Select::Columns(cols) => Some(Box::new(QueryTableRequestColumns { Select::Columns(cols) => Some(Box::new(QueryTableRequestColumns {
column_names: Some(cols.clone()), column_names: Some(cols.clone()),
@@ -2407,7 +2399,6 @@ impl NativeTable {
with_row_id: Some(vq.base.with_row_id), with_row_id: Some(vq.base.with_row_id),
bypass_vector_index: Some(!vq.use_index), bypass_vector_index: Some(!vq.use_index),
full_text_query, full_text_query,
version: None,
..Default::default() ..Default::default()
}) })
} }
@@ -2426,7 +2417,7 @@ impl NativeTable {
.map(|f| self.filter_to_sql(f)) .map(|f| self.filter_to_sql(f))
.transpose()?; .transpose()?;
let columns: Option<Box<QueryTableRequestColumns>> = match &q.select { let columns = match &q.select {
Select::All => None, Select::All => None,
Select::Columns(cols) => Some(Box::new(QueryTableRequestColumns { Select::Columns(cols) => Some(Box::new(QueryTableRequestColumns {
column_names: Some(cols.clone()), column_names: Some(cols.clone()),
@@ -2470,18 +2461,10 @@ impl NativeTable {
columns, columns,
prefilter: Some(q.prefilter), prefilter: Some(q.prefilter),
offset: q.offset.map(|o| o as i32), offset: q.offset.map(|o| o as i32),
ef: None,
refine_factor: None,
distance_type: None,
nprobes: None,
vector_column: None, // No vector column for plain queries vector_column: None, // No vector column for plain queries
with_row_id: Some(q.with_row_id), with_row_id: Some(q.with_row_id),
bypass_vector_index: Some(true), // No vector index for plain queries bypass_vector_index: Some(true), // No vector index for plain queries
full_text_query, full_text_query,
version: None,
fast_search: None,
lower_bound: None,
upper_bound: None,
..Default::default() ..Default::default()
}) })
} }
@@ -3087,11 +3070,8 @@ impl BaseTable for NativeTable {
/// Delete rows from the table /// Delete rows from the table
async fn delete(&self, predicate: &str) -> Result<DeleteResult> { async fn delete(&self, predicate: &str) -> Result<DeleteResult> {
let mut dataset = self.dataset.get_mut().await?; // Delegate to the submodule implementation
dataset.delete(predicate).await?; delete::execute_delete(self, predicate).await
Ok(DeleteResult {
version: dataset.version().version,
})
} }
async fn tags(&self) -> Result<Box<dyn Tags + '_>> { async fn tags(&self) -> Result<Box<dyn Tags + '_>> {
@@ -3244,7 +3224,7 @@ impl BaseTable for NativeTable {
.get() .get()
.await .await
.ok() .ok()
.and_then(|dataset| dataset.storage_options().cloned()) .and_then(|dataset| dataset.initial_storage_options().cloned())
} }
async fn index_stats(&self, index_name: &str) -> Result<Option<IndexStatistics>> { async fn index_stats(&self, index_name: &str) -> Result<Option<IndexStatistics>> {
@@ -3409,7 +3389,6 @@ pub struct FragmentSummaryStats {
#[cfg(test)] #[cfg(test)]
#[allow(deprecated)] #[allow(deprecated)]
mod tests { mod tests {
use std::iter;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
@@ -4026,7 +4005,7 @@ mod tests {
schema.clone(), schema.clone(),
vec![ vec![
Arc::new(Int32Array::from_iter_values(offset..(offset + 10))), Arc::new(Int32Array::from_iter_values(offset..(offset + 10))),
Arc::new(Int32Array::from_iter_values(iter::repeat(age).take(10))), Arc::new(Int32Array::from_iter_values(std::iter::repeat_n(age, 10))),
], ],
)], )],
schema, schema,
@@ -5154,15 +5133,16 @@ mod tests {
let any_query = AnyQuery::VectorQuery(vq); let any_query = AnyQuery::VectorQuery(vq);
let ns_request = table.convert_to_namespace_query(&any_query).unwrap(); let ns_request = table.convert_to_namespace_query(&any_query).unwrap();
let column_names = ns_request
.columns
.as_ref()
.and_then(|cols| cols.column_names.clone());
assert_eq!(ns_request.k, 10); assert_eq!(ns_request.k, 10);
assert_eq!(ns_request.offset, Some(5)); assert_eq!(ns_request.offset, Some(5));
assert_eq!(ns_request.filter, Some("id > 0".to_string())); assert_eq!(ns_request.filter, Some("id > 0".to_string()));
assert_eq!(column_names, Some(vec!["id".to_string()])); assert_eq!(
ns_request
.columns
.as_ref()
.and_then(|c| c.column_names.as_ref()),
Some(&vec!["id".to_string()])
);
assert_eq!(ns_request.vector_column, Some("vector".to_string())); assert_eq!(ns_request.vector_column, Some("vector".to_string()));
assert_eq!(ns_request.distance_type, Some("l2".to_string())); assert_eq!(ns_request.distance_type, Some("l2".to_string()));
assert!(ns_request.vector.single_vector.is_some()); assert!(ns_request.vector.single_vector.is_some());
@@ -5199,16 +5179,17 @@ mod tests {
let any_query = AnyQuery::Query(q); let any_query = AnyQuery::Query(q);
let ns_request = table.convert_to_namespace_query(&any_query).unwrap(); let ns_request = table.convert_to_namespace_query(&any_query).unwrap();
let column_names = ns_request
.columns
.as_ref()
.and_then(|cols| cols.column_names.clone());
// Plain queries should pass an empty vector // Plain queries should pass an empty vector
assert_eq!(ns_request.k, 20); assert_eq!(ns_request.k, 20);
assert_eq!(ns_request.offset, Some(5)); assert_eq!(ns_request.offset, Some(5));
assert_eq!(ns_request.filter, Some("id > 5".to_string())); assert_eq!(ns_request.filter, Some("id > 5".to_string()));
assert_eq!(column_names, Some(vec!["id".to_string()])); assert_eq!(
ns_request
.columns
.as_ref()
.and_then(|c| c.column_names.as_ref()),
Some(&vec!["id".to_string()])
);
assert_eq!(ns_request.with_row_id, Some(true)); assert_eq!(ns_request.with_row_id, Some(true));
assert_eq!(ns_request.bypass_vector_index, Some(true)); assert_eq!(ns_request.bypass_vector_index, Some(true));
assert!(ns_request.vector_column.is_none()); // No vector column for plain queries assert!(ns_request.vector_column.is_none()); // No vector column for plain queries

View File

@@ -100,8 +100,7 @@ impl DatasetRef {
let should_checkout = match &target_ref { let should_checkout = match &target_ref {
refs::Ref::Version(_, Some(target_ver)) => version != target_ver, refs::Ref::Version(_, Some(target_ver)) => version != target_ver,
refs::Ref::Version(_, None) => true, // No specific version, always checkout refs::Ref::Version(_, None) => true, // No specific version, always checkout
refs::Ref::VersionNumber(target_ver) => version != target_ver, refs::Ref::Tag(_) => true, // Always checkout for tags
refs::Ref::Tag(_) => true, // Always checkout for tags
}; };
if should_checkout { if should_checkout {

View File

@@ -0,0 +1,161 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use serde::{Deserialize, Serialize};
use super::NativeTable;
use crate::Result;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct DeleteResult {
// The commit version associated with the operation.
// A version of `0` indicates compatibility with legacy servers that do not return
/// a commit version.
#[serde(default)]
pub version: u64,
}
/// Internal implementation of the delete logic
///
/// This logic was moved from NativeTable::delete to keep table.rs clean.
pub(crate) async fn execute_delete(table: &NativeTable, predicate: &str) -> Result<DeleteResult> {
// We access the dataset from the table. Since this is in the same module hierarchy (super),
// and 'dataset' is pub(crate), we can access it.
let mut dataset = table.dataset.get_mut().await?;
// Perform the actual delete on the Lance dataset
dataset.delete(predicate).await?;
// Return the result with the new version
Ok(DeleteResult {
version: dataset.version().version,
})
}
#[cfg(test)]
mod tests {
use crate::connect;
use arrow_array::{record_batch, Int32Array, RecordBatch, RecordBatchIterator};
use arrow_schema::{DataType, Field, Schema};
use std::sync::Arc;
use crate::query::ExecutableQuery;
use futures::TryStreamExt;
#[tokio::test]
async fn test_delete_simple() {
let conn = connect("memory://").execute().await.unwrap();
// 1. Create a table with values 0 to 9
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(0..10))],
)
.unwrap();
let table = conn
.create_table(
"test_delete",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// 2. Verify initial state
assert_eq!(table.count_rows(None).await.unwrap(), 10);
// 3. Execute Delete (removes values > 5)
table.delete("i > 5").await.unwrap();
// 4. Verify results
assert_eq!(table.count_rows(None).await.unwrap(), 6); // 0, 1, 2, 3, 4, 5 remain
// 5. Verify specific data consistency
let batches = table
.query()
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch = &batches[0];
let array = batch
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
// Ensure no value > 5 exists
for val in array.iter() {
assert!(val.unwrap() <= 5);
}
}
#[tokio::test]
async fn rows_removed_schema_same() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(
("id", Int32, [1, 2, 3, 4, 5]),
("name", Utf8, ["a", "b", "c", "d", "e"])
)
.unwrap();
let original_schema = batch.schema();
let table = conn
.create_table(
"test_delete_all",
RecordBatchIterator::new(vec![Ok(batch)], original_schema.clone()),
)
.execute()
.await
.unwrap();
table.delete("true").await.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 0);
let current_schema = table.schema().await.unwrap();
//check if the original schema is the same as current
assert_eq!(current_schema, original_schema);
}
#[tokio::test]
async fn test_delete_false_increments_version() {
let conn = connect("memory://").execute().await.unwrap();
// Create a table with 5 rows
let batch = record_batch!(("id", Int32, [1, 2, 3, 4, 5])).unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_delete_noop",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// Capture the initial state (Rows = 5, Version = 1)
let initial_rows = table.count_rows(None).await.unwrap();
let initial_version = table.version().await.unwrap();
assert_eq!(initial_rows, 5);
table.delete("false").await.unwrap();
// Rows should still be 5
let current_rows = table.count_rows(None).await.unwrap();
assert_eq!(
current_rows, initial_rows,
"Data should not change when predicate is false"
);
// version check
let current_version = table.version().await.unwrap();
assert!(
current_version > initial_version,
"Table version must increment after delete operation"
);
}
}

View File

@@ -4,7 +4,6 @@
use std::{ use std::{
borrow::Cow, borrow::Cow,
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
iter::repeat,
sync::Arc, sync::Arc,
}; };
@@ -268,9 +267,10 @@ fn create_some_records() -> Result<impl IntoArrow> {
schema.clone(), schema.clone(),
vec![ vec![
Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)), Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)),
Arc::new(StringArray::from_iter( Arc::new(StringArray::from_iter(std::iter::repeat_n(
repeat(Some("hello world".to_string())).take(TOTAL), Some("hello world".to_string()),
)), TOTAL,
))),
], ],
) )
.unwrap()] .unwrap()]