Compare commits

...

41 Commits

Author SHA1 Message Date
Lance Release
027d53500b Bump version: 0.29.2-beta.0 → 0.29.2 2026-02-09 06:05:42 +00:00
Lance Release
9098f47e73 Bump version: 0.29.1 → 0.29.2-beta.0 2026-02-09 06:05:40 +00:00
Jack Ye
826a3e5ee9 ci(nodejs): add repository field to package.json for npm provenance (#3003)
## Summary

- Added `repository` field to all nodejs package.json files (main
package + 7 platform-specific packages)
- This fixes the npm publish E422 error where sigstore provenance
verification fails because the repository.url was empty

## Root Cause

Failing CI:
https://github.com/lancedb/lancedb/actions/runs/21770794768/job/62821570260

npm's sigstore provenance verification requires the `repository.url`
field in package.json to match the GitHub repository URL from the
provenance bundle. The platform-specific packages
(`@lancedb/lancedb-darwin-arm64`, etc.) were missing this field
entirely, causing the publish to fail with:

```
npm error 422 Unprocessable Entity - Error verifying sigstore provenance bundle: 
Failed to validate repository information: package.json: "repository.url" is "", 
expected to match "https://github.com/lancedb/lancedb" from provenance
```

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-08 22:04:32 -08:00
Lance Release
9fac56252e Bump version: 0.26.1-beta.0 → 0.26.1 2026-02-07 00:33:18 +00:00
Lance Release
c55ca20c1b Bump version: 0.26.0 → 0.26.1-beta.0 2026-02-07 00:33:02 +00:00
Lance Release
5cdb15feef Bump version: 0.29.1-beta.0 → 0.29.1 2026-02-07 00:32:44 +00:00
Lance Release
7a3eea927f Bump version: 0.29.0 → 0.29.1-beta.0 2026-02-07 00:32:42 +00:00
Jack Ye
5dd9b072d8 ci: upgrade node version for publishing (#2993)
Trusted publishing requires npm >=11.5.1, which means node>=24.

Also need `npm config set provenance true` to fully enable it
2026-02-06 16:30:46 -08:00
Abhishek
6dde379d44 refactor: extract schema evolution logic from table.rs into submodule (#2973)
Continues the modularization effort of schema evolution operations as
outlined in #2949

## Summary
- Extracts schema evolution operations (add_columns, alter_columns,
drop_columns) from `table.rs` into `table/schema_evolution.rs`
- Public API remains unchanged via re-exports
## Test plan
- [x] All new schema evolution tests pass
- [x] All existing tests pass
- [x] `cargo clippy` passes with no warnings
  - [x] `cargo fmt --check` passes
2026-02-06 11:33:18 -08:00
Lance Release
55f09ef1cd Bump version: 0.26.0-beta.0 → 0.26.0 2026-02-06 18:08:30 +00:00
Lance Release
e9d8651d18 Bump version: 0.25.0-beta.0 → 0.26.0-beta.0 2026-02-06 18:08:08 +00:00
Lance Release
071f467571 Bump version: 0.29.0-beta.0 → 0.29.0 2026-02-06 18:07:49 +00:00
Lance Release
f83aa25119 Bump version: 0.28.0-beta.0 → 0.29.0-beta.0 2026-02-06 18:07:48 +00:00
Jack Ye
0a8fe4d026 ci: fix python version for latest release (#2989)
It was accidentally corrupted in
https://github.com/lancedb/lancedb/pull/2972
2026-02-06 10:07:03 -08:00
Jack Ye
3ad7be9825 fix: remove x86_64-apple-darwin from list of npm triples (#2987)
Missed during https://github.com/lancedb/lancedb/pull/2987
2026-02-06 09:43:44 -08:00
LanceDB Robot
589041d842 feat: update lance dependency to v2.0.0 (#2985)
## Summary
- Bump Lance Rust crates to v2.0.0 (from v2.0.0-rc.4) and update Java
`lance-core` to 2.0.0.
- Verified `cargo clippy --workspace --tests --all-features -- -D
warnings` and `cargo fmt --all`.
- Triggering tag: v2.0.0.
2026-02-05 17:39:32 -08:00
Jack Ye
2e4cd56ab1 ci: auto-publish lancedb java sdk (#2986)
Avoid the need to manually approve an artifact release in Maven Central
2026-02-05 16:30:32 -08:00
Jack Ye
6fd8586fa7 fix: avoid force push in codex workflows to work with v0.95.0 git safety (#2981)
## Summary
- Codex CLI v0.95.0 ([PR
#10258](https://github.com/openai/codex/pull/10258)) hardened git
command safety so force push (`git push -f`, `--force`,
`--force-with-lease`, `+refspec`) now requires approval, which blocks it
in non-interactive `exec` mode.
- This broke the
[codex-update-lance-dependency](https://github.com/lancedb/lancedb/actions/runs/21727536000/job/62673436482)
workflow — the job succeeded but failed to push the branch or create the
PR.
- Replace force push with `gh api` branch deletion followed by regular
`git push`.
- Also update the script to bump Java lance-core version which was
missing previously

## Test plan
- [x] Re-run the `Codex Update Lance Dependency` workflow with a test
tag to verify the push and PR creation succeed:
https://github.com/lancedb/lancedb/pull/2983

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 15:57:45 -08:00
Jack Ye
6329b57604 docs: update nodejs docs for storage options APIs (#2978)
Regenerate TypeScript docs to include the new initialStorageOptions()
and latestStorageOptions() methods added in #2966.

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-04 16:07:58 -08:00
Will Jones
c51b13e70f ci: fix publish failure notifications being skipped (#2976)
## Summary

The `report-failure` jobs in npm, cargo, and pypi publish workflows
checked for
`release` or `workflow_dispatch` events, but these workflows are
triggered by tag
pushes where `github.event_name` is `push`. The condition was never
true, so failure
notifications were silently skipped.

- Use `startsWith(github.ref, 'refs/tags/...')` to match actual tag
triggers
- Add `failure()` to only notify on actual failures

This matches the pattern already used by `java-publish.yml`.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-04 11:22:27 -08:00
Jack Ye
0859312b83 feat: add initial and latest storage options apis (#2966)
Expose `initial_storage_options()` and `latest_storage_options()` in
lance Dataset, in lancedb rust, python and typescript SDKs.

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-04 10:31:39 -08:00
Weston Pace
a6e8ec8d48 ci: remove npm auth token to allow trusted publisher (#2975) 2026-02-04 07:28:42 -08:00
Jack Ye
bd2c6d0763 chore: update lance dependency to v2.0.0-rc.4 (#2972) 2026-02-03 14:38:39 -08:00
Will Jones
fbf4a53475 feat(rust): implement TableProvider::insert_into() for LanceDB tables (#2939)
Implements `InsertExec` and `RemoteInsertExec` to support running
inserts in DataFusion.

## Context

In https://github.com/lancedb/lancedb/pull/2929, I've prototyped moving
the insert pipeline into DataFusion. This will enable parallelism at two
levels:

1. Running preprocessing, such as casting the input schema or computing
embeddings
2. Writing out files

This PR is just the first part of running the actual writes. In the end,
the plans might look like:

```
InsertExec
  RepartitionExec num_partitions=<write_parallelism>
    ProjectionExec vector=compute_embedding()
      RepartitionExec num_partitions=<num_cpus>
        DataSourceExec
```

where `num_cpus` is used to take advantage of all cores, while
`write_parallelism` might be less than `num_cpus` if there are too few
rows to want to split writes across `num_cpus` files.

Later PRs will move the preprocessing steps into DataFusion, and then
hook this up to the `Table::add()` implementations.

## Relation to future SQL work

We eventually plan on having the Remote SDK go through a FlightSQL
endpoint. Then for most queries we will send just the SQL string to the
server, and not run any sort of DataFusion plan on the client.

However, I think writes will be a little special, especially bulk writes
where we need to upload large streams of data and likely want
parallelism. So we'll have different code paths for writes, and I think
using DataFusion makes sense, especially as long as we are doing the
pre-processing on the client side still.
2026-02-03 10:38:02 -08:00
Vedant Madane
d3e15f3e17 fix(node): allow bigint[] for takeRowIds (#2916)
## Summary

This PR changes takeRowIds to accept bigint[] instead of 
number[], matching the type of _rowid returned by withRowId().

## Problem

When retrieving row IDs using \withRowId()\ and querying them back with
takeRowIds(), users get an error because:

1. _rowid values are returned as JavaScript bigint
2. takeRowIds() expected number[]
3. NAPI failed to convert: Error: Failed to convert napi value BigInt
into rust type i64

## Reproduction

\\\js
import lancedb from '@lancedb/lancedb';

const db = await lancedb.connect('memory://');
const table = await db.createTable('test', [{ id: 1, vector: [1.0, 2.0]
}]);

const results = await table.query().withRowId().toArray();
const rowIds = results.map(row => row._rowid);

console.log('types:', rowIds.map(id => typeof id)); // ['bigint']
await table.takeRowIds(rowIds).toArray(); // ❌ Error before fix
\\\

## Solution

- Updated TypeScript signature from takeRowIds(rowIds: number[]) to
takeRowIds(rowIds: bigint[])
- Updated Rust NAPI binding to accept Vec<BigInt> and convert using
get_u64()

Fixes #2722

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
2026-02-03 10:09:51 -08:00
ChinmayGowda71
9c017d8348 refactor: extract update logic to src/table/update.rs (#2964)
References #2949 Part 2 of table.rs refactor. Moved UpdateResult,
UpdateBuilder, and execution logic to src/table/update.rs. No functional
changes API remains identical.

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
2026-02-03 09:54:19 -08:00
Rashid Ul Islam
c3cc2530b7 feat(python): expose fast_search in synchronous API (Fixes #2612) (#2962)
Fixes #2612

This PR exposes the private _fast_search attribute via a public
fast_search() method in the synchronous LanceVectorQueryBuilder.

Previously, enabling fast search in the sync API required accessing a
private member (query._fast_search = True). This change aligns the
synchronous API with the Async and Remote APIs, allowing for cleaner,
more Pythonic method chaining.

Changes:
Added fast_search() method to LanceVectorQueryBuilder in
python/python/lancedb/query.py.
Added a unit test verifying the flag works with high-dimensional data
(2560 dims) and chaining.
Example Usage:

Before:

```
query = table.search(vector)
query._fast_search = True  # Private attribute usage
results = query.limit(10).to_pandas()
```

After:

```
results = (
    table.search(vector)
    .fast_search()
    .limit(10)
    .to_pandas()
)
```

Verification:
I have added a test case (test_fast_search_high_dimension) that
replicates the scenario described in the issue (2560 dimensions, cosine
distance) to ensure the pipeline constructs the query correctly without
errors.

Checklist:

- [ ]  I have added tests to cover my changes.
- [ ]  All new and existing tests passed.
- [ ]  Documentation has been updated (inline docstrings).

Signed-off-by: Rashidul Islam <rasidulislam71@gmail.com>
2026-02-03 09:17:27 -08:00
Lance Release
571295b0d9 Bump version: 0.24.1 → 0.25.0-beta.0 2026-02-03 04:48:34 +00:00
Lance Release
972c682857 Bump version: 0.27.1 → 0.28.0-beta.0 2026-02-03 04:47:20 +00:00
LuQQiu
4f8ee82730 chore: update lance core java version to 1.0.4 (#2971) 2026-02-02 20:43:36 -08:00
Will Jones
131024839f fix: include _rowid in hash and calculated split projections (#2965)
## Summary

- PR #2957 changed the permutation builder to only select `_rowid` from
the base table, but `Splitter::project()` for hash and calculated splits
replaced the selection entirely, dropping `_rowid`.
- Include `_rowid` in the column selections for hash and calculated
split projections.
- Fix a Python test that queried the permutation table for base table
columns no longer materialized.

Fixes the `test_split_hash`, `test_split_hash_with_discard`,
`test_split_calculated`, `test_shuffle_combined_with_splits`, and
`test_filter_with_splits` failures in `test_permutation.py`.

## Test plan

- [x] `cargo test -p lancedb -- permutation` (22 passed)
- [x] `pytest python/tests/test_permutation.py` (46 passed)
- [x] `npm test __test__/permutation.test.ts` (20 passed)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-02 16:27:58 -08:00
ChinmayGowda71
3c7ddf4d0c refactor: modularize table.rs and extract delete logic (#2952)
References #2949 Moved DeleteResult and delete() implementation to
src/table/delete.rs. No functional changes. Added a test delete which
works. Will work on refactoring update next.
2026-02-02 11:54:49 -08:00
Siyuan Huang
461176f9f2 docs: update REST API link in README.md (#2906)
Fix broken REST API docs link in README.md by replacing
https://docs.lancedb.com/api-reference/introduction (404) with
https://docs.lancedb.com/api-reference/rest
2026-01-30 15:49:41 -08:00
Aman Harsh
3b8996bb69 fix(python): cancel remote queries on sync API interruption (#2913)
Fixes #2898 

Problem:
Sync API cancellations didn’t stop remote query coroutines, so requests
could continue after interrupt.

Changes:
- Cancel run_coroutine_threadsafe futures on any BaseException in the
sync background loop
- Update cancellation test to avoid starting a real background thread
and cover GeneratorExit
2026-01-30 15:47:18 -08:00
Mesut-Doner
3755064e93 fix(rust): support embeddings in create_empty_table (#2961)
Fixes the Rust SDK's `create_empty_table` to properly support embedding
column definitions, bringing it to parity with the Python SDK.

## Problem

The Rust SDK's `Connection::create_empty_table` did not support setting
embedding columns. When using `.add_embedding()` on the builder, the
embedding column definitions were lost because
`TableDefinition::new_from_schema(schema)` marks all columns as physical
only, without embedding metadata.

The Python SDK worked around this by creating an empty record batch with
proper schema metadata rather than using `create_empty_table` directly.

## Solution
Modified `CreateTableBuilder<false>` to handle embeddings

Closes #2759
2026-01-30 15:44:18 -08:00
Xin Sun
8773b865a9 fix(python): uses PIL incorrectly and may raise AttributeError (#2954)
Importing `PIL` alone does not guarantee that the `Image` submodule is
loaded. In a clean environment where no other code has imported
`PIL.Image` before, `PIL.Image` does not exist on the `PIL` package,
which leads to the AttributeError.
2026-01-30 15:33:10 -08:00
fzowl
1ee29675b3 feat(python): adding VoyageAI v4 models (#2959)
Adding VoyageAI v4 models
 - with these, i added unit tests
 - added example code (tested!)
2026-01-30 15:16:03 -08:00
Weston Pace
9be28448f5 fix: don't store all columns in the permutation table (#2957)
The permutation table was always intended to be a small table of row id
pointers (and split id). However, it was accidentally doing a full
materialization of the base table 🤦

This PR changes the permutation builder to only store row id and split
id.
2026-01-29 16:06:36 -08:00
Lei Xu
357197bacc chore!: change support python version from 3.10 to 3.13 (#2955)
Python 3.9 is EOL since Oct 2025. and last two pyarrow builts were
against python3.10-3.13.

* This PR is contributed by codex-gpt5.2
2026-01-30 01:47:50 +08:00
Lei Xu
ad51e2dd1f fix: support pydantic list of structs or optional struct (#2953)
Closes #2950

*This code is generated by codex-gpt5.2*
2026-01-28 21:08:18 -08:00
Weston Pace
e9e904783c feat: allow the permutation builder memory limit to be configured by env var (#2946)
Running into issues with DF sorting again. This will at least allow the
memory limit to be set large to bypass problems.
2026-01-28 09:02:59 +05:30
76 changed files with 3963 additions and 1228 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.24.1"
current_version = "0.26.1"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -3,7 +3,7 @@ name: build-linux-wheel
description: "Build a manylinux wheel for lance"
inputs:
python-minor-version:
description: "8, 9, 10, 11, 12"
description: "10, 11, 12, 13"
required: true
args:
description: "--release"

View File

@@ -3,7 +3,7 @@ name: build_wheel
description: "Build a lance wheel"
inputs:
python-minor-version:
description: "8, 9, 10, 11"
description: "10, 11, 12, 13"
required: true
args:
description: "--release"

View File

@@ -3,7 +3,7 @@ name: build_wheel
description: "Build a lance wheel"
inputs:
python-minor-version:
description: "8, 9, 10, 11"
description: "10, 11, 12, 13, 14"
required: true
args:
description: "--release"

View File

@@ -42,7 +42,7 @@ jobs:
name: Report Workflow Failure
runs-on: ubuntu-latest
needs: [build]
if: always() && (github.event_name == 'release' || github.event_name == 'workflow_dispatch')
if: always() && failure() && startsWith(github.ref, 'refs/tags/v')
permissions:
contents: read
issues: write

View File

@@ -86,16 +86,17 @@ jobs:
You are running inside the lancedb repository on a GitHub Actions runner. Update the Lance dependency to version ${VERSION} and prepare a pull request for maintainers to review.
Follow these steps exactly:
1. Use script "ci/set_lance_version.py" to update Lance dependencies. The script already refreshes Cargo metadata, so allow it to finish even if it takes time.
2. Run "cargo clippy --workspace --tests --all-features -- -D warnings". If diagnostics appear, fix them yourself and rerun clippy until it exits cleanly. Do not skip any warnings.
3. After clippy succeeds, run "cargo fmt --all" to format the workspace.
4. Ensure the repository is clean except for intentional changes. Inspect "git status --short" and "git diff" to confirm the dependency update and any required fixes.
5. Create and switch to a new branch named "${BRANCH_NAME}" (replace any duplicated hyphens if necessary).
6. Stage all relevant files with "git add -A". Commit using the message "${COMMIT_TYPE}: update lance dependency to v${VERSION}".
7. Push the branch to origin. If the branch already exists, force-push your changes.
8. env "GH_TOKEN" is available, use "gh" tools for github related operations like creating pull request.
9. Create a pull request targeting "main" with title "${COMMIT_TYPE}: update lance dependency to v${VERSION}". First, write the PR body to /tmp/pr-body.md using a heredoc (cat <<'EOF' > /tmp/pr-body.md). The body should summarize the dependency bump, clippy/fmt verification, and link the triggering tag (${TAG}). Then run "gh pr create --body-file /tmp/pr-body.md".
10. After creating the PR, display the PR URL, "git status --short", and a concise summary of the commands run and their results.
1. Use script "ci/set_lance_version.py" to update Lance Rust dependencies. The script already refreshes Cargo metadata, so allow it to finish even if it takes time.
2. Update the Java lance-core dependency version in "java/pom.xml": change the "<lance-core.version>...</lance-core.version>" property to "${VERSION}".
3. Run "cargo clippy --workspace --tests --all-features -- -D warnings". If diagnostics appear, fix them yourself and rerun clippy until it exits cleanly. Do not skip any warnings.
4. After clippy succeeds, run "cargo fmt --all" to format the workspace.
5. Ensure the repository is clean except for intentional changes. Inspect "git status --short" and "git diff" to confirm the dependency update and any required fixes.
6. Create and switch to a new branch named "${BRANCH_NAME}" (replace any duplicated hyphens if necessary).
7. Stage all relevant files with "git add -A". Commit using the message "${COMMIT_TYPE}: update lance dependency to v${VERSION}".
8. Push the branch to origin. If the remote branch already exists, delete it first with "gh api -X DELETE repos/lancedb/lancedb/git/refs/heads/${BRANCH_NAME}" then push with "git push origin ${BRANCH_NAME}". Do NOT use "git push --force" or "git push -f".
9. env "GH_TOKEN" is available, use "gh" tools for github related operations like creating pull request.
10. Create a pull request targeting "main" with title "${COMMIT_TYPE}: update lance dependency to v${VERSION}". First, write the PR body to /tmp/pr-body.md using a heredoc (cat <<'EOF' > /tmp/pr-body.md). The body should summarize the dependency bump, clippy/fmt verification, and link the triggering tag (${TAG}). Then run "gh pr create --body-file /tmp/pr-body.md".
11. After creating the PR, display the PR URL, "git status --short", and a concise summary of the commands run and their results.
Constraints:
- Use bash commands; avoid modifying GitHub workflow files other than through the scripted task above.

View File

@@ -41,7 +41,7 @@ jobs:
sudo apt install -y protobuf-compiler libssl-dev
rustup update && rustup default
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: "3.10"
cache: "pip"

View File

@@ -8,6 +8,7 @@ on:
paths:
- Cargo.toml
- nodejs/**
- docs/src/js/**
- .github/workflows/nodejs.yml
- docker-compose.yml

View File

@@ -318,7 +318,7 @@ jobs:
- name: Setup node
uses: actions/setup-node@v4
with:
node-version: 20
node-version: 24
cache: npm
cache-dependency-path: nodejs/package-lock.json
registry-url: "https://registry.npmjs.org"
@@ -348,9 +348,9 @@ jobs:
run: find npm
- name: Publish
env:
NODE_AUTH_TOKEN: ${{ secrets.LANCEDB_NPM_REGISTRY_TOKEN }}
DRY_RUN: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: |
npm config set provenance true
ARGS="--access public"
if [[ $DRY_RUN == "true" ]]; then
ARGS="$ARGS --dry-run"
@@ -363,7 +363,7 @@ jobs:
name: Report Workflow Failure
runs-on: ubuntu-latest
needs: [build-lancedb, test-lancedb, publish]
if: always() && (github.event_name == 'release' || github.event_name == 'workflow_dispatch')
if: always() && failure() && startsWith(github.ref, 'refs/tags/v')
permissions:
contents: read
issues: write

View File

@@ -44,12 +44,12 @@ jobs:
fetch-depth: 0
lfs: true
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v6
with:
python-version: 3.8
python-version: "3.10"
- uses: ./.github/workflows/build_linux_wheel
with:
python-minor-version: 8
python-minor-version: 10
args: "--release --strip ${{ matrix.config.extra_args }}"
arm-build: ${{ matrix.config.platform == 'aarch64' }}
manylinux: ${{ matrix.config.manylinux }}
@@ -74,12 +74,12 @@ jobs:
fetch-depth: 0
lfs: true
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v6
with:
python-version: 3.12
python-version: "3.13"
- uses: ./.github/workflows/build_mac_wheel
with:
python-minor-version: 8
python-minor-version: 10
args: "--release --strip --target ${{ matrix.config.target }} --features fp16kernels"
- uses: ./.github/workflows/upload_wheel
if: startsWith(github.ref, 'refs/tags/python-v')
@@ -95,12 +95,12 @@ jobs:
fetch-depth: 0
lfs: true
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v6
with:
python-version: 3.12
python-version: "3.13"
- uses: ./.github/workflows/build_windows_wheel
with:
python-minor-version: 8
python-minor-version: 10
args: "--release --strip"
vcpkg_token: ${{ secrets.VCPKG_GITHUB_PACKAGES }}
- uses: ./.github/workflows/upload_wheel
@@ -181,7 +181,7 @@ jobs:
permissions:
contents: read
issues: write
if: always() && (github.event_name == 'release' || github.event_name == 'workflow_dispatch')
if: always() && failure() && startsWith(github.ref, 'refs/tags/python-v')
steps:
- uses: actions/checkout@v4
- uses: ./.github/actions/create-failure-issue

View File

@@ -36,9 +36,9 @@ jobs:
fetch-depth: 0
lfs: true
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: "3.12"
python-version: "3.13"
- name: Install ruff
run: |
pip install ruff==0.9.9
@@ -61,9 +61,9 @@ jobs:
fetch-depth: 0
lfs: true
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: "3.12"
python-version: "3.13"
- name: Install protobuf compiler
run: |
sudo apt update
@@ -90,9 +90,9 @@ jobs:
fetch-depth: 0
lfs: true
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: "3.12"
python-version: "3.13"
cache: "pip"
- name: Install protobuf
run: |
@@ -110,7 +110,7 @@ jobs:
timeout-minutes: 30
strategy:
matrix:
python-minor-version: ["9", "12"]
python-minor-version: ["10", "13"]
runs-on: "ubuntu-24.04"
defaults:
run:
@@ -126,7 +126,7 @@ jobs:
sudo apt update
sudo apt install -y protobuf-compiler
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: 3.${{ matrix.python-minor-version }}
- uses: ./.github/workflows/build_linux_wheel
@@ -156,9 +156,9 @@ jobs:
fetch-depth: 0
lfs: true
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: "3.12"
python-version: "3.13"
- uses: ./.github/workflows/build_mac_wheel
with:
args: --profile ci
@@ -185,9 +185,9 @@ jobs:
fetch-depth: 0
lfs: true
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: "3.12"
python-version: "3.13"
- uses: ./.github/workflows/build_windows_wheel
with:
args: --profile ci
@@ -212,9 +212,9 @@ jobs:
sudo apt update
sudo apt install -y protobuf-compiler
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: 3.9
python-version: "3.10"
- name: Install lancedb
run: |
pip install "pydantic<2"

853
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -15,39 +15,40 @@ categories = ["database-implementations"]
rust-version = "1.88.0"
[workspace.dependencies]
lance = { "version" = "=1.0.4", default-features = false, "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-core = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-datagen = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-file = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-io = { "version" = "=1.0.4", default-features = false, "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-index = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-linalg = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace-impls = { "version" = "=1.0.4", default-features = false, "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-table = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-testing = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-datafusion = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-encoding = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-arrow = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance = { "version" = "=2.0.0", default-features = false }
lance-core = "=2.0.0"
lance-datagen = "=2.0.0"
lance-file = "=2.0.0"
lance-io = { "version" = "=2.0.0", default-features = false }
lance-index = "=2.0.0"
lance-linalg = "=2.0.0"
lance-namespace = "=2.0.0"
lance-namespace-impls = { "version" = "=2.0.0", default-features = false }
lance-table = "=2.0.0"
lance-testing = "=2.0.0"
lance-datafusion = "=2.0.0"
lance-encoding = "=2.0.0"
lance-arrow = "=2.0.0"
ahash = "0.8"
# Note that this one does not include pyarrow
arrow = { version = "56.2", optional = false }
arrow-array = "56.2"
arrow-data = "56.2"
arrow-ipc = "56.2"
arrow-ord = "56.2"
arrow-schema = "56.2"
arrow-select = "56.2"
arrow-cast = "56.2"
arrow = { version = "57.2", optional = false }
arrow-array = "57.2"
arrow-data = "57.2"
arrow-ipc = "57.2"
arrow-ord = "57.2"
arrow-schema = "57.2"
arrow-select = "57.2"
arrow-cast = "57.2"
async-trait = "0"
datafusion = { version = "50.1", default-features = false }
datafusion-catalog = "50.1"
datafusion-common = { version = "50.1", default-features = false }
datafusion-execution = "50.1"
datafusion-expr = "50.1"
datafusion-physical-plan = "50.1"
datafusion = { version = "51.0", default-features = false }
datafusion-catalog = "51.0"
datafusion-common = { version = "51.0", default-features = false }
datafusion-execution = "51.0"
datafusion-expr = "51.0"
datafusion-physical-plan = "51.0"
datafusion-physical-expr = "51.0"
env_logger = "0.11"
half = { "version" = "2.6.0", default-features = false, features = [
half = { "version" = "2.7.1", default-features = false, features = [
"num-traits",
] }
futures = "0"

View File

@@ -66,7 +66,7 @@ Follow the [Quickstart](https://lancedb.com/docs/quickstart/) doc to set up Lanc
| Python SDK | https://lancedb.github.io/lancedb/python/python/ |
| Typescript SDK | https://lancedb.github.io/lancedb/js/globals/ |
| Rust SDK | https://docs.rs/lancedb/latest/lancedb/index.html |
| REST API | https://docs.lancedb.com/api-reference/introduction |
| REST API | https://docs.lancedb.com/api-reference/rest |
## **Join Us and Contribute**

View File

@@ -0,0 +1,62 @@
# VoyageAI Embeddings
Voyage AI provides cutting-edge embedding and rerankers.
Using voyageai API requires voyageai package, which can be installed using `pip install voyageai`. Voyage AI embeddings are used to generate embeddings for text data. The embeddings can be used for various tasks like semantic search, clustering, and classification.
You also need to set the `VOYAGE_API_KEY` environment variable to use the VoyageAI API.
Supported models are:
**Voyage-4 Series (Latest)**
- voyage-4 (1024 dims, general-purpose and multilingual retrieval, 320K batch tokens)
- voyage-4-lite (1024 dims, optimized for latency and cost, 1M batch tokens)
- voyage-4-large (1024 dims, best retrieval quality, 120K batch tokens)
**Voyage-3 Series**
- voyage-3
- voyage-3-lite
**Domain-Specific Models**
- voyage-finance-2
- voyage-multilingual-2
- voyage-law-2
- voyage-code-2
Supported parameters (to be passed in `create` method) are:
| Parameter | Type | Default Value | Description |
|---|---|--------|---------|
| `name` | `str` | `None` | The model ID of the model to use. Supported base models for Text Embeddings: voyage-4, voyage-4-lite, voyage-4-large, voyage-3, voyage-3-lite, voyage-finance-2, voyage-multilingual-2, voyage-law-2, voyage-code-2 |
| `input_type` | `str` | `None` | Type of the input text. Default to None. Other options: query, document. |
| `truncation` | `bool` | `True` | Whether to truncate the input texts to fit within the context length. |
Usage Example:
```python
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import EmbeddingFunctionRegistry
voyageai = EmbeddingFunctionRegistry
.get_instance()
.get("voyageai")
.create(name="voyage-3")
class TextModel(LanceModel):
text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
data = [ { "text": "hello world" },
{ "text": "goodbye world" }]
db = lancedb.connect("~/.lancedb")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(data)
```

View File

@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
<dependency>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-core</artifactId>
<version>0.24.1</version>
<version>0.26.1</version>
</dependency>
```

View File

@@ -367,6 +367,27 @@ Use [Table.listIndices](Table.md#listindices) to find the names of the indices.
***
### initialStorageOptions()
```ts
abstract initialStorageOptions(): Promise<undefined | null | Record<string, string>>
```
Get the initial storage options that were passed in when opening this table.
For dynamically refreshed options (e.g., credential vending), use
[Table.latestStorageOptions](Table.md#lateststorageoptions).
Warning: This is an internal API and the return value is subject to change.
#### Returns
`Promise`&lt;`undefined` \| `null` \| `Record`&lt;`string`, `string`&gt;&gt;
The storage options, or undefined if no storage options were configured.
***
### isOpen()
```ts
@@ -381,6 +402,28 @@ Return true if the table has not been closed
***
### latestStorageOptions()
```ts
abstract latestStorageOptions(): Promise<undefined | null | Record<string, string>>
```
Get the latest storage options, refreshing from provider if configured.
This method is useful for credential vending scenarios where storage options
may be refreshed dynamically. If no dynamic provider is configured, this
returns the initial static options.
Warning: This is an internal API and the return value is subject to change.
#### Returns
`Promise`&lt;`undefined` \| `null` \| `Record`&lt;`string`, `string`&gt;&gt;
The storage options, or undefined if no storage options were configured.
***
### listIndices()
```ts
@@ -705,8 +748,11 @@ Create a query that returns a subset of the rows in the table.
#### Parameters
* **rowIds**: `number`[]
* **rowIds**: readonly (`number` \| `bigint`)[]
The row ids of the rows to return.
Row ids returned by `withRowId()` are `bigint`, so `bigint[]` is supported.
For convenience / backwards compatibility, `number[]` is also accepted (for
small row ids that fit in a safe integer).
#### Returns

View File

@@ -8,7 +8,7 @@
<parent>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.24.1-final.0</version>
<version>0.26.1-final.0</version>
<relativePath>../pom.xml</relativePath>
</parent>

View File

@@ -6,7 +6,7 @@
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.24.1-final.0</version>
<version>0.26.1-final.0</version>
<packaging>pom</packaging>
<name>${project.artifactId}</name>
<description>LanceDB Java SDK Parent POM</description>
@@ -28,7 +28,7 @@
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<arrow.version>15.0.0</arrow.version>
<lance-core.version>1.0.0-rc.2</lance-core.version>
<lance-core.version>2.0.0</lance-core.version>
<spotless.skip>false</spotless.skip>
<spotless.version>2.30.0</spotless.version>
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>
@@ -292,11 +292,12 @@
<plugin>
<groupId>org.sonatype.central</groupId>
<artifactId>central-publishing-maven-plugin</artifactId>
<version>0.4.0</version>
<version>0.8.0</version>
<extensions>true</extensions>
<configuration>
<publishingServerId>ossrh</publishingServerId>
<tokenAuth>true</tokenAuth>
<autoPublish>true</autoPublish>
</configuration>
</plugin>
<plugin>

View File

@@ -1,7 +1,7 @@
[package]
name = "lancedb-nodejs"
edition.workspace = true
version = "0.24.1"
version = "0.26.1"
license.workspace = true
description.workspace = true
repository.workspace = true

View File

@@ -312,6 +312,66 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
expect(res.getChild("id")?.toJSON()).toEqual([2, 3]);
});
it("should support takeRowIds with bigint array", async () => {
await table.add([{ id: 1 }, { id: 2 }, { id: 3 }]);
// Get actual row IDs using withRowId()
const allRows = await table.query().withRowId().toArray();
const rowIds = allRows.map((row) => row._rowid) as bigint[];
// Verify row IDs are bigint
expect(typeof rowIds[0]).toBe("bigint");
// Use takeRowIds with bigint array (the main use case from issue #2722)
const res = await table.takeRowIds([rowIds[0], rowIds[2]]).toArray();
expect(res.map((r) => r.id)).toEqual([1, 3]);
});
it("should support takeRowIds with number array for backwards compatibility", async () => {
await table.add([{ id: 1 }, { id: 2 }, { id: 3 }]);
// Small row IDs can be passed as numbers
const res = await table.takeRowIds([0, 2]).toArray();
expect(res.map((r) => r.id)).toEqual([1, 3]);
});
it("should support takeRowIds with mixed bigint and number array", async () => {
await table.add([{ id: 1 }, { id: 2 }, { id: 3 }]);
// Mixed array of bigint and number
const res = await table.takeRowIds([0n, 1, 2n]).toArray();
expect(res.map((r) => r.id)).toEqual([1, 2, 3]);
});
it("should throw for non-integer number in takeRowIds", () => {
expect(() => table.takeRowIds([1.5])).toThrow(
"Row id must be an integer (or bigint)",
);
expect(() => table.takeRowIds([0, 1.1, 2])).toThrow(
"Row id must be an integer (or bigint)",
);
});
it("should throw for negative number in takeRowIds", () => {
expect(() => table.takeRowIds([-1])).toThrow("Row id cannot be negative");
expect(() => table.takeRowIds([0, -5, 2])).toThrow(
"Row id cannot be negative",
);
});
it("should throw for unsafe large number in takeRowIds", () => {
// Number.MAX_SAFE_INTEGER + 1 is not safe
const unsafeNumber = Number.MAX_SAFE_INTEGER + 1;
expect(() => table.takeRowIds([unsafeNumber])).toThrow(
"Row id is too large for number; use bigint instead",
);
});
it("should reject negative bigint in takeRowIds", async () => {
await table.add([{ id: 1 }]);
// Negative bigint should be rejected by the Rust layer
expect(() => {
table.takeRowIds([-1n]);
}).toThrow("Row id cannot be negative");
});
it("should return the table as an instance of an arrow table", async () => {
const arrowTbl = await table.toArrow();
expect(arrowTbl).toBeInstanceOf(ArrowTable);
@@ -1520,9 +1580,9 @@ describe("when optimizing a dataset", () => {
it("delete unverified", async () => {
const version = await table.version();
const versionFile = `${tmpDir.name}/${table.name}.lance/_versions/${
version - 1
}.manifest`;
const versionFile = `${tmpDir.name}/${table.name}.lance/_versions/${String(
18446744073709551615n - (BigInt(version) - 1n),
).padStart(20, "0")}.manifest`;
fs.rmSync(versionFile);
let stats = await table.optimize({ deleteUnverified: false });

View File

@@ -347,9 +347,13 @@ export abstract class Table {
/**
* Create a query that returns a subset of the rows in the table.
* @param rowIds The row ids of the rows to return.
*
* Row ids returned by `withRowId()` are `bigint`, so `bigint[]` is supported.
* For convenience / backwards compatibility, `number[]` is also accepted (for
* small row ids that fit in a safe integer).
* @returns A builder that can be used to parameterize the query.
*/
abstract takeRowIds(rowIds: number[]): TakeQuery;
abstract takeRowIds(rowIds: readonly (bigint | number)[]): TakeQuery;
/**
* Create a search query to find the nearest neighbors
@@ -538,6 +542,35 @@ export abstract class Table {
*
*/
abstract stats(): Promise<TableStatistics>;
/**
* Get the initial storage options that were passed in when opening this table.
*
* For dynamically refreshed options (e.g., credential vending), use
* {@link Table.latestStorageOptions}.
*
* Warning: This is an internal API and the return value is subject to change.
*
* @returns The storage options, or undefined if no storage options were configured.
*/
abstract initialStorageOptions(): Promise<
Record<string, string> | null | undefined
>;
/**
* Get the latest storage options, refreshing from provider if configured.
*
* This method is useful for credential vending scenarios where storage options
* may be refreshed dynamically. If no dynamic provider is configured, this
* returns the initial static options.
*
* Warning: This is an internal API and the return value is subject to change.
*
* @returns The storage options, or undefined if no storage options were configured.
*/
abstract latestStorageOptions(): Promise<
Record<string, string> | null | undefined
>;
}
export class LocalTable extends Table {
@@ -686,8 +719,24 @@ export class LocalTable extends Table {
return new TakeQuery(this.inner.takeOffsets(offsets));
}
takeRowIds(rowIds: number[]): TakeQuery {
return new TakeQuery(this.inner.takeRowIds(rowIds));
takeRowIds(rowIds: readonly (bigint | number)[]): TakeQuery {
const ids = rowIds.map((id) => {
if (typeof id === "bigint") {
return id;
}
if (!Number.isInteger(id)) {
throw new Error("Row id must be an integer (or bigint)");
}
if (id < 0) {
throw new Error("Row id cannot be negative");
}
if (!Number.isSafeInteger(id)) {
throw new Error("Row id is too large for number; use bigint instead");
}
return BigInt(id);
});
return new TakeQuery(this.inner.takeRowIds(ids));
}
query(): Query {
@@ -858,6 +907,18 @@ export class LocalTable extends Table {
return await this.inner.stats();
}
async initialStorageOptions(): Promise<
Record<string, string> | null | undefined
> {
return await this.inner.initialStorageOptions();
}
async latestStorageOptions(): Promise<
Record<string, string> | null | undefined
> {
return await this.inner.latestStorageOptions();
}
mergeInsert(on: string | string[]): MergeInsertBuilder {
on = Array.isArray(on) ? on : [on];
return new MergeInsertBuilder(this.inner.mergeInsert(on), this.schema());

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-arm64",
"version": "0.24.1",
"version": "0.26.1",
"os": ["darwin"],
"cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node",
@@ -8,5 +8,9 @@
"license": "Apache-2.0",
"engines": {
"node": ">= 18"
},
"repository": {
"type": "git",
"url": "https://github.com/lancedb/lancedb"
}
}

View File

@@ -1,3 +0,0 @@
# `@lancedb/lancedb-darwin-x64`
This is the **x86_64-apple-darwin** binary for `@lancedb/lancedb`

View File

@@ -1,12 +0,0 @@
{
"name": "@lancedb/lancedb-darwin-x64",
"version": "0.24.1",
"os": ["darwin"],
"cpu": ["x64"],
"main": "lancedb.darwin-x64.node",
"files": ["lancedb.darwin-x64.node"],
"license": "Apache-2.0",
"engines": {
"node": ">= 18"
}
}

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.24.1",
"version": "0.26.1",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node",
@@ -9,5 +9,9 @@
"engines": {
"node": ">= 18"
},
"libc": ["glibc"]
"libc": ["glibc"],
"repository": {
"type": "git",
"url": "https://github.com/lancedb/lancedb"
}
}

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-musl",
"version": "0.24.1",
"version": "0.26.1",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-musl.node",
@@ -9,5 +9,9 @@
"engines": {
"node": ">= 18"
},
"libc": ["musl"]
"libc": ["musl"],
"repository": {
"type": "git",
"url": "https://github.com/lancedb/lancedb"
}
}

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.24.1",
"version": "0.26.1",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node",
@@ -9,5 +9,9 @@
"engines": {
"node": ">= 18"
},
"libc": ["glibc"]
"libc": ["glibc"],
"repository": {
"type": "git",
"url": "https://github.com/lancedb/lancedb"
}
}

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-musl",
"version": "0.24.1",
"version": "0.26.1",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-musl.node",
@@ -9,5 +9,9 @@
"engines": {
"node": ">= 18"
},
"libc": ["musl"]
"libc": ["musl"],
"repository": {
"type": "git",
"url": "https://github.com/lancedb/lancedb"
}
}

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-arm64-msvc",
"version": "0.24.1",
"version": "0.26.1",
"os": [
"win32"
],
@@ -14,5 +14,9 @@
"license": "Apache-2.0",
"engines": {
"node": ">= 18"
},
"repository": {
"type": "git",
"url": "https://github.com/lancedb/lancedb"
}
}

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.24.1",
"version": "0.26.1",
"os": ["win32"],
"cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node",
@@ -8,5 +8,9 @@
"license": "Apache-2.0",
"engines": {
"node": ">= 18"
},
"repository": {
"type": "git",
"url": "https://github.com/lancedb/lancedb"
}
}

View File

@@ -1,12 +1,12 @@
{
"name": "@lancedb/lancedb",
"version": "0.24.1",
"version": "0.26.1",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "@lancedb/lancedb",
"version": "0.24.1",
"version": "0.26.1",
"cpu": [
"x64",
"arm64"

View File

@@ -11,7 +11,7 @@
"ann"
],
"private": false,
"version": "0.24.1",
"version": "0.26.1",
"main": "dist/index.js",
"exports": {
".": "./dist/index.js",
@@ -25,7 +25,6 @@
"triples": {
"defaults": false,
"additional": [
"x86_64-apple-darwin",
"aarch64-apple-darwin",
"x86_64-unknown-linux-gnu",
"aarch64-unknown-linux-gnu",
@@ -37,6 +36,10 @@
}
},
"license": "Apache-2.0",
"repository": {
"type": "git",
"url": "https://github.com/lancedb/lancedb"
},
"devDependencies": {
"@aws-sdk/client-dynamodb": "^3.33.0",
"@aws-sdk/client-kms": "^3.33.0",

View File

@@ -166,6 +166,19 @@ impl Table {
Ok(stats.into())
}
#[napi(catch_unwind)]
pub async fn initial_storage_options(&self) -> napi::Result<Option<HashMap<String, String>>> {
Ok(self.inner_ref()?.initial_storage_options().await)
}
#[napi(catch_unwind)]
pub async fn latest_storage_options(&self) -> napi::Result<Option<HashMap<String, String>>> {
self.inner_ref()?
.latest_storage_options()
.await
.default_error()
}
#[napi(catch_unwind)]
pub async fn update(
&self,
@@ -208,18 +221,24 @@ impl Table {
}
#[napi(catch_unwind)]
pub fn take_row_ids(&self, row_ids: Vec<i64>) -> napi::Result<TakeQuery> {
pub fn take_row_ids(&self, row_ids: Vec<BigInt>) -> napi::Result<TakeQuery> {
Ok(TakeQuery::new(
self.inner_ref()?.take_row_ids(
row_ids
.into_iter()
.map(|o| {
u64::try_from(o).map_err(|e| {
napi::Error::from_reason(format!(
"Failed to convert row id to u64: {}",
e
.map(|id| {
let (negative, value, lossless) = id.get_u64();
if negative {
Err(napi::Error::from_reason(
"Row id cannot be negative".to_string(),
))
})
} else if !lossless {
Err(napi::Error::from_reason(
"Row id is too large to fit in u64".to_string(),
))
} else {
Ok(value)
}
})
.collect::<Result<Vec<_>>>()?,
),

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.27.1"
current_version = "0.29.2"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -16,7 +16,7 @@ The Python package is a wrapper around the Rust library, `lancedb`. We use
To set up your development environment, you will need to install the following:
1. Python 3.9 or later
1. Python 3.10 or later
2. Cargo (Rust's package manager). Use [rustup](https://rustup.rs/) to install.
3. [protoc](https://grpc.io/docs/protoc-installation/) (Protocol Buffers compiler)

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-python"
version = "0.27.1"
version = "0.29.2"
edition.workspace = true
description = "Python bindings for LanceDB"
license.workspace = true
@@ -14,15 +14,15 @@ name = "_lancedb"
crate-type = ["cdylib"]
[dependencies]
arrow = { version = "56.2", features = ["pyarrow"] }
arrow = { version = "57.2", features = ["pyarrow"] }
async-trait = "0.1"
lancedb = { path = "../rust/lancedb", default-features = false }
lance-core.workspace = true
lance-namespace.workspace = true
lance-io.workspace = true
env_logger.workspace = true
pyo3 = { version = "0.25", features = ["extension-module", "abi3-py39"] }
pyo3-async-runtimes = { version = "0.25", features = [
pyo3 = { version = "0.26", features = ["extension-module", "abi3-py39"] }
pyo3-async-runtimes = { version = "0.26", features = [
"attributes",
"tokio-runtime",
] }
@@ -32,7 +32,7 @@ snafu.workspace = true
tokio = { version = "1.40", features = ["sync"] }
[build-dependencies]
pyo3-build-config = { version = "0.25", features = [
pyo3-build-config = { version = "0.26", features = [
"extension-module",
"abi3-py39",
] }

View File

@@ -16,7 +16,7 @@ description = "lancedb"
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]
license = { file = "LICENSE" }
readme = "README.md"
requires-python = ">=3.9"
requires-python = ">=3.10"
keywords = [
"data-format",
"data-science",
@@ -33,10 +33,10 @@ classifiers = [
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Scientific/Engineering",
]
@@ -137,4 +137,4 @@ include = [
"python/lancedb/_lancedb.pyi",
]
exclude = ["python/tests/"]
pythonVersion = "3.12"
pythonVersion = "3.13"

View File

@@ -180,6 +180,8 @@ class Table:
delete_unverified: Optional[bool] = None,
) -> OptimizeStats: ...
async def uri(self) -> str: ...
async def initial_storage_options(self) -> Optional[Dict[str, str]]: ...
async def latest_storage_options(self) -> Optional[Dict[str, str]]: ...
@property
def tags(self) -> Tags: ...
def query(self) -> Query: ...

View File

@@ -22,7 +22,12 @@ class BackgroundEventLoop:
self.thread.start()
def run(self, future):
return asyncio.run_coroutine_threadsafe(future, self.loop).result()
concurrent_future = asyncio.run_coroutine_threadsafe(future, self.loop)
try:
return concurrent_future.result()
except BaseException:
concurrent_future.cancel()
raise
LOOP = BackgroundEventLoop()

View File

@@ -275,7 +275,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
"""
Convert image inputs to PIL Images.
"""
PIL = attempt_import_or_raise("PIL", "pillow")
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
requests = attempt_import_or_raise("requests", "requests")
images = self.sanitize_input(images)
pil_images = []
@@ -285,12 +285,12 @@ class ColPaliEmbeddings(EmbeddingFunction):
if image.startswith(("http://", "https://")):
response = requests.get(image, timeout=10)
response.raise_for_status()
pil_images.append(PIL.Image.open(io.BytesIO(response.content)))
pil_images.append(PIL_Image.open(io.BytesIO(response.content)))
else:
with PIL.Image.open(image) as im:
with PIL_Image.open(image) as im:
pil_images.append(im.copy())
elif isinstance(image, bytes):
pil_images.append(PIL.Image.open(io.BytesIO(image)))
pil_images.append(PIL_Image.open(io.BytesIO(image)))
else:
# Assume it's a PIL Image; will raise if invalid
pil_images.append(image)

View File

@@ -77,8 +77,8 @@ class JinaEmbeddings(EmbeddingFunction):
if isinstance(inputs, list):
inputs = inputs
else:
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(inputs, PIL.Image.Image):
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(inputs, PIL_Image.Image):
inputs = [inputs]
return inputs
@@ -89,13 +89,13 @@ class JinaEmbeddings(EmbeddingFunction):
elif isinstance(image, (str, Path)):
parsed = urlparse.urlparse(image)
# TODO handle drive letter on windows.
PIL = attempt_import_or_raise("PIL", "pillow")
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if parsed.scheme == "file":
pil_image = PIL.Image.open(parsed.path)
pil_image = PIL_Image.open(parsed.path)
elif parsed.scheme == "":
pil_image = PIL.Image.open(image if os.name == "nt" else parsed.path)
pil_image = PIL_Image.open(image if os.name == "nt" else parsed.path)
elif parsed.scheme.startswith("http"):
pil_image = PIL.Image.open(io.BytesIO(url_retrieve(image)))
pil_image = PIL_Image.open(io.BytesIO(url_retrieve(image)))
else:
raise NotImplementedError("Only local and http(s) urls are supported")
buffered = io.BytesIO()
@@ -103,9 +103,9 @@ class JinaEmbeddings(EmbeddingFunction):
image_bytes = buffered.getvalue()
image_dict = {"image": base64.b64encode(image_bytes).decode("utf-8")}
else:
PIL = attempt_import_or_raise("PIL", "pillow")
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(image, PIL.Image.Image):
if isinstance(image, PIL_Image.Image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
image_bytes = buffered.getvalue()
@@ -136,9 +136,9 @@ class JinaEmbeddings(EmbeddingFunction):
elif isinstance(query, (Path, bytes)):
return [self.generate_image_embedding(query)]
else:
PIL = attempt_import_or_raise("PIL", "pillow")
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(query, PIL.Image.Image):
if isinstance(query, PIL_Image.Image):
return [self.generate_image_embedding(query)]
else:
raise TypeError(

View File

@@ -71,8 +71,8 @@ class OpenClipEmbeddings(EmbeddingFunction):
if isinstance(query, str):
return [self.generate_text_embeddings(query)]
else:
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(query, PIL.Image.Image):
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(query, PIL_Image.Image):
return [self.generate_image_embedding(query)]
else:
raise TypeError("OpenClip supports str or PIL Image as query")
@@ -145,20 +145,20 @@ class OpenClipEmbeddings(EmbeddingFunction):
return self._encode_and_normalize_image(image)
def _to_pil(self, image: Union[str, bytes]):
PIL = attempt_import_or_raise("PIL", "pillow")
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(image, bytes):
return PIL.Image.open(io.BytesIO(image))
if isinstance(image, PIL.Image.Image):
return PIL_Image.open(io.BytesIO(image))
if isinstance(image, PIL_Image.Image):
return image
elif isinstance(image, str):
parsed = urlparse.urlparse(image)
# TODO handle drive letter on windows.
if parsed.scheme == "file":
return PIL.Image.open(parsed.path)
return PIL_Image.open(parsed.path)
elif parsed.scheme == "":
return PIL.Image.open(image if os.name == "nt" else parsed.path)
return PIL_Image.open(image if os.name == "nt" else parsed.path)
elif parsed.scheme.startswith("http"):
return PIL.Image.open(io.BytesIO(url_retrieve(image)))
return PIL_Image.open(io.BytesIO(url_retrieve(image)))
else:
raise NotImplementedError("Only local and http(s) urls are supported")

View File

@@ -56,8 +56,8 @@ class SigLipEmbeddings(EmbeddingFunction):
if isinstance(query, str):
return [self.generate_text_embeddings(query)]
else:
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(query, PIL.Image.Image):
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(query, PIL_Image.Image):
return [self.generate_image_embedding(query)]
else:
raise TypeError("SigLIP supports str or PIL Image as query")
@@ -127,21 +127,21 @@ class SigLipEmbeddings(EmbeddingFunction):
return image_features.cpu().detach().numpy().squeeze()
def _to_pil(self, image: Union[str, bytes, "PIL.Image.Image"]):
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(image, PIL.Image.Image):
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(image, PIL_Image.Image):
return image.convert("RGB") if image.mode != "RGB" else image
elif isinstance(image, bytes):
return PIL.Image.open(io.BytesIO(image)).convert("RGB")
return PIL_Image.open(io.BytesIO(image)).convert("RGB")
elif isinstance(image, str):
parsed = urlparse.urlparse(image)
if parsed.scheme == "file":
return PIL.Image.open(parsed.path).convert("RGB")
return PIL_Image.open(parsed.path).convert("RGB")
elif parsed.scheme == "":
path = image if os.name == "nt" else parsed.path
return PIL.Image.open(path).convert("RGB")
return PIL_Image.open(path).convert("RGB")
elif parsed.scheme.startswith("http"):
image_bytes = url_retrieve(image)
return PIL.Image.open(io.BytesIO(image_bytes)).convert("RGB")
return PIL_Image.open(io.BytesIO(image_bytes)).convert("RGB")
else:
raise NotImplementedError("Only local and http(s) urls are supported")
else:

View File

@@ -21,6 +21,9 @@ if TYPE_CHECKING:
# Token limits for different VoyageAI models
VOYAGE_TOTAL_TOKEN_LIMITS = {
"voyage-4": 320_000,
"voyage-4-lite": 1_000_000,
"voyage-4-large": 120_000,
"voyage-context-3": 32_000,
"voyage-3.5-lite": 1_000_000,
"voyage-3.5": 320_000,
@@ -61,7 +64,7 @@ def is_video_path(path: Path) -> bool:
def transform_input(input_data: Union[str, bytes, Path]):
PIL = attempt_import_or_raise("PIL", "pillow")
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(input_data, str):
if is_valid_url(input_data):
if is_video_url(input_data):
@@ -70,7 +73,7 @@ def transform_input(input_data: Union[str, bytes, Path]):
content = {"type": "image_url", "image_url": input_data}
else:
content = {"type": "text", "text": input_data}
elif isinstance(input_data, PIL.Image.Image):
elif isinstance(input_data, PIL_Image.Image):
buffered = BytesIO()
input_data.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
@@ -79,7 +82,7 @@ def transform_input(input_data: Union[str, bytes, Path]):
"image_base64": "data:image/jpeg;base64," + img_str,
}
elif isinstance(input_data, bytes):
img = PIL.Image.open(BytesIO(input_data))
img = PIL_Image.open(BytesIO(input_data))
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
@@ -98,7 +101,7 @@ def transform_input(input_data: Union[str, bytes, Path]):
"video_base64": video_str,
}
else:
img = PIL.Image.open(input_data)
img = PIL_Image.open(input_data)
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
@@ -116,8 +119,8 @@ def sanitize_multimodal_input(inputs: Union[TEXT, IMAGES]) -> List[Any]:
"""
Sanitize the input to the embedding function.
"""
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(inputs, (str, bytes, Path, PIL.Image.Image)):
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(inputs, (str, bytes, Path, PIL_Image.Image)):
inputs = [inputs]
elif isinstance(inputs, list):
pass # Already a list, use as-is
@@ -130,7 +133,7 @@ def sanitize_multimodal_input(inputs: Union[TEXT, IMAGES]) -> List[Any]:
f"Input type {type(inputs)} not allowed with multimodal model."
)
if not all(isinstance(x, (str, bytes, Path, PIL.Image.Image)) for x in inputs):
if not all(isinstance(x, (str, bytes, Path, PIL_Image.Image)) for x in inputs):
raise ValueError("Each input should be either str, bytes, Path or Image.")
return [transform_input(i) for i in inputs]
@@ -167,6 +170,9 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
name: str
The name of the model to use. List of acceptable models:
* voyage-4 (1024 dims, general-purpose and multilingual retrieval)
* voyage-4-lite (1024 dims, optimized for latency and cost)
* voyage-4-large (1024 dims, best retrieval quality)
* voyage-context-3
* voyage-3.5
* voyage-3.5-lite
@@ -215,6 +221,9 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
_FLEXIBLE_DIM_MODELS: ClassVar[list] = ["voyage-multimodal-3.5"]
_VALID_DIMENSIONS: ClassVar[list] = [256, 512, 1024, 2048]
text_embedding_models: list = [
"voyage-4",
"voyage-4-lite",
"voyage-4-large",
"voyage-3.5",
"voyage-3.5-lite",
"voyage-3",
@@ -252,6 +261,9 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
elif self.name == "voyage-code-2":
return 1536
elif self.name in [
"voyage-4",
"voyage-4-lite",
"voyage-4-large",
"voyage-context-3",
"voyage-3.5",
"voyage-3.5-lite",

View File

@@ -275,7 +275,7 @@ def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
return pa.timestamp("us", tz=tz)
elif getattr(py_type, "__origin__", None) in (list, tuple):
child = py_type.__args__[0]
return pa.list_(_py_type_to_arrow_type(child, field))
return _pydantic_list_child_to_arrow(child, field)
raise TypeError(
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}."
)
@@ -298,12 +298,18 @@ else:
def _pydantic_type_to_arrow_type(tp: Any, field: FieldInfo) -> pa.DataType:
def _safe_issubclass(candidate: Any, base: type) -> bool:
try:
return issubclass(candidate, base)
except TypeError:
return False
if inspect.isclass(tp):
if issubclass(tp, pydantic.BaseModel):
if _safe_issubclass(tp, pydantic.BaseModel):
# Struct
fields = _pydantic_model_to_fields(tp)
return pa.struct(fields)
if issubclass(tp, FixedSizeListMixin):
if _safe_issubclass(tp, FixedSizeListMixin):
if getattr(tp, "is_multi_vector", lambda: False)():
return pa.list_(pa.list_(tp.value_arrow_type(), tp.dim()))
# For regular Vector
@@ -311,45 +317,67 @@ def _pydantic_type_to_arrow_type(tp: Any, field: FieldInfo) -> pa.DataType:
return _py_type_to_arrow_type(tp, field)
def _pydantic_list_child_to_arrow(child: Any, field: FieldInfo) -> pa.DataType:
unwrapped = _unwrap_optional_annotation(child)
if unwrapped is not None:
return pa.list_(
pa.field("item", _pydantic_type_to_arrow_type(unwrapped, field), True)
)
return pa.list_(_pydantic_type_to_arrow_type(child, field))
def _unwrap_optional_annotation(annotation: Any) -> Any | None:
if isinstance(annotation, (_GenericAlias, GenericAlias)):
origin = annotation.__origin__
args = annotation.__args__
if origin == Union:
non_none = [arg for arg in args if arg is not type(None)]
if len(non_none) == 1 and len(non_none) != len(args):
return non_none[0]
elif sys.version_info >= (3, 10) and isinstance(annotation, types.UnionType):
args = annotation.__args__
non_none = [arg for arg in args if arg is not type(None)]
if len(non_none) == 1 and len(non_none) != len(args):
return non_none[0]
return None
def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
"""Convert a Pydantic FieldInfo to Arrow DataType"""
unwrapped = _unwrap_optional_annotation(field.annotation)
if unwrapped is not None:
return _pydantic_type_to_arrow_type(unwrapped, field)
if isinstance(field.annotation, (_GenericAlias, GenericAlias)):
origin = field.annotation.__origin__
args = field.annotation.__args__
if origin is list:
child = args[0]
return pa.list_(_py_type_to_arrow_type(child, field))
elif origin == Union:
if len(args) == 2 and args[1] is type(None):
return _pydantic_type_to_arrow_type(args[0], field)
elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType):
args = field.annotation.__args__
if len(args) == 2:
for typ in args:
if typ is type(None):
continue
return _py_type_to_arrow_type(typ, field)
return _pydantic_list_child_to_arrow(child, field)
return _pydantic_type_to_arrow_type(field.annotation, field)
def is_nullable(field: FieldInfo) -> bool:
"""Check if a Pydantic FieldInfo is nullable."""
if _unwrap_optional_annotation(field.annotation) is not None:
return True
if isinstance(field.annotation, (_GenericAlias, GenericAlias)):
origin = field.annotation.__origin__
args = field.annotation.__args__
if origin == Union:
if len(args) == 2 and args[1] is type(None):
if any(typ is type(None) for typ in args):
return True
elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType):
args = field.annotation.__args__
for typ in args:
if typ is type(None):
return True
elif inspect.isclass(field.annotation) and issubclass(
field.annotation, FixedSizeListMixin
):
return field.annotation.nullable()
elif inspect.isclass(field.annotation):
try:
if issubclass(field.annotation, FixedSizeListMixin):
return field.annotation.nullable()
except TypeError:
return False
return False

View File

@@ -961,22 +961,27 @@ class LanceQueryBuilder(ABC):
>>> query = [100, 100]
>>> plan = table.search(query).analyze_plan()
>>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
AnalyzeExec verbose=true, metrics=[], cumulative_cpu=...
TracedExec, metrics=[], cumulative_cpu=...
ProjectionExec: expr=[...], metrics=[...], cumulative_cpu=...
GlobalLimitExec: skip=0, fetch=10, metrics=[...], cumulative_cpu=...
FilterExec: _distance@2 IS NOT NULL,
metrics=[output_rows=..., elapsed_compute=...], cumulative_cpu=...
SortExec: TopK(fetch=10), expr=[...],
AnalyzeExec verbose=true, elapsed=..., metrics=...
TracedExec, elapsed=..., metrics=...
ProjectionExec: elapsed=..., expr=[...],
metrics=[output_rows=..., elapsed_compute=..., output_bytes=...]
GlobalLimitExec: elapsed=..., skip=0, fetch=10,
metrics=[output_rows=..., elapsed_compute=..., output_bytes=...]
FilterExec: elapsed=..., _distance@2 IS NOT NULL, metrics=[...]
SortExec: elapsed=..., TopK(fetch=10), expr=[...],
preserve_partitioning=[...],
metrics=[output_rows=..., elapsed_compute=..., row_replacements=...],
cumulative_cpu=...
KNNVectorDistance: metric=l2,
metrics=[output_rows=..., elapsed_compute=..., output_batches=...],
cumulative_cpu=...
LanceRead: uri=..., projection=[vector], ...
metrics=[output_rows=..., elapsed_compute=...,
bytes_read=..., iops=..., requests=...], cumulative_cpu=...
metrics=[output_rows=..., elapsed_compute=...,
output_bytes=..., row_replacements=...]
KNNVectorDistance: elapsed=..., metric=l2,
metrics=[output_rows=..., elapsed_compute=...,
output_bytes=..., output_batches=...]
LanceRead: elapsed=..., uri=..., projection=[vector],
num_fragments=..., range_before=None, range_after=None,
row_id=true, row_addr=false,
full_filter=--, refine_filter=--,
metrics=[output_rows=..., elapsed_compute=..., output_bytes=...,
fragments_scanned=..., ranges_scanned=1, rows_scanned=1,
bytes_read=..., iops=..., requests=..., task_wait_time=...]
Returns
-------
@@ -1428,6 +1433,19 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._bypass_vector_index = True
return self
def fast_search(self) -> LanceVectorQueryBuilder:
"""
Skip a flat search of unindexed data. This will improve
search performance but search results will not include unindexed data.
Returns
-------
LanceVectorQueryBuilder
The LanceVectorQueryBuilder object.
"""
self._fast_search = True
return self
class LanceFtsQueryBuilder(LanceQueryBuilder):
"""A builder for full text search for LanceDB."""

View File

@@ -2222,6 +2222,37 @@ class LanceTable(Table):
def uri(self) -> str:
return LOOP.run(self._table.uri())
def initial_storage_options(self) -> Optional[Dict[str, str]]:
"""Get the initial storage options that were passed in when opening this table.
For dynamically refreshed options (e.g., credential vending), use
:meth:`latest_storage_options`.
Warning: This is an internal API and the return value is subject to change.
Returns
-------
Optional[Dict[str, str]]
The storage options, or None if no storage options were configured.
"""
return LOOP.run(self._table.initial_storage_options())
def latest_storage_options(self) -> Optional[Dict[str, str]]:
"""Get the latest storage options, refreshing from provider if configured.
This method is useful for credential vending scenarios where storage options
may be refreshed dynamically. If no dynamic provider is configured, this
returns the initial static options.
Warning: This is an internal API and the return value is subject to change.
Returns
-------
Optional[Dict[str, str]]
The storage options, or None if no storage options were configured.
"""
return LOOP.run(self._table.latest_storage_options())
def create_scalar_index(
self,
column: str,
@@ -3624,6 +3655,37 @@ class AsyncTable:
"""
return await self._inner.uri()
async def initial_storage_options(self) -> Optional[Dict[str, str]]:
"""Get the initial storage options that were passed in when opening this table.
For dynamically refreshed options (e.g., credential vending), use
:meth:`latest_storage_options`.
Warning: This is an internal API and the return value is subject to change.
Returns
-------
Optional[Dict[str, str]]
The storage options, or None if no storage options were configured.
"""
return await self._inner.initial_storage_options()
async def latest_storage_options(self) -> Optional[Dict[str, str]]:
"""Get the latest storage options, refreshing from provider if configured.
This method is useful for credential vending scenarios where storage options
may be refreshed dynamically. If no dynamic provider is configured, this
returns the initial static options.
Warning: This is an internal API and the return value is subject to change.
Returns
-------
Optional[Dict[str, str]]
The storage options, or None if no storage options were configured.
"""
return await self._inner.latest_storage_options()
async def add(
self,
data: DATA,

View File

@@ -517,19 +517,36 @@ def test_ollama_embedding(tmp_path):
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
def test_voyageai_embedding_function():
voyageai = get_registry().get("voyageai").create(name="voyage-3", max_retries=0)
@pytest.mark.parametrize(
"model_name,expected_dims",
[
("voyage-3", 1024),
("voyage-4", 1024),
("voyage-4-lite", 1024),
("voyage-4-large", 1024),
],
)
def test_voyageai_embedding_function(model_name, expected_dims, tmp_path):
"""Integration test for VoyageAI text embedding models with real API calls."""
voyageai = get_registry().get("voyageai").create(name=model_name, max_retries=0)
class TextModel(LanceModel):
text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect("~/lancedb")
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
assert voyageai.ndims() == expected_dims, (
f"{model_name} should have {expected_dims} dimensions"
)
# Test search functionality
result = tbl.search("hello").limit(1).to_pandas()
assert result["text"][0] == "hello world"
@pytest.mark.slow

View File

@@ -438,11 +438,15 @@ def test_filter_with_splits(mem_db):
row_count = permutation_tbl.count_rows()
assert row_count == 67
data = permutation_tbl.search(None).to_arrow().to_pydict()
# Verify the permutation table only contains row_id and split_id
assert set(permutation_tbl.schema.names) == {"row_id", "split_id"}
row_ids = permutation_tbl.search(None).to_arrow().to_pydict()["row_id"]
data = tbl.take_row_ids(row_ids).to_arrow().to_pydict()
categories = data["category"]
# All categories should be A or B
assert all(cat in ["A", "B"] for cat in categories)
assert all(cat in ("A", "B") for cat in categories)
def test_filter_with_shuffle(mem_db):

View File

@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import json
import sys
from datetime import date, datetime
from typing import List, Optional, Tuple
@@ -20,10 +19,6 @@ from pydantic import BaseModel
from pydantic import Field
@pytest.mark.skipif(
sys.version_info < (3, 9),
reason="using native type alias requires python3.9 or higher",
)
def test_pydantic_to_arrow():
class StructModel(pydantic.BaseModel):
a: str
@@ -83,10 +78,6 @@ def test_pydantic_to_arrow():
assert schema == expect_schema
@pytest.mark.skipif(
sys.version_info < (3, 10),
reason="using | type syntax requires python3.10 or higher",
)
def test_optional_types_py310():
class TestModel(pydantic.BaseModel):
a: str | None
@@ -105,10 +96,233 @@ def test_optional_types_py310():
assert schema == expect_schema
@pytest.mark.skipif(
sys.version_info > (3, 8),
reason="using native type alias requires python3.9 or higher",
)
def test_optional_structs():
class SplitInfo(pydantic.BaseModel):
start_frame: int
end_frame: int
class TestModel(pydantic.BaseModel):
id: str
split: SplitInfo | None = None
schema = pydantic_to_schema(TestModel)
expect_schema = pa.schema(
[
pa.field("id", pa.utf8(), False),
pa.field(
"split",
pa.struct(
[
pa.field("start_frame", pa.int64(), False),
pa.field("end_frame", pa.int64(), False),
]
),
True,
),
]
)
assert schema == expect_schema
def test_optional_struct_list_py310():
class SplitInfo(pydantic.BaseModel):
start_frame: int
end_frame: int
class TestModel(pydantic.BaseModel):
id: str
splits: list[SplitInfo] | None = None
schema = pydantic_to_schema(TestModel)
expect_schema = pa.schema(
[
pa.field("id", pa.utf8(), False),
pa.field(
"splits",
pa.list_(
pa.struct(
[
pa.field("start_frame", pa.int64(), False),
pa.field("end_frame", pa.int64(), False),
]
)
),
True,
),
]
)
assert schema == expect_schema
def test_nested_struct_list():
class SplitInfo(pydantic.BaseModel):
start_frame: int
end_frame: int
class TestModel(pydantic.BaseModel):
id: str
splits: list[SplitInfo]
schema = pydantic_to_schema(TestModel)
expect_schema = pa.schema(
[
pa.field("id", pa.utf8(), False),
pa.field(
"splits",
pa.list_(
pa.struct(
[
pa.field("start_frame", pa.int64(), False),
pa.field("end_frame", pa.int64(), False),
]
)
),
False,
),
]
)
assert schema == expect_schema
def test_nested_struct_list_optional():
class SplitInfo(pydantic.BaseModel):
start_frame: int
end_frame: int
class TestModel(pydantic.BaseModel):
id: str
splits: Optional[list[SplitInfo]] = None
schema = pydantic_to_schema(TestModel)
expect_schema = pa.schema(
[
pa.field("id", pa.utf8(), False),
pa.field(
"splits",
pa.list_(
pa.struct(
[
pa.field("start_frame", pa.int64(), False),
pa.field("end_frame", pa.int64(), False),
]
)
),
True,
),
]
)
assert schema == expect_schema
def test_nested_struct_list_optional_items():
class SplitInfo(pydantic.BaseModel):
start_frame: int
end_frame: int
class TestModel(pydantic.BaseModel):
id: str
splits: list[Optional[SplitInfo]]
schema = pydantic_to_schema(TestModel)
expect_schema = pa.schema(
[
pa.field("id", pa.utf8(), False),
pa.field(
"splits",
pa.list_(
pa.field(
"item",
pa.struct(
[
pa.field("start_frame", pa.int64(), False),
pa.field("end_frame", pa.int64(), False),
]
),
True,
)
),
False,
),
]
)
assert schema == expect_schema
def test_nested_struct_list_optional_container_and_items():
class SplitInfo(pydantic.BaseModel):
start_frame: int
end_frame: int
class TestModel(pydantic.BaseModel):
id: str
splits: Optional[list[Optional[SplitInfo]]] = None
schema = pydantic_to_schema(TestModel)
expect_schema = pa.schema(
[
pa.field("id", pa.utf8(), False),
pa.field(
"splits",
pa.list_(
pa.field(
"item",
pa.struct(
[
pa.field("start_frame", pa.int64(), False),
pa.field("end_frame", pa.int64(), False),
]
),
True,
)
),
True,
),
]
)
assert schema == expect_schema
def test_nested_struct_list_optional_items_pep604():
class SplitInfo(pydantic.BaseModel):
start_frame: int
end_frame: int
class TestModel(pydantic.BaseModel):
id: str
splits: list[SplitInfo | None]
schema = pydantic_to_schema(TestModel)
expect_schema = pa.schema(
[
pa.field("id", pa.utf8(), False),
pa.field(
"splits",
pa.list_(
pa.field(
"item",
pa.struct(
[
pa.field("start_frame", pa.int64(), False),
pa.field("end_frame", pa.int64(), False),
]
),
True,
)
),
False,
),
]
)
assert schema == expect_schema
def test_pydantic_to_arrow_py38():
class StructModel(pydantic.BaseModel):
a: str

View File

@@ -1499,3 +1499,30 @@ def test_search_empty_table(mem_db):
# Search on empty table should return empty results, not crash
results = table.search([1.0, 2.0]).limit(5).to_list()
assert results == []
def test_fast_search(tmp_path):
db = lancedb.connect(tmp_path)
# Generate data matching the async test style
vectors = pa.FixedShapeTensorArray.from_numpy_ndarray(
np.random.rand(256, 32)
).storage
table = db.create_table("test", pa.table({"vector": vectors}))
# FIX: Pass arguments directly instead of using 'config=IvfPq(...)'
table.create_index(vector_column_name="vector", num_partitions=1, num_sub_vectors=1)
# Add data to ensure table has enough segments/rows
table.add(pa.table({"vector": vectors}))
q = [1.0] * 32
# 1. Normal Search -> Should include "LanceScan" (Brute Force / Scan)
plan = table.search(q).explain_plan(True)
assert "LanceScan" in plan
# 2. Fast Search -> Should NOT include "LanceScan" (Uses Index)
plan = table.search(q).fast_search().explain_plan(True)
assert "LanceScan" not in plan

View File

@@ -8,7 +8,7 @@ import http.server
import json
import threading
import time
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch
import uuid
from packaging.version import Version
@@ -601,7 +601,6 @@ def test_head():
def test_query_sync_minimal():
def handler(body):
assert body == {
"distance_type": "l2",
"k": 10,
"prefilter": True,
"refine_factor": None,
@@ -685,7 +684,6 @@ def test_query_sync_maximal():
def test_query_sync_nprobes():
def handler(body):
assert body == {
"distance_type": "l2",
"k": 10,
"prefilter": True,
"fast_search": True,
@@ -715,7 +713,6 @@ def test_query_sync_nprobes():
def test_query_sync_no_max_nprobes():
def handler(body):
assert body == {
"distance_type": "l2",
"k": 10,
"prefilter": True,
"fast_search": True,
@@ -838,7 +835,6 @@ def test_query_sync_hybrid():
else:
# Vector query
assert body == {
"distance_type": "l2",
"k": 42,
"prefilter": True,
"refine_factor": None,
@@ -1203,3 +1199,22 @@ async def test_header_provider_overrides_static_headers():
extra_headers={"X-API-Key": "static-key", "X-Extra": "extra-value"},
) as db:
await db.table_names()
@pytest.mark.parametrize("exception", [KeyboardInterrupt, SystemExit, GeneratorExit])
def test_background_loop_cancellation(exception):
"""Test that BackgroundEventLoop.run() cancels the future on interrupt."""
from lancedb.background_loop import BackgroundEventLoop
mock_future = MagicMock()
mock_future.result.side_effect = exception()
with (
patch.object(BackgroundEventLoop, "__init__", return_value=None),
patch("asyncio.run_coroutine_threadsafe", return_value=mock_future),
):
loop = BackgroundEventLoop()
loop.loop = MagicMock()
with pytest.raises(exception):
loop.run(None)
mock_future.cancel.assert_called_once()

View File

@@ -1880,8 +1880,13 @@ async def test_optimize_delete_unverified(tmp_db_async: AsyncConnection, tmp_pat
],
)
version = await table.version()
path = tmp_path / "test.lance" / "_versions" / f"{version - 1}.manifest"
assert version == 2
# By removing a manifest file, we make the data files we just inserted unverified
version_name = 18446744073709551615 - (version - 1)
path = tmp_path / "test.lance" / "_versions" / f"{version_name:020}.manifest"
os.remove(path)
stats = await table.optimize(delete_unverified=False)
assert stats.prune.old_versions_removed == 0
stats = await table.optimize(

View File

@@ -0,0 +1,108 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
"""Unit tests for VoyageAI embedding function.
These tests verify model registration and configuration without requiring API calls.
"""
import pytest
from unittest.mock import MagicMock, patch
from lancedb.embeddings import get_registry
@pytest.fixture(autouse=True)
def reset_voyageai_client():
"""Reset VoyageAI client before and after each test to avoid state pollution."""
from lancedb.embeddings.voyageai import VoyageAIEmbeddingFunction
VoyageAIEmbeddingFunction.client = None
yield
VoyageAIEmbeddingFunction.client = None
class TestVoyageAIModelRegistration:
"""Tests for VoyageAI model registration and configuration."""
@pytest.fixture
def mock_voyageai_client(self):
"""Mock VoyageAI client to avoid API calls."""
with patch.dict("os.environ", {"VOYAGE_API_KEY": "test-key"}):
with patch("lancedb.embeddings.voyageai.attempt_import_or_raise") as mock:
mock_client = MagicMock()
mock_voyageai = MagicMock()
mock_voyageai.Client.return_value = mock_client
mock.return_value = mock_voyageai
yield mock_client
def test_voyageai_registered(self):
"""Test that VoyageAI is registered in the embedding function registry."""
registry = get_registry()
assert registry.get("voyageai") is not None
@pytest.mark.parametrize(
"model_name,expected_dims",
[
# Voyage-4 series (all 1024 dims)
("voyage-4", 1024),
("voyage-4-lite", 1024),
("voyage-4-large", 1024),
# Voyage-3 series
("voyage-3", 1024),
("voyage-3-lite", 512),
# Domain-specific models
("voyage-finance-2", 1024),
("voyage-multilingual-2", 1024),
("voyage-law-2", 1024),
("voyage-code-2", 1536),
# Multimodal
("voyage-multimodal-3", 1024),
],
)
def test_model_dimensions(self, model_name, expected_dims, mock_voyageai_client):
"""Test that each model returns the correct dimensions."""
registry = get_registry()
func = registry.get("voyageai").create(name=model_name)
assert func.ndims() == expected_dims, (
f"Model {model_name} should have {expected_dims} dimensions"
)
def test_unsupported_model_raises_error(self, mock_voyageai_client):
"""Test that unsupported models raise ValueError."""
registry = get_registry()
func = registry.get("voyageai").create(name="unsupported-model")
with pytest.raises(ValueError, match="not supported"):
func.ndims()
@pytest.mark.parametrize(
"model_name",
[
"voyage-4",
"voyage-4-lite",
"voyage-4-large",
],
)
def test_voyage4_models_are_text_models(self, model_name, mock_voyageai_client):
"""Test that voyage-4 models are classified as text models (not multimodal)."""
registry = get_registry()
func = registry.get("voyageai").create(name=model_name)
assert not func._is_multimodal_model(model_name), (
f"{model_name} should be a text model, not multimodal"
)
def test_voyage4_models_in_text_embedding_list(self, mock_voyageai_client):
"""Test that voyage-4 models are in the text_embedding_models list."""
registry = get_registry()
func = registry.get("voyageai").create(name="voyage-4")
assert "voyage-4" in func.text_embedding_models
assert "voyage-4-lite" in func.text_embedding_models
assert "voyage-4-large" in func.text_embedding_models
def test_voyage4_models_not_in_multimodal_list(self, mock_voyageai_client):
"""Test that voyage-4 models are NOT in the multimodal_embedding_models list."""
registry = get_registry()
func = registry.get("voyageai").create(name="voyage-4")
assert "voyage-4" not in func.multimodal_embedding_models
assert "voyage-4-lite" not in func.multimodal_embedding_models
assert "voyage-4-large" not in func.multimodal_embedding_models

View File

@@ -10,8 +10,7 @@ use arrow::{
use futures::stream::StreamExt;
use lancedb::arrow::SendableRecordBatchStream;
use pyo3::{
exceptions::PyStopAsyncIteration, pyclass, pymethods, Bound, PyAny, PyObject, PyRef, PyResult,
Python,
exceptions::PyStopAsyncIteration, pyclass, pymethods, Bound, Py, PyAny, PyRef, PyResult, Python,
};
use pyo3_async_runtimes::tokio::future_into_py;
@@ -36,8 +35,11 @@ impl RecordBatchStream {
#[pymethods]
impl RecordBatchStream {
#[getter]
pub fn schema(&self, py: Python) -> PyResult<PyObject> {
(*self.schema).clone().into_pyarrow(py)
pub fn schema(&self, py: Python) -> PyResult<Py<PyAny>> {
(*self.schema)
.clone()
.into_pyarrow(py)
.map(|obj| obj.unbind())
}
pub fn __aiter__(self_: PyRef<'_, Self>) -> PyRef<'_, Self> {
@@ -53,7 +55,12 @@ impl RecordBatchStream {
.next()
.await
.ok_or_else(|| PyStopAsyncIteration::new_err(""))?;
Python::with_gil(|py| inner_next.infer_error()?.to_pyarrow(py))
Python::attach(|py| {
inner_next
.infer_error()?
.to_pyarrow(py)
.map(|obj| obj.unbind())
})
})
}
}

View File

@@ -12,7 +12,7 @@ use pyo3::{
exceptions::{PyRuntimeError, PyValueError},
pyclass, pyfunction, pymethods,
types::{PyDict, PyDictMethods},
Bound, FromPyObject, Py, PyAny, PyObject, PyRef, PyResult, Python,
Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
};
use pyo3_async_runtimes::tokio::future_into_py;
@@ -114,7 +114,7 @@ impl Connection {
data: Bound<'_, PyAny>,
namespace: Vec<String>,
storage_options: Option<HashMap<String, String>>,
storage_options_provider: Option<PyObject>,
storage_options_provider: Option<Py<PyAny>>,
location: Option<String>,
) -> PyResult<Bound<'a, PyAny>> {
let inner = self_.get_inner()?.clone();
@@ -152,7 +152,7 @@ impl Connection {
schema: Bound<'_, PyAny>,
namespace: Vec<String>,
storage_options: Option<HashMap<String, String>>,
storage_options_provider: Option<PyObject>,
storage_options_provider: Option<Py<PyAny>>,
location: Option<String>,
) -> PyResult<Bound<'a, PyAny>> {
let inner = self_.get_inner()?.clone();
@@ -187,7 +187,7 @@ impl Connection {
name: String,
namespace: Vec<String>,
storage_options: Option<HashMap<String, String>>,
storage_options_provider: Option<PyObject>,
storage_options_provider: Option<Py<PyAny>>,
index_cache_size: Option<u32>,
location: Option<String>,
) -> PyResult<Bound<'_, PyAny>> {
@@ -307,7 +307,7 @@ impl Connection {
..Default::default()
};
let response = inner.list_namespaces(request).await.infer_error()?;
Python::with_gil(|py| -> PyResult<Py<PyDict>> {
Python::attach(|py| -> PyResult<Py<PyDict>> {
let dict = PyDict::new(py);
dict.set_item("namespaces", response.namespaces)?;
dict.set_item("page_token", response.page_token)?;
@@ -345,7 +345,7 @@ impl Connection {
..Default::default()
};
let response = inner.create_namespace(request).await.infer_error()?;
Python::with_gil(|py| -> PyResult<Py<PyDict>> {
Python::attach(|py| -> PyResult<Py<PyDict>> {
let dict = PyDict::new(py);
dict.set_item("properties", response.properties)?;
Ok(dict.unbind())
@@ -386,7 +386,7 @@ impl Connection {
..Default::default()
};
let response = inner.drop_namespace(request).await.infer_error()?;
Python::with_gil(|py| -> PyResult<Py<PyDict>> {
Python::attach(|py| -> PyResult<Py<PyDict>> {
let dict = PyDict::new(py);
dict.set_item("properties", response.properties)?;
dict.set_item("transaction_id", response.transaction_id)?;
@@ -413,7 +413,7 @@ impl Connection {
..Default::default()
};
let response = inner.describe_namespace(request).await.infer_error()?;
Python::with_gil(|py| -> PyResult<Py<PyDict>> {
Python::attach(|py| -> PyResult<Py<PyDict>> {
let dict = PyDict::new(py);
dict.set_item("properties", response.properties)?;
Ok(dict.unbind())
@@ -443,7 +443,7 @@ impl Connection {
..Default::default()
};
let response = inner.list_tables(request).await.infer_error()?;
Python::with_gil(|py| -> PyResult<Py<PyDict>> {
Python::attach(|py| -> PyResult<Py<PyDict>> {
let dict = PyDict::new(py);
dict.set_item("tables", response.tables)?;
dict.set_item("page_token", response.page_token)?;

View File

@@ -40,7 +40,7 @@ impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
request_id,
source,
status_code,
} => Python::with_gil(|py| {
} => Python::attach(|py| {
let message = err.to_string();
let http_err_cls = py
.import(intern!(py, "lancedb.remote.errors"))?
@@ -75,7 +75,7 @@ impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
max_read_failures,
source,
status_code,
} => Python::with_gil(|py| {
} => Python::attach(|py| {
let cause_err = http_from_rust_error(
py,
source.as_ref(),

View File

@@ -12,7 +12,7 @@ pub struct PyHeaderProvider {
impl Clone for PyHeaderProvider {
fn clone(&self) -> Self {
Python::with_gil(|py| Self {
Python::attach(|py| Self {
provider: self.provider.clone_ref(py),
})
}
@@ -25,7 +25,7 @@ impl PyHeaderProvider {
/// Get headers from the Python provider (internal implementation)
fn get_headers_internal(&self) -> Result<HashMap<String, String>, String> {
Python::with_gil(|py| {
Python::attach(|py| {
// Call the get_headers method
let result = self.provider.call_method0(py, "get_headers");

View File

@@ -281,7 +281,7 @@ impl PyPermutationReader {
let reader = slf.reader.clone();
future_into_py(slf.py(), async move {
let schema = reader.output_schema(selection).await.infer_error()?;
Python::with_gil(|py| schema.to_pyarrow(py))
Python::attach(|py| schema.to_pyarrow(py).map(|obj| obj.unbind()))
})
}

View File

@@ -453,7 +453,7 @@ impl Query {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let schema = inner.output_schema().await.infer_error()?;
Python::with_gil(|py| schema.to_pyarrow(py))
Python::attach(|py| schema.to_pyarrow(py).map(|obj| obj.unbind()))
})
}
@@ -532,7 +532,7 @@ impl TakeQuery {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let schema = inner.output_schema().await.infer_error()?;
Python::with_gil(|py| schema.to_pyarrow(py))
Python::attach(|py| schema.to_pyarrow(py).map(|obj| obj.unbind()))
})
}
@@ -627,7 +627,7 @@ impl FTSQuery {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let schema = inner.output_schema().await.infer_error()?;
Python::with_gil(|py| schema.to_pyarrow(py))
Python::attach(|py| schema.to_pyarrow(py).map(|obj| obj.unbind()))
})
}
@@ -806,7 +806,7 @@ impl VectorQuery {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let schema = inner.output_schema().await.infer_error()?;
Python::with_gil(|py| schema.to_pyarrow(py))
Python::attach(|py| schema.to_pyarrow(py).map(|obj| obj.unbind()))
})
}

View File

@@ -17,20 +17,20 @@ use pyo3::types::PyDict;
/// Internal wrapper around a Python object implementing StorageOptionsProvider
pub struct PyStorageOptionsProvider {
/// The Python object implementing fetch_storage_options()
inner: PyObject,
inner: Py<PyAny>,
}
impl Clone for PyStorageOptionsProvider {
fn clone(&self) -> Self {
Python::with_gil(|py| Self {
Python::attach(|py| Self {
inner: self.inner.clone_ref(py),
})
}
}
impl PyStorageOptionsProvider {
pub fn new(obj: PyObject) -> PyResult<Self> {
Python::with_gil(|py| {
pub fn new(obj: Py<PyAny>) -> PyResult<Self> {
Python::attach(|py| {
// Verify the object has a fetch_storage_options method
if !obj.bind(py).hasattr("fetch_storage_options")? {
return Err(pyo3::exceptions::PyTypeError::new_err(
@@ -60,7 +60,7 @@ impl StorageOptionsProvider for PyStorageOptionsProviderWrapper {
let py_provider = self.py_provider.clone();
tokio::task::spawn_blocking(move || {
Python::with_gil(|py| {
Python::attach(|py| {
// Call the Python fetch_storage_options method
let result = py_provider
.inner
@@ -119,7 +119,7 @@ impl StorageOptionsProvider for PyStorageOptionsProviderWrapper {
}
fn provider_id(&self) -> String {
Python::with_gil(|py| {
Python::attach(|py| {
// Call provider_id() method on the Python object
let obj = self.py_provider.inner.bind(py);
obj.call_method0("provider_id")
@@ -143,7 +143,7 @@ impl std::fmt::Debug for PyStorageOptionsProviderWrapper {
/// This is the main entry point for converting Python StorageOptionsProvider objects
/// to Rust trait objects that can be used by the Lance ecosystem.
pub fn py_object_to_storage_options_provider(
py_obj: PyObject,
py_obj: Py<PyAny>,
) -> PyResult<Arc<dyn StorageOptionsProvider>> {
let py_provider = PyStorageOptionsProvider::new(py_obj)?;
Ok(Arc::new(PyStorageOptionsProviderWrapper::new(py_provider)))

View File

@@ -287,7 +287,7 @@ impl Table {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {
let schema = inner.schema().await.infer_error()?;
Python::with_gil(|py| schema.to_pyarrow(py))
Python::attach(|py| schema.to_pyarrow(py).map(|obj| obj.unbind()))
})
}
@@ -437,7 +437,7 @@ impl Table {
future_into_py(self_.py(), async move {
let stats = inner.index_stats(&index_name).await.infer_error()?;
if let Some(stats) = stats {
Python::with_gil(|py| {
Python::attach(|py| {
let dict = PyDict::new(py);
dict.set_item("num_indexed_rows", stats.num_indexed_rows)?;
dict.set_item("num_unindexed_rows", stats.num_unindexed_rows)?;
@@ -467,7 +467,7 @@ impl Table {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {
let stats = inner.stats().await.infer_error()?;
Python::with_gil(|py| {
Python::attach(|py| {
let dict = PyDict::new(py);
dict.set_item("total_bytes", stats.total_bytes)?;
dict.set_item("num_rows", stats.num_rows)?;
@@ -502,6 +502,20 @@ impl Table {
future_into_py(self_.py(), async move { inner.uri().await.infer_error() })
}
pub fn initial_storage_options(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {
Ok(inner.initial_storage_options().await)
})
}
pub fn latest_storage_options(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {
inner.latest_storage_options().await.infer_error()
})
}
pub fn __repr__(&self) -> String {
match &self.inner {
None => format!("ClosedTable({})", self.name),
@@ -521,7 +535,7 @@ impl Table {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {
let versions = inner.list_versions().await.infer_error()?;
let versions_as_dict = Python::with_gil(|py| {
let versions_as_dict = Python::attach(|py| {
versions
.iter()
.map(|v| {
@@ -872,7 +886,7 @@ impl Tags {
let tags = inner.tags().await.infer_error()?;
let res = tags.list().await.infer_error()?;
Python::with_gil(|py| {
Python::attach(|py| {
let py_dict = PyDict::new(py);
for (key, contents) in res {
let value_dict = PyDict::new(py);

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb"
version = "0.24.1"
version = "0.26.1"
edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true
@@ -25,6 +25,7 @@ datafusion-catalog.workspace = true
datafusion-common.workspace = true
datafusion-execution.workspace = true
datafusion-expr.workspace = true
datafusion-physical-expr.workspace = true
datafusion-physical-plan.workspace = true
datafusion.workspace = true
object_store = { workspace = true }

View File

@@ -251,8 +251,36 @@ impl CreateTableBuilder<false> {
/// Execute the create table operation
pub async fn execute(self) -> Result<Table> {
let parent = self.parent.clone();
let table = parent.create_table(self.request).await?;
Ok(Table::new(table, parent))
let embedding_registry = self.embedding_registry.clone();
let request = self.into_request()?;
Ok(Table::new_with_embedding_registry(
parent.create_table(request).await?,
parent,
embedding_registry,
))
}
fn into_request(self) -> Result<CreateTableRequest> {
if self.embeddings.is_empty() {
return Ok(self.request);
}
let CreateTableData::Empty(table_def) = self.request.data else {
unreachable!("CreateTableBuilder<false> should always have Empty data")
};
let schema = table_def.schema.clone();
let empty_batch = arrow_array::RecordBatch::new_empty(schema.clone());
let reader = Box::new(std::iter::once(Ok(empty_batch)).collect::<Vec<_>>());
let reader = arrow_array::RecordBatchIterator::new(reader.into_iter(), schema);
let with_embeddings = WithEmbeddings::new(reader, self.embeddings);
let table_definition = with_embeddings.table_definition()?;
Ok(CreateTableRequest {
data: CreateTableData::Empty(table_definition),
..self.request
})
}
}
@@ -1692,4 +1720,128 @@ mod tests {
let cloned_count = cloned_table.count_rows(None).await.unwrap();
assert_eq!(source_count, cloned_count);
}
#[tokio::test]
async fn test_create_empty_table_with_embeddings() {
use crate::embeddings::{EmbeddingDefinition, EmbeddingFunction};
use arrow_array::{
Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, StringArray,
};
use std::borrow::Cow;
#[derive(Debug, Clone)]
struct MockEmbedding {
dim: usize,
}
impl EmbeddingFunction for MockEmbedding {
fn name(&self) -> &str {
"test_embedding"
}
fn source_type(&self) -> Result<Cow<'_, DataType>> {
Ok(Cow::Owned(DataType::Utf8))
}
fn dest_type(&self) -> Result<Cow<'_, DataType>> {
Ok(Cow::Owned(DataType::new_fixed_size_list(
DataType::Float32,
self.dim as i32,
true,
)))
}
fn compute_source_embeddings(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
let len = source.len();
let values = vec![1.0f32; len * self.dim];
let values = Arc::new(Float32Array::from(values));
let field = Arc::new(Field::new("item", DataType::Float32, true));
Ok(Arc::new(FixedSizeListArray::new(
field,
self.dim as i32,
values,
None,
)))
}
fn compute_query_embeddings(&self, _input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
unimplemented!()
}
}
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = connect(uri).execute().await.unwrap();
let embed_func = Arc::new(MockEmbedding { dim: 128 });
db.embedding_registry()
.register("test_embedding", embed_func.clone())
.unwrap();
let schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)]));
let ed = EmbeddingDefinition {
source_column: "name".to_owned(),
dest_column: Some("name_embedding".to_owned()),
embedding_name: "test_embedding".to_owned(),
};
let table = db
.create_empty_table("test", schema)
.mode(CreateTableMode::Overwrite)
.add_embedding(ed)
.unwrap()
.execute()
.await
.unwrap();
let table_schema = table.schema().await.unwrap();
assert!(table_schema.column_with_name("name").is_some());
assert!(table_schema.column_with_name("name_embedding").is_some());
let embedding_field = table_schema.field_with_name("name_embedding").unwrap();
assert_eq!(
embedding_field.data_type(),
&DataType::new_fixed_size_list(DataType::Float32, 128, true)
);
let input_schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)]));
let input_batch = RecordBatch::try_new(
input_schema.clone(),
vec![Arc::new(StringArray::from(vec![
Some("Alice"),
Some("Bob"),
Some("Charlie"),
]))],
)
.unwrap();
let input_reader = Box::new(RecordBatchIterator::new(
vec![Ok(input_batch)].into_iter(),
input_schema,
));
table.add(input_reader).execute().await.unwrap();
let results = table
.query()
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(results.len(), 1);
let batch = &results[0];
assert_eq!(batch.num_rows(), 3);
assert!(batch.column_by_name("name_embedding").is_some());
let embedding_col = batch
.column_by_name("name_embedding")
.unwrap()
.as_any()
.downcast_ref::<FixedSizeListArray>()
.unwrap();
assert_eq!(embedding_col.len(), 3);
}
}

View File

@@ -19,7 +19,7 @@ use crate::{
split::{SplitStrategy, Splitter, SPLIT_ID_COLUMN},
util::{rename_column, TemporaryDirectory},
},
query::{ExecutableQuery, QueryBase},
query::{ExecutableQuery, QueryBase, Select},
Error, Result, Table,
};
@@ -27,6 +27,8 @@ pub const SRC_ROW_ID_COL: &str = "row_id";
pub const SPLIT_NAMES_CONFIG_KEY: &str = "split_names";
pub const DEFAULT_MEMORY_LIMIT: usize = 100 * 1024 * 1024;
/// Where to store the permutation table
#[derive(Debug, Clone, Default)]
enum PermutationDestination {
@@ -167,10 +169,20 @@ impl PermutationBuilder {
&self,
data: SendableRecordBatchStream,
) -> Result<SendableRecordBatchStream> {
let memory_limit = std::env::var("LANCEDB_PERM_BUILDER_MEMORY_LIMIT")
.unwrap_or_else(|_| DEFAULT_MEMORY_LIMIT.to_string())
.parse::<usize>()
.unwrap_or_else(|_| {
log::error!(
"Failed to parse LANCEDB_PERM_BUILDER_MEMORY_LIMIT, using default: {}",
DEFAULT_MEMORY_LIMIT
);
DEFAULT_MEMORY_LIMIT
});
let ctx = SessionContext::new_with_config_rt(
SessionConfig::default(),
RuntimeEnvBuilder::new()
.with_memory_limit(100 * 1024 * 1024, 1.0)
.with_memory_limit(memory_limit, 1.0)
.with_disk_manager_builder(
DiskManagerBuilder::default()
.with_mode(self.config.temp_dir.to_disk_manager_mode()),
@@ -232,7 +244,7 @@ impl PermutationBuilder {
/// Builds the permutation table and stores it in the given database.
pub async fn build(self) -> Result<Table> {
// First pass, apply filter and load row ids
let mut rows = self.base_table.query().with_row_id();
let mut rows = self.base_table.query().select(Select::columns(&[ROW_ID]));
if let Some(filter) = &self.config.filter {
rows = rows.only_if(filter);
@@ -321,6 +333,47 @@ mod tests {
use super::*;
#[tokio::test]
async fn test_permutation_table_only_stores_row_id_and_split_id() {
let temp_dir = tempfile::tempdir().unwrap();
let db = connect(temp_dir.path().to_str().unwrap())
.execute()
.await
.unwrap();
let initial_data = lance_datagen::gen_batch()
.col("col_a", lance_datagen::array::step::<Int32Type>())
.col("col_b", lance_datagen::array::step::<Int32Type>())
.into_ldb_stream(RowCount::from(100), BatchCount::from(10));
let data_table = db
.create_table_streaming("base_tbl", initial_data)
.execute()
.await
.unwrap();
let permutation_table = PermutationBuilder::new(data_table.clone())
.with_split_strategy(
SplitStrategy::Sequential {
sizes: SplitSizes::Percentages(vec![0.5, 0.5]),
},
None,
)
.with_filter("col_a > 57".to_string())
.build()
.await
.unwrap();
let schema = permutation_table.schema().await.unwrap();
let field_names: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect();
assert_eq!(
field_names,
vec!["row_id", "split_id"],
"Permutation table should only contain row_id and split_id columns, but found: {:?}",
field_names,
);
}
#[tokio::test]
async fn test_permutation_builder() {
let temp_dir = tempfile::tempdir().unwrap();
@@ -352,8 +405,6 @@ mod tests {
.await
.unwrap();
println!("permutation_table: {:?}", permutation_table);
// Potentially brittle seed-dependent values below
assert_eq!(permutation_table.count_rows(None).await.unwrap(), 330);
assert_eq!(

View File

@@ -12,6 +12,8 @@ use datafusion_common::hash_utils::create_hashes;
use futures::{StreamExt, TryStreamExt};
use lance_arrow::SchemaExt;
use lance_core::ROW_ID;
use crate::{
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
dataloader::{
@@ -360,11 +362,15 @@ impl Splitter {
pub fn project(&self, query: Query) -> Query {
match &self.strategy {
SplitStrategy::Calculated { calculation } => query.select(Select::Dynamic(vec![(
SPLIT_ID_COLUMN.to_string(),
calculation.clone(),
)])),
SplitStrategy::Hash { columns, .. } => query.select(Select::Columns(columns.clone())),
SplitStrategy::Calculated { calculation } => query.select(Select::Dynamic(vec![
(SPLIT_ID_COLUMN.to_string(), calculation.clone()),
(ROW_ID.to_string(), ROW_ID.to_string()),
])),
SplitStrategy::Hash { columns, .. } => {
let mut cols = columns.clone();
cols.push(ROW_ID.to_string());
query.select(Select::Columns(cols))
}
_ => query,
}
}

View File

@@ -1,6 +1,8 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
pub mod insert;
use crate::index::Index;
use crate::index::IndexStatistics;
use crate::query::{QueryFilter, QueryRequest, Select, VectorQueryRequest};
@@ -468,7 +470,9 @@ impl<S: HttpSend> RemoteTable<S> {
self.apply_query_params(&mut body, &query.base)?;
// Apply general parameters, before we dispatch based on number of query vectors.
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
if let Some(distance_type) = query.distance_type {
body["distance_type"] = serde_json::json!(distance_type);
}
// In 0.23.1 we migrated from `nprobes` to `minimum_nprobes` and `maximum_nprobes`.
// Old client / new server: since minimum_nprobes is missing, fallback to nprobes
// New client / old server: old server will only see nprobes, make sure to set both
@@ -1493,6 +1497,14 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
None
}
async fn initial_storage_options(&self) -> Option<HashMap<String, String>> {
None
}
async fn latest_storage_options(&self) -> Result<Option<HashMap<String, String>>> {
Ok(None)
}
async fn stats(&self) -> Result<TableStatistics> {
let request = self
.client
@@ -1508,6 +1520,21 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
})?;
Ok(stats)
}
async fn create_insert_exec(
&self,
input: Arc<dyn ExecutionPlan>,
write_params: lance::dataset::WriteParams,
) -> Result<Arc<dyn ExecutionPlan>> {
let overwrite = matches!(write_params.mode, lance::dataset::WriteMode::Overwrite);
Ok(Arc::new(insert::RemoteInsertExec::new(
self.name.clone(),
self.identifier.clone(),
self.client.clone(),
input,
overwrite,
)))
}
}
#[derive(Serialize)]
@@ -2230,7 +2257,6 @@ mod tests {
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
let mut expected_body = serde_json::json!({
"prefilter": true,
"distance_type": "l2",
"nprobes": 20,
"minimum_nprobes": 20,
"maximum_nprobes": 20,

View File

@@ -0,0 +1,438 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! DataFusion ExecutionPlan for inserting data into remote LanceDB tables.
use std::any::Any;
use std::sync::{Arc, Mutex};
use arrow_array::{ArrayRef, RecordBatch, UInt64Array};
use arrow_ipc::CompressionType;
use arrow_schema::ArrowError;
use datafusion_common::{DataFusionError, Result as DataFusionResult};
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_physical_expr::EquivalenceProperties;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
use futures::StreamExt;
use http::header::CONTENT_TYPE;
use crate::remote::client::{HttpSend, RestfulLanceDbClient, Sender};
use crate::remote::table::RemoteTable;
use crate::remote::ARROW_STREAM_CONTENT_TYPE;
use crate::table::datafusion::insert::COUNT_SCHEMA;
use crate::table::AddResult;
use crate::Error;
/// ExecutionPlan for inserting data into a remote LanceDB table.
///
/// This plan:
/// 1. Requires single partition (no parallel remote inserts yet)
/// 2. Streams data as Arrow IPC to `/v1/table/{id}/insert/` endpoint
/// 3. Stores AddResult for retrieval after execution
#[derive(Debug)]
pub struct RemoteInsertExec<S: HttpSend = Sender> {
table_name: String,
identifier: String,
client: RestfulLanceDbClient<S>,
input: Arc<dyn ExecutionPlan>,
overwrite: bool,
properties: PlanProperties,
add_result: Arc<Mutex<Option<AddResult>>>,
}
impl<S: HttpSend + 'static> RemoteInsertExec<S> {
/// Create a new RemoteInsertExec.
pub fn new(
table_name: String,
identifier: String,
client: RestfulLanceDbClient<S>,
input: Arc<dyn ExecutionPlan>,
overwrite: bool,
) -> Self {
let schema = COUNT_SCHEMA.clone();
let properties = PlanProperties::new(
EquivalenceProperties::new(schema),
datafusion_physical_plan::Partitioning::UnknownPartitioning(1),
datafusion_physical_plan::execution_plan::EmissionType::Final,
datafusion_physical_plan::execution_plan::Boundedness::Bounded,
);
Self {
table_name,
identifier,
client,
input,
overwrite,
properties,
add_result: Arc::new(Mutex::new(None)),
}
}
/// Get the add result after execution.
// TODO: this will be used when we wire this up to Table::add().
#[allow(dead_code)]
pub fn add_result(&self) -> Option<AddResult> {
self.add_result.lock().unwrap().clone()
}
fn stream_as_body(data: SendableRecordBatchStream) -> DataFusionResult<reqwest::Body> {
let options = arrow_ipc::writer::IpcWriteOptions::default()
.try_with_compression(Some(CompressionType::LZ4_FRAME))?;
let writer = arrow_ipc::writer::StreamWriter::try_new_with_options(
Vec::new(),
&data.schema(),
options,
)?;
let stream = futures::stream::try_unfold((data, writer), move |(mut data, mut writer)| {
async move {
match data.next().await {
Some(Ok(batch)) => {
writer.write(&batch)?;
let buffer = std::mem::take(writer.get_mut());
Ok(Some((buffer, (data, writer))))
}
Some(Err(e)) => Err(e),
None => {
if let Err(ArrowError::IpcError(_msg)) = writer.finish() {
// Will error if already closed.
return Ok(None);
};
let buffer = std::mem::take(writer.get_mut());
Ok(Some((buffer, (data, writer))))
}
}
}
});
Ok(reqwest::Body::wrap_stream(stream))
}
}
impl<S: HttpSend + 'static> DisplayAs for RemoteInsertExec<S> {
fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(
f,
"RemoteInsertExec: table={}, overwrite={}",
self.table_name, self.overwrite
)
}
DisplayFormatType::TreeRender => {
write!(f, "RemoteInsertExec")
}
}
}
}
impl<S: HttpSend + 'static> ExecutionPlan for RemoteInsertExec<S> {
fn name(&self) -> &str {
Self::static_name()
}
fn as_any(&self) -> &dyn Any {
self
}
fn properties(&self) -> &PlanProperties {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}
fn maintains_input_order(&self) -> Vec<bool> {
vec![false]
}
fn required_input_distribution(&self) -> Vec<datafusion_physical_plan::Distribution> {
// Until we have a separate commit endpoint, we need to do all inserts in a single partition
vec![datafusion_physical_plan::Distribution::SinglePartition]
}
fn benefits_from_input_partitioning(&self) -> Vec<bool> {
vec![false]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
if children.len() != 1 {
return Err(DataFusionError::Internal(
"RemoteInsertExec requires exactly one child".to_string(),
));
}
Ok(Arc::new(Self::new(
self.table_name.clone(),
self.identifier.clone(),
self.client.clone(),
children[0].clone(),
self.overwrite,
)))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> DataFusionResult<SendableRecordBatchStream> {
if partition != 0 {
return Err(DataFusionError::Internal(
"RemoteInsertExec only supports single partition execution".to_string(),
));
}
let input_stream = self.input.execute(0, context)?;
let client = self.client.clone();
let identifier = self.identifier.clone();
let overwrite = self.overwrite;
let add_result = self.add_result.clone();
let table_name = self.table_name.clone();
let stream = futures::stream::once(async move {
let mut request = client
.post(&format!("/v1/table/{}/insert/", identifier))
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE);
if overwrite {
request = request.query(&[("mode", "overwrite")]);
}
let body = Self::stream_as_body(input_stream)?;
let request = request.body(body);
let (request_id, response) = client
.send(request)
.await
.map_err(|e| DataFusionError::External(Box::new(e)))?;
let response =
RemoteTable::<Sender>::handle_table_not_found(&table_name, response, &request_id)
.await
.map_err(|e| DataFusionError::External(Box::new(e)))?;
let response = client
.check_response(&request_id, response)
.await
.map_err(|e| DataFusionError::External(Box::new(e)))?;
let body_text = response.text().await.map_err(|e| {
DataFusionError::External(Box::new(Error::Http {
source: Box::new(e),
request_id: request_id.clone(),
status_code: None,
}))
})?;
let parsed_result = if body_text.trim().is_empty() {
// Backward compatible with old servers
AddResult { version: 0 }
} else {
serde_json::from_str(&body_text).map_err(|e| {
DataFusionError::External(Box::new(Error::Http {
source: format!("Failed to parse add response: {}", e).into(),
request_id: request_id.clone(),
status_code: None,
}))
})?
};
{
let mut res_lock = add_result.lock().map_err(|_| {
DataFusionError::Execution("Failed to acquire lock for add_result".to_string())
})?;
*res_lock = Some(parsed_result);
}
// Return a single batch with count 0 (actual count is tracked in add_result)
let count_array: ArrayRef = Arc::new(UInt64Array::from(vec![0u64]));
let batch = RecordBatch::try_new(COUNT_SCHEMA.clone(), vec![count_array])?;
Ok::<_, DataFusionError>(batch)
});
Ok(Box::pin(RecordBatchStreamAdapter::new(
COUNT_SCHEMA.clone(),
stream,
)))
}
}
#[cfg(test)]
mod tests {
use arrow_array::record_batch;
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use datafusion::prelude::SessionContext;
use datafusion_catalog::MemTable;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use crate::remote::ARROW_STREAM_CONTENT_TYPE;
use crate::table::datafusion::BaseTableAdapter;
use crate::Table;
fn schema_json() -> &'static str {
r#"{"fields": [{"name": "id", "type": {"type": "int32"}, "nullable": true}]}"#
}
#[tokio::test]
async fn test_remote_insert_exec_execute_empty() {
let request_count = Arc::new(AtomicUsize::new(0));
let request_count_clone = request_count.clone();
let table = Table::new_with_handler("my_table", move |request| {
let path = request.url().path();
if path == "/v1/table/my_table/describe/" {
// Return schema for BaseTableAdapter::try_new
return http::Response::builder()
.status(200)
.body(format!(r#"{{"version": 1, "schema": {}}}"#, schema_json()))
.unwrap();
}
if path == "/v1/table/my_table/insert/" {
assert_eq!(request.method(), "POST");
assert_eq!(
request.headers().get("Content-Type").unwrap(),
ARROW_STREAM_CONTENT_TYPE
);
request_count_clone.fetch_add(1, Ordering::SeqCst);
return http::Response::builder()
.status(200)
.body(r#"{"version": 2}"#.to_string())
.unwrap();
}
panic!("Unexpected request path: {}", path);
});
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
"id",
DataType::Int32,
true,
)]));
// Create empty MemTable (no batches)
let source_table = MemTable::try_new(schema, vec![vec![]]).unwrap();
let ctx = SessionContext::new();
// Register the remote table as insert target
let provider = BaseTableAdapter::try_new(table.base_table().clone())
.await
.unwrap();
ctx.register_table("my_table", Arc::new(provider)).unwrap();
// Register empty source
ctx.register_table("empty_source", Arc::new(source_table))
.unwrap();
// Execute the INSERT
ctx.sql("INSERT INTO my_table SELECT * FROM empty_source")
.await
.unwrap()
.collect()
.await
.unwrap();
// Verify: should have made exactly one HTTP request even with empty input
assert_eq!(request_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_remote_insert_exec_multi_partition() {
let request_count = Arc::new(AtomicUsize::new(0));
let request_count_clone = request_count.clone();
let table = Table::new_with_handler("my_table", move |request| {
let path = request.url().path();
if path == "/v1/table/my_table/describe/" {
// Return schema for BaseTableAdapter::try_new
return http::Response::builder()
.status(200)
.body(format!(r#"{{"version": 1, "schema": {}}}"#, schema_json()))
.unwrap();
}
if path == "/v1/table/my_table/insert/" {
assert_eq!(request.method(), "POST");
assert_eq!(
request.headers().get("Content-Type").unwrap(),
ARROW_STREAM_CONTENT_TYPE
);
request_count_clone.fetch_add(1, Ordering::SeqCst);
return http::Response::builder()
.status(200)
.body(r#"{"version": 2}"#.to_string())
.unwrap();
}
panic!("Unexpected request path: {}", path);
});
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
"id",
DataType::Int32,
true,
)]));
// Create MemTable with multiple partitions and multiple batches
let source_table = MemTable::try_new(
schema,
vec![
// Partition 0
vec![
record_batch!(("id", Int32, [1, 2])).unwrap(),
record_batch!(("id", Int32, [3, 4])).unwrap(),
],
// Partition 1
vec![record_batch!(("id", Int32, [5, 6, 7])).unwrap()],
// Partition 2
vec![record_batch!(("id", Int32, [8])).unwrap()],
],
)
.unwrap();
let ctx = SessionContext::new();
// Register the remote table as insert target
let provider = BaseTableAdapter::try_new(table.base_table().clone())
.await
.unwrap();
ctx.register_table("my_table", Arc::new(provider)).unwrap();
// Register multi-partition source
ctx.register_table("multi_partition_source", Arc::new(source_table))
.unwrap();
// Get the physical plan and verify it includes a repartition to 1
let df = ctx
.sql("INSERT INTO my_table SELECT * FROM multi_partition_source")
.await
.unwrap();
let plan = df.clone().create_physical_plan().await.unwrap();
let plan_str = datafusion::physical_plan::displayable(plan.as_ref())
.indent(true)
.to_string();
// The plan should include a CoalescePartitionsExec to merge partitions
assert!(
plan_str.contains("CoalescePartitionsExec"),
"Expected CoalescePartitionsExec in plan:\n{}",
plan_str
);
// Execute the INSERT
df.collect().await.unwrap();
// Verify: should have made exactly one HTTP request despite multiple input partitions
assert_eq!(request_count.load(Ordering::SeqCst), 1);
}
}

View File

@@ -23,9 +23,7 @@ pub use lance::dataset::ColumnAlteration;
pub use lance::dataset::NewColumnTransform;
pub use lance::dataset::ReadParams;
pub use lance::dataset::Version;
use lance::dataset::{
InsertBuilder, UpdateBuilder as LanceUpdateBuilder, WhenMatched, WriteMode, WriteParams,
};
use lance::dataset::{InsertBuilder, WhenMatched, WriteMode, WriteParams};
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
use lance::index::vector::utils::infer_vector_dim;
use lance::index::vector::VectorIndexParams;
@@ -79,10 +77,14 @@ use self::merge::MergeInsertBuilder;
pub mod datafusion;
pub(crate) mod dataset;
pub mod delete;
pub mod merge;
pub mod schema_evolution;
pub mod update;
use crate::index::waiter::wait_for_index;
pub use chrono::Duration;
pub use delete::DeleteResult;
use futures::future::{join_all, Either};
pub use lance::dataset::optimize::CompactionOptions;
pub use lance::dataset::refs::{TagContents, Tags as LanceTags};
@@ -90,7 +92,9 @@ pub use lance::dataset::scanner::DatasetRecordBatchStream;
use lance::dataset::statistics::DatasetStatisticsExt;
use lance_index::frag_reuse::FRAG_REUSE_INDEX_NAME;
pub use lance_index::optimize::OptimizeOptions;
pub use schema_evolution::{AddColumnsResult, AlterColumnsResult, DropColumnsResult};
use serde_with::skip_serializing_none;
pub use update::{UpdateBuilder, UpdateResult};
/// Defines the type of column
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -327,72 +331,6 @@ impl<T: IntoArrow> AddDataBuilder<T> {
}
}
/// A builder for configuring an [`Table::update`] operation
#[derive(Debug, Clone)]
pub struct UpdateBuilder {
parent: Arc<dyn BaseTable>,
pub(crate) filter: Option<String>,
pub(crate) columns: Vec<(String, String)>,
}
impl UpdateBuilder {
fn new(parent: Arc<dyn BaseTable>) -> Self {
Self {
parent,
filter: None,
columns: Vec::new(),
}
}
/// Limits the update operation to rows matching the given filter
///
/// If a row does not match the filter then it will be left unchanged.
pub fn only_if(mut self, filter: impl Into<String>) -> Self {
self.filter = Some(filter.into());
self
}
/// Specifies a column to update
///
/// This method may be called multiple times to update multiple columns
///
/// The `update_expr` should be an SQL expression explaining how to calculate
/// the new value for the column. The expression will be evaluated against the
/// previous row's value.
///
/// # Examples
///
/// ```
/// # use lancedb::Table;
/// # async fn doctest_helper(tbl: Table) {
/// let mut operation = tbl.update();
/// // Increments the `bird_count` value by 1
/// operation = operation.column("bird_count", "bird_count + 1");
/// operation.execute().await.unwrap();
/// # }
/// ```
pub fn column(
mut self,
column_name: impl Into<String>,
update_expr: impl Into<String>,
) -> Self {
self.columns.push((column_name.into(), update_expr.into()));
self
}
/// Executes the update operation.
/// Returns the update result
pub async fn execute(self) -> Result<UpdateResult> {
if self.columns.is_empty() {
Err(Error::InvalidInput {
message: "at least one column must be specified in an update operation".to_string(),
})
} else {
self.parent.clone().update(self).await
}
}
}
/// Filters that can be used to limit the rows returned by a query
pub enum Filter {
/// A SQL filter string
@@ -426,17 +364,6 @@ pub trait Tags: Send + Sync {
async fn update(&mut self, tag: &str, version: u64) -> Result<()>;
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct UpdateResult {
#[serde(default)]
pub rows_updated: u64,
// The commit version associated with the operation.
// A version of `0` indicates compatibility with legacy servers that do not return
/// a commit version.
#[serde(default)]
pub version: u64,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct AddResult {
// The commit version associated with the operation.
@@ -446,15 +373,6 @@ pub struct AddResult {
pub version: u64,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct DeleteResult {
// The commit version associated with the operation.
// A version of `0` indicates compatibility with legacy servers that do not return
/// a commit version.
#[serde(default)]
pub version: u64,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct MergeResult {
// The commit version associated with the operation.
@@ -480,33 +398,6 @@ pub struct MergeResult {
pub num_attempts: u32,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct AddColumnsResult {
// The commit version associated with the operation.
// A version of `0` indicates compatibility with legacy servers that do not return
/// a commit version.
#[serde(default)]
pub version: u64,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct AlterColumnsResult {
// The commit version associated with the operation.
// A version of `0` indicates compatibility with legacy servers that do not return
/// a commit version.
#[serde(default)]
pub version: u64,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct DropColumnsResult {
// The commit version associated with the operation.
// A version of `0` indicates compatibility with legacy servers that do not return
/// a commit version.
#[serde(default)]
pub version: u64,
}
/// A trait for anything "table-like". This is used for both native tables (which target
/// Lance datasets) and remote tables (which target LanceDB cloud)
///
@@ -611,7 +502,17 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
/// Get the table URI (storage location)
async fn uri(&self) -> Result<String>;
/// Get the storage options used when opening this table, if any.
#[deprecated(since = "0.25.0", note = "Use initial_storage_options() instead")]
async fn storage_options(&self) -> Option<HashMap<String, String>>;
/// Get the initial storage options that were passed in when opening this table.
///
/// For dynamically refreshed options (e.g., credential vending), use [`Self::latest_storage_options`].
async fn initial_storage_options(&self) -> Option<HashMap<String, String>>;
/// Get the latest storage options, refreshing from provider if configured.
///
/// Returns `Ok(Some(options))` if storage options are available (static or refreshed),
/// `Ok(None)` if no storage options were configured, or `Err(...)` if refresh failed.
async fn latest_storage_options(&self) -> Result<Option<HashMap<String, String>>>;
/// Poll until the columns are fully indexed. Will return Error::Timeout if the columns
/// are not fully indexed within the timeout.
async fn wait_for_index(
@@ -621,6 +522,19 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
) -> Result<()>;
/// Get statistics on the table
async fn stats(&self) -> Result<TableStatistics>;
/// Create an ExecutionPlan for inserting data into the table.
///
/// This is used by the DataFusion TableProvider implementation to support
/// INSERT INTO statements.
async fn create_insert_exec(
&self,
_input: Arc<dyn datafusion_physical_plan::ExecutionPlan>,
_write_params: WriteParams,
) -> Result<Arc<dyn datafusion_physical_plan::ExecutionPlan>> {
Err(Error::NotSupported {
message: "create_insert_exec not implemented".to_string(),
})
}
}
/// A Table is a collection of strong typed Rows.
@@ -1328,10 +1242,32 @@ impl Table {
/// Get the storage options used when opening this table, if any.
///
/// Warning: This is an internal API and the return value is subject to change.
#[deprecated(since = "0.25.0", note = "Use initial_storage_options() instead")]
pub async fn storage_options(&self) -> Option<HashMap<String, String>> {
#[allow(deprecated)]
self.inner.storage_options().await
}
/// Get the initial storage options that were passed in when opening this table.
///
/// For dynamically refreshed options (e.g., credential vending), use [`Self::latest_storage_options`].
///
/// Warning: This is an internal API and the return value is subject to change.
pub async fn initial_storage_options(&self) -> Option<HashMap<String, String>> {
self.inner.initial_storage_options().await
}
/// Get the latest storage options, refreshing from provider if configured.
///
/// This method is useful for credential vending scenarios where storage options
/// may be refreshed dynamically. If no dynamic provider is configured, this
/// returns the initial static options.
///
/// Warning: This is an internal API and the return value is subject to change.
pub async fn latest_storage_options(&self) -> Result<Option<HashMap<String, String>>> {
self.inner.latest_storage_options().await
}
/// Get statistics about an index.
/// Returns None if the index does not exist.
pub async fn index_stats(
@@ -1425,7 +1361,9 @@ impl Table {
})
.collect::<Vec<_>>();
let unioned = Arc::new(UnionExec::new(projected_plans));
let unioned = UnionExec::try_new(projected_plans).map_err(|err| Error::Runtime {
message: err.to_string(),
})?;
// We require 1 partition in the final output
let repartitioned = RepartitionExec::try_new(
unioned,
@@ -2802,25 +2740,8 @@ impl BaseTable for NativeTable {
}
async fn update(&self, update: UpdateBuilder) -> Result<UpdateResult> {
let dataset = self.dataset.get().await?.clone();
let mut builder = LanceUpdateBuilder::new(Arc::new(dataset));
if let Some(predicate) = update.filter {
builder = builder.update_where(&predicate)?;
}
for (column, value) in update.columns {
builder = builder.set(column, &value)?;
}
let operation = builder.build()?;
let res = operation.execute().await?;
self.dataset
.set_latest(res.new_dataset.as_ref().clone())
.await;
Ok(UpdateResult {
rows_updated: res.rows_updated,
version: res.new_dataset.version().version,
})
// Delegate to the submodule implementation
update::execute_update(self, update).await
}
async fn create_plan(
@@ -3078,11 +2999,8 @@ impl BaseTable for NativeTable {
/// Delete rows from the table
async fn delete(&self, predicate: &str) -> Result<DeleteResult> {
let mut dataset = self.dataset.get_mut().await?;
dataset.delete(predicate).await?;
Ok(DeleteResult {
version: dataset.version().version,
})
// Delegate to the submodule implementation
delete::execute_delete(self, predicate).await
}
async fn tags(&self) -> Result<Box<dyn Tags + '_>> {
@@ -3148,27 +3066,15 @@ impl BaseTable for NativeTable {
transforms: NewColumnTransform,
read_columns: Option<Vec<String>>,
) -> Result<AddColumnsResult> {
let mut dataset = self.dataset.get_mut().await?;
dataset.add_columns(transforms, read_columns, None).await?;
Ok(AddColumnsResult {
version: dataset.version().version,
})
schema_evolution::execute_add_columns(self, transforms, read_columns).await
}
async fn alter_columns(&self, alterations: &[ColumnAlteration]) -> Result<AlterColumnsResult> {
let mut dataset = self.dataset.get_mut().await?;
dataset.alter_columns(alterations).await?;
Ok(AlterColumnsResult {
version: dataset.version().version,
})
schema_evolution::execute_alter_columns(self, alterations).await
}
async fn drop_columns(&self, columns: &[&str]) -> Result<DropColumnsResult> {
let mut dataset = self.dataset.get_mut().await?;
dataset.drop_columns(columns).await?;
Ok(DropColumnsResult {
version: dataset.version().version,
})
schema_evolution::execute_drop_columns(self, columns).await
}
async fn list_indices(&self) -> Result<Vec<IndexConfig>> {
@@ -3231,6 +3137,10 @@ impl BaseTable for NativeTable {
}
async fn storage_options(&self) -> Option<HashMap<String, String>> {
self.initial_storage_options().await
}
async fn initial_storage_options(&self) -> Option<HashMap<String, String>> {
self.dataset
.get()
.await
@@ -3238,6 +3148,11 @@ impl BaseTable for NativeTable {
.and_then(|dataset| dataset.initial_storage_options().cloned())
}
async fn latest_storage_options(&self) -> Result<Option<HashMap<String, String>>> {
let dataset = self.dataset.get().await?;
Ok(dataset.latest_storage_options().await?.map(|o| o.0))
}
async fn index_stats(&self, index_name: &str) -> Result<Option<IndexStatistics>> {
let stats = match self
.dataset
@@ -3351,6 +3266,21 @@ impl BaseTable for NativeTable {
};
Ok(stats)
}
async fn create_insert_exec(
&self,
input: Arc<dyn datafusion_physical_plan::ExecutionPlan>,
write_params: WriteParams,
) -> Result<Arc<dyn datafusion_physical_plan::ExecutionPlan>> {
let ds = self.dataset.get().await?;
let dataset = Arc::new((*ds).clone());
Ok(Arc::new(datafusion::insert::InsertExec::new(
self.dataset.clone(),
dataset,
input,
write_params,
)))
}
}
#[skip_serializing_none]
@@ -3406,15 +3336,12 @@ mod tests {
use arrow_array::{
builder::{ListBuilder, StringBuilder},
Array, BooleanArray, Date32Array, FixedSizeListArray, Float32Array, Float64Array,
Int32Array, Int64Array, LargeStringArray, RecordBatch, RecordBatchIterator,
RecordBatchReader, StringArray, TimestampMillisecondArray, TimestampNanosecondArray,
UInt32Array,
Array, BooleanArray, FixedSizeListArray, Float32Array, Int32Array, LargeStringArray,
RecordBatch, RecordBatchIterator, RecordBatchReader, StringArray,
};
use arrow_array::{BinaryArray, LargeBinaryArray};
use arrow_data::ArrayDataBuilder;
use arrow_schema::{DataType, Field, Schema, TimeUnit};
use futures::TryStreamExt;
use arrow_schema::{DataType, Field, Schema};
use lance::dataset::WriteMode;
use lance::io::{ObjectStoreParams, WrappingObjectStore};
use lance::Dataset;
@@ -3426,7 +3353,6 @@ mod tests {
use crate::connection::ConnectBuilder;
use crate::index::scalar::{BTreeIndexBuilder, BitmapIndexBuilder};
use crate::index::vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder};
use crate::query::{ExecutableQuery, QueryBase};
#[tokio::test]
async fn test_open() {
@@ -3648,306 +3574,6 @@ mod tests {
assert_eq!(table.name(), "test");
}
#[tokio::test]
async fn test_update_with_predicate() {
let tmp_dir = tempdir().unwrap();
let dataset_path = tmp_dir.path().join("test.lance");
let uri = dataset_path.to_str().unwrap();
let conn = connect(uri)
.read_consistency_interval(Duration::from_secs(0))
.execute()
.await
.unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
]));
let record_batch_iter = RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(0..10)),
Arc::new(StringArray::from_iter_values(vec![
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
])),
],
)
.unwrap()]
.into_iter()
.map(Ok),
schema.clone(),
);
let table = conn
.create_table("my_table", record_batch_iter)
.execute()
.await
.unwrap();
table
.update()
.only_if("id > 5")
.column("name", "'foo'")
.execute()
.await
.unwrap();
let mut batches = table
.query()
.select(Select::columns(&["id", "name"]))
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
while let Some(batch) = batches.pop() {
let ids = batch
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap()
.iter()
.collect::<Vec<_>>();
let names = batch
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.iter()
.collect::<Vec<_>>();
for (i, name) in names.iter().enumerate() {
let id = ids[i].unwrap();
let name = name.unwrap();
if id > 5 {
assert_eq!(name, "foo");
} else {
assert_eq!(name, &format!("{}", (b'a' + id as u8) as char));
}
}
}
}
#[tokio::test]
async fn test_update_all_types() {
let tmp_dir = tempdir().unwrap();
let dataset_path = tmp_dir.path().join("test.lance");
let uri = dataset_path.to_str().unwrap();
let conn = connect(uri)
.read_consistency_interval(Duration::from_secs(0))
.execute()
.await
.unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("int32", DataType::Int32, false),
Field::new("int64", DataType::Int64, false),
Field::new("uint32", DataType::UInt32, false),
Field::new("string", DataType::Utf8, false),
Field::new("large_string", DataType::LargeUtf8, false),
Field::new("float32", DataType::Float32, false),
Field::new("float64", DataType::Float64, false),
Field::new("bool", DataType::Boolean, false),
Field::new("date32", DataType::Date32, false),
Field::new(
"timestamp_ns",
DataType::Timestamp(TimeUnit::Nanosecond, None),
false,
),
Field::new(
"timestamp_ms",
DataType::Timestamp(TimeUnit::Millisecond, None),
false,
),
Field::new(
"vec_f32",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2),
false,
),
Field::new(
"vec_f64",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, true)), 2),
false,
),
]));
let record_batch_iter = RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(0..10)),
Arc::new(Int64Array::from_iter_values(0..10)),
Arc::new(UInt32Array::from_iter_values(0..10)),
Arc::new(StringArray::from_iter_values(vec![
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
])),
Arc::new(LargeStringArray::from_iter_values(vec![
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
])),
Arc::new(Float32Array::from_iter_values((0..10).map(|i| i as f32))),
Arc::new(Float64Array::from_iter_values((0..10).map(|i| i as f64))),
Arc::new(Into::<BooleanArray>::into(vec![
true, false, true, false, true, false, true, false, true, false,
])),
Arc::new(Date32Array::from_iter_values(0..10)),
Arc::new(TimestampNanosecondArray::from_iter_values(0..10)),
Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
Arc::new(
create_fixed_size_list(
Float32Array::from_iter_values((0..20).map(|i| i as f32)),
2,
)
.unwrap(),
),
Arc::new(
create_fixed_size_list(
Float64Array::from_iter_values((0..20).map(|i| i as f64)),
2,
)
.unwrap(),
),
],
)
.unwrap()]
.into_iter()
.map(Ok),
schema.clone(),
);
let table = conn
.create_table("my_table", record_batch_iter)
.execute()
.await
.unwrap();
// check it can do update for each type
let updates: Vec<(&str, &str)> = vec![
("string", "'foo'"),
("large_string", "'large_foo'"),
("int32", "1"),
("int64", "1"),
("uint32", "1"),
("float32", "1.0"),
("float64", "1.0"),
("bool", "true"),
("date32", "1"),
("timestamp_ns", "1"),
("timestamp_ms", "1"),
("vec_f32", "[1.0, 1.0]"),
("vec_f64", "[1.0, 1.0]"),
];
let mut update_op = table.update();
for (column, value) in updates {
update_op = update_op.column(column, value);
}
update_op.execute().await.unwrap();
let mut batches = table
.query()
.select(Select::columns(&[
"string",
"large_string",
"int32",
"int64",
"uint32",
"float32",
"float64",
"bool",
"date32",
"timestamp_ns",
"timestamp_ms",
"vec_f32",
"vec_f64",
]))
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch = batches.pop().unwrap();
macro_rules! assert_column {
($column:expr, $array_type:ty, $expected:expr) => {
let array = $column
.as_any()
.downcast_ref::<$array_type>()
.unwrap()
.iter()
.collect::<Vec<_>>();
for v in array {
assert_eq!(v, Some($expected));
}
};
}
assert_column!(batch.column(0), StringArray, "foo");
assert_column!(batch.column(1), LargeStringArray, "large_foo");
assert_column!(batch.column(2), Int32Array, 1);
assert_column!(batch.column(3), Int64Array, 1);
assert_column!(batch.column(4), UInt32Array, 1);
assert_column!(batch.column(5), Float32Array, 1.0);
assert_column!(batch.column(6), Float64Array, 1.0);
assert_column!(batch.column(7), BooleanArray, true);
assert_column!(batch.column(8), Date32Array, 1);
assert_column!(batch.column(9), TimestampNanosecondArray, 1);
assert_column!(batch.column(10), TimestampMillisecondArray, 1);
let array = batch
.column(11)
.as_any()
.downcast_ref::<FixedSizeListArray>()
.unwrap()
.iter()
.collect::<Vec<_>>();
for v in array {
let v = v.unwrap();
let f32array = v.as_any().downcast_ref::<Float32Array>().unwrap();
for v in f32array {
assert_eq!(v, Some(1.0));
}
}
let array = batch
.column(12)
.as_any()
.downcast_ref::<FixedSizeListArray>()
.unwrap()
.iter()
.collect::<Vec<_>>();
for v in array {
let v = v.unwrap();
let f64array = v.as_any().downcast_ref::<Float64Array>().unwrap();
for v in f64array {
assert_eq!(v, Some(1.0));
}
}
}
#[tokio::test]
async fn test_update_via_expr() {
let tmp_dir = tempdir().unwrap();
let dataset_path = tmp_dir.path().join("test.lance");
let uri = dataset_path.to_str().unwrap();
let conn = connect(uri)
.read_consistency_interval(Duration::from_secs(0))
.execute()
.await
.unwrap();
let tbl = conn
.create_table("my_table", make_test_batches())
.execute()
.await
.unwrap();
assert_eq!(1, tbl.count_rows(Some("i == 0".to_string())).await.unwrap());
tbl.update().column("i", "i+1").execute().await.unwrap();
assert_eq!(0, tbl.count_rows(Some("i == 0".to_string())).await.unwrap());
}
#[derive(Default, Debug)]
struct NoOpCacheWrapper {
called: AtomicBool,

View File

@@ -3,6 +3,7 @@
//! This module contains adapters to allow LanceDB tables to be used as DataFusion table providers.
pub mod insert;
pub mod udtf;
use std::{collections::HashMap, sync::Arc};
@@ -13,11 +14,12 @@ use async_trait::async_trait;
use datafusion_catalog::{Session, TableProvider};
use datafusion_common::{DataFusionError, Result as DataFusionResult, Statistics};
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_expr::{Expr, TableProviderFilterPushDown, TableType};
use datafusion_expr::{dml::InsertOp, Expr, TableProviderFilterPushDown, TableType};
use datafusion_physical_plan::{
stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties,
};
use futures::{TryFutureExt, TryStreamExt};
use lance::dataset::{WriteMode, WriteParams};
use super::{AnyQuery, BaseTable};
use crate::{
@@ -250,6 +252,33 @@ impl TableProvider for BaseTableAdapter {
// TODO
None
}
async fn insert_into(
&self,
_state: &dyn Session,
input: Arc<dyn ExecutionPlan>,
insert_op: InsertOp,
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
let mode = match insert_op {
InsertOp::Append => WriteMode::Append,
InsertOp::Overwrite => WriteMode::Overwrite,
InsertOp::Replace => {
return Err(DataFusionError::NotImplemented(
"Replace mode is not supported for LanceDB tables".to_string(),
))
}
};
let write_params = WriteParams {
mode,
..Default::default()
};
self.table
.create_insert_exec(input, write_params)
.await
.map_err(|e| DataFusionError::External(e.into()))
}
}
#[cfg(test)]

View File

@@ -0,0 +1,446 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! DataFusion ExecutionPlan for inserting data into LanceDB tables.
use std::any::Any;
use std::sync::{Arc, LazyLock, Mutex};
use arrow_array::{RecordBatch, UInt64Array};
use arrow_schema::{DataType, Field, Schema as ArrowSchema, SchemaRef};
use datafusion_common::{DataFusionError, Result as DataFusionResult};
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use datafusion_physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
};
use lance::dataset::transaction::{Operation, Transaction};
use lance::dataset::{CommitBuilder, InsertBuilder, WriteParams};
use lance::Dataset;
use lance_table::format::Fragment;
use crate::table::dataset::DatasetConsistencyWrapper;
pub(crate) static COUNT_SCHEMA: LazyLock<SchemaRef> = LazyLock::new(|| {
Arc::new(ArrowSchema::new(vec![Field::new(
"count",
DataType::UInt64,
false,
)]))
});
fn operation_fragments(operation: &Operation) -> &[Fragment] {
match operation {
Operation::Append { fragments } => fragments,
Operation::Overwrite { fragments, .. } => fragments,
_ => &[],
}
}
fn count_rows_from_operation(operation: &Operation) -> u64 {
operation_fragments(operation)
.iter()
.map(|f| f.num_rows().unwrap_or(0) as u64)
.sum()
}
fn operation_fragments_mut(operation: &mut Operation) -> &mut Vec<Fragment> {
match operation {
Operation::Append { fragments } => fragments,
Operation::Overwrite { fragments, .. } => fragments,
_ => panic!("Unsupported operation type for getting mutable fragments"),
}
}
fn merge_transactions(mut transactions: Vec<Transaction>) -> Option<Transaction> {
let mut first = transactions.pop()?;
for txn in transactions {
let first_fragments = operation_fragments_mut(&mut first.operation);
let txn_fragments = operation_fragments(&txn.operation);
first_fragments.extend_from_slice(txn_fragments);
}
Some(first)
}
/// ExecutionPlan for inserting data into a native LanceDB table.
///
/// This plan executes inserts by:
/// 1. Each partition writes data independently using InsertBuilder::execute_uncommitted_stream
/// 2. The last partition to complete commits all transactions atomically
/// 3. Returns the count of inserted rows per partition
#[derive(Debug)]
pub struct InsertExec {
ds_wrapper: DatasetConsistencyWrapper,
dataset: Arc<Dataset>,
input: Arc<dyn ExecutionPlan>,
write_params: WriteParams,
properties: PlanProperties,
partial_transactions: Arc<Mutex<Vec<Transaction>>>,
}
impl InsertExec {
pub fn new(
ds_wrapper: DatasetConsistencyWrapper,
dataset: Arc<Dataset>,
input: Arc<dyn ExecutionPlan>,
write_params: WriteParams,
) -> Self {
let schema = COUNT_SCHEMA.clone();
let num_partitions = input.output_partitioning().partition_count();
let properties = PlanProperties::new(
EquivalenceProperties::new(schema),
Partitioning::UnknownPartitioning(num_partitions),
EmissionType::Final,
Boundedness::Bounded,
);
Self {
ds_wrapper,
dataset,
input,
write_params,
properties,
partial_transactions: Arc::new(Mutex::new(Vec::with_capacity(num_partitions))),
}
}
}
impl DisplayAs for InsertExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(f, "InsertExec: mode={:?}", self.write_params.mode)
}
DisplayFormatType::TreeRender => {
write!(f, "InsertExec")
}
}
}
}
impl ExecutionPlan for InsertExec {
fn name(&self) -> &str {
Self::static_name()
}
fn as_any(&self) -> &dyn Any {
self
}
fn properties(&self) -> &PlanProperties {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}
fn maintains_input_order(&self) -> Vec<bool> {
vec![false]
}
fn benefits_from_input_partitioning(&self) -> Vec<bool> {
vec![false]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
if children.len() != 1 {
return Err(DataFusionError::Internal(
"InsertExec requires exactly one child".to_string(),
));
}
Ok(Arc::new(Self::new(
self.ds_wrapper.clone(),
self.dataset.clone(),
children[0].clone(),
self.write_params.clone(),
)))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> DataFusionResult<SendableRecordBatchStream> {
let input_stream = self.input.execute(partition, context)?;
let dataset = self.dataset.clone();
let write_params = self.write_params.clone();
let partial_transactions = self.partial_transactions.clone();
let total_partitions = self.input.output_partitioning().partition_count();
let ds_wrapper = self.ds_wrapper.clone();
let stream = futures::stream::once(async move {
let transaction = InsertBuilder::new(dataset.clone())
.with_params(&write_params)
.execute_uncommitted_stream(input_stream)
.await?;
let num_rows = count_rows_from_operation(&transaction.operation);
let to_commit = {
// Don't hold the lock over an await point.
let mut txns = partial_transactions.lock().unwrap();
txns.push(transaction);
if txns.len() == total_partitions {
Some(std::mem::take(&mut *txns))
} else {
None
}
};
if let Some(transactions) = to_commit {
if let Some(merged_txn) = merge_transactions(transactions) {
let new_dataset = CommitBuilder::new(dataset.clone())
.execute(merged_txn)
.await?;
ds_wrapper.set_latest(new_dataset).await;
}
}
Ok(RecordBatch::try_new(
COUNT_SCHEMA.clone(),
vec![Arc::new(UInt64Array::from(vec![num_rows]))],
)?)
});
Ok(Box::pin(RecordBatchStreamAdapter::new(
COUNT_SCHEMA.clone(),
stream,
)))
}
}
#[cfg(test)]
mod tests {
use std::vec;
use super::*;
use arrow_array::{record_batch, Int32Array, RecordBatchIterator};
use datafusion::prelude::SessionContext;
use datafusion_catalog::MemTable;
use tempfile::tempdir;
use crate::connect;
#[tokio::test]
async fn test_insert_via_sql() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = connect(uri).execute().await.unwrap();
// Create initial table
let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap();
let schema = batch.schema();
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
let table = db
.create_table("test_insert", Box::new(reader))
.execute()
.await
.unwrap();
// Verify initial count
assert_eq!(table.count_rows(None).await.unwrap(), 3);
let ctx = SessionContext::new();
let provider =
crate::table::datafusion::BaseTableAdapter::try_new(table.base_table().clone())
.await
.unwrap();
ctx.register_table("test_insert", Arc::new(provider))
.unwrap();
ctx.sql("INSERT INTO test_insert VALUES (4), (5), (6)")
.await
.unwrap()
.collect()
.await
.unwrap();
// Verify final count
table.checkout_latest().await.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 6);
}
#[tokio::test]
async fn test_insert_overwrite_via_sql() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = connect(uri).execute().await.unwrap();
// Create initial table with 3 rows
let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap();
let schema = batch.schema();
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
let table = db
.create_table("test_overwrite", Box::new(reader))
.execute()
.await
.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 3);
let ctx = SessionContext::new();
let provider =
crate::table::datafusion::BaseTableAdapter::try_new(table.base_table().clone())
.await
.unwrap();
ctx.register_table("test_overwrite", Arc::new(provider))
.unwrap();
ctx.sql("INSERT OVERWRITE INTO test_overwrite VALUES (10), (20)")
.await
.unwrap()
.collect()
.await
.unwrap();
// Verify: should have 2 rows (overwritten, not appended)
table.checkout_latest().await.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 2);
}
#[tokio::test]
async fn test_insert_empty_batch() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = connect(uri).execute().await.unwrap();
// Create initial table
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
"id",
DataType::Int32,
false,
)]));
let batches = vec![RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap()];
let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone());
let table = db
.create_table("test_empty", Box::new(reader))
.execute()
.await
.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 3);
let ctx = SessionContext::new();
let provider =
crate::table::datafusion::BaseTableAdapter::try_new(table.base_table().clone())
.await
.unwrap();
ctx.register_table("test_empty", Arc::new(provider))
.unwrap();
let source_schema = Arc::new(ArrowSchema::new(vec![Field::new(
"id",
DataType::Int32,
false,
)]));
// Empty batches
let source_reader = RecordBatchIterator::new(
std::iter::empty::<Result<RecordBatch, arrow_schema::ArrowError>>(),
source_schema,
);
let source_table = db
.create_table("empty_source", Box::new(source_reader))
.execute()
.await
.unwrap();
let source_provider =
crate::table::datafusion::BaseTableAdapter::try_new(source_table.base_table().clone())
.await
.unwrap();
ctx.register_table("empty_source", Arc::new(source_provider))
.unwrap();
// Execute INSERT with empty source
ctx.sql("INSERT INTO test_empty SELECT * FROM empty_source")
.await
.unwrap()
.collect()
.await
.unwrap();
// Verify: should still have 3 rows (nothing inserted)
table.checkout_latest().await.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 3);
}
#[tokio::test]
async fn test_insert_multiple_batches() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = connect(uri).execute().await.unwrap();
// Create initial table
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
"id",
DataType::Int32,
true,
)]));
let batches =
vec![
RecordBatch::try_new(schema.clone(), vec![Arc::new(Int32Array::from(vec![1]))])
.unwrap(),
];
let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone());
let table = db
.create_table("test_multi_batch", Box::new(reader))
.execute()
.await
.unwrap();
let ctx = SessionContext::new();
let provider =
crate::table::datafusion::BaseTableAdapter::try_new(table.base_table().clone())
.await
.unwrap();
ctx.register_table("test_multi_batch", Arc::new(provider))
.unwrap();
// Memtable with multiple batches and multiple partitions
let source_table = MemTable::try_new(
schema.clone(),
vec![
// Partition 0
vec![
record_batch!(("id", Int32, [2, 3])).unwrap(),
record_batch!(("id", Int32, [4, 5])).unwrap(),
],
// Partition 1
vec![record_batch!(("id", Int32, [6, 7, 8])).unwrap()],
],
)
.unwrap();
ctx.register_table("multi_batch_source", Arc::new(source_table))
.unwrap();
ctx.sql("INSERT INTO test_multi_batch SELECT * FROM multi_batch_source")
.await
.unwrap()
.collect()
.await
.unwrap();
// Verify: should have 1 + 2 + 2 + 3 = 8 rows
table.checkout_latest().await.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 8);
}
}

View File

@@ -100,7 +100,8 @@ impl DatasetRef {
let should_checkout = match &target_ref {
refs::Ref::Version(_, Some(target_ver)) => version != target_ver,
refs::Ref::Version(_, None) => true, // No specific version, always checkout
refs::Ref::Tag(_) => true, // Always checkout for tags
refs::Ref::VersionNumber(target_ver) => version != target_ver,
refs::Ref::Tag(_) => true, // Always checkout for tags
};
if should_checkout {

View File

@@ -0,0 +1,161 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use serde::{Deserialize, Serialize};
use super::NativeTable;
use crate::Result;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct DeleteResult {
// The commit version associated with the operation.
// A version of `0` indicates compatibility with legacy servers that do not return
/// a commit version.
#[serde(default)]
pub version: u64,
}
/// Internal implementation of the delete logic
///
/// This logic was moved from NativeTable::delete to keep table.rs clean.
pub(crate) async fn execute_delete(table: &NativeTable, predicate: &str) -> Result<DeleteResult> {
// We access the dataset from the table. Since this is in the same module hierarchy (super),
// and 'dataset' is pub(crate), we can access it.
let mut dataset = table.dataset.get_mut().await?;
// Perform the actual delete on the Lance dataset
dataset.delete(predicate).await?;
// Return the result with the new version
Ok(DeleteResult {
version: dataset.version().version,
})
}
#[cfg(test)]
mod tests {
use crate::connect;
use arrow_array::{record_batch, Int32Array, RecordBatch, RecordBatchIterator};
use arrow_schema::{DataType, Field, Schema};
use std::sync::Arc;
use crate::query::ExecutableQuery;
use futures::TryStreamExt;
#[tokio::test]
async fn test_delete_simple() {
let conn = connect("memory://").execute().await.unwrap();
// 1. Create a table with values 0 to 9
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(0..10))],
)
.unwrap();
let table = conn
.create_table(
"test_delete",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// 2. Verify initial state
assert_eq!(table.count_rows(None).await.unwrap(), 10);
// 3. Execute Delete (removes values > 5)
table.delete("i > 5").await.unwrap();
// 4. Verify results
assert_eq!(table.count_rows(None).await.unwrap(), 6); // 0, 1, 2, 3, 4, 5 remain
// 5. Verify specific data consistency
let batches = table
.query()
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch = &batches[0];
let array = batch
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
// Ensure no value > 5 exists
for val in array.iter() {
assert!(val.unwrap() <= 5);
}
}
#[tokio::test]
async fn rows_removed_schema_same() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(
("id", Int32, [1, 2, 3, 4, 5]),
("name", Utf8, ["a", "b", "c", "d", "e"])
)
.unwrap();
let original_schema = batch.schema();
let table = conn
.create_table(
"test_delete_all",
RecordBatchIterator::new(vec![Ok(batch)], original_schema.clone()),
)
.execute()
.await
.unwrap();
table.delete("true").await.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 0);
let current_schema = table.schema().await.unwrap();
//check if the original schema is the same as current
assert_eq!(current_schema, original_schema);
}
#[tokio::test]
async fn test_delete_false_increments_version() {
let conn = connect("memory://").execute().await.unwrap();
// Create a table with 5 rows
let batch = record_batch!(("id", Int32, [1, 2, 3, 4, 5])).unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_delete_noop",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// Capture the initial state (Rows = 5, Version = 1)
let initial_rows = table.count_rows(None).await.unwrap();
let initial_version = table.version().await.unwrap();
assert_eq!(initial_rows, 5);
table.delete("false").await.unwrap();
// Rows should still be 5
let current_rows = table.count_rows(None).await.unwrap();
assert_eq!(
current_rows, initial_rows,
"Data should not change when predicate is false"
);
// version check
let current_version = table.version().await.unwrap();
assert!(
current_version > initial_version,
"Table version must increment after delete operation"
);
}
}

View File

@@ -0,0 +1,666 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! Schema evolution operations for LanceDB tables.
//!
//! This module provides functionality to modify the schema of existing tables:
//! - [`add_columns`](execute_add_columns): Add new columns using SQL expressions
//! - [`alter_columns`](execute_alter_columns): Rename columns, change types, or modify nullability
//! - [`drop_columns`](execute_drop_columns): Remove columns from the table
use lance::dataset::{ColumnAlteration, NewColumnTransform};
use serde::{Deserialize, Serialize};
use super::NativeTable;
use crate::Result;
/// The result of an add columns operation.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct AddColumnsResult {
// The commit version associated with the operation.
// A version of `0` indicates compatibility with legacy servers that do not return
/// a commit version.
#[serde(default)]
pub version: u64,
}
/// The result of an alter columns operation.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct AlterColumnsResult {
// The commit version associated with the operation.
// A version of `0` indicates compatibility with legacy servers that do not return
/// a commit version.
#[serde(default)]
pub version: u64,
}
/// The result of a drop columns operation.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct DropColumnsResult {
// The commit version associated with the operation.
// A version of `0` indicates compatibility with legacy servers that do not return
/// a commit version.
#[serde(default)]
pub version: u64,
}
/// Internal implementation of the add columns logic.
///
/// Adds new columns to the table using the provided transforms.
pub(crate) async fn execute_add_columns(
table: &NativeTable,
transforms: NewColumnTransform,
read_columns: Option<Vec<String>>,
) -> Result<AddColumnsResult> {
let mut dataset = table.dataset.get_mut().await?;
dataset.add_columns(transforms, read_columns, None).await?;
Ok(AddColumnsResult {
version: dataset.version().version,
})
}
/// Internal implementation of the alter columns logic.
///
/// Alters existing columns in the table (rename, change type, or modify nullability).
pub(crate) async fn execute_alter_columns(
table: &NativeTable,
alterations: &[ColumnAlteration],
) -> Result<AlterColumnsResult> {
let mut dataset = table.dataset.get_mut().await?;
dataset.alter_columns(alterations).await?;
Ok(AlterColumnsResult {
version: dataset.version().version,
})
}
/// Internal implementation of the drop columns logic.
///
/// Removes columns from the table.
pub(crate) async fn execute_drop_columns(
table: &NativeTable,
columns: &[&str],
) -> Result<DropColumnsResult> {
let mut dataset = table.dataset.get_mut().await?;
dataset.drop_columns(columns).await?;
Ok(DropColumnsResult {
version: dataset.version().version,
})
}
#[cfg(test)]
mod tests {
use arrow_array::{record_batch, Int32Array, RecordBatchIterator, StringArray};
use arrow_schema::DataType;
use futures::TryStreamExt;
use lance::dataset::ColumnAlteration;
use crate::connect;
use crate::query::{ExecutableQuery, QueryBase, Select};
use crate::table::NewColumnTransform;
// Add Columns Tests
#[tokio::test]
async fn test_add_columns_with_sql_expression() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(("id", Int32, [1, 2, 3, 4, 5])).unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_add_columns",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
let initial_version = table.version().await.unwrap();
// Add a computed column
let result = table
.add_columns(
NewColumnTransform::SqlExpressions(vec![("doubled".into(), "id * 2".into())]),
None,
)
.await
.unwrap();
// Version should increment
assert!(result.version > initial_version);
// Verify the new column exists with correct values
let batches = table
.query()
.select(Select::columns(&["id", "doubled"]))
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch = &batches[0];
let ids: Vec<i32> = batch
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap()
.iter()
.map(|v| v.unwrap())
.collect();
let doubled: Vec<i32> = batch
.column(1)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap()
.iter()
.map(|v| v.unwrap())
.collect();
for (id, d) in ids.iter().zip(doubled.iter()) {
assert_eq!(*d, id * 2);
}
}
#[tokio::test]
async fn test_add_multiple_columns() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(("x", Int32, [10, 20, 30])).unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_add_multi_columns",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// Add multiple columns at once
table
.add_columns(
NewColumnTransform::SqlExpressions(vec![
("y".into(), "x + 1".into()),
("z".into(), "x * x".into()),
]),
None,
)
.await
.unwrap();
// Verify schema has all columns
let schema = table.schema().await.unwrap();
assert_eq!(schema.fields().len(), 3);
assert!(schema.field_with_name("x").is_ok());
assert!(schema.field_with_name("y").is_ok());
assert!(schema.field_with_name("z").is_ok());
}
#[tokio::test]
async fn test_add_column_with_constant_expression() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_add_const_column",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// Add a column with a constant value
table
.add_columns(
NewColumnTransform::SqlExpressions(vec![("constant".into(), "42".into())]),
None,
)
.await
.unwrap();
let schema = table.schema().await.unwrap();
assert!(schema.field_with_name("constant").is_ok());
// Verify all values are 42
let batches = table
.query()
.select(Select::columns(&["constant"]))
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch = &batches[0];
let values = batch["constant"]
.as_any()
.downcast_ref::<arrow_array::Int64Array>()
.unwrap()
.values();
assert!(values.iter().all(|&v| v == 42));
}
// Alter Columns Tests
#[tokio::test]
async fn test_alter_column_rename() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(("old_name", Int32, [1, 2, 3])).unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_alter_rename",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
let initial_version = table.version().await.unwrap();
// Rename the column
let result = table
.alter_columns(&[ColumnAlteration::new("old_name".into()).rename("new_name".into())])
.await
.unwrap();
// Version should increment
assert!(result.version > initial_version);
// Verify rename
let schema = table.schema().await.unwrap();
assert!(schema.field_with_name("old_name").is_err());
assert!(schema.field_with_name("new_name").is_ok());
}
#[tokio::test]
async fn test_alter_column_set_nullable() {
use arrow_array::RecordBatch;
use arrow_schema::{Field, Schema};
use std::sync::Arc;
let conn = connect("memory://").execute().await.unwrap();
// Create a schema with a non-nullable field
let schema = Arc::new(Schema::new(vec![Field::new(
"value",
DataType::Int32,
false,
)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let table = conn
.create_table(
"test_alter_nullable",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// Initially non-nullable
let schema = table.schema().await.unwrap();
assert!(!schema.field_with_name("value").unwrap().is_nullable());
// Make it nullable
table
.alter_columns(&[ColumnAlteration::new("value".into()).set_nullable(true)])
.await
.unwrap();
// Verify it's now nullable
let schema = table.schema().await.unwrap();
assert!(schema.field_with_name("value").unwrap().is_nullable());
}
#[tokio::test]
async fn test_alter_column_cast_type() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(("num", Int32, [1, 2, 3])).unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_cast_type",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// Cast Int32 to Int64 (a supported cast)
table
.alter_columns(&[ColumnAlteration::new("num".into()).cast_to(DataType::Int64)])
.await
.unwrap();
// Verify type changed
let schema = table.schema().await.unwrap();
assert_eq!(
schema.field_with_name("num").unwrap().data_type(),
&DataType::Int64
);
// Query the data and verify the returned type is correct
let batches = table
.query()
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch = &batches[0];
let values = batch["num"]
.as_any()
.downcast_ref::<arrow_array::Int64Array>()
.unwrap()
.values();
assert_eq!(values.as_ref(), &[1i64, 2, 3]);
}
#[tokio::test]
async fn test_alter_column_invalid_cast_fails() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(("num", Int32, [1, 2, 3])).unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_invalid_cast",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// Casting Int32 to Float64 is not supported
let result = table
.alter_columns(&[ColumnAlteration::new("num".into()).cast_to(DataType::Float64)])
.await;
let err = result.unwrap_err();
assert!(
err.to_string().contains("cast"),
"Expected error message to contain 'cast', got: {}",
err
);
}
#[tokio::test]
async fn test_alter_multiple_columns() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(("a", Int32, [1, 2, 3]), ("b", Int32, [4, 5, 6])).unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_alter_multi",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// Alter multiple columns at once
table
.alter_columns(&[
ColumnAlteration::new("a".into()).rename("alpha".into()),
ColumnAlteration::new("b".into()).set_nullable(true),
])
.await
.unwrap();
let schema = table.schema().await.unwrap();
assert!(schema.field_with_name("alpha").is_ok());
assert!(schema.field_with_name("a").is_err());
assert!(schema.field_with_name("b").unwrap().is_nullable());
}
// Drop Columns Tests
#[tokio::test]
async fn test_drop_single_column() {
let conn = connect("memory://").execute().await.unwrap();
let batch =
record_batch!(("keep", Int32, [1, 2, 3]), ("remove", Int32, [4, 5, 6])).unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_drop_single",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
let initial_version = table.version().await.unwrap();
// Drop a column
let result = table.drop_columns(&["remove"]).await.unwrap();
// Version should increment
assert!(result.version > initial_version);
// Verify column was dropped
let schema = table.schema().await.unwrap();
assert_eq!(schema.fields().len(), 1);
assert!(schema.field_with_name("keep").is_ok());
assert!(schema.field_with_name("remove").is_err());
}
#[tokio::test]
async fn test_drop_multiple_columns() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(
("a", Int32, [1, 2]),
("b", Int32, [3, 4]),
("c", Int32, [5, 6]),
("d", Int32, [7, 8])
)
.unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_drop_multi",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// Drop multiple columns
table.drop_columns(&["b", "d"]).await.unwrap();
// Verify only a and c remain
let schema = table.schema().await.unwrap();
assert_eq!(schema.fields().len(), 2);
assert!(schema.field_with_name("a").is_ok());
assert!(schema.field_with_name("c").is_ok());
assert!(schema.field_with_name("b").is_err());
assert!(schema.field_with_name("d").is_err());
}
#[tokio::test]
async fn test_drop_column_preserves_data() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(
("id", Int32, [1, 2, 3]),
("name", Utf8, ["a", "b", "c"]),
("extra", Int32, [10, 20, 30])
)
.unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_drop_preserves",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// Drop the extra column
table.drop_columns(&["extra"]).await.unwrap();
// Verify remaining data is intact
let batches = table
.query()
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch = &batches[0];
assert_eq!(batch.num_columns(), 2);
assert_eq!(batch.num_rows(), 3);
let ids: Vec<i32> = batch
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap()
.iter()
.map(|v| v.unwrap())
.collect();
assert_eq!(ids, vec![1, 2, 3]);
let names: Vec<&str> = batch
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.iter()
.map(|v| v.unwrap())
.collect();
assert_eq!(names, vec!["a", "b", "c"]);
}
// Error Case Tests
#[tokio::test]
async fn test_drop_nonexistent_column_fails() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(("existing", Int32, [1, 2, 3])).unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_drop_nonexistent",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// Try to drop a column that doesn't exist
let result = table.drop_columns(&["nonexistent"]).await;
let err = result.unwrap_err();
assert!(
err.to_string().contains("nonexistent"),
"Expected error message to contain column name 'nonexistent', got: {}",
err
);
}
#[tokio::test]
async fn test_alter_nonexistent_column_fails() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(("existing", Int32, [1, 2, 3])).unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_alter_nonexistent",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// Try to alter a column that doesn't exist
let result = table
.alter_columns(&[ColumnAlteration::new("nonexistent".into()).rename("new".into())])
.await;
let err = result.unwrap_err();
assert!(
err.to_string().contains("nonexistent"),
"Expected error message to contain column name 'nonexistent', got: {}",
err
);
}
// Version Tracking Tests
#[tokio::test]
async fn test_schema_operations_increment_version() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(("a", Int32, [1, 2, 3]), ("b", Int32, [4, 5, 6])).unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_version_increment",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
let v1 = table.version().await.unwrap();
// Add column increments version
let add_result = table
.add_columns(
NewColumnTransform::SqlExpressions(vec![("c".into(), "a + b".into())]),
None,
)
.await
.unwrap();
assert!(add_result.version > v1);
let v2 = table.version().await.unwrap();
assert_eq!(add_result.version, v2);
// Alter column increments version
let alter_result = table
.alter_columns(&[ColumnAlteration::new("c".into()).rename("sum".into())])
.await
.unwrap();
assert!(alter_result.version > v2);
let v3 = table.version().await.unwrap();
assert_eq!(alter_result.version, v3);
// Drop column increments version
let drop_result = table.drop_columns(&["b"]).await.unwrap();
assert!(drop_result.version > v3);
let v4 = table.version().await.unwrap();
assert_eq!(drop_result.version, v4);
}
}

View File

@@ -0,0 +1,441 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::Arc;
use lance::dataset::UpdateBuilder as LanceUpdateBuilder;
use serde::{Deserialize, Serialize};
use super::{BaseTable, NativeTable};
use crate::Error;
use crate::Result;
/// The result of an update operation
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct UpdateResult {
#[serde(default)]
pub rows_updated: u64,
/// The commit version associated with the operation.
#[serde(default)]
pub version: u64,
}
/// A builder for configuring a [`crate::table::Table::update`] operation
#[derive(Debug, Clone)]
pub struct UpdateBuilder {
parent: Arc<dyn BaseTable>,
pub(crate) filter: Option<String>,
pub(crate) columns: Vec<(String, String)>,
}
impl UpdateBuilder {
pub(crate) fn new(parent: Arc<dyn BaseTable>) -> Self {
Self {
parent,
filter: None,
columns: Vec::new(),
}
}
/// Limits the update operation to rows matching the given filter
///
/// If a row does not match the filter then it will be left unchanged.
pub fn only_if(mut self, filter: impl Into<String>) -> Self {
self.filter = Some(filter.into());
self
}
/// Specifies a column to update
///
/// This method may be called multiple times to update multiple columns
///
/// The `update_expr` should be an SQL expression explaining how to calculate
/// the new value for the column. The expression will be evaluated against the
/// previous row's value.
pub fn column(
mut self,
column_name: impl Into<String>,
update_expr: impl Into<String>,
) -> Self {
self.columns.push((column_name.into(), update_expr.into()));
self
}
/// Executes the update operation.
pub async fn execute(self) -> Result<UpdateResult> {
if self.columns.is_empty() {
Err(Error::InvalidInput {
message: "at least one column must be specified in an update operation".to_string(),
})
} else {
self.parent.clone().update(self).await
}
}
}
/// Internal implementation of the update logic
pub(crate) async fn execute_update(
table: &NativeTable,
update: UpdateBuilder,
) -> Result<UpdateResult> {
// 1. Snapshot the current dataset
let dataset = table.dataset.get().await?.clone();
// 2. Initialize the Lance Core builder
let mut builder = LanceUpdateBuilder::new(Arc::new(dataset));
// 3. Apply the filter (WHERE clause)
if let Some(predicate) = update.filter {
builder = builder.update_where(&predicate)?;
}
// 4. Apply the columns (SET clause)
for (column, value) in update.columns {
builder = builder.set(column, &value)?;
}
// 5. Execute the operation (Write new files)
let operation = builder.build()?;
let res = operation.execute().await?;
// 6. Update the table's view of the latest version
table
.dataset
.set_latest(res.new_dataset.as_ref().clone())
.await;
Ok(UpdateResult {
rows_updated: res.rows_updated,
version: res.new_dataset.version().version,
})
}
#[cfg(test)]
mod tests {
use crate::connect;
use crate::query::QueryBase;
use crate::query::{ExecutableQuery, Select};
use arrow_array::{
record_batch, Array, BooleanArray, Date32Array, FixedSizeListArray, Float32Array,
Float64Array, Int32Array, Int64Array, LargeStringArray, RecordBatch, RecordBatchIterator,
RecordBatchReader, StringArray, TimestampMillisecondArray, TimestampNanosecondArray,
UInt32Array,
};
use arrow_data::ArrayDataBuilder;
use arrow_schema::{ArrowError, DataType, Field, Schema, TimeUnit};
use futures::TryStreamExt;
use std::sync::Arc;
use std::time::Duration;
#[tokio::test]
async fn test_update_all_types() {
let conn = connect("memory://")
.read_consistency_interval(Duration::from_secs(0))
.execute()
.await
.unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("int32", DataType::Int32, false),
Field::new("int64", DataType::Int64, false),
Field::new("uint32", DataType::UInt32, false),
Field::new("string", DataType::Utf8, false),
Field::new("large_string", DataType::LargeUtf8, false),
Field::new("float32", DataType::Float32, false),
Field::new("float64", DataType::Float64, false),
Field::new("bool", DataType::Boolean, false),
Field::new("date32", DataType::Date32, false),
Field::new(
"timestamp_ns",
DataType::Timestamp(TimeUnit::Nanosecond, None),
false,
),
Field::new(
"timestamp_ms",
DataType::Timestamp(TimeUnit::Millisecond, None),
false,
),
Field::new(
"vec_f32",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2),
false,
),
Field::new(
"vec_f64",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, true)), 2),
false,
),
]));
let record_batch_iter = RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(0..10)),
Arc::new(Int64Array::from_iter_values(0..10)),
Arc::new(UInt32Array::from_iter_values(0..10)),
Arc::new(StringArray::from_iter_values(vec![
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
])),
Arc::new(LargeStringArray::from_iter_values(vec![
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
])),
Arc::new(Float32Array::from_iter_values((0..10).map(|i| i as f32))),
Arc::new(Float64Array::from_iter_values((0..10).map(|i| i as f64))),
Arc::new(Into::<BooleanArray>::into(vec![
true, false, true, false, true, false, true, false, true, false,
])),
Arc::new(Date32Array::from_iter_values(0..10)),
Arc::new(TimestampNanosecondArray::from_iter_values(0..10)),
Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
Arc::new(
create_fixed_size_list(
Float32Array::from_iter_values((0..20).map(|i| i as f32)),
2,
)
.unwrap(),
),
Arc::new(
create_fixed_size_list(
Float64Array::from_iter_values((0..20).map(|i| i as f64)),
2,
)
.unwrap(),
),
],
)
.unwrap()]
.into_iter()
.map(Ok),
schema.clone(),
);
let table = conn
.create_table("my_table", record_batch_iter)
.execute()
.await
.unwrap();
// check it can do update for each type
let updates: Vec<(&str, &str)> = vec![
("string", "'foo'"),
("large_string", "'large_foo'"),
("int32", "1"),
("int64", "1"),
("uint32", "1"),
("float32", "1.0"),
("float64", "1.0"),
("bool", "true"),
("date32", "1"),
("timestamp_ns", "1"),
("timestamp_ms", "1"),
("vec_f32", "[1.0, 1.0]"),
("vec_f64", "[1.0, 1.0]"),
];
let mut update_op = table.update();
for (column, value) in updates {
update_op = update_op.column(column, value);
}
update_op.execute().await.unwrap();
let mut batches = table
.query()
.select(Select::columns(&[
"string",
"large_string",
"int32",
"int64",
"uint32",
"float32",
"float64",
"bool",
"date32",
"timestamp_ns",
"timestamp_ms",
"vec_f32",
"vec_f64",
]))
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch = batches.pop().unwrap();
macro_rules! assert_column {
($column:expr, $array_type:ty, $expected:expr) => {
let array = $column
.as_any()
.downcast_ref::<$array_type>()
.unwrap()
.iter()
.collect::<Vec<_>>();
for v in array {
assert_eq!(v, Some($expected));
}
};
}
assert_column!(batch.column(0), StringArray, "foo");
assert_column!(batch.column(1), LargeStringArray, "large_foo");
assert_column!(batch.column(2), Int32Array, 1);
assert_column!(batch.column(3), Int64Array, 1);
assert_column!(batch.column(4), UInt32Array, 1);
assert_column!(batch.column(5), Float32Array, 1.0);
assert_column!(batch.column(6), Float64Array, 1.0);
assert_column!(batch.column(7), BooleanArray, true);
assert_column!(batch.column(8), Date32Array, 1);
assert_column!(batch.column(9), TimestampNanosecondArray, 1);
assert_column!(batch.column(10), TimestampMillisecondArray, 1);
let array = batch
.column(11)
.as_any()
.downcast_ref::<FixedSizeListArray>()
.unwrap()
.iter()
.collect::<Vec<_>>();
for v in array {
let v = v.unwrap();
let f32array = v.as_any().downcast_ref::<Float32Array>().unwrap();
for v in f32array {
assert_eq!(v, Some(1.0));
}
}
let array = batch
.column(12)
.as_any()
.downcast_ref::<FixedSizeListArray>()
.unwrap()
.iter()
.collect::<Vec<_>>();
for v in array {
let v = v.unwrap();
let f64array = v.as_any().downcast_ref::<Float64Array>().unwrap();
for v in f64array {
assert_eq!(v, Some(1.0));
}
}
}
///Two helper functions
fn create_fixed_size_list<T: Array>(
values: T,
list_size: i32,
) -> Result<FixedSizeListArray, ArrowError> {
let list_type = DataType::FixedSizeList(
Arc::new(Field::new("item", values.data_type().clone(), true)),
list_size,
);
let data = ArrayDataBuilder::new(list_type)
.len(values.len() / list_size as usize)
.add_child_data(values.into_data())
.build()
.unwrap();
Ok(FixedSizeListArray::from(data))
}
fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static {
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(0..10))],
)],
schema,
)
}
#[tokio::test]
async fn test_update_with_predicate() {
let conn = connect("memory://")
.read_consistency_interval(Duration::from_secs(0))
.execute()
.await
.unwrap();
let batch = record_batch!(
("id", Int32, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
(
"name",
Utf8,
["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"]
)
)
.unwrap();
let schema = batch.schema();
// need the iterator for create table
let record_batch_iter = RecordBatchIterator::new(vec![Ok(batch)], schema);
let table = conn
.create_table("my_table", record_batch_iter)
.execute()
.await
.unwrap();
table
.update()
.only_if("id > 5")
.column("name", "'foo'")
.execute()
.await
.unwrap();
let mut batches = table
.query()
.select(Select::columns(&["id", "name"]))
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
while let Some(batch) = batches.pop() {
let ids = batch
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap()
.iter()
.collect::<Vec<_>>();
let names = batch
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.iter()
.collect::<Vec<_>>();
for (i, name) in names.iter().enumerate() {
let id = ids[i].unwrap();
let name = name.unwrap();
if id > 5 {
assert_eq!(name, "foo");
} else {
assert_eq!(name, &format!("{}", (b'a' + id as u8) as char));
}
}
}
}
#[tokio::test]
async fn test_update_via_expr() {
let conn = connect("memory://")
.read_consistency_interval(Duration::from_secs(0))
.execute()
.await
.unwrap();
let tbl = conn
.create_table("my_table", make_test_batches())
.execute()
.await
.unwrap();
assert_eq!(1, tbl.count_rows(Some("i == 0".to_string())).await.unwrap());
tbl.update().column("i", "i+1").execute().await.unwrap();
assert_eq!(0, tbl.count_rows(Some("i == 0".to_string())).await.unwrap());
}
}