mirror of
https://github.com/lancedb/lancedb.git
synced 2026-06-03 04:10:41 +00:00
Compare commits
1 Commits
feat/table
...
jack/idp-o
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c42848ae78 |
@@ -1,7 +0,0 @@
|
||||
# Agent Skills
|
||||
|
||||
This directory contains repo-scoped code agent skills for the LanceDB project.
|
||||
|
||||
Each skill is a folder that contains a required `SKILL.md` and optional bundled resources.
|
||||
|
||||
Codex discovers skills from `.agents/skills` in the current working directory and parent directories.
|
||||
@@ -1,98 +0,0 @@
|
||||
---
|
||||
name: lancedb-update-lance-dependency
|
||||
description: Update LanceDB to a specific Lance release or tag. Use when bumping Lance dependencies in the lancedb repository, including Rust workspace Lance crates, Java lance-core, validation, branch creation, commit, push, and PR creation when requested.
|
||||
---
|
||||
|
||||
# LanceDB Update Lance Dependency
|
||||
|
||||
## Scope
|
||||
|
||||
Use this skill in the `lancedb/lancedb` repository when updating the Lance dependency to a specific Lance version or tag.
|
||||
|
||||
Inputs can be a version (`7.2.0-beta.1`), a tag (`v7.2.0-beta.1`), a tag ref (`refs/tags/v7.2.0-beta.1`), or `latest`.
|
||||
|
||||
## Workflow
|
||||
|
||||
1. Confirm the worktree status with `git status --short`.
|
||||
2. Resolve the target Lance version:
|
||||
|
||||
- If the input is `latest`, empty, or omitted, run:
|
||||
|
||||
```bash
|
||||
python3 ci/check_lance_release.py
|
||||
```
|
||||
|
||||
Parse the JSON output. If `needs_update` is not `true`, stop without creating a PR. Otherwise use `latest_tag`.
|
||||
|
||||
- If the input is explicit, use it directly.
|
||||
|
||||
3. Compute update metadata without changing files:
|
||||
|
||||
```bash
|
||||
python3 ci/update_lance_dependency.py "$TAG_OR_VERSION" --metadata-only
|
||||
```
|
||||
|
||||
Before making changes, check for an existing open PR with the emitted `pr_title`:
|
||||
|
||||
```bash
|
||||
gh pr list --search "\"$PR_TITLE\" in:title" --state open --limit 1 --json number,url,title
|
||||
```
|
||||
|
||||
If a matching open PR exists, stop and report it instead of creating a duplicate.
|
||||
|
||||
4. Run the deterministic update entrypoint:
|
||||
|
||||
```bash
|
||||
python3 ci/update_lance_dependency.py "$TAG_OR_VERSION"
|
||||
```
|
||||
|
||||
This updates the Rust workspace Lance dependencies through `ci/set_lance_version.py`, updates `java/pom.xml`, refreshes Cargo metadata, and prints JSON metadata containing `branch_name`, `commit_message`, and `pr_title`.
|
||||
|
||||
5. Run validation:
|
||||
|
||||
```bash
|
||||
cargo clippy --quiet --workspace --tests --all-features -- -D warnings
|
||||
cargo fmt --all --quiet
|
||||
```
|
||||
|
||||
Fix real diagnostics and rerun clippy until it succeeds. Do not skip warnings.
|
||||
|
||||
6. Inspect `git status --short` and `git diff` to ensure only the Lance dependency update and required compatibility fixes are present.
|
||||
|
||||
7. If the task only asks to prepare local changes, stop here and report the changed files and validation result.
|
||||
|
||||
8. If the task asks to publish the update, create a branch using the printed `branch_name`, stage all relevant files, and commit using the printed `commit_message`. Do not amend or rewrite existing commits.
|
||||
|
||||
9. Push to `origin`. Before creating the PR, check that the current token has push permission:
|
||||
|
||||
```bash
|
||||
gh api repos/lancedb/lancedb --jq .permissions.push
|
||||
```
|
||||
|
||||
If the remote branch already exists for the same generated branch name, delete the remote ref with `gh api -X DELETE repos/lancedb/lancedb/git/refs/heads/$BRANCH_NAME`, then push. Do not force-push.
|
||||
|
||||
10. Create a PR targeting `main` with the printed `pr_title`. If there is no PR template, keep the body to two or three concise sentences: state the Lance dependency bump, note any required compatibility fixes, and link the triggering Lance tag or release.
|
||||
|
||||
11. Read back the remote PR title after creation. If it is not a Conventional Commit title, fix it immediately.
|
||||
|
||||
12. When running in GitHub Actions after creating the LanceDB PR, trigger the Sophon dependency update:
|
||||
|
||||
```bash
|
||||
gh workflow run codex-bump-lancedb-lance.yml \
|
||||
--repo lancedb/sophon \
|
||||
-f lance_ref="$LANCE_TAG" \
|
||||
-f lancedb_ref="$BRANCH_NAME"
|
||||
gh run list --repo lancedb/sophon --workflow codex-bump-lancedb-lance.yml --limit 1 --json databaseId,url,displayTitle
|
||||
```
|
||||
|
||||
Use the emitted metadata `tag` value as `LANCE_TAG`. Do this only after a new LanceDB PR has been created. If the update was skipped because no update is needed or an open PR already exists, do not trigger Sophon.
|
||||
|
||||
## GitHub Actions
|
||||
|
||||
When this skill is used from GitHub Actions, `TAG`, `GH_TOKEN`, and `GITHUB_TOKEN` may already be set. Resolve `latest` first when `TAG` is empty. Once an explicit tag or version is known, use:
|
||||
|
||||
```bash
|
||||
python3 ci/update_lance_dependency.py "$TAG" --github-output "$GITHUB_OUTPUT"
|
||||
```
|
||||
|
||||
Then use the emitted `branch_name`, `commit_message`, and `pr_title` values for branch, commit, and PR creation.
|
||||
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.30.1-beta.0"
|
||||
current_version = "0.29.1-beta.0"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
5
.github/dependabot.yml
vendored
5
.github/dependabot.yml
vendored
@@ -11,11 +11,6 @@ updates:
|
||||
schedule:
|
||||
interval: weekly
|
||||
open-pull-requests-limit: 10
|
||||
# Only update Cargo.lock, never widen/raise the version requirements in
|
||||
# Cargo.toml. The goal is keeping the lockfile (and the binaries we ship)
|
||||
# current on security fixes, not forcing our library's consumers onto
|
||||
# newer minimum versions.
|
||||
versioning-strategy: lockfile-only
|
||||
groups:
|
||||
rust-minor-patch:
|
||||
update-types:
|
||||
|
||||
@@ -29,3 +29,7 @@ runs:
|
||||
args: ${{ inputs.args }}
|
||||
docker-options: "-e PIP_EXTRA_INDEX_URL='https://pypi.fury.io/lance-format/ https://pypi.fury.io/lancedb/'"
|
||||
working-directory: python
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: windows-wheels
|
||||
path: python\target\wheels
|
||||
|
||||
@@ -4,16 +4,14 @@ on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
tag:
|
||||
description: "Tag name from Lance. If omitted, the skill will use the latest Lance release that needs an update."
|
||||
required: false
|
||||
default: ""
|
||||
description: "Tag name from Lance"
|
||||
required: true
|
||||
type: string
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: "Tag name from Lance. Leave empty to use the latest Lance release that needs an update."
|
||||
required: false
|
||||
default: ""
|
||||
description: "Tag name from Lance"
|
||||
required: true
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
@@ -27,7 +25,7 @@ jobs:
|
||||
steps:
|
||||
- name: Show inputs
|
||||
run: |
|
||||
echo "tag = ${{ inputs.tag || 'latest' }}"
|
||||
echo "tag = ${{ inputs.tag }}"
|
||||
|
||||
- name: Checkout Repo LanceDB
|
||||
uses: actions/checkout@v4
|
||||
@@ -73,21 +71,65 @@ jobs:
|
||||
OPENAI_API_KEY: ${{ secrets.CODEX_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
TARGET_TAG="${TAG:-latest}"
|
||||
VERSION="${TAG#refs/tags/}"
|
||||
VERSION="${VERSION#v}"
|
||||
BRANCH_NAME="codex/update-lance-${VERSION//[^a-zA-Z0-9]/-}"
|
||||
|
||||
# Use "chore" for beta/rc versions, "feat" for stable releases
|
||||
if [[ "${VERSION}" == *beta* ]] || [[ "${VERSION}" == *rc* ]]; then
|
||||
COMMIT_TYPE="chore"
|
||||
else
|
||||
COMMIT_TYPE="feat"
|
||||
fi
|
||||
|
||||
cat <<EOF >/tmp/codex-prompt.txt
|
||||
You are running inside the lancedb repository on a GitHub Actions runner.
|
||||
You are running inside the lancedb repository on a GitHub Actions runner. Update the Lance dependency to version ${VERSION} and prepare a pull request for maintainers to review.
|
||||
|
||||
Use \$lancedb-update-lance-dependency with target "${TARGET_TAG}".
|
||||
Follow these steps exactly:
|
||||
1. Use script "ci/set_lance_version.py" to update Lance Rust dependencies. The script already refreshes Cargo metadata, so allow it to finish even if it takes time.
|
||||
2. Update the Java lance-core dependency version in "java/pom.xml": change the "<lance-core.version>...</lance-core.version>" property to "${VERSION}".
|
||||
3. Run "cargo clippy --workspace --tests --all-features -- -D warnings". If diagnostics appear, fix them yourself and rerun clippy until it exits cleanly. Do not skip any warnings.
|
||||
4. After clippy succeeds, run "cargo fmt --all" to format the workspace.
|
||||
5. Ensure the repository is clean except for intentional changes. Inspect "git status --short" and "git diff" to confirm the dependency update and any required fixes.
|
||||
6. Create and switch to a new branch named "${BRANCH_NAME}" (replace any duplicated hyphens if necessary).
|
||||
7. Stage all relevant files with "git add -A". Commit using the message "${COMMIT_TYPE}: update lance dependency to v${VERSION}".
|
||||
8. Push the branch to origin. If the remote branch already exists, delete it first with "gh api -X DELETE repos/lancedb/lancedb/git/refs/heads/${BRANCH_NAME}" then push with "git push origin ${BRANCH_NAME}". Do NOT use "git push --force" or "git push -f".
|
||||
9. env "GH_TOKEN" is available, use "gh" tools for github related operations like creating pull request.
|
||||
10. Create a pull request targeting "main" with title "${COMMIT_TYPE}: update lance dependency to v${VERSION}". First, write the PR body to /tmp/pr-body.md using a heredoc (cat <<'EOF' > /tmp/pr-body.md). The body should summarize the dependency bump, clippy/fmt verification, and link the triggering tag (${TAG}). Then run "gh pr create --body-file /tmp/pr-body.md".
|
||||
11. After creating the PR, display the PR URL, "git status --short", and a concise summary of the commands run and their results.
|
||||
|
||||
Constraints:
|
||||
- Use env "GH_TOKEN" for GitHub operations.
|
||||
- Do not merge the pull request.
|
||||
- Do not force-push.
|
||||
- Do not create a duplicate pull request if an open PR already exists for the target Lance version.
|
||||
- If any command fails, diagnose and fix the root cause instead of aborting.
|
||||
- After creating the PR, display the PR URL, "git status --short", and a concise summary of the commands run and their results.
|
||||
- Use bash commands; avoid modifying GitHub workflow files other than through the scripted task above.
|
||||
- Do not merge the PR.
|
||||
- If any command fails, diagnose and fix the issue instead of aborting.
|
||||
EOF
|
||||
|
||||
printenv OPENAI_API_KEY | codex login --with-api-key
|
||||
codex --config shell_environment_policy.ignore_default_excludes=true exec --dangerously-bypass-approvals-and-sandbox "$(cat /tmp/codex-prompt.txt)"
|
||||
|
||||
- name: Trigger sophon dependency update
|
||||
env:
|
||||
TAG: ${{ inputs.tag }}
|
||||
GH_TOKEN: ${{ secrets.ROBOT_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
VERSION="${TAG#refs/tags/}"
|
||||
VERSION="${VERSION#v}"
|
||||
LANCEDB_BRANCH="codex/update-lance-${VERSION//[^a-zA-Z0-9]/-}"
|
||||
|
||||
echo "Triggering sophon workflow with:"
|
||||
echo " lance_ref: ${TAG#refs/tags/}"
|
||||
echo " lancedb_ref: ${LANCEDB_BRANCH}"
|
||||
|
||||
gh workflow run codex-bump-lancedb-lance.yml \
|
||||
--repo lancedb/sophon \
|
||||
-f lance_ref="${TAG#refs/tags/}" \
|
||||
-f lancedb_ref="${LANCEDB_BRANCH}"
|
||||
|
||||
- name: Show latest sophon workflow run
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.ROBOT_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
echo "Latest sophon workflow run:"
|
||||
gh run list --repo lancedb/sophon --workflow codex-bump-lancedb-lance.yml --limit 1 --json databaseId,url,displayTitle
|
||||
|
||||
62
.github/workflows/lance-release-timer.yml
vendored
Normal file
62
.github/workflows/lance-release-timer.yml
vendored
Normal file
@@ -0,0 +1,62 @@
|
||||
name: Lance Release Timer
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "*/10 * * * *"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
actions: write
|
||||
|
||||
concurrency:
|
||||
group: lance-release-timer
|
||||
cancel-in-progress: false
|
||||
|
||||
jobs:
|
||||
trigger-update:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Check for new Lance tag
|
||||
id: check
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.ROBOT_TOKEN }}
|
||||
run: |
|
||||
python3 ci/check_lance_release.py --github-output "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Look for existing PR
|
||||
if: steps.check.outputs.needs_update == 'true'
|
||||
id: pr
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.ROBOT_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
TITLE="chore: update lance dependency to v${{ steps.check.outputs.latest_version }}"
|
||||
COUNT=$(gh pr list --search "\"$TITLE\" in:title" --state open --limit 1 --json number --jq 'length')
|
||||
if [ "$COUNT" -gt 0 ]; then
|
||||
echo "Open PR already exists for $TITLE"
|
||||
echo "pr_exists=true" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "No existing PR for $TITLE"
|
||||
echo "pr_exists=false" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Trigger codex update workflow
|
||||
if: steps.check.outputs.needs_update == 'true' && steps.pr.outputs.pr_exists != 'true'
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.ROBOT_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
TAG=${{ steps.check.outputs.latest_tag }}
|
||||
gh workflow run codex-update-lance-dependency.yml -f tag=refs/tags/$TAG
|
||||
|
||||
- name: Show latest codex workflow run
|
||||
if: steps.check.outputs.needs_update == 'true' && steps.pr.outputs.pr_exists != 'true'
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.ROBOT_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
gh run list --workflow codex-update-lance-dependency.yml --limit 1 --json databaseId,url,displayTitle
|
||||
5
.github/workflows/nodejs.yml
vendored
5
.github/workflows/nodejs.yml
vendored
@@ -157,10 +157,7 @@ jobs:
|
||||
npx jest --testEnvironment jest-environment-node-single-context --verbose
|
||||
macos:
|
||||
timeout-minutes: 30
|
||||
# macos-15 ships a newer linker; the older macos-14 linker fails to insert
|
||||
# branch islands when the debug cdylib's __text section exceeds the 128 MB
|
||||
# AArch64 B/BL branch range.
|
||||
runs-on: "macos-15"
|
||||
runs-on: "macos-14"
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
110
.github/workflows/pypi-publish.yml
vendored
110
.github/workflows/pypi-publish.yml
vendored
@@ -8,9 +8,6 @@ on:
|
||||
# This should trigger a dry run (we skip the final publish step)
|
||||
paths:
|
||||
- .github/workflows/pypi-publish.yml
|
||||
- .github/workflows/build_linux_wheel/action.yml
|
||||
- .github/workflows/build_mac_wheel/action.yml
|
||||
- .github/workflows/build_windows_wheel/action.yml
|
||||
- Cargo.toml # Change in dependency frequently breaks builds
|
||||
- Cargo.lock
|
||||
|
||||
@@ -24,21 +21,32 @@ jobs:
|
||||
linux:
|
||||
name: Python ${{ matrix.config.platform }} manylinux${{ matrix.config.manylinux }}
|
||||
timeout-minutes: 60
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
strategy:
|
||||
matrix:
|
||||
config:
|
||||
- platform: x86_64
|
||||
manylinux: "2_17"
|
||||
extra_args: ""
|
||||
runner: ubuntu-22.04
|
||||
- platform: x86_64
|
||||
manylinux: "2_28"
|
||||
extra_args: "--features fp16kernels"
|
||||
runner: ubuntu-22.04
|
||||
# For successful fat LTO builds, we need a large runner to avoid OOM errors.
|
||||
- platform: aarch64
|
||||
manylinux: "2_17"
|
||||
extra_args: ""
|
||||
# For successful fat LTO builds, we need a large runner to avoid OOM errors.
|
||||
runner: ubuntu-2404-8x-arm64
|
||||
- platform: aarch64
|
||||
manylinux: "2_28"
|
||||
extra_args: "--features fp16kernels"
|
||||
runner: ubuntu-2404-8x-arm64
|
||||
runs-on: ${{ matrix.config.runner }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
lfs: true
|
||||
@@ -52,14 +60,15 @@ jobs:
|
||||
args: "--release --strip ${{ matrix.config.extra_args }}"
|
||||
arm-build: ${{ matrix.config.platform == 'aarch64' }}
|
||||
manylinux: ${{ matrix.config.manylinux }}
|
||||
- uses: actions/upload-artifact@v7
|
||||
- uses: ./.github/workflows/upload_wheel
|
||||
if: startsWith(github.ref, 'refs/tags/python-v')
|
||||
with:
|
||||
name: wheels-linux-${{ matrix.config.platform }}-${{ matrix.config.manylinux }}
|
||||
path: target/wheels/lancedb-*.whl
|
||||
if-no-files-found: error
|
||||
fury_token: ${{ secrets.FURY_TOKEN }}
|
||||
mac:
|
||||
timeout-minutes: 90
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
runs-on: ${{ matrix.config.runner }}
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -69,7 +78,7 @@ jobs:
|
||||
env:
|
||||
MACOSX_DEPLOYMENT_TARGET: 10.15
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
lfs: true
|
||||
@@ -81,21 +90,18 @@ jobs:
|
||||
with:
|
||||
python-minor-version: 10
|
||||
args: "--release --strip --target ${{ matrix.config.target }} --features fp16kernels"
|
||||
- uses: actions/upload-artifact@v7
|
||||
- uses: ./.github/workflows/upload_wheel
|
||||
if: startsWith(github.ref, 'refs/tags/python-v')
|
||||
with:
|
||||
name: wheels-mac-${{ matrix.config.target }}
|
||||
path: target/wheels/lancedb-*.whl
|
||||
if-no-files-found: error
|
||||
fury_token: ${{ secrets.FURY_TOKEN }}
|
||||
windows:
|
||||
timeout-minutes: 90
|
||||
timeout-minutes: 60
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
runs-on: windows-latest
|
||||
env:
|
||||
# link.exe is single-threaded and the long pole on Windows builds. Use
|
||||
# rustc's bundled lld-link instead.
|
||||
CARGO_TARGET_X86_64_PC_WINDOWS_MSVC_LINKER: rust-lld
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
lfs: true
|
||||
@@ -107,70 +113,18 @@ jobs:
|
||||
with:
|
||||
python-minor-version: 10
|
||||
args: "--release --strip"
|
||||
- uses: actions/upload-artifact@v7
|
||||
vcpkg_token: ${{ secrets.VCPKG_GITHUB_PACKAGES }}
|
||||
- uses: ./.github/workflows/upload_wheel
|
||||
if: startsWith(github.ref, 'refs/tags/python-v')
|
||||
with:
|
||||
name: wheels-windows
|
||||
path: target/wheels/lancedb-*.whl
|
||||
if-no-files-found: error
|
||||
publish:
|
||||
name: Publish wheels
|
||||
if: startsWith(github.ref, 'refs/tags/python-v')
|
||||
needs: [linux, mac, windows]
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Download wheel artifacts
|
||||
uses: actions/download-artifact@v8
|
||||
with:
|
||||
pattern: wheels-*
|
||||
path: target/wheels
|
||||
merge-multiple: true
|
||||
- name: List wheels
|
||||
run: ls -la target/wheels
|
||||
- name: Choose repo
|
||||
id: choose_repo
|
||||
run: |
|
||||
if [[ ${{ github.ref }} == *beta* ]]; then
|
||||
echo "repo=fury" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "repo=pypi" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
- name: Publish to Fury
|
||||
if: steps.choose_repo.outputs.repo == 'fury'
|
||||
env:
|
||||
FURY_TOKEN: ${{ secrets.FURY_TOKEN }}
|
||||
run: |
|
||||
shopt -s nullglob
|
||||
WHEELS=(target/wheels/lancedb-*.whl)
|
||||
if [[ ${#WHEELS[@]} -eq 0 ]]; then
|
||||
echo "No wheels found in target/wheels/" >&2
|
||||
exit 1
|
||||
fi
|
||||
for WHEEL in "${WHEELS[@]}"; do
|
||||
echo "Uploading $WHEEL to Fury"
|
||||
curl -f -F package=@"$WHEEL" "https://$FURY_TOKEN@push.fury.io/lancedb/"
|
||||
done
|
||||
# NOTE: pypa/gh-action-pypi-publish must be invoked directly from a
|
||||
# workflow file, not from inside a composite action. When called from a
|
||||
# composite, `github.action_repository` is empty (actions/runner#2473)
|
||||
# and the action falls back to `github.repository`, producing a bogus
|
||||
# `docker://ghcr.io/<repo>:<ref>` image reference that GHA tries to pull.
|
||||
- name: Publish to PyPI
|
||||
if: steps.choose_repo.outputs.repo == 'pypi'
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
packages-dir: target/wheels/
|
||||
fury_token: ${{ secrets.FURY_TOKEN }}
|
||||
gh-release:
|
||||
if: startsWith(github.ref, 'refs/tags/python-v')
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
lfs: true
|
||||
@@ -233,13 +187,13 @@ jobs:
|
||||
report-failure:
|
||||
name: Report Workflow Failure
|
||||
runs-on: ubuntu-latest
|
||||
needs: [linux, mac, windows, publish]
|
||||
needs: [linux, mac, windows]
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
if: always() && failure() && startsWith(github.ref, 'refs/tags/python-v')
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- uses: ./.github/actions/create-failure-issue
|
||||
with:
|
||||
job-results: ${{ toJSON(needs) }}
|
||||
|
||||
2
.github/workflows/python.yml
vendored
2
.github/workflows/python.yml
vendored
@@ -205,7 +205,7 @@ jobs:
|
||||
- name: Delete wheels
|
||||
run: rm -rf target/wheels
|
||||
pydantic1x:
|
||||
timeout-minutes: 60
|
||||
timeout-minutes: 30
|
||||
runs-on: "ubuntu-24.04"
|
||||
defaults:
|
||||
run:
|
||||
|
||||
20
.github/workflows/rust.yml
vendored
20
.github/workflows/rust.yml
vendored
@@ -233,26 +233,6 @@ jobs:
|
||||
cargo update -p aws-sdk-sso --precise 1.62.0
|
||||
cargo update -p aws-sdk-ssooidc --precise 1.63.0
|
||||
cargo update -p aws-sdk-sts --precise 1.63.0
|
||||
# aws-runtime/sigv4/credential-types/types and the aws-smithy-*
|
||||
# crates bumped their MSRV to 1.91.1 in late 2026; pin to the last
|
||||
# 1.91.0-compatible versions. The order matters — each downgrade
|
||||
# only succeeds once everything that still pins it at a higher
|
||||
# version has itself been downgraded.
|
||||
cargo update -p aws-runtime --precise 1.5.12
|
||||
cargo update -p aws-types --precise 1.3.9
|
||||
cargo update -p aws-sigv4 --precise 1.3.5
|
||||
cargo update -p aws-credential-types --precise 1.2.8
|
||||
cargo update -p aws-smithy-checksums --precise 0.63.9
|
||||
cargo update -p aws-smithy-runtime --precise 1.9.3
|
||||
cargo update -p aws-smithy-http --precise 0.62.4
|
||||
cargo update -p aws-smithy-eventstream --precise 0.60.12
|
||||
cargo update -p aws-smithy-http-client --precise 1.1.3
|
||||
cargo update -p aws-smithy-observability --precise 0.1.4
|
||||
cargo update -p aws-smithy-query --precise 0.60.8
|
||||
cargo update -p aws-smithy-runtime-api --precise 1.9.1
|
||||
cargo update -p aws-smithy-async --precise 1.2.6
|
||||
cargo update -p aws-smithy-types --precise 1.3.5
|
||||
cargo update -p aws-smithy-xml --precise 0.60.11
|
||||
cargo update -p home --precise 0.5.9
|
||||
- name: cargo +${{ matrix.msrv }} check
|
||||
env:
|
||||
|
||||
34
.github/workflows/upload_wheel/action.yml
vendored
Normal file
34
.github/workflows/upload_wheel/action.yml
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
name: upload-wheel
|
||||
|
||||
description: "Upload wheels to Pypi"
|
||||
inputs:
|
||||
fury_token:
|
||||
required: true
|
||||
description: "release token for the fury repo"
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Choose repo
|
||||
shell: bash
|
||||
id: choose_repo
|
||||
run: |
|
||||
if [[ ${{ github.ref }} == *beta* ]]; then
|
||||
echo "repo=fury" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "repo=pypi" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
- name: Publish to Fury
|
||||
if: steps.choose_repo.outputs.repo == 'fury'
|
||||
shell: bash
|
||||
env:
|
||||
FURY_TOKEN: ${{ inputs.fury_token }}
|
||||
run: |
|
||||
WHEEL=$(ls target/wheels/lancedb-*.whl 2> /dev/null | head -n 1)
|
||||
echo "Uploading $WHEEL to Fury"
|
||||
curl -f -F package=@$WHEEL https://$FURY_TOKEN@push.fury.io/lancedb/
|
||||
- name: Publish to PyPI
|
||||
if: steps.choose_repo.outputs.repo == 'pypi'
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
packages-dir: target/wheels/
|
||||
28
AGENTS.md
28
AGENTS.md
@@ -17,33 +17,9 @@ Common commands:
|
||||
* Run tests: `cargo test --quiet --features remote --tests`
|
||||
* Run specific test: `cargo test --quiet --features remote -p <package_name> --test <test_name>`
|
||||
* Lint: `cargo clippy --quiet --features remote --tests --examples`
|
||||
* Format Rust: `cargo fmt --all`
|
||||
* Format Python: `ruff format .`
|
||||
* Lint Python: `ruff check .`
|
||||
* Bootstrap Python dev env: `cd python && uv run --extra tests --extra dev maturin develop --extras tests,dev`
|
||||
* Run Python tests: `cd python && uv run --extra tests pytest python/tests -vv --durations=10 -m "not slow and not s3_test"`
|
||||
* Run specific Python test: `cd python && uv run --extra tests pytest python/tests/<test_file>.py::<test_name> -q`
|
||||
* Format: `cargo fmt --all`
|
||||
|
||||
For Python validation, prefer the uv-managed environment declared by `python/uv.lock`.
|
||||
Do not treat system `python`, global `pytest`, or missing editable-install errors as
|
||||
final blockers; bootstrap or enter the uv environment instead. If `lancedb._lancedb`
|
||||
is missing or stale, or if Rust/PyO3 binding code changed, rebuild the Python
|
||||
extension with the bootstrap command above before running tests.
|
||||
|
||||
Before committing changes, run formatting for every language you touched. At minimum:
|
||||
|
||||
* Rust changes: run `cargo fmt --all`.
|
||||
* Python changes: run `ruff format .` and `ruff check .` from the repository root,
|
||||
and run targeted tests through `cd python && uv run ...`.
|
||||
* TypeScript changes: run the relevant `npm`/`pnpm` lint, format, build, and docs commands in `nodejs`.
|
||||
|
||||
Before creating a PR, the exact value passed to `gh pr create --title` must follow
|
||||
Conventional Commits, such as `fix: support nested field paths in native index creation`
|
||||
or `feat(python): add dataset multiprocessing support`. Do not use a plain natural
|
||||
language summary like `Support nested field paths in native index creation` as the PR
|
||||
title. The semantic-release check uses the PR title and body as the merge commit message,
|
||||
so a non-conventional PR title will fail CI. After creating a PR, read the remote PR title
|
||||
back and fix it immediately if it is not conventional.
|
||||
Before committing changes, run formatting.
|
||||
|
||||
## Coding tips
|
||||
|
||||
|
||||
1565
Cargo.lock
generated
1565
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
28
Cargo.toml
28
Cargo.toml
@@ -13,20 +13,20 @@ categories = ["database-implementations"]
|
||||
rust-version = "1.91.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=7.2.0-beta.3", default-features = false, "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-core = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datagen = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-file = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-io = { "version" = "=7.2.0-beta.3", default-features = false, "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-index = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-linalg = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace-impls = { "version" = "=7.2.0-beta.3", default-features = false, "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-table = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-testing = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datafusion = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-encoding = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-arrow = { "version" = "=7.2.0-beta.3", "tag" = "v7.2.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance = { "version" = "=7.0.0-beta.13", default-features = false, "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-core = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datagen = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-file = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-io = { "version" = "=7.0.0-beta.13", default-features = false, "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-index = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-linalg = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace-impls = { "version" = "=7.0.0-beta.13", default-features = false, "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-table = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-testing = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datafusion = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-encoding = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-arrow = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" }
|
||||
ahash = "0.8"
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "58.0.0", optional = false }
|
||||
|
||||
@@ -112,25 +112,25 @@ def fetch_remote_tags() -> List[TagInfo]:
|
||||
"api",
|
||||
"-X",
|
||||
"GET",
|
||||
f"repos/{LANCE_REPO}/releases",
|
||||
f"repos/{LANCE_REPO}/git/refs/tags",
|
||||
"--paginate",
|
||||
"--jq",
|
||||
".[].tag_name",
|
||||
"-F",
|
||||
"per_page=20",
|
||||
".[].ref",
|
||||
]
|
||||
)
|
||||
tags: List[TagInfo] = []
|
||||
for line in output.splitlines():
|
||||
tag = line.strip()
|
||||
if not tag.startswith("v"):
|
||||
ref = line.strip()
|
||||
if not ref.startswith("refs/tags/v"):
|
||||
continue
|
||||
tag = ref.split("refs/tags/")[-1]
|
||||
version = tag.lstrip("v")
|
||||
try:
|
||||
tags.append(TagInfo(tag=tag, version=version, semver=parse_semver(version)))
|
||||
except ValueError:
|
||||
continue
|
||||
if not tags:
|
||||
raise RuntimeError("No Lance releases could be parsed from GitHub API output")
|
||||
raise RuntimeError("No Lance tags could be parsed from GitHub API output")
|
||||
return tags
|
||||
|
||||
|
||||
|
||||
@@ -1,126 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Prepare a Lance dependency update for LanceDB."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Sequence
|
||||
|
||||
try:
|
||||
from check_lance_release import parse_semver
|
||||
except ModuleNotFoundError:
|
||||
# Supports importing as ci.update_lance_dependency from tests or ad hoc checks.
|
||||
from ci.check_lance_release import parse_semver # type: ignore
|
||||
|
||||
|
||||
def normalize_version(raw: str) -> str:
|
||||
value = raw.strip()
|
||||
value = value.removeprefix("refs/tags/")
|
||||
value = value.removeprefix("v")
|
||||
try:
|
||||
parse_semver(value)
|
||||
except ValueError:
|
||||
raise ValueError(f"Unsupported Lance version or tag: {raw}")
|
||||
return value
|
||||
|
||||
|
||||
def normalized_tag(version: str) -> str:
|
||||
return f"v{version}"
|
||||
|
||||
|
||||
def branch_name(version: str) -> str:
|
||||
suffix = re.sub(r"[^a-zA-Z0-9]+", "-", version).strip("-")
|
||||
suffix = re.sub(r"-+", "-", suffix)
|
||||
return f"codex/update-lance-{suffix}"
|
||||
|
||||
|
||||
def commit_type(version: str) -> str:
|
||||
prerelease = version.split("-", maxsplit=1)[1] if "-" in version else ""
|
||||
return "chore" if "beta" in prerelease or "rc" in prerelease else "feat"
|
||||
|
||||
|
||||
def metadata_for(version: str) -> dict[str, str]:
|
||||
kind = commit_type(version)
|
||||
message = f"{kind}: update lance dependency to v{version}"
|
||||
return {
|
||||
"version": version,
|
||||
"tag": normalized_tag(version),
|
||||
"branch_name": branch_name(version),
|
||||
"commit_type": kind,
|
||||
"commit_message": message,
|
||||
"pr_title": message,
|
||||
}
|
||||
|
||||
|
||||
def run_command(cmd: Sequence[str], *, cwd: Path) -> None:
|
||||
subprocess.run(cmd, cwd=cwd, check=True)
|
||||
|
||||
|
||||
def update_java_lance_core_version(repo_root: Path, version: str) -> None:
|
||||
pom_path = repo_root / "java" / "pom.xml"
|
||||
contents = pom_path.read_text(encoding="utf-8")
|
||||
updated, count = re.subn(
|
||||
r"(<lance-core\.version>)[^<]+(</lance-core\.version>)",
|
||||
rf"\g<1>{version}\g<2>",
|
||||
contents,
|
||||
count=1,
|
||||
)
|
||||
if count != 1:
|
||||
raise RuntimeError(
|
||||
"Expected exactly one <lance-core.version> entry in java/pom.xml"
|
||||
)
|
||||
pom_path.write_text(updated, encoding="utf-8")
|
||||
|
||||
|
||||
def write_github_outputs(path: str | None, payload: dict[str, str]) -> None:
|
||||
if not path:
|
||||
return
|
||||
with open(path, "a", encoding="utf-8") as output:
|
||||
for key, value in payload.items():
|
||||
output.write(f"{key}={value}\n")
|
||||
|
||||
|
||||
def main(argv: Sequence[str] | None = None) -> int:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"tag_or_version",
|
||||
help="Lance tag or version, for example refs/tags/v7.2.0-beta.1 or 7.2.0",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-root",
|
||||
type=Path,
|
||||
default=Path(__file__).resolve().parents[1],
|
||||
help="Path to the lancedb repository root",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--github-output",
|
||||
default=None,
|
||||
help="Optional GitHub Actions output file to receive metadata fields",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata-only",
|
||||
action="store_true",
|
||||
help="Only print derived metadata; do not modify dependency files",
|
||||
)
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
repo_root = args.repo_root.resolve()
|
||||
version = normalize_version(args.tag_or_version)
|
||||
payload = metadata_for(version)
|
||||
|
||||
if not args.metadata_only:
|
||||
run_command([sys.executable, "ci/set_lance_version.py", version], cwd=repo_root)
|
||||
update_java_lance_core_version(repo_root, version)
|
||||
|
||||
write_github_outputs(args.github_output, payload)
|
||||
print(json.dumps(payload, sort_keys=True))
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
|
||||
<dependency>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-core</artifactId>
|
||||
<version>0.30.1-beta.0</version>
|
||||
<version>0.29.1-beta.0</version>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / BranchContents
|
||||
|
||||
# Class: BranchContents
|
||||
|
||||
## Constructors
|
||||
|
||||
### new BranchContents()
|
||||
|
||||
```ts
|
||||
new BranchContents(): BranchContents
|
||||
```
|
||||
|
||||
#### Returns
|
||||
|
||||
[`BranchContents`](BranchContents.md)
|
||||
|
||||
## Properties
|
||||
|
||||
### manifestSize
|
||||
|
||||
```ts
|
||||
manifestSize: number;
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
### parentBranch?
|
||||
|
||||
```ts
|
||||
optional parentBranch: string;
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
### parentVersion
|
||||
|
||||
```ts
|
||||
parentVersion: number;
|
||||
```
|
||||
@@ -1,90 +0,0 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / Branches
|
||||
|
||||
# Class: Branches
|
||||
|
||||
Branch manager for a [Table](Table.md).
|
||||
|
||||
Unlike tags, `create` and `checkout` return a new [Table](Table.md) handle scoped
|
||||
to the branch; writes on it do not affect `main`.
|
||||
|
||||
## Methods
|
||||
|
||||
### checkout()
|
||||
|
||||
```ts
|
||||
checkout(name): Promise<Table>
|
||||
```
|
||||
|
||||
Check out an existing branch and return a handle scoped to it.
|
||||
|
||||
#### Parameters
|
||||
|
||||
* **name**: `string`
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<[`Table`](Table.md)>
|
||||
|
||||
***
|
||||
|
||||
### create()
|
||||
|
||||
```ts
|
||||
create(
|
||||
name,
|
||||
fromRef?,
|
||||
fromVersion?): Promise<Table>
|
||||
```
|
||||
|
||||
Create a branch and return a handle scoped to it.
|
||||
|
||||
#### Parameters
|
||||
|
||||
* **name**: `string`
|
||||
Name of the new branch.
|
||||
|
||||
* **fromRef?**: `string`
|
||||
Source branch to fork from. Defaults to `main`.
|
||||
|
||||
* **fromVersion?**: `number`
|
||||
A specific version on `fromRef`. Defaults to latest.
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<[`Table`](Table.md)>
|
||||
|
||||
***
|
||||
|
||||
### delete()
|
||||
|
||||
```ts
|
||||
delete(name): Promise<void>
|
||||
```
|
||||
|
||||
Delete a branch.
|
||||
|
||||
#### Parameters
|
||||
|
||||
* **name**: `string`
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<`void`>
|
||||
|
||||
***
|
||||
|
||||
### list()
|
||||
|
||||
```ts
|
||||
list(): Promise<Record<string, BranchContents>>
|
||||
```
|
||||
|
||||
List all branches, mapping name to branch metadata.
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<`Record`<`string`, [`BranchContents`](BranchContents.md)>>
|
||||
@@ -441,28 +441,18 @@ Open a table in the database.
|
||||
|
||||
```ts
|
||||
abstract renameTable(
|
||||
currentName,
|
||||
oldName,
|
||||
newName,
|
||||
options?): Promise<void>
|
||||
namespacePath?): Promise<void>
|
||||
```
|
||||
|
||||
Rename a table.
|
||||
|
||||
Currently only supported by LanceDB Cloud. Local OSS connections and
|
||||
namespace-backed connections (via [connectNamespace](../functions/connectNamespace.md)) reject with
|
||||
a "not supported" error.
|
||||
|
||||
#### Parameters
|
||||
|
||||
* **currentName**: `string`
|
||||
The current name of the table.
|
||||
* **oldName**: `string`
|
||||
|
||||
* **newName**: `string`
|
||||
The new name for the table.
|
||||
|
||||
* **options?**: [`RenameTableOptions`](../interfaces/RenameTableOptions.md)
|
||||
Optional namespace paths. When
|
||||
`newNamespacePath` is omitted the table stays in `namespacePath`.
|
||||
* **namespacePath?**: `string`[]
|
||||
|
||||
#### Returns
|
||||
|
||||
|
||||
@@ -76,57 +76,6 @@ the query optimizer chooses a suboptimal path.
|
||||
|
||||
***
|
||||
|
||||
### useLsmWrite()
|
||||
|
||||
```ts
|
||||
useLsmWrite(useLsmWrite): MergeInsertBuilder
|
||||
```
|
||||
|
||||
Controls whether the merge uses the MemWAL LSM write path.
|
||||
|
||||
By default (unset), a `mergeInsert` on a table with an LSM write spec is
|
||||
routed through Lance's MemWAL shard writer, and a table without one uses
|
||||
the standard path. Pass `false` to force the standard path even when a
|
||||
spec is set. Pass `true` to require a spec — `mergeInsert` rejects if none
|
||||
is installed.
|
||||
|
||||
#### Parameters
|
||||
|
||||
* **useLsmWrite**: `boolean`
|
||||
Whether to use the LSM write path.
|
||||
|
||||
#### Returns
|
||||
|
||||
[`MergeInsertBuilder`](MergeInsertBuilder.md)
|
||||
|
||||
***
|
||||
|
||||
### validateSingleShard()
|
||||
|
||||
```ts
|
||||
validateSingleShard(validateSingleShard): MergeInsertBuilder
|
||||
```
|
||||
|
||||
Controls how an LSM merge checks that its input targets a single shard.
|
||||
|
||||
When a table has an LSM write spec, every row in a `mergeInsert` call must
|
||||
route to the same shard. When `true` (the default), every row is inspected
|
||||
to verify this. When `false`, only the first row is inspected and the
|
||||
shard it routes to is used for the whole input — a faster path for callers
|
||||
that have already pre-sharded their input. Has no effect on tables without
|
||||
an LSM write spec.
|
||||
|
||||
#### Parameters
|
||||
|
||||
* **validateSingleShard**: `boolean`
|
||||
Whether to check every row routes to one shard. Defaults to `true`.
|
||||
|
||||
#### Returns
|
||||
|
||||
[`MergeInsertBuilder`](MergeInsertBuilder.md)
|
||||
|
||||
***
|
||||
|
||||
### whenMatchedUpdateAll()
|
||||
|
||||
```ts
|
||||
|
||||
@@ -110,23 +110,6 @@ containing the new version number of the table after altering the columns.
|
||||
|
||||
***
|
||||
|
||||
### branches()
|
||||
|
||||
```ts
|
||||
abstract branches(): Promise<Branches>
|
||||
```
|
||||
|
||||
Get the branch manager for this table.
|
||||
|
||||
Branches are isolated, writable lines of history forked from another
|
||||
branch (or version). Writes on a branch do not affect `main`.
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<[`Branches`](Branches.md)>
|
||||
|
||||
***
|
||||
|
||||
### checkout()
|
||||
|
||||
```ts
|
||||
@@ -204,25 +187,6 @@ Any attempt to use the table after it is closed will result in an error.
|
||||
|
||||
***
|
||||
|
||||
### closeLsmWriters()
|
||||
|
||||
```ts
|
||||
abstract closeLsmWriters(): Promise<void>
|
||||
```
|
||||
|
||||
Drain and close any cached MemWAL shard writers held for this table.
|
||||
|
||||
When an [LsmWriteSpec](../interfaces/LsmWriteSpec.md) is installed, `mergeInsert` opens MemWAL
|
||||
shard writers and caches them for reuse across calls. This closes them,
|
||||
flushing pending data; writers reopen lazily on the next `mergeInsert`.
|
||||
It is a no-op when no writers are cached.
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<`void`>
|
||||
|
||||
***
|
||||
|
||||
### countRows()
|
||||
|
||||
```ts
|
||||
@@ -1011,29 +975,6 @@ based on the row being updated (e.g. "my_col + 1")
|
||||
|
||||
***
|
||||
|
||||
### updateFieldMetadata()
|
||||
|
||||
```ts
|
||||
abstract updateFieldMetadata(updates): Promise<UpdateFieldMetadataResult>
|
||||
```
|
||||
|
||||
Update per-field (column) metadata.
|
||||
|
||||
#### Parameters
|
||||
|
||||
* **updates**: [`FieldMetadataUpdate`](../interfaces/FieldMetadataUpdate.md)[]
|
||||
One or more per-field updates. Each
|
||||
update's metadata is merged into the field's existing metadata by default;
|
||||
a value of `null` deletes that key, and `replace: true` swaps the whole map.
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<[`UpdateFieldMetadataResult`](../interfaces/UpdateFieldMetadataResult.md)>
|
||||
|
||||
resolves to the new table version.
|
||||
|
||||
***
|
||||
|
||||
### vectorSearch()
|
||||
|
||||
```ts
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
## Enumerations
|
||||
|
||||
- [FullTextQueryType](enumerations/FullTextQueryType.md)
|
||||
- [OAuthFlowType](enumerations/OAuthFlowType.md)
|
||||
- [Occur](enumerations/Occur.md)
|
||||
- [Operator](enumerations/Operator.md)
|
||||
|
||||
@@ -19,8 +20,6 @@
|
||||
|
||||
- [BooleanQuery](classes/BooleanQuery.md)
|
||||
- [BoostQuery](classes/BoostQuery.md)
|
||||
- [BranchContents](classes/BranchContents.md)
|
||||
- [Branches](classes/Branches.md)
|
||||
- [Connection](classes/Connection.md)
|
||||
- [HeaderProvider](classes/HeaderProvider.md)
|
||||
- [Index](classes/Index.md)
|
||||
@@ -67,7 +66,6 @@
|
||||
- [DropNamespaceOptions](interfaces/DropNamespaceOptions.md)
|
||||
- [DropNamespaceResponse](interfaces/DropNamespaceResponse.md)
|
||||
- [ExecutableQuery](interfaces/ExecutableQuery.md)
|
||||
- [FieldMetadataUpdate](interfaces/FieldMetadataUpdate.md)
|
||||
- [FragmentStatistics](interfaces/FragmentStatistics.md)
|
||||
- [FragmentSummaryStats](interfaces/FragmentSummaryStats.md)
|
||||
- [FtsOptions](interfaces/FtsOptions.md)
|
||||
@@ -85,12 +83,13 @@
|
||||
- [ListNamespacesResponse](interfaces/ListNamespacesResponse.md)
|
||||
- [LsmWriteSpec](interfaces/LsmWriteSpec.md)
|
||||
- [MergeResult](interfaces/MergeResult.md)
|
||||
- [NativeOAuthConfig](interfaces/NativeOAuthConfig.md)
|
||||
- [OAuthConfig](interfaces/OAuthConfig.md)
|
||||
- [OpenTableOptions](interfaces/OpenTableOptions.md)
|
||||
- [OptimizeOptions](interfaces/OptimizeOptions.md)
|
||||
- [OptimizeStats](interfaces/OptimizeStats.md)
|
||||
- [QueryExecutionOptions](interfaces/QueryExecutionOptions.md)
|
||||
- [RemovalStats](interfaces/RemovalStats.md)
|
||||
- [RenameTableOptions](interfaces/RenameTableOptions.md)
|
||||
- [RestNamespaceConfig](interfaces/RestNamespaceConfig.md)
|
||||
- [RetryConfig](interfaces/RetryConfig.md)
|
||||
- [ScannableOptions](interfaces/ScannableOptions.md)
|
||||
@@ -104,12 +103,10 @@
|
||||
- [TimeoutConfig](interfaces/TimeoutConfig.md)
|
||||
- [TlsConfig](interfaces/TlsConfig.md)
|
||||
- [TokenResponse](interfaces/TokenResponse.md)
|
||||
- [UpdateFieldMetadataResult](interfaces/UpdateFieldMetadataResult.md)
|
||||
- [UpdateOptions](interfaces/UpdateOptions.md)
|
||||
- [UpdateResult](interfaces/UpdateResult.md)
|
||||
- [Version](interfaces/Version.md)
|
||||
- [WriteExecutionOptions](interfaces/WriteExecutionOptions.md)
|
||||
- [WriteProgress](interfaces/WriteProgress.md)
|
||||
|
||||
## Type Aliases
|
||||
|
||||
|
||||
@@ -19,39 +19,3 @@ mode: "append" | "overwrite";
|
||||
If "append" (the default) then the new data will be added to the table
|
||||
|
||||
If "overwrite" then the new data will replace the existing data in the table.
|
||||
|
||||
***
|
||||
|
||||
### progress()
|
||||
|
||||
```ts
|
||||
progress: (progress) => void;
|
||||
```
|
||||
|
||||
Optional callback invoked periodically with write progress.
|
||||
|
||||
The callback is fired once per batch written and once more with
|
||||
`done: true` when the write completes. Calls are dispatched
|
||||
asynchronously to the JS event loop and never block the write — a slow
|
||||
callback will queue events rather than back-pressure the writer.
|
||||
|
||||
Errors thrown from the callback are logged with `console.warn` and
|
||||
swallowed — they do not abort the write.
|
||||
|
||||
#### Parameters
|
||||
|
||||
* **progress**: [`WriteProgress`](WriteProgress.md)
|
||||
|
||||
#### Returns
|
||||
|
||||
`void`
|
||||
|
||||
#### Example
|
||||
|
||||
```ts
|
||||
await table.add(data, {
|
||||
progress: (p) => {
|
||||
console.log(`${p.outputRows}/${p.totalRows ?? "?"} rows`);
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
@@ -64,26 +64,34 @@ client used by manifest-enabled native connections.
|
||||
|
||||
***
|
||||
|
||||
### oauthConfig?
|
||||
|
||||
```ts
|
||||
optional oauthConfig: NativeOAuthConfig;
|
||||
```
|
||||
|
||||
(For LanceDB cloud only): OAuth configuration for IdP-based
|
||||
authentication (e.g., Azure Entra ID). When set, token acquisition
|
||||
and refresh are handled entirely in Rust.
|
||||
|
||||
***
|
||||
|
||||
### readConsistencyInterval?
|
||||
|
||||
```ts
|
||||
optional readConsistencyInterval: number;
|
||||
```
|
||||
|
||||
The interval, in seconds, at which to check for updates to the table
|
||||
from other processes. If None, then consistency is not checked. For
|
||||
performance reasons, this is the default. For strong consistency, set
|
||||
this to zero seconds. Then every read will check for updates from other
|
||||
processes. As a compromise, you can set this to a non-zero value for
|
||||
eventual consistency. If more than that interval has passed since the
|
||||
last check, then the table will be checked for updates. Note: this
|
||||
consistency only applies to read operations. Write operations are
|
||||
(For LanceDB OSS only): The interval, in seconds, at which to check for
|
||||
updates to the table from other processes. If None, then consistency is not
|
||||
checked. For performance reasons, this is the default. For strong
|
||||
consistency, set this to zero seconds. Then every read will check for
|
||||
updates from other processes. As a compromise, you can set this to a
|
||||
non-zero value for eventual consistency. If more than that interval
|
||||
has passed since the last check, then the table will be checked for updates.
|
||||
Note: this consistency only applies to read operations. Write operations are
|
||||
always consistent.
|
||||
|
||||
Stronger consistency is not free. The smaller the interval, the more
|
||||
often each read pays the cost of checking for updates against object
|
||||
storage, raising per-read latency and cost.
|
||||
|
||||
***
|
||||
|
||||
### region?
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / FieldMetadataUpdate
|
||||
|
||||
# Interface: FieldMetadataUpdate
|
||||
|
||||
A per-field metadata update, addressed by dot-path.
|
||||
|
||||
## Properties
|
||||
|
||||
### metadata
|
||||
|
||||
```ts
|
||||
metadata: Record<string, null | string>;
|
||||
```
|
||||
|
||||
Metadata key/value pairs. Merged into the field's existing metadata by
|
||||
default; a value of `null` deletes that key.
|
||||
|
||||
***
|
||||
|
||||
### path
|
||||
|
||||
```ts
|
||||
path: string;
|
||||
```
|
||||
|
||||
Dot-separated path to the field. For a top-level column this is just its
|
||||
name; for a nested field it's the path, e.g. "a.b.c".
|
||||
|
||||
***
|
||||
|
||||
### replace?
|
||||
|
||||
```ts
|
||||
optional replace: boolean;
|
||||
```
|
||||
|
||||
If true, replace the field's entire metadata map instead of merging.
|
||||
@@ -11,10 +11,7 @@ Specification selecting Lance's MemWAL LSM-style write path for
|
||||
|
||||
`specType` is `"bucket"`, `"identity"`, or `"unsharded"`. For `"bucket"`,
|
||||
`column` and `numBuckets` are required; for `"identity"`, `column` is
|
||||
required and must be a deterministic function of the unenforced primary
|
||||
key (every row with a given primary key must always produce the same
|
||||
`column` value, or upserts of that key can land in different shards and a
|
||||
stale version can win).
|
||||
required.
|
||||
|
||||
## Properties
|
||||
|
||||
|
||||
@@ -32,14 +32,6 @@ numInsertedRows: number;
|
||||
|
||||
***
|
||||
|
||||
### numRows
|
||||
|
||||
```ts
|
||||
numRows: number;
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
### numUpdatedRows
|
||||
|
||||
```ts
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / RenameTableOptions
|
||||
|
||||
# Interface: RenameTableOptions
|
||||
|
||||
## Properties
|
||||
|
||||
### namespacePath?
|
||||
|
||||
```ts
|
||||
optional namespacePath: string[];
|
||||
```
|
||||
|
||||
The namespace path of the table being renamed. Defaults to the root
|
||||
namespace (`[]`) when omitted.
|
||||
|
||||
***
|
||||
|
||||
### newNamespacePath?
|
||||
|
||||
```ts
|
||||
optional newNamespacePath: string[];
|
||||
```
|
||||
|
||||
The namespace path to move the table to as part of the rename. When
|
||||
omitted the table stays in `namespacePath`.
|
||||
@@ -1,15 +0,0 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / UpdateFieldMetadataResult
|
||||
|
||||
# Interface: UpdateFieldMetadataResult
|
||||
|
||||
## Properties
|
||||
|
||||
### version
|
||||
|
||||
```ts
|
||||
version: number;
|
||||
```
|
||||
@@ -1,84 +0,0 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / WriteProgress
|
||||
|
||||
# Interface: WriteProgress
|
||||
|
||||
Progress snapshot for a write operation, delivered to the `progress`
|
||||
callback passed to [Table.add](../classes/Table.md#add).
|
||||
|
||||
## Properties
|
||||
|
||||
### activeTasks
|
||||
|
||||
```ts
|
||||
activeTasks: number;
|
||||
```
|
||||
|
||||
Number of parallel write tasks currently in flight.
|
||||
|
||||
***
|
||||
|
||||
### done
|
||||
|
||||
```ts
|
||||
done: boolean;
|
||||
```
|
||||
|
||||
`true` for the final callback; `false` otherwise.
|
||||
|
||||
***
|
||||
|
||||
### elapsedSeconds
|
||||
|
||||
```ts
|
||||
elapsedSeconds: number;
|
||||
```
|
||||
|
||||
Wall-clock seconds since the write started.
|
||||
|
||||
***
|
||||
|
||||
### outputBytes
|
||||
|
||||
```ts
|
||||
outputBytes: number;
|
||||
```
|
||||
|
||||
Number of bytes written so far.
|
||||
|
||||
***
|
||||
|
||||
### outputRows
|
||||
|
||||
```ts
|
||||
outputRows: number;
|
||||
```
|
||||
|
||||
Number of rows written so far.
|
||||
|
||||
***
|
||||
|
||||
### totalRows?
|
||||
|
||||
```ts
|
||||
optional totalRows: number;
|
||||
```
|
||||
|
||||
Total rows expected, when the input source reports it.
|
||||
|
||||
Always set on the final callback (the one with `done: true`), falling
|
||||
back to the actual number of rows written when the source could not
|
||||
report a row count up front.
|
||||
|
||||
***
|
||||
|
||||
### totalTasks
|
||||
|
||||
```ts
|
||||
totalTasks: number;
|
||||
```
|
||||
|
||||
Total number of parallel write tasks (the write parallelism).
|
||||
@@ -166,12 +166,6 @@ lists the indices that LanceDb supports.
|
||||
|
||||
::: lancedb.index.IvfFlat
|
||||
|
||||
::: lancedb.index.IvfSq
|
||||
|
||||
::: lancedb.index.IvfRq
|
||||
|
||||
::: lancedb.index.HnswFlat
|
||||
|
||||
::: lancedb.table.IndexStatistics
|
||||
|
||||
## Querying (Asynchronous)
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.30.1-beta.0</version>
|
||||
<version>0.29.1-beta.0</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.30.1-beta.0</version>
|
||||
<version>0.29.1-beta.0</version>
|
||||
<packaging>pom</packaging>
|
||||
<name>${project.artifactId}</name>
|
||||
<description>LanceDB Java SDK Parent POM</description>
|
||||
@@ -28,7 +28,7 @@
|
||||
<properties>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<arrow.version>15.0.0</arrow.version>
|
||||
<lance-core.version>7.2.0-beta.1</lance-core.version>
|
||||
<lance-core.version>7.0.0-beta.13</lance-core.version>
|
||||
<spotless.skip>false</spotless.skip>
|
||||
<spotless.version>2.30.0</spotless.version>
|
||||
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "lancedb-nodejs"
|
||||
edition.workspace = true
|
||||
version = "0.30.1-beta.0"
|
||||
version = "0.29.1-beta.0"
|
||||
publish = false
|
||||
license.workspace = true
|
||||
description.workspace = true
|
||||
|
||||
@@ -47,14 +47,6 @@ describe("given a connection", () => {
|
||||
await db.close();
|
||||
expect(db.isOpen()).toBe(false);
|
||||
await expect(db.tableNames()).rejects.toThrow("Connection is closed");
|
||||
await expect(db.renameTable("a", "b")).rejects.toThrow(
|
||||
"Connection is closed",
|
||||
);
|
||||
});
|
||||
|
||||
it("should report renameTable as unsupported on an OSS connection", async () => {
|
||||
await db.createTable("a", [{ id: 1 }]);
|
||||
await expect(db.renameTable("a", "b")).rejects.toThrow(/not supported/);
|
||||
});
|
||||
it("should be able to create a table from an object arg `createTable(options)`, or args `createTable(name, data, options)`", async () => {
|
||||
let tbl = await db.createTable("test", [{ id: 1 }, { id: 2 }]);
|
||||
@@ -89,6 +81,16 @@ describe("given a connection", () => {
|
||||
await db.createTable("test4", [{ id: 1 }, { id: 2 }]);
|
||||
});
|
||||
|
||||
it("should expose renameTable and reject on OSS listing DB", async () => {
|
||||
await db.createTable("old_name", [{ id: 1 }]);
|
||||
|
||||
await expect(db.renameTable("old_name", "new_name")).rejects.toThrow(
|
||||
"rename_table is not supported in LanceDB OSS",
|
||||
);
|
||||
|
||||
await expect(db.tableNames()).resolves.toEqual(["old_name"]);
|
||||
});
|
||||
|
||||
it("should fail if creating table twice, unless overwrite is true", async () => {
|
||||
let tbl = await db.createTable("test", [{ id: 1 }, { id: 2 }]);
|
||||
await expect(tbl.countRows()).resolves.toBe(2);
|
||||
@@ -171,22 +173,18 @@ describe("given a connection", () => {
|
||||
|
||||
let manifestDir =
|
||||
tmpDir.name + "/test_manifest_paths_v2_empty.lance/_versions";
|
||||
readdirSync(manifestDir)
|
||||
.filter((f) => f.endsWith(".manifest"))
|
||||
.forEach((file) => {
|
||||
expect(file).toMatch(/^\d{20}\.manifest$/);
|
||||
});
|
||||
readdirSync(manifestDir).forEach((file) => {
|
||||
expect(file).toMatch(/^\d{20}\.manifest$/);
|
||||
});
|
||||
|
||||
table = (await db.createTable("test_manifest_paths_v2", [{ id: 1 }], {
|
||||
enableV2ManifestPaths: true,
|
||||
})) as LocalTable;
|
||||
expect(await table.usesV2ManifestPaths()).toBe(true);
|
||||
manifestDir = tmpDir.name + "/test_manifest_paths_v2.lance/_versions";
|
||||
readdirSync(manifestDir)
|
||||
.filter((f) => f.endsWith(".manifest"))
|
||||
.forEach((file) => {
|
||||
expect(file).toMatch(/^\d{20}\.manifest$/);
|
||||
});
|
||||
readdirSync(manifestDir).forEach((file) => {
|
||||
expect(file).toMatch(/^\d{20}\.manifest$/);
|
||||
});
|
||||
});
|
||||
|
||||
it("should be able to migrate tables to the V2 manifest paths", async () => {
|
||||
@@ -203,20 +201,16 @@ describe("given a connection", () => {
|
||||
|
||||
const manifestDir =
|
||||
tmpDir.name + "/test_manifest_path_migration.lance/_versions";
|
||||
readdirSync(manifestDir)
|
||||
.filter((f) => f.endsWith(".manifest"))
|
||||
.forEach((file) => {
|
||||
expect(file).toMatch(/^\d\.manifest$/);
|
||||
});
|
||||
readdirSync(manifestDir).forEach((file) => {
|
||||
expect(file).toMatch(/^\d\.manifest$/);
|
||||
});
|
||||
|
||||
await table.migrateManifestPathsV2();
|
||||
expect(await table.usesV2ManifestPaths()).toBe(true);
|
||||
|
||||
readdirSync(manifestDir)
|
||||
.filter((f) => f.endsWith(".manifest"))
|
||||
.forEach((file) => {
|
||||
expect(file).toMatch(/^\d{20}\.manifest$/);
|
||||
});
|
||||
readdirSync(manifestDir).forEach((file) => {
|
||||
expect(file).toMatch(/^\d{20}\.manifest$/);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -617,68 +617,4 @@ describe("remote connection", () => {
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("renameTable", () => {
|
||||
async function captureRenameRequest(
|
||||
call: (db: Connection) => Promise<void>,
|
||||
): Promise<{ url: string; body: Record<string, unknown> }> {
|
||||
let captured: { url: string; body: Record<string, unknown> } | undefined;
|
||||
await withMockDatabase((req, res) => {
|
||||
let raw = "";
|
||||
req.on("data", (chunk) => {
|
||||
raw += chunk;
|
||||
});
|
||||
req.on("end", () => {
|
||||
captured = {
|
||||
url: req.url ?? "",
|
||||
body: raw ? JSON.parse(raw) : {},
|
||||
};
|
||||
res.writeHead(200, { "Content-Type": "application/json" }).end("");
|
||||
});
|
||||
}, call);
|
||||
if (!captured) {
|
||||
throw new Error("mock server never saw a request");
|
||||
}
|
||||
return captured;
|
||||
}
|
||||
|
||||
it("sends rename request for a table in the root namespace", async () => {
|
||||
const { url, body } = await captureRenameRequest(async (db) => {
|
||||
await db.renameTable("table1", "table2");
|
||||
});
|
||||
expect(url).toBe("/v1/table/table1/rename/");
|
||||
// biome-ignore lint/style/useNamingConvention: snake_case mandated by the server wire format
|
||||
expect(body).toEqual({ new_table_name: "table2" });
|
||||
});
|
||||
|
||||
it("omits new_namespace when only the current namespace is supplied", async () => {
|
||||
// Safe-default check: passing namespacePath alone must not send
|
||||
// `new_namespace`, so the server keeps the table in its current
|
||||
// namespace instead of silently moving it to root.
|
||||
const { url, body } = await captureRenameRequest(async (db) => {
|
||||
await db.renameTable("table1", "table2", {
|
||||
namespacePath: ["ns1"],
|
||||
});
|
||||
});
|
||||
expect(url).toBe("/v1/table/ns1$table1/rename/");
|
||||
// biome-ignore lint/style/useNamingConvention: snake_case mandated by the server wire format
|
||||
expect(body).toEqual({ new_table_name: "table2" });
|
||||
});
|
||||
|
||||
it("includes new_namespace in the body for a cross-namespace rename", async () => {
|
||||
const { url, body } = await captureRenameRequest(async (db) => {
|
||||
await db.renameTable("table1", "table2", {
|
||||
namespacePath: ["ns1"],
|
||||
newNamespacePath: ["ns2"],
|
||||
});
|
||||
});
|
||||
expect(url).toBe("/v1/table/ns1$table1/rename/");
|
||||
expect(body).toEqual({
|
||||
// biome-ignore lint/style/useNamingConvention: snake_case mandated by the server wire format
|
||||
new_table_name: "table2",
|
||||
// biome-ignore lint/style/useNamingConvention: snake_case mandated by the server wire format
|
||||
new_namespace: ["ns2"],
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -28,7 +28,6 @@ import {
|
||||
List,
|
||||
Schema,
|
||||
SchemaLike,
|
||||
Struct,
|
||||
Type,
|
||||
Uint8,
|
||||
Utf8,
|
||||
@@ -85,39 +84,6 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
||||
await expect(table.countRows()).resolves.toBe(3);
|
||||
});
|
||||
|
||||
it("should support branches", async () => {
|
||||
await table.add([{ id: 1 }]);
|
||||
expect(await table.countRows()).toBe(1);
|
||||
|
||||
// fork an isolated, writable branch from main
|
||||
const branch = await (await table.branches()).create("exp");
|
||||
expect(await branch.countRows()).toBe(1);
|
||||
await branch.add([{ id: 2 }]);
|
||||
expect(await branch.countRows()).toBe(2);
|
||||
// main is untouched by branch writes
|
||||
expect(await table.countRows()).toBe(1);
|
||||
|
||||
// listed, with main (null) as the parent
|
||||
const list = await (await table.branches()).list();
|
||||
expect(Object.keys(list)).toContain("exp");
|
||||
expect(list["exp"].parentBranch).toBeNull();
|
||||
|
||||
// fromRef="main" is equivalent to the default
|
||||
await (await table.branches()).create("exp2", "main");
|
||||
const list2 = await (await table.branches()).list();
|
||||
expect(list2["exp2"].parentBranch).toBeNull();
|
||||
|
||||
// checkout returns a handle scoped to the branch's latest
|
||||
const checkedOut = await (await table.branches()).checkout("exp");
|
||||
expect(await checkedOut.countRows()).toBe(2);
|
||||
|
||||
// delete removes it
|
||||
await (await table.branches()).delete("exp");
|
||||
await (await table.branches()).delete("exp2");
|
||||
const after = await (await table.branches()).list();
|
||||
expect(Object.keys(after)).not.toContain("exp");
|
||||
});
|
||||
|
||||
it("should show table stats", async () => {
|
||||
await table.add([{ id: 1 }, { id: 2 }]);
|
||||
await table.add([{ id: 1 }]);
|
||||
@@ -149,48 +115,6 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
||||
await expect(table.countRows()).resolves.toBe(1);
|
||||
});
|
||||
|
||||
it("should invoke the progress callback", async () => {
|
||||
const events: import("../lancedb").WriteProgress[] = [];
|
||||
await table.add([{ id: 1 }, { id: 2 }, { id: 3 }], {
|
||||
progress: (p) => events.push(p),
|
||||
});
|
||||
|
||||
expect(events.length).toBeGreaterThan(0);
|
||||
const last = events[events.length - 1];
|
||||
expect(last.done).toBe(true);
|
||||
// Earlier callbacks must have done=false.
|
||||
for (const ev of events.slice(0, -1)) {
|
||||
expect(ev.done).toBe(false);
|
||||
}
|
||||
// outputRows reflects the rows added in this call, not table size.
|
||||
expect(last.outputRows).toBe(3);
|
||||
// The input source (an array) reports a row count, so totalRows is set.
|
||||
expect(last.totalRows).toBe(3);
|
||||
// outputRows is monotonic.
|
||||
for (let i = 1; i < events.length; i++) {
|
||||
expect(events[i].outputRows).toBeGreaterThanOrEqual(
|
||||
events[i - 1].outputRows,
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
it("should swallow errors thrown from the progress callback", async () => {
|
||||
const warn = jest
|
||||
.spyOn(console, "warn")
|
||||
.mockImplementation(() => undefined);
|
||||
try {
|
||||
const res = await table.add([{ id: 1 }, { id: 2 }], {
|
||||
progress: () => {
|
||||
throw new Error("callback bomb");
|
||||
},
|
||||
});
|
||||
expect(res.version).toBeGreaterThan(0);
|
||||
expect(warn).toHaveBeenCalled();
|
||||
} finally {
|
||||
warn.mockRestore();
|
||||
}
|
||||
});
|
||||
|
||||
it("should let me close the table", async () => {
|
||||
expect(table.isOpen()).toBe(true);
|
||||
table.close();
|
||||
@@ -814,113 +738,6 @@ describe("When creating an index", () => {
|
||||
expect(indices2.length).toBe(0);
|
||||
});
|
||||
|
||||
it("should create and search a nested vector index", async () => {
|
||||
const db = await connect(tmpDir.name);
|
||||
const nestedSchema = new Schema([
|
||||
new Field("id", new Int32(), true),
|
||||
new Field(
|
||||
"image",
|
||||
new Struct([
|
||||
new Field(
|
||||
"embedding",
|
||||
new FixedSizeList(2, new Field("item", new Float32(), true)),
|
||||
true,
|
||||
),
|
||||
]),
|
||||
true,
|
||||
),
|
||||
]);
|
||||
const nestedTable = await db.createTable(
|
||||
"nested_vector",
|
||||
makeArrowTable(
|
||||
Array.from({ length: 300 }, (_, id) => ({
|
||||
id,
|
||||
image: { embedding: [id, id + 1] },
|
||||
})),
|
||||
{ schema: nestedSchema },
|
||||
),
|
||||
);
|
||||
|
||||
await nestedTable.createIndex("image.embedding", {
|
||||
name: "image_embedding_idx",
|
||||
});
|
||||
const indices = await nestedTable.listIndices();
|
||||
expect(indices).toContainEqual({
|
||||
name: "image_embedding_idx",
|
||||
indexType: "IvfPq",
|
||||
columns: ["image.embedding"],
|
||||
});
|
||||
|
||||
const explicit = await nestedTable
|
||||
.query()
|
||||
.nearestTo([0.0, 1.0])
|
||||
.column("image.embedding")
|
||||
.limit(1)
|
||||
.toArray();
|
||||
const inferred = await nestedTable
|
||||
.query()
|
||||
.nearestTo([0.0, 1.0])
|
||||
.limit(1)
|
||||
.toArray();
|
||||
expect(inferred[0].id).toEqual(explicit[0].id);
|
||||
});
|
||||
|
||||
it("should report multiple nested vector candidates", async () => {
|
||||
const db = await connect(tmpDir.name);
|
||||
const nestedSchema = new Schema([
|
||||
new Field(
|
||||
"image",
|
||||
new Struct([
|
||||
new Field(
|
||||
"embedding",
|
||||
new FixedSizeList(2, new Field("item", new Float32(), true)),
|
||||
true,
|
||||
),
|
||||
]),
|
||||
true,
|
||||
),
|
||||
new Field(
|
||||
"text",
|
||||
new Struct([
|
||||
new Field(
|
||||
"embedding",
|
||||
new FixedSizeList(2, new Field("item", new Float32(), true)),
|
||||
true,
|
||||
),
|
||||
]),
|
||||
true,
|
||||
),
|
||||
]);
|
||||
const nestedTable = await db.createTable(
|
||||
"multiple_nested_vectors",
|
||||
makeArrowTable(
|
||||
[
|
||||
{
|
||||
image: { embedding: [0.0, 1.0] },
|
||||
text: { embedding: [2.0, 3.0] },
|
||||
},
|
||||
],
|
||||
{ schema: nestedSchema },
|
||||
),
|
||||
);
|
||||
|
||||
await expect(
|
||||
nestedTable.query().nearestTo([0.0, 1.0]).limit(1).toArray(),
|
||||
).rejects.toThrow(/image\.embedding.*text\.embedding/);
|
||||
});
|
||||
|
||||
it("should report when no default vector column exists", async () => {
|
||||
const db = await connect(tmpDir.name);
|
||||
const noVectorTable = await db.createTable(
|
||||
"no_vector",
|
||||
makeArrowTable([{ id: 0, label: "cat" }]),
|
||||
);
|
||||
|
||||
await expect(
|
||||
noVectorTable.query().nearestTo([0.0, 1.0]).limit(1).toArray(),
|
||||
).rejects.toThrow(/No vector column/);
|
||||
});
|
||||
|
||||
it("should wait for index readiness", async () => {
|
||||
// Create an index and then wait for it to be ready
|
||||
await tbl.createIndex("vec");
|
||||
@@ -1604,33 +1421,6 @@ describe("schema evolution", function () {
|
||||
expect(await table.schema()).toEqual(expectedSchema3);
|
||||
});
|
||||
|
||||
it("can update field metadata", async function () {
|
||||
const con = await connect(tmpDir.name);
|
||||
const table = await con.createTable("fm", [
|
||||
{ id: 1, category: "a" },
|
||||
{ id: 2, category: "b" },
|
||||
]);
|
||||
|
||||
const res = await table.updateFieldMetadata([
|
||||
{ path: "category", metadata: { unit: "label", pii: "false" } },
|
||||
]);
|
||||
expect(res).toHaveProperty("version");
|
||||
expect(res.version).toBe(2);
|
||||
|
||||
let cat = (await table.schema()).fields.find((f) => f.name === "category");
|
||||
expect(cat?.metadata.get("unit")).toBe("label");
|
||||
expect(cat?.metadata.get("pii")).toBe("false");
|
||||
|
||||
// merge: add a key, delete one via null, keep the rest
|
||||
await table.updateFieldMetadata([
|
||||
{ path: "category", metadata: { source: "import", pii: null } },
|
||||
]);
|
||||
cat = (await table.schema()).fields.find((f) => f.name === "category");
|
||||
expect(cat?.metadata.get("unit")).toBe("label"); // preserved
|
||||
expect(cat?.metadata.get("source")).toBe("import"); // added
|
||||
expect(cat?.metadata.has("pii")).toBe(false); // deleted
|
||||
});
|
||||
|
||||
it("can cast to various types", async function () {
|
||||
const con = await connect(tmpDir.name);
|
||||
|
||||
@@ -2685,97 +2475,3 @@ describe("setLsmWriteSpec / unsetLsmWriteSpec", () => {
|
||||
).rejects.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe("LSM merge insert", () => {
|
||||
let tmpDir: tmp.DirResult;
|
||||
|
||||
beforeEach(() => {
|
||||
tmpDir = tmp.dirSync({ unsafeCleanup: true });
|
||||
});
|
||||
afterEach(() => tmpDir.removeCallback());
|
||||
|
||||
async function bucketTable(conn: Connection): Promise<Table> {
|
||||
// The primary key column must be non-nullable.
|
||||
const table = await conn.createEmptyTable(
|
||||
"t",
|
||||
new arrow.Schema([
|
||||
new arrow.Field("id", new arrow.Utf8(), false),
|
||||
new arrow.Field("value", new arrow.Float64(), true),
|
||||
]),
|
||||
);
|
||||
await table.add([
|
||||
{ id: "a", value: 1 },
|
||||
{ id: "b", value: 2 },
|
||||
]);
|
||||
await table.setUnenforcedPrimaryKey("id");
|
||||
// numBuckets = 1: every row routes to the single bucket.
|
||||
await table.setLsmWriteSpec({
|
||||
specType: "bucket",
|
||||
column: "id",
|
||||
numBuckets: 1,
|
||||
});
|
||||
return table;
|
||||
}
|
||||
|
||||
it("routes merge_insert through the shard writer", async () => {
|
||||
const conn = await connect(tmpDir.name);
|
||||
const table = await bucketTable(conn);
|
||||
|
||||
const res = await table
|
||||
.mergeInsert("id")
|
||||
.whenMatchedUpdateAll()
|
||||
.whenNotMatchedInsertAll()
|
||||
.execute([
|
||||
{ id: "c", value: 3 },
|
||||
{ id: "d", value: 4 },
|
||||
]);
|
||||
// LSM path: rows go to the MemWAL, so only numRows is populated.
|
||||
expect(res.numRows).toBe(2);
|
||||
expect(res.version).toBe(0);
|
||||
expect(res.numInsertedRows).toBe(0);
|
||||
|
||||
await table.closeLsmWriters();
|
||||
});
|
||||
|
||||
it("falls back to the standard path with useLsmWrite(false)", async () => {
|
||||
const conn = await connect(tmpDir.name);
|
||||
const table = await bucketTable(conn);
|
||||
|
||||
const res = await table
|
||||
.mergeInsert("id")
|
||||
.whenNotMatchedInsertAll()
|
||||
.useLsmWrite(false)
|
||||
.execute([
|
||||
{ id: "b", value: 9 },
|
||||
{ id: "e", value: 5 },
|
||||
]);
|
||||
// Standard path commits: id="e" inserted ("b" already exists).
|
||||
expect(res.numInsertedRows).toBe(1);
|
||||
expect(await table.countRows()).toBe(3);
|
||||
});
|
||||
|
||||
it("supports validateSingleShard(false)", async () => {
|
||||
const conn = await connect(tmpDir.name);
|
||||
const table = await bucketTable(conn);
|
||||
|
||||
const res = await table
|
||||
.mergeInsert("id")
|
||||
.whenMatchedUpdateAll()
|
||||
.whenNotMatchedInsertAll()
|
||||
.validateSingleShard(false)
|
||||
.execute([{ id: "f", value: 6 }]);
|
||||
expect(res.numRows).toBe(1);
|
||||
});
|
||||
|
||||
it("rejects a non-upsert merge under an LSM spec", async () => {
|
||||
const conn = await connect(tmpDir.name);
|
||||
const table = await bucketTable(conn);
|
||||
|
||||
await expect(
|
||||
table
|
||||
.mergeInsert("id")
|
||||
.whenNotMatchedInsertAll()
|
||||
.execute([{ id: "g", value: 7 }]),
|
||||
).rejects.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -144,19 +144,6 @@ export interface DropNamespaceOptions {
|
||||
behavior?: "restrict" | "cascade";
|
||||
}
|
||||
|
||||
export interface RenameTableOptions {
|
||||
/**
|
||||
* The namespace path of the table being renamed. Defaults to the root
|
||||
* namespace (`[]`) when omitted.
|
||||
*/
|
||||
namespacePath?: string[];
|
||||
/**
|
||||
* The namespace path to move the table to as part of the rename. When
|
||||
* omitted the table stays in `namespacePath`.
|
||||
*/
|
||||
newNamespacePath?: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* A LanceDB Connection that allows you to open tables and create new ones.
|
||||
*
|
||||
@@ -309,6 +296,12 @@ export abstract class Connection {
|
||||
*/
|
||||
abstract dropTable(name: string, namespacePath?: string[]): Promise<void>;
|
||||
|
||||
abstract renameTable(
|
||||
oldName: string,
|
||||
newName: string,
|
||||
namespacePath?: string[],
|
||||
): Promise<void>;
|
||||
|
||||
/**
|
||||
* Drop all tables in the database.
|
||||
* @param {string[]} namespacePath The namespace path to drop tables from (defaults to root namespace).
|
||||
@@ -404,24 +397,6 @@ export abstract class Connection {
|
||||
isShallow?: boolean;
|
||||
},
|
||||
): Promise<Table>;
|
||||
|
||||
/**
|
||||
* Rename a table.
|
||||
*
|
||||
* Currently only supported by LanceDB Cloud. Local OSS connections and
|
||||
* namespace-backed connections (via {@link connectNamespace}) reject with
|
||||
* a "not supported" error.
|
||||
*
|
||||
* @param {string} currentName - The current name of the table.
|
||||
* @param {string} newName - The new name for the table.
|
||||
* @param {RenameTableOptions} options - Optional namespace paths. When
|
||||
* `newNamespacePath` is omitted the table stays in `namespacePath`.
|
||||
*/
|
||||
abstract renameTable(
|
||||
currentName: string,
|
||||
newName: string,
|
||||
options?: RenameTableOptions,
|
||||
): Promise<void>;
|
||||
}
|
||||
|
||||
/** @hideconstructor */
|
||||
@@ -640,6 +615,14 @@ export class LocalConnection extends Connection {
|
||||
return this.inner.dropTable(name, namespacePath ?? []);
|
||||
}
|
||||
|
||||
async renameTable(
|
||||
oldName: string,
|
||||
newName: string,
|
||||
namespacePath?: string[],
|
||||
): Promise<void> {
|
||||
return this.inner.renameTable(oldName, newName, namespacePath ?? []);
|
||||
}
|
||||
|
||||
async dropAllTables(namespacePath?: string[]): Promise<void> {
|
||||
return this.inner.dropAllTables(namespacePath ?? []);
|
||||
}
|
||||
@@ -682,19 +665,6 @@ export class LocalConnection extends Connection {
|
||||
options?.behavior,
|
||||
);
|
||||
}
|
||||
|
||||
async renameTable(
|
||||
currentName: string,
|
||||
newName: string,
|
||||
options?: RenameTableOptions,
|
||||
): Promise<void> {
|
||||
return this.inner.renameTable(
|
||||
currentName,
|
||||
newName,
|
||||
options?.namespacePath ?? [],
|
||||
options?.newNamespacePath,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -38,12 +38,10 @@ export {
|
||||
FragmentSummaryStats,
|
||||
Tags,
|
||||
TagContents,
|
||||
BranchContents,
|
||||
MergeResult,
|
||||
AddResult,
|
||||
AddColumnsResult,
|
||||
AlterColumnsResult,
|
||||
UpdateFieldMetadataResult,
|
||||
DeleteResult,
|
||||
DropColumnsResult,
|
||||
UpdateResult,
|
||||
@@ -52,6 +50,7 @@ export {
|
||||
SplitHashOptions,
|
||||
SplitSequentialOptions,
|
||||
ShuffleOptions,
|
||||
OAuthConfig as NativeOAuthConfig,
|
||||
} from "./native.js";
|
||||
|
||||
export {
|
||||
@@ -73,7 +72,6 @@ export {
|
||||
CreateNamespaceResponse,
|
||||
DropNamespaceResponse,
|
||||
DescribeNamespaceResponse,
|
||||
RenameTableOptions,
|
||||
} from "./connection";
|
||||
|
||||
export { Session } from "./native.js";
|
||||
@@ -112,15 +110,12 @@ export {
|
||||
|
||||
export {
|
||||
Table,
|
||||
Branches,
|
||||
AddDataOptions,
|
||||
UpdateOptions,
|
||||
OptimizeOptions,
|
||||
Version,
|
||||
WriteProgress,
|
||||
LsmWriteSpec,
|
||||
ColumnAlteration,
|
||||
FieldMetadataUpdate,
|
||||
} from "./table";
|
||||
|
||||
export {
|
||||
@@ -130,6 +125,8 @@ export {
|
||||
TokenResponse,
|
||||
} from "./header";
|
||||
|
||||
export { OAuthConfig, OAuthFlowType } from "./oauth";
|
||||
|
||||
export { MergeInsertBuilder, WriteExecutionOptions } from "./merge";
|
||||
|
||||
export * as embedding from "./embedding";
|
||||
|
||||
@@ -87,41 +87,6 @@ export class MergeInsertBuilder {
|
||||
this.#schema,
|
||||
);
|
||||
}
|
||||
/**
|
||||
* Controls whether the merge uses the MemWAL LSM write path.
|
||||
*
|
||||
* By default (unset), a `mergeInsert` on a table with an LSM write spec is
|
||||
* routed through Lance's MemWAL shard writer, and a table without one uses
|
||||
* the standard path. Pass `false` to force the standard path even when a
|
||||
* spec is set. Pass `true` to require a spec — `mergeInsert` rejects if none
|
||||
* is installed.
|
||||
*
|
||||
* @param useLsmWrite - Whether to use the LSM write path.
|
||||
*/
|
||||
useLsmWrite(useLsmWrite: boolean): MergeInsertBuilder {
|
||||
return new MergeInsertBuilder(
|
||||
this.#native.useLsmWrite(useLsmWrite),
|
||||
this.#schema,
|
||||
);
|
||||
}
|
||||
/**
|
||||
* Controls how an LSM merge checks that its input targets a single shard.
|
||||
*
|
||||
* When a table has an LSM write spec, every row in a `mergeInsert` call must
|
||||
* route to the same shard. When `true` (the default), every row is inspected
|
||||
* to verify this. When `false`, only the first row is inspected and the
|
||||
* shard it routes to is used for the whole input — a faster path for callers
|
||||
* that have already pre-sharded their input. Has no effect on tables without
|
||||
* an LSM write spec.
|
||||
*
|
||||
* @param validateSingleShard - Whether to check every row routes to one shard. Defaults to `true`.
|
||||
*/
|
||||
validateSingleShard(validateSingleShard: boolean): MergeInsertBuilder {
|
||||
return new MergeInsertBuilder(
|
||||
this.#native.validateSingleShard(validateSingleShard),
|
||||
this.#schema,
|
||||
);
|
||||
}
|
||||
/**
|
||||
* Executes the merge insert operation
|
||||
*
|
||||
|
||||
82
nodejs/lancedb/oauth.ts
Normal file
82
nodejs/lancedb/oauth.ts
Normal file
@@ -0,0 +1,82 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
/**
|
||||
* OAuth authentication flow types.
|
||||
*/
|
||||
export enum OAuthFlowType {
|
||||
/** Client Credentials grant (service-to-service / M2M). */
|
||||
ClientCredentials = "client_credentials",
|
||||
/** Authorization Code with PKCE (interactive browser-based auth). */
|
||||
AuthorizationCodePKCE = "authorization_code_pkce",
|
||||
/** Device Code grant (CLI / headless environments). */
|
||||
DeviceCode = "device_code",
|
||||
/** Azure Managed Identity via IMDS. */
|
||||
AzureManagedIdentity = "azure_managed_identity",
|
||||
/** Workload Identity Federation (K8s, GitHub Actions). */
|
||||
WorkloadIdentity = "workload_identity",
|
||||
}
|
||||
|
||||
/**
|
||||
* OAuth configuration for LanceDB authentication.
|
||||
*
|
||||
* All token acquisition and refresh is handled in the Rust layer.
|
||||
* This config is passed through to Rust via napi-rs.
|
||||
*
|
||||
* @example Client Credentials (service-to-service):
|
||||
* ```typescript
|
||||
* const config: OAuthConfig = {
|
||||
* issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
|
||||
* clientId: "app-id",
|
||||
* clientSecret: "secret",
|
||||
* scopes: ["api://lancedb-api/.default"],
|
||||
* };
|
||||
* ```
|
||||
*
|
||||
* @example Azure Managed Identity:
|
||||
* ```typescript
|
||||
* const config: OAuthConfig = {
|
||||
* issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
|
||||
* clientId: "app-id",
|
||||
* scopes: ["api://lancedb-api/.default"],
|
||||
* flow: OAuthFlowType.AzureManagedIdentity,
|
||||
* };
|
||||
* ```
|
||||
*/
|
||||
export interface OAuthConfig {
|
||||
/**
|
||||
* OIDC issuer URL or OAuth authority URL.
|
||||
* For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
|
||||
*/
|
||||
issuerUrl: string;
|
||||
|
||||
/** Application / Client ID. */
|
||||
clientId: string;
|
||||
|
||||
/**
|
||||
* OAuth scopes to request.
|
||||
* For Azure: `["api://{app_id}/.default"]`
|
||||
*/
|
||||
scopes: string[];
|
||||
|
||||
/** Authentication flow (default: ClientCredentials). */
|
||||
flow?: OAuthFlowType;
|
||||
|
||||
/** Client secret (required for ClientCredentials). */
|
||||
clientSecret?: string;
|
||||
|
||||
/** Redirect URI (AuthorizationCodePKCE flow). */
|
||||
redirectUri?: string;
|
||||
|
||||
/** Port for local callback server (AuthorizationCodePKCE, default: 8400). */
|
||||
callbackPort?: number;
|
||||
|
||||
/** Client ID for user-assigned managed identity (AzureManagedIdentity). */
|
||||
managedIdentityClientId?: string;
|
||||
|
||||
/** Path to federated token file (WorkloadIdentity). */
|
||||
tokenFile?: string;
|
||||
|
||||
/** Seconds before expiry to trigger proactive refresh (default: 300). */
|
||||
refreshBufferSecs?: number;
|
||||
}
|
||||
@@ -25,16 +25,13 @@ import {
|
||||
AddColumnsSql,
|
||||
AddResult,
|
||||
AlterColumnsResult,
|
||||
BranchContents,
|
||||
DeleteResult,
|
||||
DropColumnsResult,
|
||||
IndexConfig,
|
||||
IndexStatistics,
|
||||
Branches as NativeBranches,
|
||||
OptimizeStats,
|
||||
TableStatistics,
|
||||
Tags,
|
||||
UpdateFieldMetadataResult,
|
||||
UpdateResult,
|
||||
Table as _NativeTable,
|
||||
} from "./native";
|
||||
@@ -49,33 +46,6 @@ import { sanitizeType } from "./sanitize";
|
||||
import { IntoSql, toSQL } from "./util";
|
||||
export { IndexConfig } from "./native";
|
||||
|
||||
/**
|
||||
* Progress snapshot for a write operation, delivered to the `progress`
|
||||
* callback passed to {@link Table.add}.
|
||||
*/
|
||||
export interface WriteProgress {
|
||||
/** Number of rows written so far. */
|
||||
outputRows: number;
|
||||
/** Number of bytes written so far. */
|
||||
outputBytes: number;
|
||||
/**
|
||||
* Total rows expected, when the input source reports it.
|
||||
*
|
||||
* Always set on the final callback (the one with `done: true`), falling
|
||||
* back to the actual number of rows written when the source could not
|
||||
* report a row count up front.
|
||||
*/
|
||||
totalRows?: number;
|
||||
/** Wall-clock seconds since the write started. */
|
||||
elapsedSeconds: number;
|
||||
/** Number of parallel write tasks currently in flight. */
|
||||
activeTasks: number;
|
||||
/** Total number of parallel write tasks (the write parallelism). */
|
||||
totalTasks: number;
|
||||
/** `true` for the final callback; `false` otherwise. */
|
||||
done: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Options for adding data to a table.
|
||||
*/
|
||||
@@ -86,28 +56,6 @@ export interface AddDataOptions {
|
||||
* If "overwrite" then the new data will replace the existing data in the table.
|
||||
*/
|
||||
mode: "append" | "overwrite";
|
||||
|
||||
/**
|
||||
* Optional callback invoked periodically with write progress.
|
||||
*
|
||||
* The callback is fired once per batch written and once more with
|
||||
* `done: true` when the write completes. Calls are dispatched
|
||||
* asynchronously to the JS event loop and never block the write — a slow
|
||||
* callback will queue events rather than back-pressure the writer.
|
||||
*
|
||||
* Errors thrown from the callback are logged with `console.warn` and
|
||||
* swallowed — they do not abort the write.
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* await table.add(data, {
|
||||
* progress: (p) => {
|
||||
* console.log(`${p.outputRows}/${p.totalRows ?? "?"} rows`);
|
||||
* },
|
||||
* });
|
||||
* ```
|
||||
*/
|
||||
progress: (progress: WriteProgress) => void;
|
||||
}
|
||||
|
||||
export interface UpdateOptions {
|
||||
@@ -164,10 +112,7 @@ export interface Version {
|
||||
*
|
||||
* `specType` is `"bucket"`, `"identity"`, or `"unsharded"`. For `"bucket"`,
|
||||
* `column` and `numBuckets` are required; for `"identity"`, `column` is
|
||||
* required and must be a deterministic function of the unenforced primary
|
||||
* key (every row with a given primary key must always produce the same
|
||||
* `column` value, or upserts of that key can land in different shards and a
|
||||
* stale version can win).
|
||||
* required.
|
||||
*/
|
||||
export interface LsmWriteSpec {
|
||||
/** One of `"bucket"`, `"identity"`, or `"unsharded"`. */
|
||||
@@ -511,18 +456,6 @@ export abstract class Table {
|
||||
abstract alterColumns(
|
||||
columnAlterations: ColumnAlteration[],
|
||||
): Promise<AlterColumnsResult>;
|
||||
|
||||
/**
|
||||
* Update per-field (column) metadata.
|
||||
* @param {FieldMetadataUpdate[]} updates One or more per-field updates. Each
|
||||
* update's metadata is merged into the field's existing metadata by default;
|
||||
* a value of `null` deletes that key, and `replace: true` swaps the whole map.
|
||||
* @returns {Promise<UpdateFieldMetadataResult>} resolves to the new table version.
|
||||
*/
|
||||
abstract updateFieldMetadata(
|
||||
updates: FieldMetadataUpdate[],
|
||||
): Promise<UpdateFieldMetadataResult>;
|
||||
|
||||
/**
|
||||
* Drop one or more columns from the dataset
|
||||
*
|
||||
@@ -585,16 +518,6 @@ export abstract class Table {
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
abstract unsetLsmWriteSpec(): Promise<void>;
|
||||
/**
|
||||
* Drain and close any cached MemWAL shard writers held for this table.
|
||||
*
|
||||
* When an {@link LsmWriteSpec} is installed, `mergeInsert` opens MemWAL
|
||||
* shard writers and caches them for reuse across calls. This closes them,
|
||||
* flushing pending data; writers reopen lazily on the next `mergeInsert`.
|
||||
* It is a no-op when no writers are cached.
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
abstract closeLsmWriters(): Promise<void>;
|
||||
/** Retrieve the version of the table */
|
||||
|
||||
abstract version(): Promise<number>;
|
||||
@@ -655,14 +578,6 @@ export abstract class Table {
|
||||
*/
|
||||
abstract tags(): Promise<Tags>;
|
||||
|
||||
/**
|
||||
* Get the branch manager for this table.
|
||||
*
|
||||
* Branches are isolated, writable lines of history forked from another
|
||||
* branch (or version). Writes on a branch do not affect `main`.
|
||||
*/
|
||||
abstract branches(): Promise<Branches>;
|
||||
|
||||
/**
|
||||
* Restore the table to the currently checked out version
|
||||
*
|
||||
@@ -790,20 +705,7 @@ export class LocalTable extends Table {
|
||||
const schema = await this.schema();
|
||||
|
||||
const buffer = await fromDataToBuffer(data, undefined, schema);
|
||||
// Wrap the user callback so a thrown error doesn't surface as an
|
||||
// unhandled exception (the callback fires from a napi threadsafe
|
||||
// function — exceptions there crash the process).
|
||||
const userProgress = options?.progress;
|
||||
const progress = userProgress
|
||||
? (p: WriteProgress) => {
|
||||
try {
|
||||
userProgress(p);
|
||||
} catch (e) {
|
||||
console.warn("Table.add progress callback threw:", e);
|
||||
}
|
||||
}
|
||||
: undefined;
|
||||
return await this.inner.add(buffer, mode, progress);
|
||||
return await this.inner.add(buffer, mode);
|
||||
}
|
||||
|
||||
async update(
|
||||
@@ -1060,12 +962,6 @@ export class LocalTable extends Table {
|
||||
return await this.inner.alterColumns(processedAlterations);
|
||||
}
|
||||
|
||||
async updateFieldMetadata(
|
||||
updates: FieldMetadataUpdate[],
|
||||
): Promise<UpdateFieldMetadataResult> {
|
||||
return await this.inner.updateFieldMetadata(updates);
|
||||
}
|
||||
|
||||
async dropColumns(columnNames: string[]): Promise<DropColumnsResult> {
|
||||
return await this.inner.dropColumns(columnNames);
|
||||
}
|
||||
@@ -1083,10 +979,6 @@ export class LocalTable extends Table {
|
||||
return await this.inner.unsetLsmWriteSpec();
|
||||
}
|
||||
|
||||
async closeLsmWriters(): Promise<void> {
|
||||
return await this.inner.closeLsmWriters();
|
||||
}
|
||||
|
||||
async version(): Promise<number> {
|
||||
return await this.inner.version();
|
||||
}
|
||||
@@ -1118,10 +1010,6 @@ export class LocalTable extends Table {
|
||||
return await this.inner.tags();
|
||||
}
|
||||
|
||||
async branches(): Promise<Branches> {
|
||||
return new Branches(await this.inner.branches());
|
||||
}
|
||||
|
||||
async optimize(options?: Partial<OptimizeOptions>): Promise<OptimizeStats> {
|
||||
let cleanupOlderThanMs;
|
||||
if (
|
||||
@@ -1236,67 +1124,3 @@ export interface ColumnAlteration {
|
||||
/** Set the new nullability. Note that a nullable column cannot be made non-nullable. */
|
||||
nullable?: boolean;
|
||||
}
|
||||
|
||||
/** A per-field metadata update, addressed by dot-path. */
|
||||
export interface FieldMetadataUpdate {
|
||||
/**
|
||||
* Dot-separated path to the field. For a top-level column this is just its
|
||||
* name; for a nested field it's the path, e.g. "a.b.c".
|
||||
*/
|
||||
path: string;
|
||||
/**
|
||||
* Metadata key/value pairs. Merged into the field's existing metadata by
|
||||
* default; a value of `null` deletes that key.
|
||||
*/
|
||||
metadata: Record<string, string | null>;
|
||||
/** If true, replace the field's entire metadata map instead of merging. */
|
||||
replace?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Branch manager for a {@link Table}.
|
||||
*
|
||||
* Unlike tags, `create` and `checkout` return a new {@link Table} handle scoped
|
||||
* to the branch; writes on it do not affect `main`.
|
||||
*/
|
||||
export class Branches {
|
||||
#inner: NativeBranches;
|
||||
|
||||
/**
|
||||
* Construct a Branches manager. Internal use only.
|
||||
* @hidden
|
||||
*/
|
||||
constructor(inner: NativeBranches) {
|
||||
this.#inner = inner;
|
||||
}
|
||||
|
||||
/** List all branches, mapping name to branch metadata. */
|
||||
async list(): Promise<Record<string, BranchContents>> {
|
||||
return await this.#inner.list();
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a branch and return a handle scoped to it.
|
||||
*
|
||||
* @param name Name of the new branch.
|
||||
* @param fromRef Source branch to fork from. Defaults to `main`.
|
||||
* @param fromVersion A specific version on `fromRef`. Defaults to latest.
|
||||
*/
|
||||
async create(
|
||||
name: string,
|
||||
fromRef?: string,
|
||||
fromVersion?: number,
|
||||
): Promise<Table> {
|
||||
return new LocalTable(await this.#inner.create(name, fromRef, fromVersion));
|
||||
}
|
||||
|
||||
/** Check out an existing branch and return a handle scoped to it. */
|
||||
async checkout(name: string): Promise<Table> {
|
||||
return new LocalTable(await this.#inner.checkout(name));
|
||||
}
|
||||
|
||||
/** Delete a branch. */
|
||||
async delete(name: string): Promise<void> {
|
||||
return await this.#inner.delete(name);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-arm64",
|
||||
"version": "0.30.1-beta.0",
|
||||
"version": "0.29.1-beta.0",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.darwin-arm64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||
"version": "0.30.1-beta.0",
|
||||
"version": "0.29.1-beta.0",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||
"version": "0.30.1-beta.0",
|
||||
"version": "0.29.1-beta.0",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||
"version": "0.30.1-beta.0",
|
||||
"version": "0.29.1-beta.0",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||
"version": "0.30.1-beta.0",
|
||||
"version": "0.29.1-beta.0",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||
"version": "0.30.1-beta.0",
|
||||
"version": "0.29.1-beta.0",
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||
"version": "0.30.1-beta.0",
|
||||
"version": "0.29.1-beta.0",
|
||||
"os": ["win32"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.win32-x64-msvc.node",
|
||||
|
||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.30.1-beta.0",
|
||||
"version": "0.29.1-beta.0",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.30.1-beta.0",
|
||||
"version": "0.29.1-beta.0",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
"ann"
|
||||
],
|
||||
"private": false,
|
||||
"version": "0.30.1-beta.0",
|
||||
"version": "0.29.1-beta.0",
|
||||
"main": "dist/index.js",
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
|
||||
@@ -112,6 +112,11 @@ impl Connection {
|
||||
|
||||
builder = builder.client_config(rust_config);
|
||||
|
||||
if let Some(oauth_config) = options.oauth_config {
|
||||
let config: lancedb::remote::oauth::OAuthConfig = oauth_config.into();
|
||||
builder = builder.oauth_config(config);
|
||||
}
|
||||
|
||||
if let Some(api_key) = options.api_key {
|
||||
builder = builder.api_key(&api_key);
|
||||
}
|
||||
@@ -328,6 +333,20 @@ impl Connection {
|
||||
.default_error()
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn rename_table(
|
||||
&self,
|
||||
old_name: String,
|
||||
new_name: String,
|
||||
namespace_path: Option<Vec<String>>,
|
||||
) -> napi::Result<()> {
|
||||
let ns = namespace_path.unwrap_or_default();
|
||||
self.get_inner()?
|
||||
.rename_table(&old_name, &new_name, &ns, &ns)
|
||||
.await
|
||||
.default_error()
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn drop_all_tables(&self, namespace_path: Option<Vec<String>>) -> napi::Result<()> {
|
||||
let ns = namespace_path.unwrap_or_default();
|
||||
@@ -459,23 +478,4 @@ impl Connection {
|
||||
transaction_id: resp.transaction_id,
|
||||
})
|
||||
}
|
||||
|
||||
/// Rename a table. `current_namespace_path` and `new_namespace_path` default to
|
||||
/// the root namespace when omitted; the caller is expected to either pass both
|
||||
/// or pass neither.
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn rename_table(
|
||||
&self,
|
||||
current_name: String,
|
||||
new_name: String,
|
||||
current_namespace_path: Option<Vec<String>>,
|
||||
new_namespace_path: Option<Vec<String>>,
|
||||
) -> napi::Result<()> {
|
||||
let cur_ns = current_namespace_path.unwrap_or_default();
|
||||
let new_ns = new_namespace_path.unwrap_or_default();
|
||||
self.get_inner()?
|
||||
.rename_table(¤t_name, &new_name, &cur_ns, &new_ns)
|
||||
.await
|
||||
.default_error()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,19 +24,15 @@ mod util;
|
||||
#[napi(object)]
|
||||
#[derive(Debug)]
|
||||
pub struct ConnectionOptions {
|
||||
/// The interval, in seconds, at which to check for updates to the table
|
||||
/// from other processes. If None, then consistency is not checked. For
|
||||
/// performance reasons, this is the default. For strong consistency, set
|
||||
/// this to zero seconds. Then every read will check for updates from other
|
||||
/// processes. As a compromise, you can set this to a non-zero value for
|
||||
/// eventual consistency. If more than that interval has passed since the
|
||||
/// last check, then the table will be checked for updates. Note: this
|
||||
/// consistency only applies to read operations. Write operations are
|
||||
/// (For LanceDB OSS only): The interval, in seconds, at which to check for
|
||||
/// updates to the table from other processes. If None, then consistency is not
|
||||
/// checked. For performance reasons, this is the default. For strong
|
||||
/// consistency, set this to zero seconds. Then every read will check for
|
||||
/// updates from other processes. As a compromise, you can set this to a
|
||||
/// non-zero value for eventual consistency. If more than that interval
|
||||
/// has passed since the last check, then the table will be checked for updates.
|
||||
/// Note: this consistency only applies to read operations. Write operations are
|
||||
/// always consistent.
|
||||
///
|
||||
/// Stronger consistency is not free. The smaller the interval, the more
|
||||
/// often each read pays the cost of checking for updates against object
|
||||
/// storage, raising per-read latency and cost.
|
||||
pub read_consistency_interval: Option<f64>,
|
||||
/// (For LanceDB OSS only): configuration for object storage.
|
||||
///
|
||||
@@ -65,6 +61,10 @@ pub struct ConnectionOptions {
|
||||
/// (For LanceDB cloud only): the host to use for LanceDB cloud. Used
|
||||
/// for testing purposes.
|
||||
pub host_override: Option<String>,
|
||||
/// (For LanceDB cloud only): OAuth configuration for IdP-based
|
||||
/// authentication (e.g., Azure Entra ID). When set, token acquisition
|
||||
/// and refresh are handled entirely in Rust.
|
||||
pub oauth_config: Option<remote::OAuthConfig>,
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
|
||||
@@ -50,20 +50,6 @@ impl NativeMergeInsertBuilder {
|
||||
this
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn use_lsm_write(&self, use_lsm_write: bool) -> Self {
|
||||
let mut this = self.clone();
|
||||
this.inner.use_lsm_write(use_lsm_write);
|
||||
this
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn validate_single_shard(&self, validate_single_shard: bool) -> Self {
|
||||
let mut this = self.clone();
|
||||
this.inner.validate_single_shard(validate_single_shard);
|
||||
this
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn execute(&self, buf: Buffer) -> napi::Result<MergeResult> {
|
||||
let data = ipc_file_to_batches(buf.to_vec())
|
||||
|
||||
@@ -140,6 +140,67 @@ impl From<TlsConfig> for lancedb::remote::TlsConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// OAuth configuration for LanceDB authentication.
|
||||
/// All token acquisition and refresh is handled in the Rust layer.
|
||||
#[napi(object)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OAuthConfig {
|
||||
/// OIDC issuer URL or OAuth authority URL.
|
||||
/// For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
|
||||
pub issuer_url: String,
|
||||
/// Application / Client ID.
|
||||
pub client_id: String,
|
||||
/// OAuth scopes to request. For Azure: `["api://{app_id}/.default"]`
|
||||
pub scopes: Vec<String>,
|
||||
/// Authentication flow: "client_credentials", "authorization_code_pkce",
|
||||
/// "device_code", "azure_managed_identity", "workload_identity"
|
||||
pub flow: Option<String>,
|
||||
/// Client secret (required for client_credentials).
|
||||
pub client_secret: Option<String>,
|
||||
/// Redirect URI (authorization_code_pkce flow).
|
||||
pub redirect_uri: Option<String>,
|
||||
/// Port for local callback server (authorization_code_pkce, default: 8400).
|
||||
pub callback_port: Option<u16>,
|
||||
/// Client ID for user-assigned managed identity (azure_managed_identity).
|
||||
pub managed_identity_client_id: Option<String>,
|
||||
/// Path to federated token file (workload_identity).
|
||||
pub token_file: Option<String>,
|
||||
/// Seconds before expiry to trigger proactive refresh (default: 300).
|
||||
pub refresh_buffer_secs: Option<u32>,
|
||||
}
|
||||
|
||||
impl From<OAuthConfig> for lancedb::remote::oauth::OAuthConfig {
|
||||
fn from(config: OAuthConfig) -> Self {
|
||||
use lancedb::remote::oauth::OAuthFlow;
|
||||
|
||||
let flow = match config.flow.as_deref().unwrap_or("client_credentials") {
|
||||
"authorization_code_pkce" => OAuthFlow::AuthorizationCodePKCE {
|
||||
redirect_uri: config.redirect_uri,
|
||||
callback_port: config.callback_port,
|
||||
},
|
||||
"device_code" => OAuthFlow::DeviceCode,
|
||||
"azure_managed_identity" => OAuthFlow::AzureManagedIdentity {
|
||||
client_id: config.managed_identity_client_id,
|
||||
},
|
||||
"workload_identity" => OAuthFlow::WorkloadIdentity {
|
||||
token_file: config
|
||||
.token_file
|
||||
.expect("tokenFile is required for workload_identity flow"),
|
||||
},
|
||||
other => panic!("Unknown OAuth flow type: {other}"),
|
||||
};
|
||||
|
||||
Self {
|
||||
issuer_url: config.issuer_url,
|
||||
client_id: config.client_id,
|
||||
client_secret: config.client_secret,
|
||||
scopes: config.scopes,
|
||||
flow,
|
||||
refresh_buffer_secs: config.refresh_buffer_secs.map(|v| v as u64),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ClientConfig> for lancedb::remote::ClientConfig {
|
||||
fn from(config: ClientConfig) -> Self {
|
||||
Self {
|
||||
|
||||
@@ -5,12 +5,10 @@ use std::collections::HashMap;
|
||||
|
||||
use lancedb::ipc::{ipc_file_to_batches, ipc_file_to_schema};
|
||||
use lancedb::table::{
|
||||
AddDataMode, ColumnAlteration as LanceColumnAlteration, Duration,
|
||||
FieldMetadataUpdate as LanceFieldMetadataUpdate, NewColumnTransform, OptimizeAction,
|
||||
OptimizeOptions, Ref, Table as LanceDbTable,
|
||||
AddDataMode, ColumnAlteration as LanceColumnAlteration, Duration, NewColumnTransform,
|
||||
OptimizeAction, OptimizeOptions, Table as LanceDbTable,
|
||||
};
|
||||
use napi::bindgen_prelude::*;
|
||||
use napi::threadsafe_function::{ThreadsafeFunction, ThreadsafeFunctionCallMode};
|
||||
use napi_derive::napi;
|
||||
|
||||
use crate::error::NapiErrorExt;
|
||||
@@ -69,16 +67,8 @@ impl Table {
|
||||
schema_to_buffer(&schema)
|
||||
}
|
||||
|
||||
#[napi(
|
||||
catch_unwind,
|
||||
ts_args_type = "buf: Buffer, mode: string, progressCallback?: (progress: WriteProgressInfo) => void"
|
||||
)]
|
||||
pub async fn add(
|
||||
&self,
|
||||
buf: Buffer,
|
||||
mode: String,
|
||||
progress_callback: Option<ProgressFn>,
|
||||
) -> napi::Result<AddResult> {
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn add(&self, buf: Buffer, mode: String) -> napi::Result<AddResult> {
|
||||
let batches = ipc_file_to_batches(buf.to_vec())
|
||||
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
|
||||
let batches = batches
|
||||
@@ -102,19 +92,6 @@ impl Table {
|
||||
return Err(napi::Error::from_reason(format!("Invalid mode: {}", mode)));
|
||||
};
|
||||
|
||||
if let Some(tsfn) = progress_callback {
|
||||
op = op.progress(move |p| {
|
||||
// NonBlocking: dispatch onto the JS event loop without
|
||||
// blocking the writer thread. With napi-rs's default
|
||||
// unbounded queue, events are not dropped — a slow JS
|
||||
// callback will just queue them.
|
||||
tsfn.call(
|
||||
WriteProgressInfo::from(p),
|
||||
ThreadsafeFunctionCallMode::NonBlocking,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
let res = op.execute().await.default_error()?;
|
||||
Ok(res.into())
|
||||
}
|
||||
@@ -356,23 +333,6 @@ impl Table {
|
||||
Ok(res.into())
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn update_field_metadata(
|
||||
&self,
|
||||
updates: Vec<FieldMetadataUpdate>,
|
||||
) -> napi::Result<UpdateFieldMetadataResult> {
|
||||
let updates = updates
|
||||
.into_iter()
|
||||
.map(LanceFieldMetadataUpdate::from)
|
||||
.collect::<Vec<_>>();
|
||||
let res = self
|
||||
.inner_ref()?
|
||||
.update_field_metadata(&updates)
|
||||
.await
|
||||
.default_error()?;
|
||||
Ok(res.into())
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn drop_columns(&self, columns: Vec<String>) -> napi::Result<DropColumnsResult> {
|
||||
let col_refs = columns.iter().map(String::as_str).collect::<Vec<_>>();
|
||||
@@ -409,11 +369,6 @@ impl Table {
|
||||
.default_error()
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn close_lsm_writers(&self) -> napi::Result<()> {
|
||||
self.inner_ref()?.close_lsm_writers().await.default_error()
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn version(&self) -> napi::Result<i64> {
|
||||
self.inner_ref()?
|
||||
@@ -478,13 +433,6 @@ impl Table {
|
||||
})
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn branches(&self) -> napi::Result<Branches> {
|
||||
Ok(Branches {
|
||||
inner: self.inner_ref()?.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn optimize(
|
||||
&self,
|
||||
@@ -706,44 +654,6 @@ pub struct OptimizeStats {
|
||||
pub prune: RemovalStats,
|
||||
}
|
||||
|
||||
/// Progress snapshot for a write operation, delivered to the JS callback
|
||||
/// passed to `Table.add`.
|
||||
#[napi(object)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct WriteProgressInfo {
|
||||
/// Number of rows written so far.
|
||||
pub output_rows: i64,
|
||||
/// Number of bytes written so far.
|
||||
pub output_bytes: i64,
|
||||
/// Total rows expected, if the input source reports it.
|
||||
/// Always set on the final callback (where `done` is `true`).
|
||||
pub total_rows: Option<i64>,
|
||||
/// Wall-clock seconds since monitoring started.
|
||||
pub elapsed_seconds: f64,
|
||||
/// Number of parallel write tasks currently in flight.
|
||||
pub active_tasks: i64,
|
||||
/// Total number of parallel write tasks (the write parallelism).
|
||||
pub total_tasks: i64,
|
||||
/// `true` for the final callback; `false` otherwise.
|
||||
pub done: bool,
|
||||
}
|
||||
|
||||
impl From<&lancedb::table::write_progress::WriteProgress> for WriteProgressInfo {
|
||||
fn from(p: &lancedb::table::write_progress::WriteProgress) -> Self {
|
||||
Self {
|
||||
output_rows: p.output_rows() as i64,
|
||||
output_bytes: p.output_bytes() as i64,
|
||||
total_rows: p.total_rows().map(|n| n as i64),
|
||||
elapsed_seconds: p.elapsed().as_secs_f64(),
|
||||
active_tasks: p.active_tasks() as i64,
|
||||
total_tasks: p.total_tasks() as i64,
|
||||
done: p.done(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type ProgressFn = ThreadsafeFunction<WriteProgressInfo, (), WriteProgressInfo, Status, false>;
|
||||
|
||||
/// A definition of a column alteration. The alteration changes the column at
|
||||
/// `path` to have the new name `name`, to be nullable if `nullable` is true,
|
||||
/// and to have the data type `data_type`. At least one of `rename` or `nullable`
|
||||
@@ -772,29 +682,6 @@ pub struct ColumnAlteration {
|
||||
pub nullable: Option<bool>,
|
||||
}
|
||||
|
||||
/// A per-field metadata update, addressed by dot-path. Merges into the field's
|
||||
/// existing metadata by default; a `null` value deletes a key, and `replace`
|
||||
/// swaps the field's entire metadata map.
|
||||
#[napi(object)]
|
||||
pub struct FieldMetadataUpdate {
|
||||
/// Dot-separated path to the field (e.g. "embedding" or "a.b.c").
|
||||
pub path: String,
|
||||
/// Metadata keys to set; a `null` value deletes that key.
|
||||
pub metadata: HashMap<String, Option<String>>,
|
||||
/// If true, replace the field's entire metadata map instead of merging.
|
||||
pub replace: Option<bool>,
|
||||
}
|
||||
|
||||
impl From<FieldMetadataUpdate> for LanceFieldMetadataUpdate {
|
||||
fn from(js: FieldMetadataUpdate) -> Self {
|
||||
Self {
|
||||
path: js.path,
|
||||
metadata: js.metadata,
|
||||
replace: js.replace.unwrap_or(false),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ColumnAlteration> for LanceColumnAlteration {
|
||||
type Error = String;
|
||||
fn try_from(js: ColumnAlteration) -> std::result::Result<Self, Self::Error> {
|
||||
@@ -993,7 +880,6 @@ pub struct MergeResult {
|
||||
pub num_updated_rows: i64,
|
||||
pub num_deleted_rows: i64,
|
||||
pub num_attempts: i64,
|
||||
pub num_rows: i64,
|
||||
}
|
||||
|
||||
impl From<lancedb::table::MergeResult> for MergeResult {
|
||||
@@ -1004,7 +890,6 @@ impl From<lancedb::table::MergeResult> for MergeResult {
|
||||
num_updated_rows: value.num_updated_rows as i64,
|
||||
num_deleted_rows: value.num_deleted_rows as i64,
|
||||
num_attempts: value.num_attempts as i64,
|
||||
num_rows: value.num_rows as i64,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1035,19 +920,6 @@ impl From<lancedb::table::AlterColumnsResult> for AlterColumnsResult {
|
||||
}
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
pub struct UpdateFieldMetadataResult {
|
||||
pub version: i64,
|
||||
}
|
||||
|
||||
impl From<lancedb::table::UpdateFieldMetadataResult> for UpdateFieldMetadataResult {
|
||||
fn from(value: lancedb::table::UpdateFieldMetadataResult) -> Self {
|
||||
Self {
|
||||
version: value.version as i64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
pub struct DropColumnsResult {
|
||||
pub version: i64,
|
||||
@@ -1067,13 +939,6 @@ pub struct TagContents {
|
||||
pub manifest_size: i64,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub struct BranchContents {
|
||||
pub parent_branch: Option<String>,
|
||||
pub parent_version: i64,
|
||||
pub manifest_size: i64,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub struct Tags {
|
||||
inner: LanceDbTable,
|
||||
@@ -1142,60 +1007,3 @@ impl Tags {
|
||||
.default_error()
|
||||
}
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub struct Branches {
|
||||
inner: LanceDbTable,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl Branches {
|
||||
#[napi]
|
||||
pub async fn list(&self) -> napi::Result<HashMap<String, BranchContents>> {
|
||||
let branches = self.inner.list_branches().await.default_error()?;
|
||||
let result = branches
|
||||
.into_iter()
|
||||
.map(|(k, v)| {
|
||||
(
|
||||
k,
|
||||
BranchContents {
|
||||
parent_branch: v.parent_branch,
|
||||
parent_version: v.parent_version as i64,
|
||||
manifest_size: v.manifest_size as i64,
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub async fn create(
|
||||
&self,
|
||||
name: String,
|
||||
from_ref: Option<String>,
|
||||
from_version: Option<i64>,
|
||||
) -> napi::Result<Table> {
|
||||
// "main" and None are two spellings of the root branch; normalize so
|
||||
// from_ref = "main" behaves identically to the default.
|
||||
let from_ref = from_ref.filter(|b| b != "main");
|
||||
let from = Ref::Version(from_ref, from_version.map(|v| v as u64));
|
||||
let table = self
|
||||
.inner
|
||||
.create_branch(&name, from)
|
||||
.await
|
||||
.default_error()?;
|
||||
Ok(Table::new(table))
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub async fn checkout(&self, name: String) -> napi::Result<Table> {
|
||||
let table = self.inner.checkout_branch(&name).await.default_error()?;
|
||||
Ok(Table::new(table))
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub async fn delete(&self, name: String) -> napi::Result<()> {
|
||||
self.inner.delete_branch(&name).await.default_error()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.33.1-beta.0"
|
||||
current_version = "0.32.1-beta.0"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -4,26 +4,16 @@ code is in the `src/` directory and the Python bindings are in the `lancedb/` di
|
||||
|
||||
Common commands:
|
||||
|
||||
* Bootstrap dev env: `uv run --extra tests --extra dev maturin develop --extras tests,dev`
|
||||
* Build: `make develop`
|
||||
* Format: `make format`
|
||||
* Lint: `make check`
|
||||
* Fix lints: `make fix`
|
||||
* Test: `uv run --extra tests pytest python/tests -vv --durations=10 -m "not slow and not s3_test"`
|
||||
* Run specific test: `uv run --extra tests pytest python/tests/<test_file>.py::<test_name> -q`
|
||||
* Doc test: `uv run --extra tests pytest --doctest-modules python/lancedb`
|
||||
|
||||
Use the uv-managed environment declared by `uv.lock` for Python validation. Do
|
||||
not treat system `python`, global `pytest`, or missing editable-install errors
|
||||
as final blockers; bootstrap or enter the uv environment instead. `make test`
|
||||
and `make doctest` assume the development environment is already prepared.
|
||||
* Test: `make test`
|
||||
* Doc test: `make doctest`
|
||||
|
||||
Before committing changes, run lints and then formatting.
|
||||
|
||||
When you change the Rust code, PyO3 binding code, or see a missing/stale
|
||||
`lancedb._lancedb`, recompile the Python bindings with
|
||||
`uv run --extra tests --extra dev maturin develop --extras tests,dev` before
|
||||
running tests.
|
||||
When you change the Rust code, you will need to recompile the Python bindings: `make develop`.
|
||||
|
||||
When you export new types from Rust to Python, you must manually update `python/lancedb/_lancedb.pyi`
|
||||
with the corresponding type hints. You can run `pyright` to check for type errors in the Python code.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.33.1-beta.0"
|
||||
version = "0.32.1-beta.0"
|
||||
publish = false
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
|
||||
@@ -94,6 +94,7 @@ def connect(
|
||||
host_override: str, optional
|
||||
The override url for LanceDB Cloud.
|
||||
read_consistency_interval: timedelta, default None
|
||||
(For LanceDB OSS only)
|
||||
The interval at which to check for updates to the table from other
|
||||
processes. If None, then consistency is not checked. For performance
|
||||
reasons, this is the default. For strong consistency, set this to
|
||||
@@ -103,10 +104,6 @@ def connect(
|
||||
the last check, then the table will be checked for updates. Note: this
|
||||
consistency only applies to read operations. Write operations are
|
||||
always consistent.
|
||||
|
||||
Stronger consistency is not free. The smaller the interval, the more
|
||||
often each read pays the cost of checking for updates against object
|
||||
storage, raising per-read latency and cost.
|
||||
client_config: ClientConfig or dict, optional
|
||||
Configuration options for the LanceDB Cloud HTTP client. If a dict, then
|
||||
the keys are the attributes of the ClientConfig class. If None, then the
|
||||
@@ -150,13 +147,6 @@ def connect(
|
||||
>>> db = lancedb.connect("s3://my-bucket/lancedb",
|
||||
... storage_options={"aws_access_key_id": "***"})
|
||||
|
||||
For tests and temporary data, use an in-memory database:
|
||||
|
||||
>>> db = lancedb.connect("memory://")
|
||||
|
||||
In-memory databases are not persisted. Tables are dropped when the last
|
||||
connection or table handle referencing them is closed.
|
||||
|
||||
Connect to LanceDB cloud:
|
||||
|
||||
>>> db = lancedb.connect("db://my_database", api_key="ldb_...",
|
||||
@@ -220,7 +210,6 @@ def connect(
|
||||
request_thread_pool=request_thread_pool,
|
||||
client_config=client_config,
|
||||
storage_options=storage_options,
|
||||
read_consistency_interval=read_consistency_interval,
|
||||
**kwargs,
|
||||
)
|
||||
_check_s3_bucket_with_dots(str(uri), storage_options)
|
||||
@@ -315,15 +304,6 @@ def deserialize_conn(
|
||||
manifest_enabled=parsed.get("manifest_enabled", False),
|
||||
namespace_client_properties=parsed.get("namespace_client_properties"),
|
||||
)
|
||||
elif connection_type == "remote":
|
||||
return RemoteDBConnection(
|
||||
parsed["db_url"],
|
||||
parsed["api_key"],
|
||||
parsed.get("region", "us-east-1"),
|
||||
host_override=parsed.get("host_override"),
|
||||
client_config=parsed.get("client_config"),
|
||||
storage_options=storage_options,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown connection_type: {connection_type}")
|
||||
|
||||
@@ -340,6 +320,7 @@ async def connect_async(
|
||||
session: Optional[Session] = None,
|
||||
manifest_enabled: bool = False,
|
||||
namespace_client_properties: Optional[Dict[str, str]] = None,
|
||||
oauth_config=None,
|
||||
) -> AsyncConnection:
|
||||
"""Connect to a LanceDB database.
|
||||
|
||||
@@ -356,6 +337,7 @@ async def connect_async(
|
||||
host_override: str, optional
|
||||
The override url for LanceDB Cloud.
|
||||
read_consistency_interval: timedelta, default None
|
||||
(For LanceDB OSS only)
|
||||
The interval at which to check for updates to the table from other
|
||||
processes. If None, then consistency is not checked. For performance
|
||||
reasons, this is the default. For strong consistency, set this to
|
||||
@@ -365,10 +347,6 @@ async def connect_async(
|
||||
the last check, then the table will be checked for updates. Note: this
|
||||
consistency only applies to read operations. Write operations are
|
||||
always consistent.
|
||||
|
||||
Stronger consistency is not free. The smaller the interval, the more
|
||||
often each read pays the cost of checking for updates against object
|
||||
storage, raising per-read latency and cost.
|
||||
client_config: ClientConfig or dict, optional
|
||||
Configuration options for the LanceDB Cloud HTTP client. If a dict, then
|
||||
the keys are the attributes of the ClientConfig class. If None, then the
|
||||
@@ -401,8 +379,6 @@ async def connect_async(
|
||||
... db = await lancedb.connect_async("s3://my-bucket/lancedb",
|
||||
... storage_options={
|
||||
... "aws_access_key_id": "***"})
|
||||
... # For tests and temporary data, use an in-memory database
|
||||
... db = await lancedb.connect_async("memory://")
|
||||
... # Connect to LanceDB cloud
|
||||
... db = await lancedb.connect_async("db://my_database", api_key="ldb_...",
|
||||
... client_config={
|
||||
@@ -435,6 +411,7 @@ async def connect_async(
|
||||
session,
|
||||
manifest_enabled,
|
||||
namespace_client_properties,
|
||||
oauth_config,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -208,9 +208,6 @@ class Table:
|
||||
async def alter_columns(
|
||||
self, columns: list[dict[str, Any]]
|
||||
) -> AlterColumnsResult: ...
|
||||
async def update_field_metadata(
|
||||
self, updates: list[dict[str, Any]]
|
||||
) -> UpdateFieldMetadataResult: ...
|
||||
async def optimize(
|
||||
self,
|
||||
*,
|
||||
@@ -223,11 +220,8 @@ class Table:
|
||||
async def set_unenforced_primary_key(self, columns: List[str]) -> None: ...
|
||||
async def set_lsm_write_spec(self, spec: LsmWriteSpec) -> None: ...
|
||||
async def unset_lsm_write_spec(self) -> None: ...
|
||||
async def close_lsm_writers(self) -> None: ...
|
||||
@property
|
||||
def tags(self) -> Tags: ...
|
||||
@property
|
||||
def branches(self) -> Branches: ...
|
||||
def query(self) -> Query: ...
|
||||
def take_offsets(self, offsets: list[int]) -> TakeQuery: ...
|
||||
def take_row_ids(self, row_ids: list[int]) -> TakeQuery: ...
|
||||
@@ -240,17 +234,6 @@ class Tags:
|
||||
async def delete(self, tag: str): ...
|
||||
async def update(self, tag: str, version: int): ...
|
||||
|
||||
class Branches:
|
||||
async def list(self) -> Dict[str, Any]: ...
|
||||
async def create(
|
||||
self,
|
||||
name: str,
|
||||
from_ref: Optional[str] = None,
|
||||
from_version: Optional[int] = None,
|
||||
) -> Table: ...
|
||||
async def checkout(self, name: str) -> Table: ...
|
||||
async def delete(self, name: str) -> None: ...
|
||||
|
||||
class IndexConfig:
|
||||
name: str
|
||||
index_type: str
|
||||
@@ -267,6 +250,7 @@ async def connect(
|
||||
session: Optional[Session],
|
||||
manifest_enabled: bool = False,
|
||||
namespace_client_properties: Optional[Dict[str, str]] = None,
|
||||
oauth_config: Optional[Any] = None,
|
||||
) -> Connection: ...
|
||||
|
||||
class RecordBatchStream:
|
||||
@@ -437,7 +421,6 @@ class MergeResult:
|
||||
num_inserted_rows: int
|
||||
num_deleted_rows: int
|
||||
num_attempts: int
|
||||
num_rows: int
|
||||
|
||||
class LsmWriteSpec:
|
||||
"""Specification selecting Lance's MemWAL LSM-style write path for
|
||||
@@ -476,9 +459,6 @@ class AddColumnsResult:
|
||||
class AlterColumnsResult:
|
||||
version: int
|
||||
|
||||
class UpdateFieldMetadataResult:
|
||||
version: int
|
||||
|
||||
class DropColumnsResult:
|
||||
version: int
|
||||
|
||||
|
||||
@@ -8,17 +8,7 @@ from abc import abstractmethod
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Union
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
@@ -323,7 +313,7 @@ class DBConnection(EnforceOverrides):
|
||||
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
|
||||
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
|
||||
>>> db.create_table("my_table", data)
|
||||
LanceTable(name='my_table', ...)
|
||||
LanceTable(name='my_table', version=1, ...)
|
||||
>>> db["my_table"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
@@ -344,7 +334,7 @@ class DBConnection(EnforceOverrides):
|
||||
... "long": [-122.7, -74.1]
|
||||
... })
|
||||
>>> db.create_table("table2", data)
|
||||
LanceTable(name='table2', ...)
|
||||
LanceTable(name='table2', version=1, ...)
|
||||
>>> db["table2"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
@@ -367,7 +357,7 @@ class DBConnection(EnforceOverrides):
|
||||
... pa.field("long", pa.float32())
|
||||
... ])
|
||||
>>> db.create_table("table3", data, schema = custom_schema)
|
||||
LanceTable(name='table3', ...)
|
||||
LanceTable(name='table3', version=1, ...)
|
||||
>>> db["table3"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
@@ -401,7 +391,7 @@ class DBConnection(EnforceOverrides):
|
||||
... pa.field("price", pa.float32()),
|
||||
... ])
|
||||
>>> db.create_table("table4", make_batches(), schema=schema)
|
||||
LanceTable(name='table4', ...)
|
||||
LanceTable(name='table4', version=1, ...)
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -578,15 +568,15 @@ class LanceDBConnection(DBConnection):
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2},
|
||||
... {"vector": [0.5, 1.3], "b": 4}])
|
||||
LanceTable(name='my_table', ...)
|
||||
LanceTable(name='my_table', version=1, ...)
|
||||
>>> db.create_table("another_table", data=[{"vector": [0.4, 0.4], "b": 6}])
|
||||
LanceTable(name='another_table', ...)
|
||||
LanceTable(name='another_table', version=1, ...)
|
||||
>>> sorted(db.table_names())
|
||||
['another_table', 'my_table']
|
||||
>>> len(db)
|
||||
2
|
||||
>>> db["my_table"]
|
||||
LanceTable(name='my_table', ...)
|
||||
LanceTable(name='my_table', version=1, ...)
|
||||
>>> "my_table" in db
|
||||
True
|
||||
>>> db.drop_table("my_table")
|
||||
@@ -857,20 +847,11 @@ class LanceDBConnection(DBConnection):
|
||||
)
|
||||
)
|
||||
|
||||
def _all_table_names(self) -> Generator[str, None, None]:
|
||||
page_token = None
|
||||
while True:
|
||||
response = self.list_tables(page_token=page_token)
|
||||
yield from response.tables
|
||||
page_token = response.page_token
|
||||
if not page_token:
|
||||
return
|
||||
|
||||
def __len__(self) -> int:
|
||||
return sum(1 for _ in self._all_table_names())
|
||||
return len(self.table_names())
|
||||
|
||||
def __contains__(self, name: str) -> bool:
|
||||
return name in self._all_table_names()
|
||||
return name in self.table_names()
|
||||
|
||||
@override
|
||||
def create_table(
|
||||
|
||||
@@ -281,9 +281,6 @@ class HnswPq:
|
||||
m: int = 20
|
||||
ef_construction: int = 300
|
||||
target_partition_size: Optional[int] = None
|
||||
# Name of the accelerator (e.g. "cuda") to use for IVF training. When set,
|
||||
# create_index() dispatches to pylance to build the index on the accelerator.
|
||||
accelerator: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -389,9 +386,6 @@ class HnswSq:
|
||||
m: int = 20
|
||||
ef_construction: int = 300
|
||||
target_partition_size: Optional[int] = None
|
||||
# Name of the accelerator (e.g. "cuda") to use for IVF training. When set,
|
||||
# create_index() dispatches to pylance to build the index on the accelerator.
|
||||
accelerator: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -585,9 +579,6 @@ class IvfFlat:
|
||||
max_iterations: int = 50
|
||||
sample_rate: int = 256
|
||||
target_partition_size: Optional[int] = None
|
||||
# Name of the accelerator (e.g. "cuda") to use for IVF training. When set,
|
||||
# create_index() dispatches to pylance to build the index on the accelerator.
|
||||
accelerator: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -618,9 +609,6 @@ class IvfSq:
|
||||
max_iterations: int = 50
|
||||
sample_rate: int = 256
|
||||
target_partition_size: Optional[int] = None
|
||||
# Name of the accelerator (e.g. "cuda") to use for IVF training. When set,
|
||||
# create_index() dispatches to pylance to build the index on the accelerator.
|
||||
accelerator: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -751,9 +739,6 @@ class IvfPq:
|
||||
max_iterations: int = 50
|
||||
sample_rate: int = 256
|
||||
target_partition_size: Optional[int] = None
|
||||
# Name of the accelerator (e.g. "cuda") to use for IVF training. When set,
|
||||
# create_index() dispatches to pylance to build the index on the accelerator.
|
||||
accelerator: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -807,9 +792,6 @@ class IvfRq:
|
||||
max_iterations: int = 50
|
||||
sample_rate: int = 256
|
||||
target_partition_size: Optional[int] = None
|
||||
# Name of the accelerator (e.g. "cuda") to use for IVF training. When set,
|
||||
# create_index() dispatches to pylance to build the index on the accelerator.
|
||||
accelerator: Optional[str] = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -34,8 +34,6 @@ class LanceMergeInsertBuilder(object):
|
||||
self._when_not_matched_by_source_condition = None
|
||||
self._timeout = None
|
||||
self._use_index = True
|
||||
self._use_lsm_write = None
|
||||
self._validate_single_shard = None
|
||||
|
||||
def when_matched_update_all(
|
||||
self, *, where: Optional[str] = None
|
||||
@@ -98,46 +96,6 @@ class LanceMergeInsertBuilder(object):
|
||||
self._use_index = use_index
|
||||
return self
|
||||
|
||||
def use_lsm_write(self, use_lsm_write: bool) -> LanceMergeInsertBuilder:
|
||||
"""
|
||||
Controls whether the merge uses the MemWAL LSM write path.
|
||||
|
||||
By default (unset), a `merge_insert` on a table with an LSM write spec
|
||||
is routed through Lance's MemWAL shard writer, and a table without one
|
||||
uses the standard path. Pass `False` to force the standard path even
|
||||
when a spec is set. Pass `True` to require a spec — `merge_insert`
|
||||
raises an error if none is installed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
use_lsm_write: bool
|
||||
Whether to use the LSM write path.
|
||||
"""
|
||||
self._use_lsm_write = use_lsm_write
|
||||
return self
|
||||
|
||||
def validate_single_shard(
|
||||
self, validate_single_shard: bool
|
||||
) -> LanceMergeInsertBuilder:
|
||||
"""
|
||||
Controls how an LSM merge checks that its input targets a single shard.
|
||||
|
||||
When a table has an LSM write spec, every row in a `merge_insert` call
|
||||
must route to the same shard. When `True` (the default), every row is
|
||||
inspected to verify this. When `False`, only the first row is inspected
|
||||
and the shard it routes to is used for the whole input — a faster path
|
||||
for callers that have already pre-sharded their input.
|
||||
|
||||
Has no effect on tables without an LSM write spec.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
validate_single_shard: bool
|
||||
Whether to check every row routes to one shard. Defaults to `True`.
|
||||
"""
|
||||
self._validate_single_shard = validate_single_shard
|
||||
return self
|
||||
|
||||
def execute(
|
||||
self,
|
||||
new_data: DATA,
|
||||
|
||||
@@ -3,13 +3,12 @@
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
|
||||
from deprecation import deprecated
|
||||
import pyarrow as pa
|
||||
|
||||
from ._lancedb import async_permutation_builder, PermutationReader
|
||||
from .table import LanceTable, Table
|
||||
from .table import LanceTable
|
||||
from .background_loop import LOOP
|
||||
from .util import batch_to_tensor, batch_to_tensor_rows
|
||||
from typing import Any, Callable, Iterator, Literal, Optional, TYPE_CHECKING, Union
|
||||
@@ -355,49 +354,6 @@ class Transforms:
|
||||
DEFAULT_BATCH_SIZE = 100
|
||||
|
||||
|
||||
def _table_to_pickle_state(table: Table) -> dict[str, Any]:
|
||||
from .remote.table import RemoteTable
|
||||
|
||||
if isinstance(table, RemoteTable):
|
||||
return {
|
||||
"kind": "remote",
|
||||
"table": table,
|
||||
}
|
||||
|
||||
if not isinstance(table, LanceTable):
|
||||
raise ValueError(f"Cannot pickle table of type {type(table)!r}")
|
||||
|
||||
base_uri = table._conn.uri
|
||||
if base_uri.startswith("memory://"):
|
||||
return {
|
||||
"kind": "memory",
|
||||
"name": table.name,
|
||||
"data": table.to_arrow(),
|
||||
}
|
||||
|
||||
return {
|
||||
"kind": "local",
|
||||
"name": table.name,
|
||||
"uri": base_uri,
|
||||
"namespace": table._namespace_path,
|
||||
"storage_options": table._conn.storage_options,
|
||||
}
|
||||
|
||||
|
||||
def _table_from_pickle_state(state: dict[str, Any]) -> Table:
|
||||
from . import connect
|
||||
|
||||
kind = state["kind"]
|
||||
if kind == "remote":
|
||||
return state["table"]
|
||||
if kind == "memory":
|
||||
return connect("memory://").create_table(state["name"], state["data"])
|
||||
if kind == "local":
|
||||
db = connect(state["uri"], storage_options=state["storage_options"])
|
||||
return db.open_table(state["name"], namespace_path=state["namespace"] or None)
|
||||
raise ValueError(f"Unknown table pickle state kind: {kind}")
|
||||
|
||||
|
||||
class Permutation:
|
||||
"""
|
||||
A Permutation is a view of a dataset that can be used as input to model training
|
||||
@@ -413,15 +369,15 @@ class Permutation:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_table: Table,
|
||||
permutation_table: Optional[Table],
|
||||
base_table: LanceTable,
|
||||
permutation_table: Optional[LanceTable],
|
||||
split: int,
|
||||
selection: dict[str, str],
|
||||
batch_size: int,
|
||||
transform_fn: Callable[pa.RecordBatch, Any],
|
||||
offset: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
connection_factory: Optional[Callable[[str], Table]] = None,
|
||||
connection_factory: Optional[Callable[[str], LanceTable]] = None,
|
||||
_reader: Optional[PermutationReader] = None,
|
||||
):
|
||||
"""
|
||||
@@ -441,7 +397,6 @@ class Permutation:
|
||||
if _reader is None:
|
||||
_reader = LOOP.run(self._build_reader())
|
||||
self.reader: PermutationReader = _reader
|
||||
self._pid = os.getpid()
|
||||
|
||||
async def _build_reader(self) -> PermutationReader:
|
||||
reader = await PermutationReader.from_tables(
|
||||
@@ -473,25 +428,29 @@ class Permutation:
|
||||
return new
|
||||
|
||||
def with_connection_factory(
|
||||
self, connection_factory: Callable[[str], Table]
|
||||
self, connection_factory: Callable[[str], LanceTable]
|
||||
) -> "Permutation":
|
||||
"""
|
||||
Creates a new permutation that will use ``connection_factory`` to reopen
|
||||
the base table when this permutation is unpickled in a worker process.
|
||||
|
||||
The factory is a callable that takes a single argument, the base table
|
||||
name, and returns a LanceDB table. It must be picklable; the worker
|
||||
The factory is a callable that takes a single argument — the base table
|
||||
name — and returns a [LanceTable]. It must be picklable; the worker
|
||||
will pickle it via standard ``pickle`` and call it to recover the base
|
||||
table. Picklable callables in practice means top-level (module-level)
|
||||
functions, ``functools.partial`` of such functions, or instances of
|
||||
picklable classes implementing ``__call__``. Lambdas and closures over
|
||||
local variables don't pickle with the default protocol.
|
||||
|
||||
A factory is optional for normal local and remote LanceDB connections:
|
||||
if not set, ``__getstate__`` captures the table's own picklable reopen
|
||||
state. Use a factory when that default state is not enough, for example
|
||||
when credentials should be loaded from the worker environment instead
|
||||
of being embedded in the pickle.
|
||||
Setting a factory is necessary when the URI alone is not enough to
|
||||
re-open the connection — most importantly for LanceDB Cloud (``db://``)
|
||||
connections, where ``api_key`` and ``region`` aren't recoverable from
|
||||
the connection object after construction.
|
||||
|
||||
For local file or cloud-storage paths the factory is optional: if not
|
||||
set, ``__getstate__`` falls back to capturing
|
||||
``(uri, storage_options, namespace_path)`` and re-opening via
|
||||
``lancedb.connect(uri, storage_options=...)``.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -549,7 +508,7 @@ class Permutation:
|
||||
return new
|
||||
|
||||
@classmethod
|
||||
def identity(cls, table: Table) -> "Permutation":
|
||||
def identity(cls, table: LanceTable) -> "Permutation":
|
||||
"""
|
||||
Creates an identity permutation for the given table.
|
||||
"""
|
||||
@@ -558,8 +517,8 @@ class Permutation:
|
||||
@classmethod
|
||||
def from_tables(
|
||||
cls,
|
||||
base_table: Table,
|
||||
permutation_table: Optional[Table] = None,
|
||||
base_table: LanceTable,
|
||||
permutation_table: Optional[LanceTable] = None,
|
||||
split: Optional[Union[str, int]] = None,
|
||||
) -> "Permutation":
|
||||
"""
|
||||
@@ -635,10 +594,11 @@ class Permutation:
|
||||
|
||||
The base table is captured either via a user-supplied
|
||||
``connection_factory`` (see [with_connection_factory]) or, as a
|
||||
fallback, by the table's own picklable reopen state. The permutation
|
||||
table is captured as a pyarrow Table (which pickles via Arrow IPC
|
||||
natively). The reader is dropped from the wire format and rebuilt
|
||||
lazily on first use.
|
||||
fallback, by introspecting ``(uri, storage_options, namespace_path)``
|
||||
on the connection. The permutation table — always an in-memory
|
||||
LanceDB table — is captured as a pyarrow Table (which pickles via
|
||||
Arrow IPC natively). The reader is dropped from the wire format;
|
||||
``__setstate__`` rebuilds it from the restored tables.
|
||||
"""
|
||||
permutation_data: Optional[pa.Table] = None
|
||||
if self.permutation_table is not None:
|
||||
@@ -662,9 +622,39 @@ class Permutation:
|
||||
# namespace from the existing connection.
|
||||
return common
|
||||
|
||||
# URI-introspection fallback: only viable for native (OSS) connections
|
||||
# where (uri, storage_options) is enough to reopen. Remote / cloud
|
||||
# connections don't expose recoverable api_key / region — those users
|
||||
# must call with_connection_factory().
|
||||
try:
|
||||
base_uri = self.base_table._conn.uri
|
||||
storage_options = self.base_table._conn.storage_options
|
||||
except AttributeError as e:
|
||||
raise ValueError(
|
||||
"Cannot pickle this Permutation: the base table's connection "
|
||||
"does not expose a uri/storage_options, which usually means it "
|
||||
"is a remote (LanceDB Cloud) connection. Call "
|
||||
"Permutation.with_connection_factory(...) first to provide a "
|
||||
"picklable callable that re-opens the base table from a worker "
|
||||
"process."
|
||||
) from e
|
||||
|
||||
if base_uri.startswith("memory://"):
|
||||
# In-memory base tables don't exist in any worker process by
|
||||
# default, so dump the entire base table into the pickle. This
|
||||
# can be expensive for large datasets — users with large
|
||||
# in-memory base tables should either persist them or set a
|
||||
# connection_factory.
|
||||
return {
|
||||
**common,
|
||||
"base_table_data": self.base_table.to_arrow(),
|
||||
}
|
||||
|
||||
return {
|
||||
**common,
|
||||
"base_table_state": _table_to_pickle_state(self.base_table),
|
||||
"base_table_uri": base_uri,
|
||||
"base_table_namespace": self.base_table._namespace_path,
|
||||
"base_table_storage_options": storage_options,
|
||||
}
|
||||
|
||||
def __setstate__(self, state: dict[str, Any]) -> None:
|
||||
@@ -673,8 +663,6 @@ class Permutation:
|
||||
connection_factory = state["connection_factory"]
|
||||
if connection_factory is not None:
|
||||
base_table = connection_factory(state["base_table_name"])
|
||||
elif "base_table_state" in state:
|
||||
base_table = _table_from_pickle_state(state["base_table_state"])
|
||||
elif "base_table_data" in state:
|
||||
# In-memory base table inlined into the pickle; rebuild the same
|
||||
# way we rebuild the in-memory permutation table.
|
||||
@@ -692,7 +680,7 @@ class Permutation:
|
||||
namespace_path=state["base_table_namespace"] or None,
|
||||
)
|
||||
|
||||
permutation_table: Optional[Table] = None
|
||||
permutation_table: Optional[LanceTable] = None
|
||||
if state["permutation_data"] is not None:
|
||||
mem_db = connect("memory://")
|
||||
permutation_table = mem_db.create_table(
|
||||
@@ -708,28 +696,10 @@ class Permutation:
|
||||
self.offset = state["offset"]
|
||||
self.limit = state["limit"]
|
||||
self.connection_factory = connection_factory
|
||||
self.reader = None
|
||||
self._pid = None
|
||||
|
||||
def _ensure_open(self) -> None:
|
||||
pid = os.getpid()
|
||||
if self.reader is not None and getattr(self, "_pid", None) == pid:
|
||||
return
|
||||
# The reader owns Rust-side table handles. Rebuild it after unpickle or
|
||||
# fork even though the Python table wrappers reopen themselves.
|
||||
if hasattr(self.base_table, "_ensure_open"):
|
||||
self.base_table._ensure_open()
|
||||
if self.permutation_table is not None and hasattr(
|
||||
self.permutation_table, "_ensure_open"
|
||||
):
|
||||
self.permutation_table._ensure_open()
|
||||
self.reader = LOOP.run(self._build_reader())
|
||||
self._pid = pid
|
||||
|
||||
@property
|
||||
def schema(self) -> pa.Schema:
|
||||
self._ensure_open()
|
||||
|
||||
async def do_output_schema():
|
||||
return await self.reader.output_schema(self.selection)
|
||||
|
||||
@@ -747,7 +717,6 @@ class Permutation:
|
||||
"""
|
||||
The number of rows in the permutation
|
||||
"""
|
||||
self._ensure_open()
|
||||
return self.reader.count_rows()
|
||||
|
||||
@property
|
||||
@@ -906,7 +875,6 @@ class Permutation:
|
||||
If skip_last_batch is True, the last batch will be skipped if it is not a
|
||||
multiple of batch_size.
|
||||
"""
|
||||
self._ensure_open()
|
||||
|
||||
async def get_iter():
|
||||
return await self.reader.read(self.selection, batch_size=batch_size)
|
||||
@@ -1008,7 +976,6 @@ class Permutation:
|
||||
so `with_format` and `with_transform` affect this method in the same way
|
||||
they affect iteration.
|
||||
"""
|
||||
self._ensure_open()
|
||||
|
||||
async def do_take_offsets():
|
||||
return await self.reader.take_offsets(offsets, selection=self.selection)
|
||||
@@ -1044,11 +1011,9 @@ class Permutation:
|
||||
"""
|
||||
Skip the first `skip` rows of the permutation
|
||||
"""
|
||||
self._ensure_open()
|
||||
new = copy.copy(self)
|
||||
new.offset = skip
|
||||
new.reader = LOOP.run(new._build_reader())
|
||||
new._pid = os.getpid()
|
||||
return new
|
||||
|
||||
@deprecated(details="Use with_take instead")
|
||||
@@ -1067,11 +1032,9 @@ class Permutation:
|
||||
"""
|
||||
Limit the permutation to `limit` rows (following any `skip`)
|
||||
"""
|
||||
self._ensure_open()
|
||||
new = copy.copy(self)
|
||||
new.limit = limit
|
||||
new.reader = LOOP.run(new._build_reader())
|
||||
new._pid = os.getpid()
|
||||
return new
|
||||
|
||||
@deprecated(details="Use with_repeat instead")
|
||||
|
||||
@@ -3,14 +3,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from datetime import timedelta
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
@@ -19,40 +17,41 @@ from typing import (
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
Any,
|
||||
)
|
||||
|
||||
import asyncio
|
||||
import deprecation
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
import pydantic
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from lancedb._lancedb import fts_query_to_json
|
||||
from lancedb.background_loop import LOOP
|
||||
from lancedb.pydantic import PYDANTIC_VERSION
|
||||
from lancedb.background_loop import LOOP
|
||||
|
||||
from . import __version__
|
||||
from .arrow import AsyncRecordBatchReader
|
||||
from .dependencies import pandas as pd
|
||||
from .expr import Expr
|
||||
from .rerankers.base import Reranker
|
||||
from .rerankers.rrf import RRFReranker
|
||||
from .rerankers.util import check_reranker_result
|
||||
from .util import flatten_columns
|
||||
from .expr import Expr
|
||||
from lancedb._lancedb import fts_query_to_json
|
||||
from typing_extensions import Annotated
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import sys
|
||||
|
||||
import PIL
|
||||
import polars as pl
|
||||
|
||||
from ._lancedb import Query as LanceQuery
|
||||
from ._lancedb import FTSQuery as LanceFTSQuery
|
||||
from ._lancedb import HybridQuery as LanceHybridQuery
|
||||
from ._lancedb import PyQueryRequest
|
||||
from ._lancedb import Query as LanceQuery
|
||||
from ._lancedb import TakeQuery as LanceTakeQuery
|
||||
from ._lancedb import VectorQuery as LanceVectorQuery
|
||||
from ._lancedb import TakeQuery as LanceTakeQuery
|
||||
from ._lancedb import PyQueryRequest
|
||||
from .common import VEC
|
||||
from .pydantic import LanceModel
|
||||
from .table import Table
|
||||
@@ -719,7 +718,6 @@ class LanceQueryBuilder(ABC):
|
||||
flatten: Optional[Union[int, bool]] = None,
|
||||
*,
|
||||
timeout: Optional[timedelta] = None,
|
||||
**kwargs,
|
||||
) -> "pd.DataFrame":
|
||||
"""
|
||||
Execute the query and return the results as a pandas DataFrame.
|
||||
@@ -737,12 +735,9 @@ class LanceQueryBuilder(ABC):
|
||||
timeout: Optional[timedelta]
|
||||
The maximum time to wait for the query to complete.
|
||||
If None, wait indefinitely.
|
||||
**kwargs
|
||||
Forwarded to pyarrow.Table.to_pandas after query execution and
|
||||
optional flattening.
|
||||
"""
|
||||
tbl = flatten_columns(self.to_arrow(timeout=timeout), flatten)
|
||||
return tbl.to_pandas(**kwargs)
|
||||
return tbl.to_pandas()
|
||||
|
||||
@abstractmethod
|
||||
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
|
||||
@@ -2357,7 +2352,6 @@ class AsyncQueryBase(object):
|
||||
self,
|
||||
flatten: Optional[Union[int, bool]] = None,
|
||||
timeout: Optional[timedelta] = None,
|
||||
**kwargs,
|
||||
) -> "pd.DataFrame":
|
||||
"""
|
||||
Execute the query and collect the results into a pandas DataFrame.
|
||||
@@ -2390,13 +2384,10 @@ class AsyncQueryBase(object):
|
||||
The maximum time to wait for the query to complete.
|
||||
If not specified, no timeout is applied. If the query does not
|
||||
complete within the specified time, an error will be raised.
|
||||
**kwargs
|
||||
Forwarded to pyarrow.Table.to_pandas after query execution and
|
||||
optional flattening.
|
||||
"""
|
||||
return (
|
||||
flatten_columns(await self.to_arrow(timeout=timeout), flatten)
|
||||
).to_pandas(**kwargs)
|
||||
).to_pandas()
|
||||
|
||||
async def to_polars(
|
||||
self,
|
||||
@@ -3349,18 +3340,16 @@ class BaseQueryBuilder(object):
|
||||
If not specified, no timeout is applied. If the query does not
|
||||
complete within the specified time, an error will be raised.
|
||||
"""
|
||||
async_reader = LOOP.run(
|
||||
self._inner.to_batches(max_batch_length=max_batch_length, timeout=timeout)
|
||||
)
|
||||
async_iter = LOOP.run(self._inner.execute(max_batch_length, timeout))
|
||||
|
||||
def iter_sync():
|
||||
try:
|
||||
while True:
|
||||
yield LOOP.run(async_reader.__anext__())
|
||||
yield LOOP.run(async_iter.__anext__())
|
||||
except StopAsyncIteration:
|
||||
return
|
||||
|
||||
return pa.RecordBatchReader.from_batches(async_reader.schema, iter_sync())
|
||||
return pa.RecordBatchReader.from_batches(async_iter.schema, iter_sync())
|
||||
|
||||
def to_arrow(self, timeout: Optional[timedelta] = None) -> pa.Table:
|
||||
"""
|
||||
@@ -3400,7 +3389,6 @@ class BaseQueryBuilder(object):
|
||||
self,
|
||||
flatten: Optional[Union[int, bool]] = None,
|
||||
timeout: Optional[timedelta] = None,
|
||||
**kwargs,
|
||||
) -> "pd.DataFrame":
|
||||
"""
|
||||
Execute the query and collect the results into a pandas DataFrame.
|
||||
@@ -3433,11 +3421,8 @@ class BaseQueryBuilder(object):
|
||||
The maximum time to wait for the query to complete.
|
||||
If not specified, no timeout is applied. If the query does not
|
||||
complete within the specified time, an error will be raised.
|
||||
**kwargs
|
||||
Forwarded to pyarrow.Table.to_pandas after query execution and
|
||||
optional flattening.
|
||||
"""
|
||||
return LOOP.run(self._inner.to_pandas(flatten, timeout, **kwargs))
|
||||
return LOOP.run(self._inner.to_pandas(flatten, timeout))
|
||||
|
||||
def to_polars(
|
||||
self,
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import List, Optional
|
||||
from lancedb import __version__
|
||||
|
||||
from .header import HeaderProvider
|
||||
from .oauth import OAuthConfig, OAuthFlowType
|
||||
|
||||
__all__ = [
|
||||
"TimeoutConfig",
|
||||
@@ -16,6 +17,8 @@ __all__ = [
|
||||
"TlsConfig",
|
||||
"ClientConfig",
|
||||
"HeaderProvider",
|
||||
"OAuthConfig",
|
||||
"OAuthFlowType",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
|
||||
from datetime import timedelta
|
||||
import json
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import sys
|
||||
@@ -18,7 +17,7 @@ else:
|
||||
|
||||
# Remove this import to fix circular dependency
|
||||
# from lancedb import connect_async
|
||||
from lancedb.remote import ClientConfig, RetryConfig, TimeoutConfig, TlsConfig
|
||||
from lancedb.remote import ClientConfig
|
||||
import pyarrow as pa
|
||||
|
||||
from ..common import DATA
|
||||
@@ -37,64 +36,6 @@ from ..table import Table
|
||||
from ..util import validate_table_name
|
||||
|
||||
|
||||
def _duration_seconds(value: Optional[timedelta]) -> Optional[float]:
|
||||
return value.total_seconds() if value is not None else None
|
||||
|
||||
|
||||
def _timeout_config_to_dict(
|
||||
config: Optional[TimeoutConfig],
|
||||
) -> Optional[dict[str, Any]]:
|
||||
if config is None:
|
||||
return None
|
||||
return {
|
||||
"timeout": _duration_seconds(config.timeout),
|
||||
"connect_timeout": _duration_seconds(config.connect_timeout),
|
||||
"read_timeout": _duration_seconds(config.read_timeout),
|
||||
"pool_idle_timeout": _duration_seconds(config.pool_idle_timeout),
|
||||
}
|
||||
|
||||
|
||||
def _retry_config_to_dict(config: RetryConfig) -> dict[str, Any]:
|
||||
return {
|
||||
"retries": config.retries,
|
||||
"connect_retries": config.connect_retries,
|
||||
"read_retries": config.read_retries,
|
||||
"backoff_factor": config.backoff_factor,
|
||||
"backoff_jitter": config.backoff_jitter,
|
||||
"statuses": config.statuses,
|
||||
}
|
||||
|
||||
|
||||
def _tls_config_to_dict(config: Optional[TlsConfig]) -> Optional[dict[str, Any]]:
|
||||
if config is None:
|
||||
return None
|
||||
return {
|
||||
"cert_file": config.cert_file,
|
||||
"key_file": config.key_file,
|
||||
"ssl_ca_cert": config.ssl_ca_cert,
|
||||
"assert_hostname": config.assert_hostname,
|
||||
}
|
||||
|
||||
|
||||
def _client_config_to_dict(config: ClientConfig) -> dict[str, Any]:
|
||||
if config.header_provider is not None:
|
||||
raise ValueError(
|
||||
"Cannot serialize a remote connection with a header_provider. "
|
||||
"Use static api_key/extra_headers or provide a worker-side "
|
||||
"connection factory instead."
|
||||
)
|
||||
return {
|
||||
"user_agent": config.user_agent,
|
||||
"retry_config": _retry_config_to_dict(config.retry_config),
|
||||
"timeout_config": _timeout_config_to_dict(config.timeout_config),
|
||||
"extra_headers": config.extra_headers,
|
||||
"id_delimiter": config.id_delimiter,
|
||||
"tls_config": _tls_config_to_dict(config.tls_config),
|
||||
"header_provider": None,
|
||||
"user_id": config.user_id,
|
||||
}
|
||||
|
||||
|
||||
class RemoteDBConnection(DBConnection):
|
||||
"""A connection to a remote LanceDB database."""
|
||||
|
||||
@@ -109,7 +50,6 @@ class RemoteDBConnection(DBConnection):
|
||||
connection_timeout: Optional[float] = None,
|
||||
read_timeout: Optional[float] = None,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
read_consistency_interval: Optional[timedelta] = None,
|
||||
):
|
||||
"""Connect to a remote LanceDB database."""
|
||||
if isinstance(client_config, dict):
|
||||
@@ -148,11 +88,6 @@ class RemoteDBConnection(DBConnection):
|
||||
parsed = urlparse(db_url)
|
||||
if parsed.scheme != "db":
|
||||
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
|
||||
self.db_url = db_url
|
||||
self.api_key = api_key
|
||||
self.region = region
|
||||
self.host_override = host_override
|
||||
self.storage_options = storage_options
|
||||
self.db_name = parsed.netloc
|
||||
|
||||
self.client_config = client_config
|
||||
@@ -168,27 +103,12 @@ class RemoteDBConnection(DBConnection):
|
||||
host_override=host_override,
|
||||
client_config=client_config,
|
||||
storage_options=storage_options,
|
||||
read_consistency_interval=read_consistency_interval,
|
||||
)
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"RemoteConnect(name={self.db_name})"
|
||||
|
||||
@override
|
||||
def serialize(self) -> str:
|
||||
return json.dumps(
|
||||
{
|
||||
"connection_type": "remote",
|
||||
"db_url": self.db_url,
|
||||
"api_key": self.api_key,
|
||||
"region": self.region,
|
||||
"host_override": self.host_override,
|
||||
"client_config": _client_config_to_dict(self.client_config),
|
||||
"storage_options": self.storage_options,
|
||||
}
|
||||
)
|
||||
|
||||
@override
|
||||
def list_namespaces(
|
||||
self,
|
||||
@@ -409,12 +329,7 @@ class RemoteDBConnection(DBConnection):
|
||||
)
|
||||
|
||||
table = LOOP.run(self._conn.open_table(name, namespace_path=namespace_path))
|
||||
return RemoteTable(
|
||||
table,
|
||||
self.db_name,
|
||||
connection_state=self.serialize,
|
||||
namespace_path=namespace_path,
|
||||
)
|
||||
return RemoteTable(table, self.db_name)
|
||||
|
||||
def clone_table(
|
||||
self,
|
||||
@@ -463,12 +378,7 @@ class RemoteDBConnection(DBConnection):
|
||||
is_shallow=is_shallow,
|
||||
)
|
||||
)
|
||||
return RemoteTable(
|
||||
table,
|
||||
self.db_name,
|
||||
connection_state=self.serialize,
|
||||
namespace_path=target_namespace_path,
|
||||
)
|
||||
return RemoteTable(table, self.db_name)
|
||||
|
||||
@override
|
||||
def create_table(
|
||||
@@ -613,12 +523,7 @@ class RemoteDBConnection(DBConnection):
|
||||
fill_value=fill_value,
|
||||
)
|
||||
)
|
||||
return RemoteTable(
|
||||
table,
|
||||
self.db_name,
|
||||
connection_state=self.serialize,
|
||||
namespace_path=namespace_path,
|
||||
)
|
||||
return RemoteTable(table, self.db_name)
|
||||
|
||||
@override
|
||||
def drop_table(self, name: str, namespace_path: Optional[List[str]] = None):
|
||||
|
||||
90
python/python/lancedb/remote/oauth.py
Normal file
90
python/python/lancedb/remote/oauth.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class OAuthFlowType(str, Enum):
|
||||
"""OAuth authentication flow types."""
|
||||
|
||||
CLIENT_CREDENTIALS = "client_credentials"
|
||||
"""Client Credentials grant (service-to-service / M2M)."""
|
||||
|
||||
AUTHORIZATION_CODE_PKCE = "authorization_code_pkce"
|
||||
"""Authorization Code with PKCE (interactive browser-based auth)."""
|
||||
|
||||
DEVICE_CODE = "device_code"
|
||||
"""Device Code grant (CLI / headless environments)."""
|
||||
|
||||
AZURE_MANAGED_IDENTITY = "azure_managed_identity"
|
||||
"""Azure Managed Identity via IMDS."""
|
||||
|
||||
WORKLOAD_IDENTITY = "workload_identity"
|
||||
"""Workload Identity Federation (K8s, GitHub Actions)."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthConfig:
|
||||
"""OAuth configuration for LanceDB authentication.
|
||||
|
||||
All token acquisition and refresh is handled in the Rust layer.
|
||||
This config is passed through to Rust via PyO3.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
issuer_url : str
|
||||
OIDC issuer URL or OAuth authority URL.
|
||||
For Azure: ``https://login.microsoftonline.com/{tenant_id}/v2.0``
|
||||
client_id : str
|
||||
Application / Client ID.
|
||||
scopes : List[str]
|
||||
OAuth scopes to request.
|
||||
For Azure: ``["api://{app_id}/.default"]``
|
||||
flow : OAuthFlowType
|
||||
Authentication flow to use. Default: CLIENT_CREDENTIALS.
|
||||
client_secret : Optional[str]
|
||||
Client secret (required for CLIENT_CREDENTIALS).
|
||||
redirect_uri : Optional[str]
|
||||
Redirect URI for AUTHORIZATION_CODE_PKCE flow.
|
||||
callback_port : Optional[int]
|
||||
Port for local HTTP callback server (AUTHORIZATION_CODE_PKCE, default: 8400).
|
||||
managed_identity_client_id : Optional[str]
|
||||
Client ID for user-assigned managed identity (AZURE_MANAGED_IDENTITY).
|
||||
token_file : Optional[str]
|
||||
Path to federated token file (WORKLOAD_IDENTITY).
|
||||
refresh_buffer_secs : Optional[int]
|
||||
Seconds before expiry to trigger proactive refresh (default: 300).
|
||||
|
||||
Examples
|
||||
--------
|
||||
Client Credentials (service-to-service):
|
||||
|
||||
>>> config = OAuthConfig(
|
||||
... issuer_url="https://login.microsoftonline.com/{tenant}/v2.0",
|
||||
... client_id="app-id",
|
||||
... client_secret="secret",
|
||||
... scopes=["api://lancedb-api/.default"],
|
||||
... )
|
||||
|
||||
Azure Managed Identity:
|
||||
|
||||
>>> config = OAuthConfig(
|
||||
... issuer_url="https://login.microsoftonline.com/{tenant}/v2.0",
|
||||
... client_id="app-id",
|
||||
... scopes=["api://lancedb-api/.default"],
|
||||
... flow=OAuthFlowType.AZURE_MANAGED_IDENTITY,
|
||||
... )
|
||||
"""
|
||||
|
||||
issuer_url: str
|
||||
client_id: str
|
||||
scopes: List[str]
|
||||
flow: OAuthFlowType = OAuthFlowType.CLIENT_CREDENTIALS
|
||||
client_secret: Optional[str] = None
|
||||
redirect_uri: Optional[str] = None
|
||||
callback_port: Optional[int] = None
|
||||
managed_identity_client_id: Optional[str] = None
|
||||
token_file: Optional[str] = None
|
||||
refresh_buffer_secs: Optional[int] = None
|
||||
@@ -2,30 +2,15 @@
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from datetime import timedelta
|
||||
import deprecation
|
||||
import logging
|
||||
from functools import cached_property
|
||||
import os
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
Literal,
|
||||
overload,
|
||||
)
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Union, Literal
|
||||
import warnings
|
||||
|
||||
from lancedb import __version__
|
||||
|
||||
from lancedb._lancedb import (
|
||||
AddColumnsResult,
|
||||
AddResult,
|
||||
AlterColumnsResult,
|
||||
UpdateFieldMetadataResult,
|
||||
DeleteResult,
|
||||
DropColumnsResult,
|
||||
IndexConfig,
|
||||
@@ -47,7 +32,6 @@ from lancedb.index import (
|
||||
LabelList,
|
||||
)
|
||||
from lancedb.remote.db import LOOP
|
||||
from lancedb.table import IndexConfigType, KNOWN_METRICS
|
||||
import pyarrow as pa
|
||||
|
||||
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
@@ -56,7 +40,7 @@ from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
from lancedb.table import _normalize_progress
|
||||
|
||||
from ..query import LanceVectorQueryBuilder, LanceQueryBuilder, LanceTakeQueryBuilder
|
||||
from ..table import AsyncTable, BlobMode, IndexStatistics, Query, Table, Tags
|
||||
from ..table import AsyncTable, IndexStatistics, Query, Table, Tags
|
||||
from ..types import BaseTokenizerType
|
||||
|
||||
|
||||
@@ -65,80 +49,14 @@ class RemoteTable(Table):
|
||||
self,
|
||||
table: AsyncTable,
|
||||
db_name: str,
|
||||
*,
|
||||
connection_state: Optional[Union[str, Callable[[], str]]] = None,
|
||||
namespace_path: Optional[List[str]] = None,
|
||||
):
|
||||
self._table_handle = table
|
||||
self._name = table.name
|
||||
self._table = table
|
||||
self.db_name = db_name
|
||||
self._connection_state = connection_state
|
||||
self._namespace_path = list(namespace_path or [])
|
||||
self._checkout_version: Optional[int] = None
|
||||
self._pid = os.getpid()
|
||||
|
||||
def _serialized_connection_state(self) -> str:
|
||||
if self._connection_state is None:
|
||||
raise RuntimeError(
|
||||
"Cannot reopen this remote table because it does not carry "
|
||||
"serialized connection state"
|
||||
)
|
||||
if callable(self._connection_state):
|
||||
self._connection_state = self._connection_state()
|
||||
return self._connection_state
|
||||
|
||||
@property
|
||||
def _table(self) -> AsyncTable:
|
||||
self._ensure_open()
|
||||
assert self._table_handle is not None
|
||||
return self._table_handle
|
||||
|
||||
@_table.setter
|
||||
def _table(self, table: AsyncTable) -> None:
|
||||
self._table_handle = table
|
||||
self._name = table.name
|
||||
self._pid = os.getpid()
|
||||
|
||||
def _ensure_open(self) -> None:
|
||||
pid = os.getpid()
|
||||
if self._table_handle is not None and self._pid == pid:
|
||||
return
|
||||
|
||||
# Pickle clears the handle; fork inherits a handle created in the
|
||||
# parent process. In both cases reopen before touching the Rust client.
|
||||
from lancedb import deserialize_conn
|
||||
|
||||
db = deserialize_conn(self._serialized_connection_state(), for_worker=True)
|
||||
table = db.open_table(self._name, namespace_path=self._namespace_path)
|
||||
if self._checkout_version is not None:
|
||||
table.checkout(self._checkout_version)
|
||||
|
||||
self._table_handle = table._table
|
||||
self.db_name = table.db_name
|
||||
self._pid = pid
|
||||
|
||||
def __getstate__(self) -> dict:
|
||||
return {
|
||||
"connection_state": self._serialized_connection_state(),
|
||||
"db_name": self.db_name,
|
||||
"name": self.name,
|
||||
"namespace_path": self._namespace_path,
|
||||
"checkout_version": self._checkout_version,
|
||||
}
|
||||
|
||||
def __setstate__(self, state: dict) -> None:
|
||||
self._table_handle = None
|
||||
self._name = state["name"]
|
||||
self.db_name = state["db_name"]
|
||||
self._connection_state = state["connection_state"]
|
||||
self._namespace_path = state["namespace_path"]
|
||||
self._checkout_version = state["checkout_version"]
|
||||
self._pid = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""The name of the table"""
|
||||
return self._name
|
||||
return self._table.name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"RemoteTable({self.db_name}.{self.name})"
|
||||
@@ -183,24 +101,18 @@ class RemoteTable(Table):
|
||||
"""to_arrow() is not yet supported on LanceDB cloud."""
|
||||
raise NotImplementedError("to_arrow() is not yet supported on LanceDB cloud.")
|
||||
|
||||
def to_pandas(self, blob_mode: BlobMode = "lazy", **kwargs):
|
||||
def to_pandas(self):
|
||||
"""to_pandas() is not yet supported on LanceDB cloud."""
|
||||
raise NotImplementedError("to_pandas() is not yet supported on LanceDB cloud.")
|
||||
|
||||
def checkout(self, version: Union[int, str]):
|
||||
result = LOOP.run(self._table.checkout(version))
|
||||
self._checkout_version = self.version
|
||||
return result
|
||||
return LOOP.run(self._table.checkout(version))
|
||||
|
||||
def checkout_latest(self):
|
||||
result = LOOP.run(self._table.checkout_latest())
|
||||
self._checkout_version = None
|
||||
return result
|
||||
return LOOP.run(self._table.checkout_latest())
|
||||
|
||||
def restore(self, version: Optional[Union[int, str]] = None):
|
||||
result = LOOP.run(self._table.restore(version))
|
||||
self._checkout_version = None
|
||||
return result
|
||||
return LOOP.run(self._table.restore(version))
|
||||
|
||||
def list_indices(self) -> Iterable[IndexConfig]:
|
||||
"""List all the indices on the table"""
|
||||
@@ -210,11 +122,6 @@ class RemoteTable(Table):
|
||||
"""List all the stats of a specified index"""
|
||||
return LOOP.run(self._table.index_stats(index_uuid))
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.25.0",
|
||||
current_version=__version__,
|
||||
details="Use create_index() with config=BTree()/Bitmap()/LabelList() instead.",
|
||||
)
|
||||
def create_scalar_index(
|
||||
self,
|
||||
column: str,
|
||||
@@ -224,12 +131,7 @@ class RemoteTable(Table):
|
||||
wait_timeout: Optional[timedelta] = None,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""Creates a scalar index.
|
||||
|
||||
.. deprecated:: 0.25.0
|
||||
Use :meth:`create_index` with a BTree, Bitmap, or LabelList config instead.
|
||||
Example: ``table.create_index("column", config=BTree())``
|
||||
|
||||
"""Creates a scalar index
|
||||
Parameters
|
||||
----------
|
||||
column : str
|
||||
@@ -260,11 +162,6 @@ class RemoteTable(Table):
|
||||
)
|
||||
)
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.25.0",
|
||||
current_version=__version__,
|
||||
details="Use create_index() with config=FTS() instead.",
|
||||
)
|
||||
def create_fts_index(
|
||||
self,
|
||||
column: str,
|
||||
@@ -285,12 +182,6 @@ class RemoteTable(Table):
|
||||
prefix_only: bool = False,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""Create a full-text search index on a column.
|
||||
|
||||
.. deprecated:: 0.25.0
|
||||
Use :meth:`create_index` with an FTS config instead.
|
||||
Example: ``table.create_index("text_column", config=FTS())``
|
||||
"""
|
||||
config = FTS(
|
||||
with_position=with_position,
|
||||
base_tokenizer=base_tokenizer,
|
||||
@@ -314,43 +205,9 @@ class RemoteTable(Table):
|
||||
)
|
||||
)
|
||||
|
||||
# New unified API overload
|
||||
@overload
|
||||
def create_index(
|
||||
self,
|
||||
column: str,
|
||||
/,
|
||||
*,
|
||||
config: IndexConfigType,
|
||||
wait_timeout: Optional[timedelta] = ...,
|
||||
name: Optional[str] = ...,
|
||||
train: bool = ...,
|
||||
) -> None: ...
|
||||
|
||||
# Legacy API overload (deprecated)
|
||||
@overload
|
||||
def create_index(
|
||||
self,
|
||||
metric: Literal["l2", "cosine", "dot", "hamming"] = ...,
|
||||
vector_column_name: str = ...,
|
||||
index_cache_size: Optional[int] = ...,
|
||||
num_partitions: Optional[int] = ...,
|
||||
num_sub_vectors: Optional[int] = ...,
|
||||
replace: Optional[bool] = ...,
|
||||
accelerator: Optional[str] = ...,
|
||||
index_type: Literal[
|
||||
"VECTOR", "IVF_FLAT", "IVF_SQ", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"
|
||||
] = ...,
|
||||
wait_timeout: Optional[timedelta] = ...,
|
||||
*,
|
||||
num_bits: int = ...,
|
||||
name: Optional[str] = ...,
|
||||
train: bool = ...,
|
||||
) -> None: ...
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
metric: str = "l2",
|
||||
metric="l2",
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
index_cache_size: Optional[int] = None,
|
||||
num_partitions: Optional[int] = None,
|
||||
@@ -361,113 +218,89 @@ class RemoteTable(Table):
|
||||
wait_timeout: Optional[timedelta] = None,
|
||||
*,
|
||||
num_bits: int = 8,
|
||||
config: Optional[IndexConfigType] = None,
|
||||
name: Optional[str] = None,
|
||||
train: bool = True,
|
||||
):
|
||||
"""Create an index on a column.
|
||||
"""Create an index on the table.
|
||||
|
||||
This method supports both the new unified API and the legacy API
|
||||
for backwards compatibility. The new API takes the column name as the
|
||||
first positional argument and an index configuration object via
|
||||
``config``; the legacy API takes the distance metric as the first
|
||||
argument plus separate ``vector_column_name`` / ``num_partitions`` /
|
||||
etc. parameters, and emits a ``DeprecationWarning``.
|
||||
Parameters
|
||||
----------
|
||||
metric : str
|
||||
The metric to use for the index. Default is "l2".
|
||||
vector_column_name : str
|
||||
The name of the vector column. Default is "vector".
|
||||
|
||||
Examples
|
||||
--------
|
||||
New API (recommended):
|
||||
|
||||
>>> table.create_index( # doctest: +SKIP
|
||||
... "vector", config=IvfPq(distance_type="l2")
|
||||
>>> import lancedb
|
||||
>>> import uuid
|
||||
>>> from lancedb.schema import vector
|
||||
>>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP
|
||||
... region="...") # doctest: +SKIP
|
||||
>>> table_name = uuid.uuid4().hex
|
||||
>>> schema = pa.schema(
|
||||
... [
|
||||
... pa.field("id", pa.uint32(), False),
|
||||
... pa.field("vector", vector(128), False),
|
||||
... pa.field("s", pa.string(), False),
|
||||
... ]
|
||||
... )
|
||||
>>> table.create_index("category", config=BTree()) # doctest: +SKIP
|
||||
>>> table.create_index("content", config=FTS()) # doctest: +SKIP
|
||||
|
||||
Legacy API (deprecated):
|
||||
|
||||
>>> table.create_index( # doctest: +SKIP
|
||||
... "l2", vector_column_name="vector"
|
||||
>>> table = db.create_table( # doctest: +SKIP
|
||||
... table_name, # doctest: +SKIP
|
||||
... schema=schema, # doctest: +SKIP
|
||||
... )
|
||||
>>> table.create_index("l2", "vector") # doctest: +SKIP
|
||||
"""
|
||||
# Detect whether this is a legacy API call
|
||||
is_legacy = self._is_legacy_create_index_call(
|
||||
metric,
|
||||
config,
|
||||
num_partitions,
|
||||
num_sub_vectors,
|
||||
vector_column_name,
|
||||
accelerator,
|
||||
index_cache_size,
|
||||
replace,
|
||||
)
|
||||
|
||||
if is_legacy:
|
||||
warnings.warn(
|
||||
"The create_index() API with metric/num_partitions parameters is "
|
||||
"deprecated and will be removed in a future version. "
|
||||
"Please migrate to the new unified API:\n"
|
||||
" # Old (deprecated):\n"
|
||||
" table.create_index('l2', vector_column_name='my_vector')\n"
|
||||
" # New (recommended):\n"
|
||||
" table.create_index('my_vector', config=IvfPq(distance_type='l2'))",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
if accelerator is not None:
|
||||
logging.warning(
|
||||
"GPU accelerator is not yet supported on LanceDB cloud."
|
||||
"If you have 100M+ vectors to index,"
|
||||
"please contact us at contact@lancedb.com"
|
||||
)
|
||||
if replace is not None:
|
||||
logging.warning(
|
||||
"replace is not supported on LanceDB cloud."
|
||||
"Existing indexes will always be replaced."
|
||||
)
|
||||
|
||||
column = vector_column_name
|
||||
|
||||
if accelerator is not None:
|
||||
logging.warning(
|
||||
"GPU accelerator is not yet supported on LanceDB cloud."
|
||||
"If you have 100M+ vectors to index,"
|
||||
"please contact us at contact@lancedb.com"
|
||||
)
|
||||
if replace is not None:
|
||||
logging.warning(
|
||||
"replace is not supported on LanceDB cloud."
|
||||
"Existing indexes will always be replaced."
|
||||
)
|
||||
|
||||
idx_type = index_type.upper()
|
||||
if idx_type == "VECTOR" or idx_type == "IVF_PQ":
|
||||
config = IvfPq(
|
||||
distance_type=metric,
|
||||
num_partitions=num_partitions,
|
||||
num_sub_vectors=num_sub_vectors,
|
||||
num_bits=num_bits,
|
||||
)
|
||||
elif idx_type == "IVF_RQ":
|
||||
config = IvfRq(
|
||||
distance_type=metric,
|
||||
num_partitions=num_partitions,
|
||||
num_bits=num_bits,
|
||||
)
|
||||
elif idx_type == "IVF_SQ":
|
||||
config = IvfSq(distance_type=metric, num_partitions=num_partitions)
|
||||
elif idx_type == "IVF_HNSW_PQ":
|
||||
raise ValueError(
|
||||
"IVF_HNSW_PQ is not supported on LanceDB cloud."
|
||||
"Please use IVF_HNSW_SQ instead."
|
||||
)
|
||||
elif idx_type == "IVF_HNSW_SQ":
|
||||
config = HnswSq(distance_type=metric, num_partitions=num_partitions)
|
||||
elif idx_type == "IVF_HNSW_FLAT":
|
||||
config = HnswFlat(distance_type=metric, num_partitions=num_partitions)
|
||||
elif idx_type == "IVF_FLAT":
|
||||
config = IvfFlat(distance_type=metric, num_partitions=num_partitions)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown vector index type: {idx_type}. Valid options are"
|
||||
" 'IVF_FLAT', 'IVF_PQ', 'IVF_RQ', 'IVF_SQ',"
|
||||
" 'IVF_HNSW_PQ', 'IVF_HNSW_SQ', 'IVF_HNSW_FLAT'"
|
||||
)
|
||||
index_type = index_type.upper()
|
||||
if index_type == "VECTOR" or index_type == "IVF_PQ":
|
||||
config = IvfPq(
|
||||
distance_type=metric,
|
||||
num_partitions=num_partitions,
|
||||
num_sub_vectors=num_sub_vectors,
|
||||
num_bits=num_bits,
|
||||
)
|
||||
elif index_type == "IVF_RQ":
|
||||
config = IvfRq(
|
||||
distance_type=metric,
|
||||
num_partitions=num_partitions,
|
||||
num_bits=num_bits,
|
||||
)
|
||||
elif index_type == "IVF_SQ":
|
||||
config = IvfSq(distance_type=metric, num_partitions=num_partitions)
|
||||
elif index_type == "IVF_HNSW_PQ":
|
||||
raise ValueError(
|
||||
"IVF_HNSW_PQ is not supported on LanceDB cloud."
|
||||
"Please use IVF_HNSW_SQ instead."
|
||||
)
|
||||
elif index_type == "IVF_HNSW_SQ":
|
||||
config = HnswSq(distance_type=metric, num_partitions=num_partitions)
|
||||
elif index_type == "IVF_HNSW_FLAT":
|
||||
config = HnswFlat(distance_type=metric, num_partitions=num_partitions)
|
||||
elif index_type == "IVF_FLAT":
|
||||
config = IvfFlat(distance_type=metric, num_partitions=num_partitions)
|
||||
else:
|
||||
column = metric
|
||||
raise ValueError(
|
||||
f"Unknown vector index type: {index_type}. Valid options are"
|
||||
" 'IVF_FLAT', 'IVF_PQ', 'IVF_RQ', 'IVF_SQ',"
|
||||
" 'IVF_HNSW_PQ', 'IVF_HNSW_SQ', 'IVF_HNSW_FLAT'"
|
||||
)
|
||||
|
||||
LOOP.run(
|
||||
self._table.create_index(
|
||||
column,
|
||||
vector_column_name,
|
||||
config=config,
|
||||
wait_timeout=wait_timeout,
|
||||
name=name,
|
||||
@@ -475,37 +308,6 @@ class RemoteTable(Table):
|
||||
)
|
||||
)
|
||||
|
||||
def _is_legacy_create_index_call(
|
||||
self,
|
||||
first_arg: str,
|
||||
config: Optional[IndexConfigType],
|
||||
num_partitions: Optional[int],
|
||||
num_sub_vectors: Optional[int],
|
||||
vector_column_name: str,
|
||||
accelerator: Optional[str],
|
||||
index_cache_size: Optional[int],
|
||||
replace: Optional[bool],
|
||||
) -> bool:
|
||||
"""Detect if this is a legacy create_index call."""
|
||||
if config is not None:
|
||||
return False
|
||||
if any(
|
||||
x is not None
|
||||
for x in (
|
||||
num_partitions,
|
||||
num_sub_vectors,
|
||||
accelerator,
|
||||
index_cache_size,
|
||||
replace,
|
||||
)
|
||||
):
|
||||
return True
|
||||
if vector_column_name != VECTOR_COLUMN_NAME:
|
||||
return True
|
||||
if first_arg.lower() in KNOWN_METRICS:
|
||||
return True
|
||||
return False
|
||||
|
||||
def add(
|
||||
self,
|
||||
data: DATA,
|
||||
@@ -851,11 +653,6 @@ class RemoteTable(Table):
|
||||
) -> AlterColumnsResult:
|
||||
return LOOP.run(self._table.alter_columns(*alterations))
|
||||
|
||||
def update_field_metadata(
|
||||
self, *updates: dict[str, Any]
|
||||
) -> UpdateFieldMetadataResult:
|
||||
return LOOP.run(self._table.update_field_metadata(*updates))
|
||||
|
||||
def drop_columns(self, columns: Iterable[str]) -> DropColumnsResult:
|
||||
return LOOP.run(self._table.drop_columns(columns))
|
||||
|
||||
@@ -871,10 +668,6 @@ class RemoteTable(Table):
|
||||
"""Not supported on LanceDB Cloud."""
|
||||
return LOOP.run(self._table.unset_lsm_write_spec())
|
||||
|
||||
def close_lsm_writers(self) -> None:
|
||||
"""No-op on LanceDB Cloud (no local shard writers)."""
|
||||
return LOOP.run(self._table.close_lsm_writers())
|
||||
|
||||
def drop_index(self, index_name: str):
|
||||
return LOOP.run(self._table.drop_index(index_name))
|
||||
|
||||
|
||||
@@ -102,15 +102,8 @@ class LinearCombinationReranker(Reranker):
|
||||
|
||||
combined_list = []
|
||||
for row_id, result in results.items():
|
||||
# Convert vector distance to a relevance score in [0, 1] where
|
||||
# higher is better. Missing vector entries are penalised with
|
||||
# `_invert_score(fill)` = 1 - fill (= 0.0 for the default fill=1).
|
||||
vector_score = self._invert_score(result.get("_distance", fill))
|
||||
# FTS scores (BM25) are already in a "higher = more relevant" space.
|
||||
# Missing FTS entries are penalised symmetrically: we use
|
||||
# `1 - fill` so that the same `fill` value drives both missing-vector
|
||||
# and missing-FTS penalties in the same direction.
|
||||
fts_score = result.get("_score", 1 - fill)
|
||||
fts_score = result.get("_score", fill)
|
||||
result["_relevance_score"] = self._combine_score(vector_score, fts_score)
|
||||
combined_list.append(result)
|
||||
|
||||
@@ -130,12 +123,8 @@ class LinearCombinationReranker(Reranker):
|
||||
return tbl
|
||||
|
||||
def _combine_score(self, vector_score, fts_score):
|
||||
# Both vector_score (inverted distance) and fts_score are in a
|
||||
# "higher = more relevant" space. A straight weighted average gives
|
||||
# higher _relevance_score to better matches, as expected.
|
||||
# Previously this returned `1 - (...)` which inverted the final
|
||||
# ranking so that the *least* relevant document ranked first.
|
||||
return self.weight * vector_score + (1 - self.weight) * fts_score
|
||||
# these scores represent distance
|
||||
return 1 - (self.weight * vector_score + (1 - self.weight) * fts_score)
|
||||
|
||||
def _invert_score(self, dist: float):
|
||||
# Invert the score between relevance and distance
|
||||
|
||||
@@ -87,8 +87,6 @@ from .util import (
|
||||
)
|
||||
from .index import lang_mapping
|
||||
|
||||
BlobMode = Literal["lazy", "bytes", "descriptions"]
|
||||
|
||||
_MODEL_BACKED_TOKENIZER_PREFIXES = ("jieba", "lindera")
|
||||
_MODEL_BACKED_TOKENIZER_ERRORS = (
|
||||
"unknown base tokenizer",
|
||||
@@ -154,7 +152,6 @@ if TYPE_CHECKING:
|
||||
AddColumnsResult,
|
||||
AddResult,
|
||||
AlterColumnsResult,
|
||||
UpdateFieldMetadataResult,
|
||||
DeleteResult,
|
||||
DropColumnsResult,
|
||||
LsmWriteSpec,
|
||||
@@ -175,24 +172,6 @@ if TYPE_CHECKING:
|
||||
DistanceType,
|
||||
)
|
||||
|
||||
# Type alias for index configuration objects
|
||||
IndexConfigType = Union[
|
||||
IvfFlat,
|
||||
IvfPq,
|
||||
IvfSq,
|
||||
IvfRq,
|
||||
HnswFlat,
|
||||
HnswPq,
|
||||
HnswSq,
|
||||
BTree,
|
||||
Bitmap,
|
||||
LabelList,
|
||||
FTS,
|
||||
]
|
||||
|
||||
# Known distance metrics for legacy API detection
|
||||
KNOWN_METRICS = {"l2", "cosine", "dot", "hamming"}
|
||||
|
||||
|
||||
def _into_pyarrow_reader(
|
||||
data, schema: Optional[pa.Schema] = None
|
||||
@@ -758,15 +737,6 @@ class Table(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def branches(self) -> "Branches":
|
||||
"""Branch management for the table.
|
||||
|
||||
Branches are isolated, writable lines of history forked from another
|
||||
branch (or version). Writes on a branch do not affect ``main``.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""The number of rows in this Table"""
|
||||
return self.count_rows(None)
|
||||
@@ -790,22 +760,14 @@ class Table(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to_pandas(self, blob_mode: BlobMode = "lazy", **kwargs) -> "pandas.DataFrame":
|
||||
def to_pandas(self) -> "pandas.DataFrame":
|
||||
"""Return the table as a pandas DataFrame.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
blob_mode: str, default "lazy"
|
||||
Controls how blob columns are returned for backends that support
|
||||
Lance blob-aware pandas conversion.
|
||||
**kwargs
|
||||
Forwarded to PyArrow / Lance pandas conversion.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame
|
||||
"""
|
||||
return self.to_arrow().to_pandas(**kwargs)
|
||||
return self.to_arrow().to_pandas()
|
||||
|
||||
@abstractmethod
|
||||
def to_arrow(self) -> pa.Table:
|
||||
@@ -835,49 +797,11 @@ class Table(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# New unified API overload
|
||||
@overload
|
||||
def create_index(
|
||||
self,
|
||||
column: str,
|
||||
/,
|
||||
*,
|
||||
config: IndexConfigType,
|
||||
replace: bool = ...,
|
||||
wait_timeout: Optional[timedelta] = ...,
|
||||
name: Optional[str] = ...,
|
||||
train: bool = ...,
|
||||
) -> None: ...
|
||||
|
||||
# Legacy API overload (deprecated)
|
||||
@overload
|
||||
def create_index(
|
||||
self,
|
||||
metric: Literal["l2", "cosine", "dot", "hamming"] = ...,
|
||||
num_partitions: Optional[int] = ...,
|
||||
num_sub_vectors: Optional[int] = ...,
|
||||
vector_column_name: str = ...,
|
||||
replace: bool = ...,
|
||||
accelerator: Optional[str] = ...,
|
||||
index_cache_size: Optional[int] = ...,
|
||||
*,
|
||||
index_type: VectorIndexType = ...,
|
||||
wait_timeout: Optional[timedelta] = ...,
|
||||
num_bits: int = ...,
|
||||
max_iterations: int = ...,
|
||||
sample_rate: int = ...,
|
||||
m: int = ...,
|
||||
ef_construction: int = ...,
|
||||
name: Optional[str] = ...,
|
||||
train: bool = ...,
|
||||
target_partition_size: Optional[int] = ...,
|
||||
) -> None: ...
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
metric: DistanceType = "l2",
|
||||
num_partitions: Optional[int] = None,
|
||||
num_sub_vectors: Optional[int] = None,
|
||||
metric="l2",
|
||||
num_partitions=256,
|
||||
num_sub_vectors=96,
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
replace: bool = True,
|
||||
accelerator: Optional[str] = None,
|
||||
@@ -890,53 +814,46 @@ class Table(ABC):
|
||||
sample_rate: int = 256,
|
||||
m: int = 20,
|
||||
ef_construction: int = 300,
|
||||
config: Optional[IndexConfigType] = None,
|
||||
name: Optional[str] = None,
|
||||
train: bool = True,
|
||||
target_partition_size: Optional[int] = None,
|
||||
):
|
||||
"""Create an index on a column.
|
||||
|
||||
This method supports both the new unified API and the legacy API
|
||||
for backwards compatibility. The new API takes the column name as the
|
||||
first positional argument and an index configuration object via
|
||||
``config``; the legacy API takes the distance metric as the first
|
||||
argument plus separate ``vector_column_name`` / ``num_partitions`` /
|
||||
etc. parameters, and emits a ``DeprecationWarning``.
|
||||
"""Create an index on the table.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric : str
|
||||
For new API: the column name to index.
|
||||
For legacy API: the distance metric ("l2", "cosine", "dot", "hamming").
|
||||
config : IndexConfigType, optional
|
||||
The index configuration object. If provided, uses the new unified API.
|
||||
Can be one of: IvfFlat, IvfPq, IvfSq, IvfRq, HnswPq, HnswSq,
|
||||
BTree, Bitmap, LabelList, FTS.
|
||||
replace : bool, default True
|
||||
Whether to replace an existing index on this column.
|
||||
wait_timeout : timedelta, optional
|
||||
Timeout to wait for async indexing to complete.
|
||||
name : str, optional
|
||||
Custom name for the index.
|
||||
train : bool, default True
|
||||
Whether to train the index with existing data.
|
||||
metric: str, default "l2"
|
||||
The distance metric to use when creating the index.
|
||||
Valid values are "l2", "cosine", "dot", or "hamming".
|
||||
l2 is euclidean distance.
|
||||
Hamming is available only for binary vectors.
|
||||
num_partitions: int, default 256
|
||||
The number of IVF partitions to use when creating the index.
|
||||
Default is 256.
|
||||
num_sub_vectors: int, default 96
|
||||
The number of PQ sub-vectors to use when creating the index.
|
||||
Default is 96.
|
||||
vector_column_name: str, default "vector"
|
||||
The vector column name to create the index.
|
||||
replace: bool, default True
|
||||
- If True, replace the existing index if it exists.
|
||||
|
||||
Examples
|
||||
--------
|
||||
New API (recommended):
|
||||
|
||||
>>> table.create_index( # doctest: +SKIP
|
||||
... "vector", config=IvfPq(distance_type="l2")
|
||||
... )
|
||||
>>> table.create_index("category", config=BTree()) # doctest: +SKIP
|
||||
>>> table.create_index("content", config=FTS()) # doctest: +SKIP
|
||||
|
||||
Legacy API (deprecated):
|
||||
|
||||
>>> table.create_index( # doctest: +SKIP
|
||||
... "l2", vector_column_name="vector"
|
||||
... )
|
||||
- If False, raise an error if duplicate index exists.
|
||||
accelerator: str, default None
|
||||
If set, use the given accelerator to create the index.
|
||||
Only support "cuda" for now.
|
||||
index_cache_size : int, optional
|
||||
The size of the index cache in number of entries. Default value is 256.
|
||||
num_bits: int
|
||||
The number of bits to encode sub-vectors. Only used with the IVF_PQ index.
|
||||
Only 4 and 8 are supported.
|
||||
wait_timeout: timedelta, optional
|
||||
The timeout to wait if indexing is asynchronous.
|
||||
name: str, optional
|
||||
The name of the index. If not provided, a default name will be generated.
|
||||
train: bool, default True
|
||||
Whether to train the index with existing data. Vector indices always train
|
||||
with existing data.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1261,7 +1178,7 @@ class Table(ABC):
|
||||
... .when_not_matched_insert_all() \\
|
||||
... .execute(new_data)
|
||||
>>> res
|
||||
MergeResult(version=2, num_updated_rows=2, num_inserted_rows=1, num_deleted_rows=0, num_attempts=1, num_rows=3)
|
||||
MergeResult(version=2, num_updated_rows=2, num_inserted_rows=1, num_deleted_rows=0, num_attempts=1)
|
||||
>>> # The order of new rows is non-deterministic since we use
|
||||
>>> # a hash-join as part of this operation and so we sort here
|
||||
>>> table.to_arrow().sort_by("a").to_pandas()
|
||||
@@ -1809,29 +1726,6 @@ class Table(ABC):
|
||||
version: the new version number of the table after the alteration.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def update_field_metadata(
|
||||
self, *updates: dict[str, Any]
|
||||
) -> UpdateFieldMetadataResult:
|
||||
"""
|
||||
Update per-field (column) metadata.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
updates : dict
|
||||
One or more dicts, each with:
|
||||
- "path": str — dot-path to the field (e.g. "embedding" or "a.b.c").
|
||||
- "metadata": dict[str, str | None] — keys to set; a value of ``None``
|
||||
deletes that key.
|
||||
- "replace": bool, optional — replace the field's whole metadata map
|
||||
instead of merging (default False).
|
||||
|
||||
Returns
|
||||
-------
|
||||
UpdateFieldMetadataResult
|
||||
version: the new table version after the update.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def drop_columns(self, columns: Iterable[str]) -> DropColumnsResult:
|
||||
"""
|
||||
@@ -2176,15 +2070,6 @@ class LanceTable(Table):
|
||||
"""
|
||||
return Tags(self._table)
|
||||
|
||||
@property
|
||||
def branches(self) -> "Branches":
|
||||
"""Branch management for the table.
|
||||
|
||||
``create``/``checkout`` return a new table handle scoped to the branch;
|
||||
writes on it do not affect ``main``.
|
||||
"""
|
||||
return Branches(self._table)
|
||||
|
||||
def checkout(self, version: Union[int, str]):
|
||||
"""Checkout a version of the table. This is an in-place operation.
|
||||
|
||||
@@ -2283,7 +2168,7 @@ class LanceTable(Table):
|
||||
return LOOP.run(self._table.count_rows(filter))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
val = f"{self.__class__.__name__}(name={self.name!r}"
|
||||
val = f"{self.__class__.__name__}(name={self.name!r}, version={self.version}"
|
||||
if self._conn.read_consistency_interval is not None:
|
||||
val += ", read_consistency_interval={!r}".format(
|
||||
self._conn.read_consistency_interval
|
||||
@@ -2298,27 +2183,14 @@ class LanceTable(Table):
|
||||
"""Return the first n rows of the table."""
|
||||
return LOOP.run(self._table.head(n))
|
||||
|
||||
def to_pandas(self, blob_mode: BlobMode = "lazy", **kwargs) -> "pd.DataFrame":
|
||||
def to_pandas(self) -> "pd.DataFrame":
|
||||
"""Return the table as a pandas DataFrame.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
blob_mode: str, default "lazy"
|
||||
Controls how Lance blob columns are returned.
|
||||
**kwargs
|
||||
Forwarded to Lance pandas conversion.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame
|
||||
"""
|
||||
if blob_mode == "lazy" and (
|
||||
self._namespace_client is not None
|
||||
or get_uri_scheme(self._dataset_path) == "memory"
|
||||
):
|
||||
return self.to_arrow().to_pandas(**kwargs)
|
||||
|
||||
return self.to_lance().to_pandas(blob_mode=blob_mode, **kwargs)
|
||||
return self.to_arrow().to_pandas()
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
"""Return the table as a pyarrow Table.
|
||||
@@ -2355,51 +2227,11 @@ class LanceTable(Table):
|
||||
dataset, allow_pyarrow_filter=False, batch_size=batch_size
|
||||
)
|
||||
|
||||
# New unified API overload
|
||||
@overload
|
||||
def create_index(
|
||||
self,
|
||||
column: str,
|
||||
/,
|
||||
*,
|
||||
config: IndexConfigType,
|
||||
replace: bool = ...,
|
||||
wait_timeout: Optional[timedelta] = ...,
|
||||
name: Optional[str] = ...,
|
||||
train: bool = ...,
|
||||
) -> None: ...
|
||||
|
||||
# Legacy API overload (deprecated)
|
||||
@overload
|
||||
def create_index(
|
||||
self,
|
||||
metric: Literal["l2", "cosine", "dot", "hamming"] = ...,
|
||||
num_partitions: Optional[int] = ...,
|
||||
num_sub_vectors: Optional[int] = ...,
|
||||
vector_column_name: str = ...,
|
||||
replace: bool = ...,
|
||||
accelerator: Optional[str] = ...,
|
||||
index_cache_size: Optional[int] = ...,
|
||||
num_bits: int = ...,
|
||||
index_type: Literal[
|
||||
"IVF_FLAT", "IVF_SQ", "IVF_PQ", "IVF_RQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"
|
||||
] = ...,
|
||||
max_iterations: int = ...,
|
||||
sample_rate: int = ...,
|
||||
m: int = ...,
|
||||
ef_construction: int = ...,
|
||||
*,
|
||||
wait_timeout: Optional[timedelta] = ...,
|
||||
name: Optional[str] = ...,
|
||||
train: bool = ...,
|
||||
target_partition_size: Optional[int] = ...,
|
||||
) -> None: ...
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
metric: str = "l2",
|
||||
num_partitions: Optional[int] = None,
|
||||
num_sub_vectors: Optional[int] = None,
|
||||
metric: DistanceType = "l2",
|
||||
num_partitions=None,
|
||||
num_sub_vectors=None,
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
replace: bool = True,
|
||||
accelerator: Optional[str] = None,
|
||||
@@ -2419,232 +2251,47 @@ class LanceTable(Table):
|
||||
m: int = 20,
|
||||
ef_construction: int = 300,
|
||||
*,
|
||||
config: Optional[IndexConfigType] = None,
|
||||
wait_timeout: Optional[timedelta] = None,
|
||||
name: Optional[str] = None,
|
||||
train: bool = True,
|
||||
target_partition_size: Optional[int] = None,
|
||||
):
|
||||
"""Create an index on a column.
|
||||
|
||||
This method supports both the new unified API and the legacy API
|
||||
for backwards compatibility. The new API takes the column name as the
|
||||
first positional argument and an index configuration object via
|
||||
``config``; the legacy API takes the distance metric as the first
|
||||
argument plus separate ``vector_column_name`` / ``num_partitions`` /
|
||||
etc. parameters, and emits a ``DeprecationWarning``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric : str
|
||||
For new API: the column name to index.
|
||||
For legacy API: the distance metric ("l2", "cosine", "dot", "hamming").
|
||||
config : IndexConfigType, optional
|
||||
The index configuration object. If provided, uses the new unified API.
|
||||
Can be one of: IvfFlat, IvfPq, IvfSq, IvfRq, HnswPq, HnswSq,
|
||||
BTree, Bitmap, LabelList, FTS.
|
||||
replace : bool, default True
|
||||
Whether to replace an existing index on this column.
|
||||
wait_timeout : timedelta, optional
|
||||
Timeout to wait for async indexing to complete.
|
||||
name : str, optional
|
||||
Custom name for the index.
|
||||
train : bool, default True
|
||||
Whether to train the index with existing data.
|
||||
|
||||
Examples
|
||||
--------
|
||||
New API (recommended):
|
||||
|
||||
>>> table.create_index( # doctest: +SKIP
|
||||
... "vector", config=IvfPq(distance_type="l2")
|
||||
... )
|
||||
>>> table.create_index("category", config=BTree()) # doctest: +SKIP
|
||||
>>> table.create_index("content", config=FTS()) # doctest: +SKIP
|
||||
|
||||
Legacy API (deprecated):
|
||||
|
||||
>>> table.create_index( # doctest: +SKIP
|
||||
... "l2", vector_column_name="vector"
|
||||
... )
|
||||
"""
|
||||
# Detect whether this is a legacy API call
|
||||
is_legacy = self._is_legacy_create_index_call(
|
||||
metric,
|
||||
config,
|
||||
num_partitions,
|
||||
num_sub_vectors,
|
||||
vector_column_name,
|
||||
accelerator,
|
||||
index_cache_size,
|
||||
)
|
||||
|
||||
if is_legacy:
|
||||
warnings.warn(
|
||||
"The create_index() API with metric/num_partitions parameters is "
|
||||
"deprecated and will be removed in a future version. "
|
||||
"Please migrate to the new unified API:\n"
|
||||
" # Old (deprecated):\n"
|
||||
" table.create_index('l2', vector_column_name='my_vector')\n"
|
||||
" # New (recommended):\n"
|
||||
" table.create_index('my_vector', config=IvfPq(distance_type='l2'))",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# Legacy API: first arg is the distance metric
|
||||
column = vector_column_name
|
||||
|
||||
# Build config from legacy parameters
|
||||
config = self._build_vector_config_from_legacy_params(
|
||||
metric=metric,
|
||||
"""Create an index on the table."""
|
||||
if accelerator is not None:
|
||||
# accelerator is only supported through pylance.
|
||||
self.to_lance().create_index(
|
||||
column=vector_column_name,
|
||||
index_type=index_type,
|
||||
metric=metric,
|
||||
num_partitions=num_partitions,
|
||||
num_sub_vectors=num_sub_vectors,
|
||||
replace=replace,
|
||||
accelerator=accelerator,
|
||||
index_cache_size=index_cache_size,
|
||||
num_bits=num_bits,
|
||||
max_iterations=max_iterations,
|
||||
sample_rate=sample_rate,
|
||||
m=m,
|
||||
ef_construction=ef_construction,
|
||||
target_partition_size=target_partition_size,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
|
||||
# Handle accelerator through pylance
|
||||
if accelerator is not None:
|
||||
self.to_lance().create_index(
|
||||
column=column,
|
||||
index_type=index_type,
|
||||
metric=metric,
|
||||
num_partitions=num_partitions,
|
||||
num_sub_vectors=num_sub_vectors,
|
||||
replace=replace,
|
||||
accelerator=accelerator,
|
||||
index_cache_size=index_cache_size,
|
||||
num_bits=num_bits,
|
||||
m=m,
|
||||
ef_construction=ef_construction,
|
||||
target_partition_size=target_partition_size,
|
||||
)
|
||||
self.checkout_latest()
|
||||
return
|
||||
else:
|
||||
# New API: metric is the column name
|
||||
column = metric
|
||||
|
||||
# Check if config has accelerator set and dispatch to pylance
|
||||
if config is not None and hasattr(config, "accelerator"):
|
||||
acc = getattr(config, "accelerator", None)
|
||||
if acc is not None:
|
||||
# Dispatch to pylance for GPU acceleration
|
||||
index_type_map = {
|
||||
"IvfFlat": "IVF_FLAT",
|
||||
"IvfSq": "IVF_SQ",
|
||||
"IvfPq": "IVF_PQ",
|
||||
"IvfRq": "IVF_RQ",
|
||||
"HnswPq": "IVF_HNSW_PQ",
|
||||
"HnswSq": "IVF_HNSW_SQ",
|
||||
}
|
||||
cfg_type = type(config).__name__
|
||||
lance_index_type = index_type_map.get(cfg_type, "IVF_PQ")
|
||||
|
||||
self.to_lance().create_index(
|
||||
column=column,
|
||||
index_type=lance_index_type,
|
||||
metric=getattr(config, "distance_type", "l2"),
|
||||
num_partitions=getattr(config, "num_partitions", None),
|
||||
num_sub_vectors=getattr(config, "num_sub_vectors", None),
|
||||
replace=replace,
|
||||
accelerator=acc,
|
||||
num_bits=getattr(config, "num_bits", 8),
|
||||
m=getattr(config, "m", 20),
|
||||
ef_construction=getattr(config, "ef_construction", 300),
|
||||
target_partition_size=getattr(
|
||||
config, "target_partition_size", None
|
||||
),
|
||||
)
|
||||
self.checkout_latest()
|
||||
return
|
||||
|
||||
return LOOP.run(
|
||||
self._table.create_index(
|
||||
column,
|
||||
replace=replace,
|
||||
config=config,
|
||||
wait_timeout=wait_timeout,
|
||||
name=name,
|
||||
train=train,
|
||||
)
|
||||
)
|
||||
|
||||
def _is_legacy_create_index_call(
|
||||
self,
|
||||
first_arg: str,
|
||||
config: Optional[IndexConfigType],
|
||||
num_partitions: Optional[int],
|
||||
num_sub_vectors: Optional[int],
|
||||
vector_column_name: str,
|
||||
accelerator: Optional[str],
|
||||
index_cache_size: Optional[int],
|
||||
) -> bool:
|
||||
"""Detect if this is a legacy create_index call."""
|
||||
# If config is provided, it's definitely the new API
|
||||
if config is not None:
|
||||
return False
|
||||
|
||||
# If old-style parameters were explicitly set, it's legacy
|
||||
if any(
|
||||
x is not None
|
||||
for x in (num_partitions, num_sub_vectors, accelerator, index_cache_size)
|
||||
):
|
||||
return True
|
||||
|
||||
# If vector_column_name differs from default, it's legacy
|
||||
if vector_column_name != VECTOR_COLUMN_NAME:
|
||||
return True
|
||||
|
||||
# If first arg is a known metric, assume legacy
|
||||
if first_arg.lower() in KNOWN_METRICS:
|
||||
return True
|
||||
|
||||
# Otherwise assume new API
|
||||
return False
|
||||
|
||||
def _build_vector_config_from_legacy_params(
|
||||
self,
|
||||
metric: str,
|
||||
index_type: str,
|
||||
num_partitions: Optional[int],
|
||||
num_sub_vectors: Optional[int],
|
||||
num_bits: int,
|
||||
max_iterations: int,
|
||||
sample_rate: int,
|
||||
m: int,
|
||||
ef_construction: int,
|
||||
target_partition_size: Optional[int],
|
||||
accelerator: Optional[str],
|
||||
) -> IndexConfigType:
|
||||
"""Build an index config object from legacy parameters."""
|
||||
if index_type == "IVF_FLAT":
|
||||
return IvfFlat(
|
||||
self.checkout_latest()
|
||||
return
|
||||
elif index_type == "IVF_FLAT":
|
||||
config = IvfFlat(
|
||||
distance_type=metric,
|
||||
num_partitions=num_partitions,
|
||||
max_iterations=max_iterations,
|
||||
sample_rate=sample_rate,
|
||||
target_partition_size=target_partition_size,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
elif index_type == "IVF_SQ":
|
||||
return IvfSq(
|
||||
config = IvfSq(
|
||||
distance_type=metric,
|
||||
num_partitions=num_partitions,
|
||||
max_iterations=max_iterations,
|
||||
sample_rate=sample_rate,
|
||||
target_partition_size=target_partition_size,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
elif index_type == "IVF_PQ":
|
||||
return IvfPq(
|
||||
config = IvfPq(
|
||||
distance_type=metric,
|
||||
num_partitions=num_partitions,
|
||||
num_sub_vectors=num_sub_vectors,
|
||||
@@ -2652,20 +2299,18 @@ class LanceTable(Table):
|
||||
max_iterations=max_iterations,
|
||||
sample_rate=sample_rate,
|
||||
target_partition_size=target_partition_size,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
elif index_type == "IVF_RQ":
|
||||
return IvfRq(
|
||||
config = IvfRq(
|
||||
distance_type=metric,
|
||||
num_partitions=num_partitions,
|
||||
num_bits=num_bits,
|
||||
max_iterations=max_iterations,
|
||||
sample_rate=sample_rate,
|
||||
target_partition_size=target_partition_size,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
elif index_type == "IVF_HNSW_PQ":
|
||||
return HnswPq(
|
||||
config = HnswPq(
|
||||
distance_type=metric,
|
||||
num_partitions=num_partitions,
|
||||
num_sub_vectors=num_sub_vectors,
|
||||
@@ -2675,10 +2320,9 @@ class LanceTable(Table):
|
||||
m=m,
|
||||
ef_construction=ef_construction,
|
||||
target_partition_size=target_partition_size,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
elif index_type == "IVF_HNSW_SQ":
|
||||
return HnswSq(
|
||||
config = HnswSq(
|
||||
distance_type=metric,
|
||||
num_partitions=num_partitions,
|
||||
max_iterations=max_iterations,
|
||||
@@ -2686,10 +2330,9 @@ class LanceTable(Table):
|
||||
m=m,
|
||||
ef_construction=ef_construction,
|
||||
target_partition_size=target_partition_size,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
elif index_type == "IVF_HNSW_FLAT":
|
||||
return HnswFlat(
|
||||
config = HnswFlat(
|
||||
distance_type=metric,
|
||||
num_partitions=num_partitions,
|
||||
max_iterations=max_iterations,
|
||||
@@ -2701,6 +2344,16 @@ class LanceTable(Table):
|
||||
else:
|
||||
raise ValueError(f"Unknown index type {index_type}")
|
||||
|
||||
return LOOP.run(
|
||||
self._table.create_index(
|
||||
vector_column_name,
|
||||
replace=replace,
|
||||
config=config,
|
||||
name=name,
|
||||
train=train,
|
||||
)
|
||||
)
|
||||
|
||||
def drop_index(self, name: str) -> None:
|
||||
"""
|
||||
Drops an index from the table
|
||||
@@ -2800,11 +2453,6 @@ class LanceTable(Table):
|
||||
"""
|
||||
return LOOP.run(self._table.latest_storage_options())
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.25.0",
|
||||
current_version=__version__,
|
||||
details="Use create_index() with config=BTree()/Bitmap()/LabelList() instead.",
|
||||
)
|
||||
def create_scalar_index(
|
||||
self,
|
||||
column: str,
|
||||
@@ -2813,12 +2461,6 @@ class LanceTable(Table):
|
||||
index_type: ScalarIndexType = "BTREE",
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""Create a scalar index on a column.
|
||||
|
||||
.. deprecated:: 0.25.0
|
||||
Use :meth:`create_index` with a BTree, Bitmap, or LabelList config instead.
|
||||
Example: ``table.create_index("column", config=BTree())``
|
||||
"""
|
||||
if index_type == "BTREE":
|
||||
config = BTree()
|
||||
elif index_type == "BITMAP":
|
||||
@@ -2831,11 +2473,6 @@ class LanceTable(Table):
|
||||
self._table.create_index(column, replace=replace, config=config, name=name)
|
||||
)
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.25.0",
|
||||
current_version=__version__,
|
||||
details="Use create_index() with config=FTS() instead.",
|
||||
)
|
||||
def create_fts_index(
|
||||
self,
|
||||
field_names: Union[str, List[str]],
|
||||
@@ -2859,12 +2496,6 @@ class LanceTable(Table):
|
||||
prefix_only: bool = False,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""Create a full-text search index on a column.
|
||||
|
||||
.. deprecated:: 0.25.0
|
||||
Use :meth:`create_index` with an FTS config instead.
|
||||
Example: ``table.create_index("text_column", config=FTS())``
|
||||
"""
|
||||
self._ensure_no_legacy_fts_index()
|
||||
|
||||
if use_tantivy:
|
||||
@@ -2888,6 +2519,11 @@ class LanceTable(Table):
|
||||
"at a time. To search over multiple text fields, create a "
|
||||
"separate FTS index for each field."
|
||||
)
|
||||
if "." in field_names:
|
||||
raise ValueError(
|
||||
"Native FTS indexes can only be created on top-level fields. "
|
||||
f"Received nested field path: {field_names!r}."
|
||||
)
|
||||
|
||||
if tokenizer_name is None:
|
||||
tokenizer_configs = {
|
||||
@@ -3625,11 +3261,6 @@ class LanceTable(Table):
|
||||
) -> AlterColumnsResult:
|
||||
return LOOP.run(self._table.alter_columns(*alterations))
|
||||
|
||||
def update_field_metadata(
|
||||
self, *updates: dict[str, Any]
|
||||
) -> UpdateFieldMetadataResult:
|
||||
return LOOP.run(self._table.update_field_metadata(*updates))
|
||||
|
||||
def drop_columns(self, columns: Iterable[str]) -> DropColumnsResult:
|
||||
return LOOP.run(self._table.drop_columns(columns))
|
||||
|
||||
@@ -3648,11 +3279,6 @@ class LanceTable(Table):
|
||||
[`AsyncTable.unset_lsm_write_spec`][lancedb.AsyncTable.unset_lsm_write_spec]."""
|
||||
return LOOP.run(self._table.unset_lsm_write_spec())
|
||||
|
||||
def close_lsm_writers(self) -> None:
|
||||
"""Close cached MemWAL shard writers. See
|
||||
[`AsyncTable.close_lsm_writers`][lancedb.AsyncTable.close_lsm_writers]."""
|
||||
return LOOP.run(self._table.close_lsm_writers())
|
||||
|
||||
def uses_v2_manifest_paths(self) -> bool:
|
||||
"""
|
||||
Check if the table is using the new v2 manifest paths.
|
||||
@@ -3684,18 +3310,10 @@ class LanceTable(Table):
|
||||
"""
|
||||
LOOP.run(self._table.migrate_v2_manifest_paths())
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.33.1",
|
||||
current_version=__version__,
|
||||
details="Use update_field_metadata() instead.",
|
||||
)
|
||||
def replace_field_metadata(self, field_name: str, new_metadata: Dict[str, str]):
|
||||
"""
|
||||
Replace the metadata of a field in the schema
|
||||
|
||||
.. deprecated:: 0.33.1
|
||||
Use :func:`update_field_metadata` instead.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
field_name: str
|
||||
@@ -4269,16 +3887,6 @@ class AsyncTable:
|
||||
"""
|
||||
await self._inner.unset_lsm_write_spec()
|
||||
|
||||
async def close_lsm_writers(self) -> None:
|
||||
"""Drain and close any cached MemWAL shard writers for this table.
|
||||
|
||||
When an LSM write spec is installed, `merge_insert` opens MemWAL shard
|
||||
writers and caches them for reuse across calls. This closes them,
|
||||
flushing pending data; writers reopen lazily on the next
|
||||
`merge_insert`. It is a no-op when no writers are cached.
|
||||
"""
|
||||
await self._inner.close_lsm_writers()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""The name of the table."""
|
||||
@@ -4337,39 +3945,14 @@ class AsyncTable:
|
||||
"""
|
||||
return AsyncQuery(self._inner.query())
|
||||
|
||||
async def _to_lance(self, **kwargs) -> lance.LanceDataset:
|
||||
try:
|
||||
import lance
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The lance library is required to use this function. "
|
||||
"Please install with `pip install pylance`."
|
||||
)
|
||||
|
||||
return lance.dataset(
|
||||
await self.uri(),
|
||||
version=await self.version(),
|
||||
storage_options=await self.latest_storage_options(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def to_pandas(self, blob_mode: BlobMode = "lazy", **kwargs) -> "pd.DataFrame":
|
||||
async def to_pandas(self) -> "pd.DataFrame":
|
||||
"""Return the table as a pandas DataFrame.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
blob_mode: str, default "lazy"
|
||||
Controls how Lance blob columns are returned.
|
||||
**kwargs
|
||||
Forwarded to PyArrow / Lance pandas conversion.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame
|
||||
"""
|
||||
if blob_mode == "lazy":
|
||||
return (await self.to_arrow()).to_pandas(**kwargs)
|
||||
return (await self._to_lance()).to_pandas(blob_mode=blob_mode, **kwargs)
|
||||
return (await self.to_arrow()).to_pandas()
|
||||
|
||||
async def to_arrow(self) -> pa.Table:
|
||||
"""Return the table as a pyarrow Table.
|
||||
@@ -4729,7 +4312,7 @@ class AsyncTable:
|
||||
... .when_not_matched_insert_all() \\
|
||||
... .execute(new_data)
|
||||
>>> res
|
||||
MergeResult(version=2, num_updated_rows=2, num_inserted_rows=1, num_deleted_rows=0, num_attempts=1, num_rows=3)
|
||||
MergeResult(version=2, num_updated_rows=2, num_inserted_rows=1, num_deleted_rows=0, num_attempts=1)
|
||||
>>> # The order of new rows is non-deterministic since we use
|
||||
>>> # a hash-join as part of this operation and so we sort here
|
||||
>>> table.to_arrow().sort_by("a").to_pandas()
|
||||
@@ -5109,8 +4692,6 @@ class AsyncTable:
|
||||
when_not_matched_by_source_condition=merge._when_not_matched_by_source_condition,
|
||||
timeout=merge._timeout,
|
||||
use_index=merge._use_index,
|
||||
use_lsm_write=merge._use_lsm_write,
|
||||
validate_single_shard=merge._validate_single_shard,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -5289,13 +4870,6 @@ class AsyncTable:
|
||||
"""
|
||||
return await self._inner.alter_columns(alterations)
|
||||
|
||||
async def update_field_metadata(
|
||||
self, *updates: dict[str, Any]
|
||||
) -> UpdateFieldMetadataResult:
|
||||
"""Update per-field metadata. See
|
||||
[`Table.update_field_metadata`][lancedb.table.Table.update_field_metadata]."""
|
||||
return await self._inner.update_field_metadata(updates)
|
||||
|
||||
async def drop_columns(self, columns: Iterable[str]):
|
||||
"""
|
||||
Drop columns from the table.
|
||||
@@ -5460,15 +5034,6 @@ class AsyncTable:
|
||||
"""
|
||||
return AsyncTags(self._inner)
|
||||
|
||||
@property
|
||||
def branches(self) -> AsyncBranches:
|
||||
"""Branch management for the table.
|
||||
|
||||
Branches are isolated, writable lines of history forked from another
|
||||
branch (or version). Writes on a branch do not affect ``main``.
|
||||
"""
|
||||
return AsyncBranches(self._inner)
|
||||
|
||||
async def optimize(
|
||||
self,
|
||||
*,
|
||||
@@ -5589,20 +5154,12 @@ class AsyncTable:
|
||||
"""
|
||||
await self._inner.migrate_manifest_paths_v2()
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.33.1",
|
||||
current_version=__version__,
|
||||
details="Use update_field_metadata() instead.",
|
||||
)
|
||||
async def replace_field_metadata(
|
||||
self, field_name: str, new_metadata: dict[str, str]
|
||||
):
|
||||
"""
|
||||
Replace the metadata of a field in the schema
|
||||
|
||||
.. deprecated:: 0.33.1
|
||||
Use :func:`update_field_metadata` instead.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
field_name: str
|
||||
@@ -5804,50 +5361,6 @@ class Tags:
|
||||
LOOP.run(self._table.tags.update(tag, version))
|
||||
|
||||
|
||||
class Branches:
|
||||
"""
|
||||
Table branch manager.
|
||||
"""
|
||||
|
||||
def __init__(self, table):
|
||||
self._table = table
|
||||
|
||||
def list(self) -> Dict[str, Any]:
|
||||
"""List all branches, mapping name to branch metadata."""
|
||||
return LOOP.run(self._table.branches.list())
|
||||
|
||||
def create(
|
||||
self,
|
||||
name: str,
|
||||
from_ref: Optional[str] = None,
|
||||
from_version: Optional[int] = None,
|
||||
) -> "LanceTable":
|
||||
"""Create a branch and return a handle scoped to it.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
Name of the new branch.
|
||||
from_ref: str, optional
|
||||
Source branch to fork from. Defaults to ``main``.
|
||||
from_version: int, optional
|
||||
A specific version on ``from_ref`` to fork from. Defaults to latest.
|
||||
"""
|
||||
async_table = LOOP.run(
|
||||
self._table.branches.create(name, from_ref, from_version)
|
||||
)
|
||||
return LanceTable.from_inner(async_table._inner)
|
||||
|
||||
def checkout(self, name: str) -> "LanceTable":
|
||||
"""Check out an existing branch and return a handle scoped to it."""
|
||||
async_table = LOOP.run(self._table.branches.checkout(name))
|
||||
return LanceTable.from_inner(async_table._inner)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
"""Delete a branch."""
|
||||
LOOP.run(self._table.branches.delete(name))
|
||||
|
||||
|
||||
class AsyncTags:
|
||||
"""
|
||||
Async table tag manager.
|
||||
@@ -5915,47 +5428,3 @@ class AsyncTags:
|
||||
The new table version to tag.
|
||||
"""
|
||||
await self._table.tags.update(tag, version)
|
||||
|
||||
|
||||
class AsyncBranches:
|
||||
"""Async table branch manager."""
|
||||
|
||||
def __init__(self, table):
|
||||
self._table = table
|
||||
|
||||
async def list(self) -> Dict[str, Any]:
|
||||
"""List all branches, mapping name to branch metadata."""
|
||||
return await self._table.branches.list()
|
||||
|
||||
async def create(
|
||||
self,
|
||||
name: str,
|
||||
from_ref: Optional[str] = None,
|
||||
from_version: Optional[int] = None,
|
||||
) -> "AsyncTable":
|
||||
"""Create a branch and return a handle scoped to it.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
Name of the new branch.
|
||||
from_ref: str, optional
|
||||
Source branch to fork from. Defaults to ``main``.
|
||||
from_version: int, optional
|
||||
A specific version on ``from_ref`` to fork from. Defaults to latest.
|
||||
"""
|
||||
# "main" and None are two spellings of the root branch in lance; normalize
|
||||
# so from_ref="main" behaves identically to the default.
|
||||
if from_ref == "main":
|
||||
from_ref = None
|
||||
inner = await self._table.branches.create(name, from_ref, from_version)
|
||||
return AsyncTable(inner)
|
||||
|
||||
async def checkout(self, name: str) -> "AsyncTable":
|
||||
"""Check out an existing branch and return a handle scoped to it."""
|
||||
inner = await self._table.branches.checkout(name)
|
||||
return AsyncTable(inner)
|
||||
|
||||
async def delete(self, name: str) -> None:
|
||||
"""Delete a branch."""
|
||||
await self._table.branches.delete(name)
|
||||
|
||||
@@ -10,7 +10,7 @@ import pathlib
|
||||
import warnings
|
||||
from datetime import date, datetime
|
||||
from functools import singledispatch
|
||||
from typing import Tuple, Union, Optional, Any, List
|
||||
from typing import Tuple, Union, Optional, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
@@ -189,33 +189,7 @@ def flatten_columns(tbl: pa.Table, flatten: Optional[Union[int, bool]] = None):
|
||||
return tbl
|
||||
|
||||
|
||||
def _format_field_path(path: List[str]) -> str:
|
||||
def format_segment(segment: str) -> str:
|
||||
if all(char.isalnum() or char == "_" for char in segment):
|
||||
return segment
|
||||
return f"`{segment.replace('`', '``')}`"
|
||||
|
||||
return ".".join(format_segment(segment) for segment in path)
|
||||
|
||||
|
||||
def _iter_vector_columns(
|
||||
field: pa.Field, path: List[str], dim: Optional[int] = None
|
||||
) -> List[str]:
|
||||
field_path = [*path, field.name]
|
||||
if is_vector_column(field.type):
|
||||
vector_dim = infer_vector_column_dim(field.type)
|
||||
if dim is None or vector_dim == dim:
|
||||
return [_format_field_path(field_path)]
|
||||
return []
|
||||
if pa.types.is_struct(field.type):
|
||||
columns = []
|
||||
for idx in range(field.type.num_fields):
|
||||
columns.extend(_iter_vector_columns(field.type.field(idx), field_path, dim))
|
||||
return columns
|
||||
return []
|
||||
|
||||
|
||||
def inf_vector_column_query(schema: pa.Schema, dim: Optional[int] = None) -> str:
|
||||
def inf_vector_column_query(schema: pa.Schema) -> str:
|
||||
"""
|
||||
Get the vector column name
|
||||
|
||||
@@ -228,21 +202,26 @@ def inf_vector_column_query(schema: pa.Schema, dim: Optional[int] = None) -> str
|
||||
-------
|
||||
str: the vector column name.
|
||||
"""
|
||||
vector_col_names = []
|
||||
for field in schema:
|
||||
vector_col_names.extend(_iter_vector_columns(field, [], dim))
|
||||
if len(vector_col_names) > 1:
|
||||
raise ValueError(
|
||||
"Schema has more than one vector column. "
|
||||
"Please specify the vector column name "
|
||||
f"for vector search. Candidates: {vector_col_names}"
|
||||
)
|
||||
if len(vector_col_names) == 0:
|
||||
vector_col_name = ""
|
||||
vector_col_count = 0
|
||||
for field_name in schema.names:
|
||||
field = schema.field(field_name)
|
||||
if is_vector_column(field.type):
|
||||
vector_col_count += 1
|
||||
if vector_col_count > 1:
|
||||
raise ValueError(
|
||||
"Schema has more than one vector column. "
|
||||
"Please specify the vector column name "
|
||||
"for vector search"
|
||||
)
|
||||
elif vector_col_count == 1:
|
||||
vector_col_name = field_name
|
||||
if vector_col_count == 0:
|
||||
raise ValueError(
|
||||
"There is no vector column in the data. "
|
||||
"Please specify the vector column name for vector search"
|
||||
)
|
||||
return vector_col_names[0]
|
||||
return vector_col_name
|
||||
|
||||
|
||||
def is_vector_column(data_type: pa.DataType) -> bool:
|
||||
@@ -268,29 +247,6 @@ def is_vector_column(data_type: pa.DataType) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def infer_vector_column_dim(data_type: pa.DataType) -> Optional[int]:
|
||||
if pa.types.is_fixed_size_list(data_type):
|
||||
return data_type.list_size
|
||||
if pa.types.is_list(data_type):
|
||||
return infer_vector_column_dim(data_type.value_type)
|
||||
return None
|
||||
|
||||
|
||||
def _query_vector_dim(query: Optional[Any]) -> Optional[int]:
|
||||
if query is None:
|
||||
return None
|
||||
if isinstance(query, np.ndarray):
|
||||
if query.ndim == 0:
|
||||
return None
|
||||
return query.shape[-1]
|
||||
if isinstance(query, list) and query:
|
||||
first = query[0]
|
||||
if isinstance(first, (list, tuple, np.ndarray)):
|
||||
return len(first)
|
||||
return len(query)
|
||||
return None
|
||||
|
||||
|
||||
def infer_vector_column_name(
|
||||
schema: pa.Schema,
|
||||
query_type: str,
|
||||
@@ -306,9 +262,7 @@ def infer_vector_column_name(
|
||||
|
||||
if query is not None or query_type == "hybrid":
|
||||
try:
|
||||
vector_column_name = inf_vector_column_query(
|
||||
schema, dim=_query_vector_dim(query)
|
||||
)
|
||||
vector_column_name = inf_vector_column_query(schema)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ async def test_upsert_async(mem_db_async):
|
||||
await table.count_rows() # 3
|
||||
res
|
||||
# MergeResult(version=2, num_updated_rows=1,
|
||||
# num_inserted_rows=1, num_deleted_rows=0, num_rows=2)
|
||||
# num_inserted_rows=1, num_deleted_rows=0)
|
||||
# --8<-- [end:upsert_basic_async]
|
||||
assert await table.count_rows() == 3
|
||||
assert res.version == 2
|
||||
@@ -86,7 +86,7 @@ def test_insert_if_not_exists(mem_db):
|
||||
table.count_rows() # 3
|
||||
res
|
||||
# MergeResult(version=2, num_updated_rows=0,
|
||||
# num_inserted_rows=1, num_deleted_rows=0, num_rows=1)
|
||||
# num_inserted_rows=1, num_deleted_rows=0)
|
||||
# --8<-- [end:insert_if_not_exists]
|
||||
assert table.count_rows() == 3
|
||||
assert res.version == 2
|
||||
@@ -116,7 +116,7 @@ async def test_insert_if_not_exists_async(mem_db_async):
|
||||
await table.count_rows() # 3
|
||||
res
|
||||
# MergeResult(version=2, num_updated_rows=0,
|
||||
# num_inserted_rows=1, num_deleted_rows=0, num_rows=1)
|
||||
# num_inserted_rows=1, num_deleted_rows=0)
|
||||
# --8<-- [end:insert_if_not_exists]
|
||||
assert await table.count_rows() == 3
|
||||
assert res.version == 2
|
||||
@@ -150,7 +150,7 @@ def test_replace_range(mem_db):
|
||||
table.count_rows("doc_id = 1") # 1
|
||||
res
|
||||
# MergeResult(version=2, num_updated_rows=1,
|
||||
# num_inserted_rows=0, num_deleted_rows=1, num_rows=1)
|
||||
# num_inserted_rows=0, num_deleted_rows=1)
|
||||
# --8<-- [end:insert_if_not_exists]
|
||||
assert table.count_rows("doc_id = 1") == 1
|
||||
assert res.version == 2
|
||||
@@ -185,7 +185,7 @@ async def test_replace_range_async(mem_db_async):
|
||||
await table.count_rows("doc_id = 1") # 1
|
||||
res
|
||||
# MergeResult(version=2, num_updated_rows=1,
|
||||
# num_inserted_rows=0, num_deleted_rows=1, num_rows=1)
|
||||
# num_inserted_rows=0, num_deleted_rows=1)
|
||||
# --8<-- [end:insert_if_not_exists]
|
||||
assert await table.count_rows("doc_id = 1") == 1
|
||||
assert res.version == 2
|
||||
|
||||
@@ -6,7 +6,6 @@ import re
|
||||
import sys
|
||||
from datetime import timedelta
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
|
||||
import lancedb
|
||||
import numpy as np
|
||||
@@ -189,43 +188,6 @@ def test_table_names(tmp_db: lancedb.DBConnection):
|
||||
assert len(result) == 3
|
||||
|
||||
|
||||
def test_db_contains_and_len_include_all_table_name_pages(tmp_db: lancedb.DBConnection):
|
||||
for idx in range(20):
|
||||
tmp_db.create_table(f"table_{idx}", data=[{"id": idx}])
|
||||
|
||||
assert len(tmp_db) == 20
|
||||
for idx in range(20):
|
||||
assert f"table_{idx}" in tmp_db
|
||||
assert "does_not_exist" not in tmp_db
|
||||
|
||||
|
||||
def test_db_contains_stops_after_matching_table_page(
|
||||
tmp_db: lancedb.DBConnection, monkeypatch
|
||||
):
|
||||
calls = []
|
||||
pages = {
|
||||
None: SimpleNamespace(tables=["table_0", "table_1"], page_token="next"),
|
||||
"next": SimpleNamespace(tables=["table_2"], page_token=None),
|
||||
}
|
||||
|
||||
def list_tables(*, page_token=None, **_kwargs):
|
||||
calls.append(page_token)
|
||||
return pages[page_token]
|
||||
|
||||
monkeypatch.setattr(tmp_db, "list_tables", list_tables)
|
||||
|
||||
assert "table_1" in tmp_db
|
||||
assert calls == [None]
|
||||
|
||||
calls.clear()
|
||||
assert "table_2" in tmp_db
|
||||
assert calls == [None, "next"]
|
||||
|
||||
calls.clear()
|
||||
assert len(tmp_db) == 3
|
||||
assert calls == [None, "next"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_names_async(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
@@ -466,8 +428,7 @@ async def test_create_table_v2_manifest_paths_async(tmp_path):
|
||||
assert await tbl.uses_v2_manifest_paths()
|
||||
manifests_dir = tmp_path / "test_v2_manifest_paths.lance" / "_versions"
|
||||
for manifest in os.listdir(manifests_dir):
|
||||
if manifest.endswith(".manifest"):
|
||||
assert re.match(r"\d{20}\.manifest", manifest)
|
||||
assert re.match(r"\d{20}\.manifest", manifest)
|
||||
|
||||
# Start a table in V1 mode then migrate
|
||||
tbl = await db_no_v2_paths.create_table(
|
||||
@@ -477,15 +438,13 @@ async def test_create_table_v2_manifest_paths_async(tmp_path):
|
||||
assert not await tbl.uses_v2_manifest_paths()
|
||||
manifests_dir = tmp_path / "test_v2_migration.lance" / "_versions"
|
||||
for manifest in os.listdir(manifests_dir):
|
||||
if manifest.endswith(".manifest"):
|
||||
assert re.match(r"\d\.manifest", manifest)
|
||||
assert re.match(r"\d\.manifest", manifest)
|
||||
|
||||
await tbl.migrate_manifest_paths_v2()
|
||||
assert await tbl.uses_v2_manifest_paths()
|
||||
|
||||
for manifest in os.listdir(manifests_dir):
|
||||
if manifest.endswith(".manifest"):
|
||||
assert re.match(r"\d{20}\.manifest", manifest)
|
||||
assert re.match(r"\d{20}\.manifest", manifest)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -215,12 +215,11 @@ def test_reject_legacy_tantivy_index(table):
|
||||
|
||||
@pytest.mark.parametrize("with_position", [True, False])
|
||||
def test_create_inverted_index(table, with_position):
|
||||
with pytest.warns(DeprecationWarning, match="create_fts_index"):
|
||||
table.create_fts_index(
|
||||
"text",
|
||||
with_position=with_position,
|
||||
name="custom_fts_index",
|
||||
)
|
||||
table.create_fts_index(
|
||||
"text",
|
||||
with_position=with_position,
|
||||
name="custom_fts_index",
|
||||
)
|
||||
indices = table.list_indices()
|
||||
fts_indices = [i for i in indices if i.index_type == "FTS"]
|
||||
assert any(i.name == "custom_fts_index" for i in fts_indices)
|
||||
@@ -564,111 +563,8 @@ def test_create_index_multiple_columns(tmp_path, table):
|
||||
|
||||
|
||||
def test_nested_schema(tmp_path, table):
|
||||
table.create_fts_index("nested.text", with_position=True)
|
||||
indices = table.list_indices()
|
||||
assert len(indices) == 1
|
||||
assert indices[0].index_type == "FTS"
|
||||
assert indices[0].columns == ["nested.text"]
|
||||
|
||||
results = (
|
||||
table.search("puppy", query_type="fts", fts_columns="nested.text")
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(results) > 0
|
||||
assert all("puppy" in row["nested"]["text"] for row in results)
|
||||
|
||||
results = table.search(MatchQuery("puppy", "nested.text")).limit(5).to_list()
|
||||
assert len(results) > 0
|
||||
assert all("puppy" in row["nested"]["text"] for row in results)
|
||||
|
||||
phrase_results = (
|
||||
table.search(PhraseQuery("puppy runs", "nested.text")).limit(5).to_list()
|
||||
)
|
||||
assert len(phrase_results) > 0
|
||||
assert all("puppy runs" in row["nested"]["text"] for row in phrase_results)
|
||||
|
||||
hybrid_results = (
|
||||
table.search(query_type="hybrid", fts_columns="nested.text")
|
||||
.vector([0 for _ in range(128)])
|
||||
.text("puppy")
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(hybrid_results) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nested_schema_async(async_table):
|
||||
await async_table.create_index("nested.text", config=FTS(with_position=True))
|
||||
indices = await async_table.list_indices()
|
||||
assert len(indices) == 1
|
||||
assert indices[0].index_type == "FTS"
|
||||
assert indices[0].columns == ["nested.text"]
|
||||
|
||||
results = await (
|
||||
async_table.query()
|
||||
.nearest_to_text("puppy", columns="nested.text")
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(results) > 0
|
||||
assert all("puppy" in row["nested"]["text"] for row in results)
|
||||
|
||||
results = await (
|
||||
async_table.query()
|
||||
.nearest_to_text(MatchQuery("puppy", "nested.text"))
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(results) > 0
|
||||
assert all("puppy" in row["nested"]["text"] for row in results)
|
||||
|
||||
phrase_results = await (
|
||||
async_table.query()
|
||||
.nearest_to_text(PhraseQuery("puppy runs", "nested.text"))
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(phrase_results) > 0
|
||||
assert all("puppy runs" in row["nested"]["text"] for row in phrase_results)
|
||||
|
||||
hybrid_results = await (
|
||||
async_table.query()
|
||||
.nearest_to([0 for _ in range(128)])
|
||||
.nearest_to_text("puppy", columns="nested.text")
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(hybrid_results) > 0
|
||||
|
||||
|
||||
def test_nested_schema_rejects_invalid_fts_fields(tmp_path):
|
||||
db = ldb.connect(tmp_path)
|
||||
data = pa.table(
|
||||
{
|
||||
"payload": pa.array(
|
||||
[
|
||||
{"text": "puppy runs", "count": 1},
|
||||
{"text": "car drives", "count": 2},
|
||||
]
|
||||
),
|
||||
"vector": pa.array(
|
||||
[[0.1, 0.1], [0.2, 0.2]],
|
||||
type=pa.list_(pa.float32(), list_size=2),
|
||||
),
|
||||
}
|
||||
)
|
||||
table = db.create_table("test", data=data)
|
||||
|
||||
with pytest.raises(ValueError, match="FTS index cannot be created.*payload"):
|
||||
table.create_fts_index("payload")
|
||||
|
||||
with pytest.raises(ValueError, match="FTS index cannot be created.*count"):
|
||||
table.create_fts_index("payload.count")
|
||||
|
||||
with pytest.raises(ValueError, match="Field path `payload.missing` not found"):
|
||||
table.create_fts_index("payload.missing")
|
||||
with pytest.raises(ValueError, match="top-level fields"):
|
||||
table.create_fts_index("nested.text")
|
||||
|
||||
|
||||
def test_search_index_with_filter(table):
|
||||
|
||||
@@ -105,46 +105,6 @@ async def test_create_scalar_index(some_table: AsyncTable):
|
||||
assert len(indices) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_nested_scalar_index_lists_canonical_paths(db_async):
|
||||
metadata_type = pa.struct(
|
||||
[
|
||||
pa.field("user_id", pa.int32()),
|
||||
pa.field("user.id", pa.int32()),
|
||||
]
|
||||
)
|
||||
data = pa.Table.from_arrays(
|
||||
[
|
||||
pa.array([1, 2, 3], type=pa.int32()),
|
||||
pa.array(
|
||||
[
|
||||
{"user_id": 10, "user.id": 100},
|
||||
{"user_id": 20, "user.id": 200},
|
||||
{"user_id": 30, "user.id": 300},
|
||||
],
|
||||
type=metadata_type,
|
||||
),
|
||||
],
|
||||
names=["user_id", "metadata"],
|
||||
)
|
||||
table = await db_async.create_table("nested_scalar_index", data)
|
||||
|
||||
await table.create_index("user_id", config=BTree(), name="top_user_id_idx")
|
||||
await table.create_index(
|
||||
"metadata.user_id", config=BTree(), name="nested_user_id_idx"
|
||||
)
|
||||
await table.create_index(
|
||||
"metadata.`user.id`", config=BTree(), name="escaped_user_id_idx"
|
||||
)
|
||||
|
||||
columns_by_name = {
|
||||
index.name: index.columns for index in await table.list_indices()
|
||||
}
|
||||
assert columns_by_name["top_user_id_idx"] == ["user_id"]
|
||||
assert columns_by_name["nested_user_id_idx"] == ["metadata.user_id"]
|
||||
assert columns_by_name["escaped_user_id_idx"] == ["metadata.`user.id`"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_fixed_size_binary_index(some_table: AsyncTable):
|
||||
await some_table.create_index("fsb", config=BTree())
|
||||
@@ -162,13 +122,12 @@ async def test_create_bitmap_index(some_table: AsyncTable):
|
||||
await some_table.create_index("data", config=Bitmap())
|
||||
indices = await some_table.list_indices()
|
||||
assert len(indices) == 3
|
||||
# list_indices returns indices in alphabetical order by name
|
||||
assert indices[0].index_type == "Bitmap"
|
||||
assert indices[0].columns == ["data"]
|
||||
assert indices[0].columns == ["id"]
|
||||
assert indices[1].index_type == "Bitmap"
|
||||
assert indices[1].columns == ["id"]
|
||||
assert indices[1].columns == ["is_active"]
|
||||
assert indices[2].index_type == "Bitmap"
|
||||
assert indices[2].columns == ["is_active"]
|
||||
assert indices[2].columns == ["data"]
|
||||
|
||||
index_name = indices[0].name
|
||||
stats = await some_table.index_stats(index_name)
|
||||
|
||||
@@ -40,6 +40,16 @@ def _make_table(tmp_path):
|
||||
def test_set_lsm_write_spec_validates(tmp_path):
|
||||
_db, table = _make_table(tmp_path)
|
||||
|
||||
# No PK set yet.
|
||||
with pytest.raises(Exception, match="primary key"):
|
||||
table.set_lsm_write_spec(LsmWriteSpec.bucket("id", 4))
|
||||
|
||||
table.set_unenforced_primary_key("id")
|
||||
|
||||
# Column mismatch.
|
||||
with pytest.raises(Exception, match="match"):
|
||||
table.set_lsm_write_spec(LsmWriteSpec.bucket("v", 4))
|
||||
|
||||
# Out-of-range num_buckets.
|
||||
with pytest.raises(Exception, match="num_buckets"):
|
||||
table.set_lsm_write_spec(LsmWriteSpec.bucket("id", 0))
|
||||
@@ -60,6 +70,7 @@ def test_unset_lsm_write_spec(tmp_path):
|
||||
table.unset_lsm_write_spec()
|
||||
|
||||
# Install a spec, then remove it; afterwards a fresh spec can be set.
|
||||
table.set_unenforced_primary_key("id")
|
||||
table.set_lsm_write_spec(LsmWriteSpec.bucket("id", 4))
|
||||
table.unset_lsm_write_spec()
|
||||
# A second unset errors — there is no spec left to remove.
|
||||
|
||||
@@ -1,196 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
"""Tests for the MemWAL LSM ``merge_insert`` dispatch."""
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
import lancedb
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from lancedb._lancedb import LsmWriteSpec
|
||||
|
||||
SCHEMA = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64(), nullable=False),
|
||||
pa.field("value", pa.int64(), nullable=False),
|
||||
]
|
||||
)
|
||||
|
||||
REGION_SCHEMA = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64(), nullable=False),
|
||||
pa.field("region", pa.utf8(), nullable=False),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _reader(ids):
|
||||
batch = pa.RecordBatch.from_arrays(
|
||||
[
|
||||
pa.array(ids, type=pa.int64()),
|
||||
pa.array(list(range(len(ids))), type=pa.int64()),
|
||||
],
|
||||
schema=SCHEMA,
|
||||
)
|
||||
return pa.RecordBatchReader.from_batches(SCHEMA, [batch])
|
||||
|
||||
|
||||
def _region_reader(rows):
|
||||
batch = pa.RecordBatch.from_arrays(
|
||||
[
|
||||
pa.array([row[0] for row in rows], type=pa.int64()),
|
||||
pa.array([row[1] for row in rows], type=pa.utf8()),
|
||||
],
|
||||
schema=REGION_SCHEMA,
|
||||
)
|
||||
return pa.RecordBatchReader.from_batches(REGION_SCHEMA, [batch])
|
||||
|
||||
|
||||
def _bucket_table(tmp_path):
|
||||
"""A table with ``id`` as the primary key and a single-bucket LSM spec."""
|
||||
db = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0))
|
||||
table = db.create_table("t", _reader([1, 2, 3]))
|
||||
table.set_unenforced_primary_key("id")
|
||||
# num_buckets = 1: every row routes to the single bucket.
|
||||
table.set_lsm_write_spec(LsmWriteSpec.bucket("id", 1))
|
||||
return table
|
||||
|
||||
|
||||
def test_lsm_merge_insert_bucket(tmp_path):
|
||||
table = _bucket_table(tmp_path)
|
||||
# Empty `on` defaults to the primary key.
|
||||
result = (
|
||||
table.merge_insert([])
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.execute(_reader([3, 4, 5]))
|
||||
)
|
||||
# LSM path: rows go to the MemWAL, so only num_rows is populated.
|
||||
assert result.num_rows == 3
|
||||
assert result.version == 0
|
||||
assert result.num_inserted_rows == 0
|
||||
assert result.num_updated_rows == 0
|
||||
|
||||
|
||||
def test_lsm_merge_insert_unsharded(tmp_path):
|
||||
db = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0))
|
||||
table = db.create_table("t", _reader([1, 2, 3]))
|
||||
table.set_unenforced_primary_key("id")
|
||||
table.set_lsm_write_spec(LsmWriteSpec.unsharded())
|
||||
result = (
|
||||
table.merge_insert("id")
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.execute(_reader([10, 11, 12, 13]))
|
||||
)
|
||||
assert result.num_rows == 4
|
||||
|
||||
|
||||
def test_lsm_merge_insert_identity(tmp_path):
|
||||
db = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0))
|
||||
table = db.create_table("t", _region_reader([(1, "us"), (2, "us")]))
|
||||
table.set_unenforced_primary_key("id")
|
||||
table.set_lsm_write_spec(LsmWriteSpec.identity("region"))
|
||||
# All rows share one identity value, so they route to one shard.
|
||||
result = (
|
||||
table.merge_insert([])
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.execute(_region_reader([(3, "us"), (4, "us")]))
|
||||
)
|
||||
assert result.num_rows == 2
|
||||
|
||||
|
||||
def test_lsm_merge_insert_use_lsm_write_false(tmp_path):
|
||||
table = _bucket_table(tmp_path) # rows id = 1, 2, 3
|
||||
# use_lsm_write(False) opts out: the standard path runs and commits.
|
||||
result = (
|
||||
table.merge_insert("id")
|
||||
.when_not_matched_insert_all()
|
||||
.use_lsm_write(False)
|
||||
.execute(_reader([3, 4, 5]))
|
||||
)
|
||||
assert result.num_inserted_rows == 2
|
||||
assert table.count_rows() == 5
|
||||
|
||||
|
||||
def test_lsm_merge_insert_validate_single_shard_off(tmp_path):
|
||||
table = _bucket_table(tmp_path)
|
||||
result = (
|
||||
table.merge_insert([])
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.validate_single_shard(False)
|
||||
.execute(_reader([6, 7, 8]))
|
||||
)
|
||||
assert result.num_rows == 3
|
||||
|
||||
|
||||
def test_lsm_merge_insert_use_lsm_write_true_requires_spec(tmp_path):
|
||||
# A table with a primary key but no LSM write spec installed.
|
||||
db = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0))
|
||||
table = db.create_table("t", _reader([1, 2, 3]))
|
||||
table.set_unenforced_primary_key("id")
|
||||
with pytest.raises(Exception, match="use_lsm_write"):
|
||||
(
|
||||
table.merge_insert("id")
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.use_lsm_write(True)
|
||||
.execute(_reader([4]))
|
||||
)
|
||||
|
||||
|
||||
def test_lsm_merge_insert_rejects_on_not_primary_key(tmp_path):
|
||||
table = _bucket_table(tmp_path)
|
||||
with pytest.raises(Exception, match="primary key"):
|
||||
(
|
||||
table.merge_insert("value")
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.execute(_reader([1]))
|
||||
)
|
||||
|
||||
|
||||
def test_lsm_merge_insert_rejects_non_upsert(tmp_path):
|
||||
table = _bucket_table(tmp_path)
|
||||
# Insert-only (no when_matched_update_all) is not the upsert shape.
|
||||
with pytest.raises(Exception, match="upsert"):
|
||||
table.merge_insert([]).when_not_matched_insert_all().execute(_reader([4]))
|
||||
|
||||
|
||||
def test_lsm_close_writers(tmp_path):
|
||||
table = _bucket_table(tmp_path)
|
||||
(
|
||||
table.merge_insert([])
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.execute(_reader([7, 8]))
|
||||
)
|
||||
table.close_lsm_writers()
|
||||
# The writer reopens lazily on the next merge_insert.
|
||||
result = (
|
||||
table.merge_insert([])
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.execute(_reader([9]))
|
||||
)
|
||||
assert result.num_rows == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_lsm_merge_insert(tmp_path):
|
||||
db = await lancedb.connect_async(
|
||||
tmp_path, read_consistency_interval=timedelta(seconds=0)
|
||||
)
|
||||
table = await db.create_table("t", _reader([1, 2, 3]))
|
||||
await table.set_unenforced_primary_key("id")
|
||||
await table.set_lsm_write_spec(LsmWriteSpec.bucket("id", 1))
|
||||
|
||||
builder = (
|
||||
table.merge_insert([]).when_matched_update_all().when_not_matched_insert_all()
|
||||
)
|
||||
result = await builder.execute(_reader([3, 4, 5]))
|
||||
assert result.num_rows == 3
|
||||
await table.close_lsm_writers()
|
||||
@@ -165,22 +165,6 @@ def test_offset(table):
|
||||
assert len(results_with_offset.to_pandas()) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_to_pandas_kwargs(table, table_async):
|
||||
sync_df = (
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||
.select(["id"])
|
||||
.limit(1)
|
||||
.to_pandas(split_blocks=True)
|
||||
)
|
||||
assert sync_df["id"].tolist() == [1]
|
||||
|
||||
async_df = await (
|
||||
table_async.query().select(["id"]).limit(2).to_pandas(split_blocks=True)
|
||||
)
|
||||
assert async_df["id"].tolist() == [1, 2]
|
||||
|
||||
|
||||
def test_order_by_plain_query(mem_db):
|
||||
table = mem_db.create_table(
|
||||
"test_order_by",
|
||||
@@ -1512,37 +1496,6 @@ def test_take_queries(tmp_path):
|
||||
]
|
||||
|
||||
|
||||
def test_take_queries_to_batches(tmp_path):
|
||||
# Regression test for the sync take-query path: `to_batches` previously
|
||||
# raised ``AttributeError: 'AsyncTakeQuery' object has no attribute
|
||||
# 'execute'`` because the inherited ``BaseQueryBuilder.to_batches`` called
|
||||
# ``execute`` on the async wrapper instead of the native query.
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pa.table({"idx": list(range(100)), "label": [str(i) for i in range(100)]})
|
||||
table = db.create_table("test", data)
|
||||
|
||||
# Take by offset → to_batches
|
||||
rs = list(table.take_offsets([5, 2, 17]).to_batches())
|
||||
assert all(isinstance(b, pa.RecordBatch) for b in rs)
|
||||
assert sum(b.num_rows for b in rs) == 3
|
||||
assert sorted(v for b in rs for v in b.column("idx").to_pylist()) == [2, 5, 17]
|
||||
|
||||
# Take by row id → to_batches
|
||||
rs = list(table.take_row_ids([5, 2, 17]).to_batches())
|
||||
assert all(isinstance(b, pa.RecordBatch) for b in rs)
|
||||
assert sum(b.num_rows for b in rs) == 3
|
||||
assert sorted(v for b in rs for v in b.column("idx").to_pylist()) == [2, 5, 17]
|
||||
|
||||
# Take with select projection → to_batches preserves the projection
|
||||
rs = list(table.take_row_ids([5, 2, 17]).select(["label"]).to_batches())
|
||||
assert all(b.schema.names == ["label"] for b in rs)
|
||||
assert sorted(v for b in rs for v in b.column("label").to_pylist()) == [
|
||||
"17",
|
||||
"2",
|
||||
"5",
|
||||
]
|
||||
|
||||
|
||||
def test_getitems(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pa.table(
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import contextlib
|
||||
from datetime import timedelta
|
||||
import http.server
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import pickle
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
@@ -172,155 +171,6 @@ def test_table_len_sync():
|
||||
assert len(table) == 1
|
||||
|
||||
|
||||
def test_remote_connection_serializes():
|
||||
def handler(request):
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b'{"tables": []}')
|
||||
|
||||
with mock_lancedb_connection(handler) as db:
|
||||
serialized = json.loads(db.serialize())
|
||||
assert isinstance(serialized["client_config"], dict)
|
||||
restored = lancedb.deserialize_conn(db.serialize())
|
||||
assert restored.table_names() == []
|
||||
|
||||
|
||||
def test_remote_table_is_picklable():
|
||||
def handler(request):
|
||||
request.close_connection = True
|
||||
if request.path == "/v1/table/test/describe/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
payload = json.dumps(
|
||||
{
|
||||
"version": 1,
|
||||
"schema": {
|
||||
"fields": [
|
||||
{"name": "id", "type": {"type": "int64"}, "nullable": False}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
request.wfile.write(payload.encode())
|
||||
elif request.path == "/v1/table/test/count_rows/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b"3")
|
||||
else:
|
||||
request.send_response(404)
|
||||
request.end_headers()
|
||||
|
||||
with mock_lancedb_connection(handler) as db:
|
||||
table = db.open_table("test")
|
||||
restored = pickle.loads(pickle.dumps(table))
|
||||
assert restored.count_rows() == 3
|
||||
|
||||
|
||||
def test_remote_table_open_does_not_require_picklable_client_config():
|
||||
from lancedb.remote import HeaderProvider
|
||||
|
||||
class LocalHeaderProvider(HeaderProvider):
|
||||
def get_headers(self):
|
||||
return {"X-Test-Header": "present"}
|
||||
|
||||
def handler(request):
|
||||
request.close_connection = True
|
||||
assert request.headers.get("X-Test-Header") == "present"
|
||||
if request.path == "/v1/table/test/describe/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b'{"version": 1, "schema": {"fields": []}}')
|
||||
elif request.path == "/v1/table/test/count_rows/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b"3")
|
||||
else:
|
||||
request.send_response(404)
|
||||
request.end_headers()
|
||||
|
||||
with http.server.HTTPServer(
|
||||
("localhost", 0), make_mock_http_handler(handler)
|
||||
) as server:
|
||||
port = server.server_address[1]
|
||||
handle = threading.Thread(target=server.serve_forever)
|
||||
handle.start()
|
||||
try:
|
||||
db = lancedb.connect(
|
||||
"db://dev",
|
||||
api_key="fake",
|
||||
host_override=f"http://localhost:{port}",
|
||||
client_config={
|
||||
"retry_config": {"retries": 0},
|
||||
"timeout_config": {"connect_timeout": 2, "read_timeout": 2},
|
||||
"header_provider": LocalHeaderProvider(),
|
||||
},
|
||||
)
|
||||
table = db.open_table("test")
|
||||
assert table.count_rows() == 3
|
||||
with pytest.raises(ValueError, match="header_provider"):
|
||||
pickle.dumps(table)
|
||||
finally:
|
||||
server.shutdown()
|
||||
handle.join()
|
||||
|
||||
|
||||
def test_remote_permutation_is_picklable():
|
||||
from lancedb.permutation import Permutation
|
||||
|
||||
rows = list(range(10))
|
||||
|
||||
def handler(request):
|
||||
request.close_connection = True
|
||||
if request.path == "/v1/table/test/describe/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
payload = json.dumps(
|
||||
{
|
||||
"version": 1,
|
||||
"schema": {
|
||||
"fields": [
|
||||
{"name": "a", "type": {"type": "int64"}, "nullable": False}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
request.wfile.write(payload.encode())
|
||||
elif request.path == "/v1/table/test/count_rows/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(str(len(rows)).encode())
|
||||
elif request.path == "/v1/table/test/query/":
|
||||
content_len = int(request.headers.get("Content-Length"))
|
||||
body = json.loads(request.rfile.read(content_len))
|
||||
if "filter" in body:
|
||||
match = re.search(r"_rowoffset in \((.*?)\)", body["filter"])
|
||||
offsets = [int(offset.strip()) for offset in match.group(1).split(",")]
|
||||
else:
|
||||
offsets = rows
|
||||
table = pa.table({"a": [rows[offset] for offset in offsets]})
|
||||
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/vnd.apache.arrow.file")
|
||||
request.end_headers()
|
||||
with pa.ipc.new_file(request.wfile, schema=table.schema) as writer:
|
||||
writer.write_table(table)
|
||||
else:
|
||||
request.send_response(404)
|
||||
request.end_headers()
|
||||
|
||||
with mock_lancedb_connection(handler) as db:
|
||||
permutation = Permutation.identity(db.open_table("test"))
|
||||
restored = pickle.loads(pickle.dumps(permutation))
|
||||
assert restored.__getitems__([0, 2, 4]) == [{"a": 0}, {"a": 2}, {"a": 4}]
|
||||
|
||||
|
||||
def test_create_table_exist_ok():
|
||||
def handler(request):
|
||||
if request.path == "/v1/table/test/create/?mode=exist_ok":
|
||||
@@ -419,25 +269,6 @@ def test_table_unimplemented_functions():
|
||||
table.to_pandas()
|
||||
|
||||
|
||||
def test_table_to_pandas_not_supported():
|
||||
def handler(request):
|
||||
if request.path == "/v1/table/test/create/?mode=create":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b"{}")
|
||||
else:
|
||||
request.send_response(404)
|
||||
request.end_headers()
|
||||
|
||||
with mock_lancedb_connection(handler) as db:
|
||||
table = db.create_table("test", [{"id": 1}])
|
||||
with pytest.raises(NotImplementedError):
|
||||
table.to_pandas()
|
||||
with pytest.raises(NotImplementedError):
|
||||
table.to_pandas(blob_mode="bytes", split_blocks=True)
|
||||
|
||||
|
||||
def test_table_add_in_threadpool():
|
||||
def handler(request):
|
||||
if request.path == "/v1/table/test/insert/":
|
||||
@@ -512,22 +343,6 @@ def test_table_create_indices():
|
||||
schema=dict(
|
||||
fields=[
|
||||
dict(name="id", type={"type": "int64"}, nullable=False),
|
||||
dict(name="text", type={"type": "string"}, nullable=False),
|
||||
dict(
|
||||
name="vector",
|
||||
type={
|
||||
"type": "fixed_size_list",
|
||||
"fields": [
|
||||
dict(
|
||||
name="item",
|
||||
type={"type": "float"},
|
||||
nullable=True,
|
||||
)
|
||||
],
|
||||
"length": 2,
|
||||
},
|
||||
nullable=False,
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
@@ -586,25 +401,22 @@ def test_table_create_indices():
|
||||
# This is a smoke-test.
|
||||
table = db.create_table("test", [{"id": 1}])
|
||||
|
||||
# Test create_scalar_index with custom name (legacy method)
|
||||
with pytest.warns(DeprecationWarning, match="create_scalar_index"):
|
||||
table.create_scalar_index(
|
||||
"id", wait_timeout=timedelta(seconds=2), name="custom_scalar_idx"
|
||||
)
|
||||
# Test create_scalar_index with custom name
|
||||
table.create_scalar_index(
|
||||
"id", wait_timeout=timedelta(seconds=2), name="custom_scalar_idx"
|
||||
)
|
||||
|
||||
# Test create_fts_index with custom name (legacy method)
|
||||
with pytest.warns(DeprecationWarning, match="create_fts_index"):
|
||||
table.create_fts_index(
|
||||
"text", wait_timeout=timedelta(seconds=2), name="custom_fts_idx"
|
||||
)
|
||||
# Test create_fts_index with custom name
|
||||
table.create_fts_index(
|
||||
"text", wait_timeout=timedelta(seconds=2), name="custom_fts_idx"
|
||||
)
|
||||
|
||||
# Test create_index with custom name (legacy form: vector_column_name kwarg)
|
||||
with pytest.warns(DeprecationWarning, match="create_index"):
|
||||
table.create_index(
|
||||
vector_column_name="vector",
|
||||
wait_timeout=timedelta(seconds=10),
|
||||
name="custom_vector_idx",
|
||||
)
|
||||
# Test create_index with custom name
|
||||
table.create_index(
|
||||
vector_column_name="vector",
|
||||
wait_timeout=timedelta(seconds=10),
|
||||
name="custom_vector_idx",
|
||||
)
|
||||
|
||||
# Validate that the name parameter was passed correctly in requests
|
||||
assert len(received_requests) == 3
|
||||
@@ -633,98 +445,6 @@ def test_table_create_indices():
|
||||
table.drop_index("custom_fts_idx")
|
||||
|
||||
|
||||
def test_remote_create_index_new_api():
|
||||
received_requests = []
|
||||
|
||||
def handler(request):
|
||||
if request.path == "/v1/table/test/create_index/":
|
||||
content_len = int(request.headers.get("Content-Length", 0))
|
||||
body = request.rfile.read(content_len) if content_len > 0 else b""
|
||||
received_requests.append(json.loads(body) if body else {})
|
||||
request.send_response(200)
|
||||
request.end_headers()
|
||||
elif request.path == "/v1/table/test/create/?mode=create":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b"{}")
|
||||
elif request.path == "/v1/table/test/describe/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(
|
||||
json.dumps(
|
||||
dict(
|
||||
version=1,
|
||||
schema=dict(
|
||||
fields=[
|
||||
dict(name="id", type={"type": "int64"}, nullable=False),
|
||||
dict(
|
||||
name="category",
|
||||
type={"type": "string"},
|
||||
nullable=False,
|
||||
),
|
||||
dict(
|
||||
name="text", type={"type": "string"}, nullable=False
|
||||
),
|
||||
dict(
|
||||
name="vector",
|
||||
type={
|
||||
"type": "fixed_size_list",
|
||||
"fields": [
|
||||
dict(
|
||||
name="item",
|
||||
type={"type": "float"},
|
||||
nullable=True,
|
||||
)
|
||||
],
|
||||
"length": 2,
|
||||
},
|
||||
nullable=False,
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
).encode()
|
||||
)
|
||||
else:
|
||||
request.send_response(404)
|
||||
request.end_headers()
|
||||
|
||||
from lancedb.index import BTree, FTS, IvfPq, IvfRq
|
||||
|
||||
with mock_lancedb_connection(handler) as db:
|
||||
table = db.create_table("test", [{"id": 1}])
|
||||
|
||||
# New API: column-first, config= kwarg. Should NOT emit DeprecationWarning.
|
||||
import warnings as _warnings
|
||||
|
||||
with _warnings.catch_warnings():
|
||||
_warnings.simplefilter("error", DeprecationWarning)
|
||||
table.create_index("vector", config=IvfPq(distance_type="l2"))
|
||||
table.create_index("category", config=BTree())
|
||||
table.create_index("text", config=FTS())
|
||||
# IvfRq via new API
|
||||
table.create_index("vector", config=IvfRq(distance_type="l2"))
|
||||
|
||||
# Legacy index_type="IVF_RQ" routes to IvfRq config under the hood.
|
||||
with pytest.warns(DeprecationWarning, match="create_index"):
|
||||
table.create_index(
|
||||
vector_column_name="vector",
|
||||
index_type="IVF_RQ",
|
||||
num_partitions=8,
|
||||
)
|
||||
|
||||
assert len(received_requests) == 5
|
||||
assert [req["column"] for req in received_requests] == [
|
||||
"vector",
|
||||
"category",
|
||||
"text",
|
||||
"vector",
|
||||
"vector",
|
||||
]
|
||||
|
||||
|
||||
def test_table_wait_for_index_timeout():
|
||||
def handler(request):
|
||||
index_stats = dict(
|
||||
@@ -1550,10 +1270,6 @@ def _remote_fork_child(port: int, queue) -> None:
|
||||
queue.put(db.table_names())
|
||||
|
||||
|
||||
def _remote_table_fork_child(table, queue) -> None:
|
||||
queue.put(table.count_rows())
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform != "linux",
|
||||
reason=(
|
||||
@@ -1616,65 +1332,3 @@ def test_remote_connection_after_fork():
|
||||
finally:
|
||||
server.shutdown()
|
||||
server_thread.join()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform != "linux",
|
||||
reason=(
|
||||
"fork() is unavailable on Windows and unsafe on macOS "
|
||||
"(Apple frameworks/TLS are not fork-safe)"
|
||||
),
|
||||
)
|
||||
def test_inherited_remote_table_reopens_after_fork():
|
||||
def handler(request):
|
||||
if request.path == "/v1/table/test/describe/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b'{"version": 1, "schema": {"fields": []}}')
|
||||
elif request.path == "/v1/table/test/count_rows/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b"7")
|
||||
else:
|
||||
request.send_response(404)
|
||||
request.end_headers()
|
||||
|
||||
server = http.server.HTTPServer(("localhost", 0), make_mock_http_handler(handler))
|
||||
port = server.server_address[1]
|
||||
server_thread = threading.Thread(target=server.serve_forever)
|
||||
server_thread.start()
|
||||
try:
|
||||
db = lancedb.connect(
|
||||
"db://dev",
|
||||
api_key="fake",
|
||||
host_override=f"http://localhost:{port}",
|
||||
client_config={
|
||||
"retry_config": {"retries": 0},
|
||||
"timeout_config": {"connect_timeout": 2, "read_timeout": 2},
|
||||
},
|
||||
)
|
||||
table = db.open_table("test")
|
||||
assert table.count_rows() == 7
|
||||
|
||||
ctx = mp.get_context("fork")
|
||||
queue = ctx.Queue()
|
||||
proc = ctx.Process(target=_remote_table_fork_child, args=(table, queue))
|
||||
proc.start()
|
||||
proc.join(timeout=15)
|
||||
|
||||
if proc.is_alive():
|
||||
proc.terminate()
|
||||
proc.join(timeout=5)
|
||||
if proc.is_alive():
|
||||
proc.kill()
|
||||
proc.join()
|
||||
pytest.fail("Remote table hung after fork")
|
||||
|
||||
assert proc.exitcode == 0, f"child exited with code {proc.exitcode}"
|
||||
assert not queue.empty(), "child produced no result"
|
||||
assert queue.get() == 7
|
||||
finally:
|
||||
server.shutdown()
|
||||
server_thread.join()
|
||||
|
||||
@@ -603,89 +603,3 @@ def test_cross_encoder_reranker_return_all(tmp_path):
|
||||
assert "_relevance_score" in result.column_names
|
||||
assert "_score" in result.column_names
|
||||
assert "_distance" in result.column_names
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression tests for LinearCombinationReranker scoring bugs (issue #3154)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_linear_combination_best_match_ranks_first():
|
||||
"""
|
||||
The document that is BOTH the closest vector match AND the only FTS match
|
||||
must rank first. Previously _combine_score subtracted from 1, inverting
|
||||
the ranking so the worst document ranked highest.
|
||||
"""
|
||||
reranker = LinearCombinationReranker(weight=0.7, return_score="all")
|
||||
|
||||
# rowid 0: perfect vector match, sole FTS match → should rank 1st
|
||||
# rowid 1: mediocre vector, no FTS match
|
||||
# rowid 2: bad vector, no FTS match
|
||||
vector_results = pa.Table.from_pydict(
|
||||
{
|
||||
"_rowid": [0, 1, 2],
|
||||
"_distance": [0.0, 0.5, 0.9],
|
||||
}
|
||||
)
|
||||
fts_results = pa.Table.from_pydict(
|
||||
{
|
||||
"_rowid": [0],
|
||||
"_score": [1.0],
|
||||
}
|
||||
)
|
||||
|
||||
combined = reranker.merge_results(vector_results, fts_results, fill=1.0)
|
||||
scores = dict(
|
||||
zip(
|
||||
combined["_rowid"].to_pylist(),
|
||||
combined["_relevance_score"].to_pylist(),
|
||||
)
|
||||
)
|
||||
|
||||
# rowid 0 must have the highest relevance score
|
||||
assert scores[0] > scores[1], (
|
||||
f"Best match (rowid 0, score={scores[0]:.4f}) should beat "
|
||||
f"mid match (rowid 1, score={scores[1]:.4f})"
|
||||
)
|
||||
assert scores[1] > scores[2], (
|
||||
f"Mid match (rowid 1, score={scores[1]:.4f}) should beat "
|
||||
f"bad match (rowid 2, score={scores[2]:.4f})"
|
||||
)
|
||||
|
||||
|
||||
def test_linear_combination_missing_fts_is_penalised():
|
||||
"""
|
||||
A document with no FTS match must score *lower* than a document that
|
||||
has a mediocre FTS match, everything else being equal. Previously
|
||||
missing-FTS entries used fill=1.0 directly, which gave them a reward
|
||||
(via the 1-(...) inversion) instead of a penalty.
|
||||
"""
|
||||
reranker = LinearCombinationReranker(weight=0.5, return_score="all")
|
||||
|
||||
vector_results = pa.Table.from_pydict(
|
||||
{
|
||||
"_rowid": [0, 1],
|
||||
"_distance": [0.2, 0.2], # identical vector scores
|
||||
}
|
||||
)
|
||||
fts_results = pa.Table.from_pydict(
|
||||
{
|
||||
"_rowid": [0], # rowid 1 has no FTS match
|
||||
"_score": [0.3], # small FTS score
|
||||
}
|
||||
)
|
||||
|
||||
combined = reranker.merge_results(vector_results, fts_results, fill=1.0)
|
||||
scores = dict(
|
||||
zip(
|
||||
combined["_rowid"].to_pylist(),
|
||||
combined["_relevance_score"].to_pylist(),
|
||||
)
|
||||
)
|
||||
|
||||
# rowid 0 has a small FTS score; rowid 1 has none.
|
||||
# Even a small FTS contribution should beat having none at all.
|
||||
assert scores[0] > scores[1], (
|
||||
f"Document with FTS score (rowid 0, {scores[0]:.4f}) should beat "
|
||||
f"document with no FTS match (rowid 1, {scores[1]:.4f})"
|
||||
)
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from datetime import date, datetime, timedelta
|
||||
from time import sleep
|
||||
from typing import List
|
||||
@@ -12,7 +11,7 @@ from unittest.mock import patch
|
||||
|
||||
import lancedb
|
||||
from lancedb.dependencies import _PANDAS_AVAILABLE
|
||||
from lancedb.index import BTree, FTS, HnswFlat, HnswPq, HnswSq, IvfPq
|
||||
from lancedb.index import HnswFlat, HnswPq, HnswSq, IvfPq
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
import pyarrow as pa
|
||||
@@ -34,7 +33,7 @@ def test_basic(mem_db: DBConnection):
|
||||
table = mem_db.create_table("test", data=data)
|
||||
|
||||
assert table.name == "test"
|
||||
assert "LanceTable(name='test', _conn=LanceDBConnection(" in repr(table)
|
||||
assert "LanceTable(name='test', version=1, _conn=LanceDBConnection(" in repr(table)
|
||||
expected_schema = pa.schema(
|
||||
{
|
||||
"vector": pa.list_(pa.float32(), 2),
|
||||
@@ -48,85 +47,6 @@ def test_basic(mem_db: DBConnection):
|
||||
assert table.to_arrow() == expected_data
|
||||
|
||||
|
||||
def test_table_to_pandas_default_matches_arrow(tmp_db: DBConnection):
|
||||
pd = pytest.importorskip("pandas")
|
||||
data = pa.table({"id": [1, 2], "text": ["one", "two"]})
|
||||
table = tmp_db.create_table("test_to_pandas_old_call", data=data)
|
||||
|
||||
expected = data.to_pandas()
|
||||
pd.testing.assert_frame_equal(table.to_pandas(), expected)
|
||||
|
||||
|
||||
def test_table_to_pandas_blob_bytes(tmp_db: DBConnection):
|
||||
pytest.importorskip("lance")
|
||||
data = pa.table(
|
||||
{
|
||||
"id": pa.array([1, 2], pa.int64()),
|
||||
"blob": pa.array([b"hello", b"world"], pa.large_binary()),
|
||||
},
|
||||
schema=pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field(
|
||||
"blob", pa.large_binary(), metadata={"lance-encoding:blob": "true"}
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
table = tmp_db.create_table("test_to_pandas_blob_bytes", data=data)
|
||||
|
||||
df = table.to_pandas(blob_mode="bytes")
|
||||
|
||||
assert df["blob"].tolist() == [b"hello", b"world"]
|
||||
|
||||
|
||||
def test_table_to_pandas_kwargs(tmp_db: DBConnection):
|
||||
pd = pytest.importorskip("pandas")
|
||||
data = pa.table({"id": pa.array([1, 2], pa.int64())})
|
||||
table = tmp_db.create_table("test_to_pandas_kwargs", data=data)
|
||||
|
||||
df = table.to_pandas(types_mapper=pd.ArrowDtype)
|
||||
|
||||
assert str(df["id"].dtype) == "int64[pyarrow]"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_table_to_pandas_blob_bytes(tmp_db_async: AsyncConnection):
|
||||
pytest.importorskip("lance")
|
||||
data = pa.table(
|
||||
{
|
||||
"id": pa.array([1, 2], pa.int64()),
|
||||
"blob": pa.array([b"hello", b"world"], pa.large_binary()),
|
||||
},
|
||||
schema=pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field(
|
||||
"blob", pa.large_binary(), metadata={"lance-encoding:blob": "true"}
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
table = await tmp_db_async.create_table(
|
||||
"test_async_to_pandas_blob_bytes", data=data
|
||||
)
|
||||
|
||||
df = await table.to_pandas(blob_mode="bytes")
|
||||
|
||||
assert df["blob"].tolist() == [b"hello", b"world"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_table_to_pandas_kwargs(tmp_db_async: AsyncConnection):
|
||||
pd = pytest.importorskip("pandas")
|
||||
data = pa.table({"id": pa.array([1, 2], pa.int64())})
|
||||
table = await tmp_db_async.create_table("test_async_to_pandas_kwargs", data=data)
|
||||
|
||||
df = await table.to_pandas(types_mapper=pd.ArrowDtype)
|
||||
|
||||
assert str(df["id"].dtype) == "int64[pyarrow]"
|
||||
|
||||
|
||||
def test_create_table_infers_large_int_vectors(mem_db: DBConnection):
|
||||
data = [{"vector": [0, 300]}]
|
||||
|
||||
@@ -903,79 +823,6 @@ async def test_async_tags(mem_db_async: AsyncConnection):
|
||||
)
|
||||
|
||||
|
||||
def test_branches(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
table = db.create_table(
|
||||
"test",
|
||||
data=[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
],
|
||||
)
|
||||
assert table.count_rows() == 2
|
||||
|
||||
# fork an isolated, writable branch from main
|
||||
branch = table.branches.create("exp")
|
||||
assert branch.count_rows() == 2
|
||||
branch.add(data=[{"vector": [10.0, 11.0], "item": "baz", "price": 30.0}])
|
||||
|
||||
# writes on the branch do not touch main
|
||||
assert branch.count_rows() == 3
|
||||
assert table.count_rows() == 2
|
||||
|
||||
# the branch is listed, with main (None) as its parent
|
||||
branches = table.branches.list()
|
||||
assert "exp" in branches
|
||||
assert branches["exp"]["parent_branch"] is None
|
||||
|
||||
# from_ref="main" is equivalent to the default
|
||||
table.branches.create("exp2", from_ref="main")
|
||||
assert table.branches.list()["exp2"]["parent_branch"] is None
|
||||
|
||||
# checkout returns a handle scoped to the branch's latest
|
||||
checked_out = table.branches.checkout("exp")
|
||||
assert checked_out.count_rows() == 3
|
||||
|
||||
# delete removes it
|
||||
table.branches.delete("exp")
|
||||
table.branches.delete("exp2")
|
||||
assert "exp" not in table.branches.list()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_branches(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
table = await db.create_table(
|
||||
"test",
|
||||
data=[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
],
|
||||
)
|
||||
assert await table.count_rows() == 2
|
||||
|
||||
branch = await table.branches.create("exp")
|
||||
assert await branch.count_rows() == 2
|
||||
await branch.add(data=[{"vector": [10.0, 11.0], "item": "baz", "price": 30.0}])
|
||||
|
||||
assert await branch.count_rows() == 3
|
||||
assert await table.count_rows() == 2
|
||||
|
||||
branches = await table.branches.list()
|
||||
assert "exp" in branches
|
||||
assert branches["exp"]["parent_branch"] is None
|
||||
|
||||
await table.branches.create("exp2", from_ref="main")
|
||||
assert (await table.branches.list())["exp2"]["parent_branch"] is None
|
||||
|
||||
checked_out = await table.branches.checkout("exp")
|
||||
assert await checked_out.count_rows() == 3
|
||||
|
||||
await table.branches.delete("exp")
|
||||
await table.branches.delete("exp2")
|
||||
assert "exp" not in await table.branches.list()
|
||||
|
||||
|
||||
@patch("lancedb.table.AsyncTable.create_index")
|
||||
def test_create_index_method(mock_create_index, mem_db: DBConnection):
|
||||
table = mem_db.create_table(
|
||||
@@ -1002,12 +849,7 @@ def test_create_index_method(mock_create_index, mem_db: DBConnection):
|
||||
num_bits=4,
|
||||
)
|
||||
mock_create_index.assert_called_with(
|
||||
"vector",
|
||||
replace=True,
|
||||
config=expected_config,
|
||||
wait_timeout=None,
|
||||
name=None,
|
||||
train=True,
|
||||
"vector", replace=True, config=expected_config, name=None, train=True
|
||||
)
|
||||
|
||||
# Test with target_partition_size
|
||||
@@ -1027,12 +869,7 @@ def test_create_index_method(mock_create_index, mem_db: DBConnection):
|
||||
target_partition_size=8192,
|
||||
)
|
||||
mock_create_index.assert_called_with(
|
||||
"vector",
|
||||
replace=True,
|
||||
config=expected_config,
|
||||
wait_timeout=None,
|
||||
name=None,
|
||||
train=True,
|
||||
"vector", replace=True, config=expected_config, name=None, train=True
|
||||
)
|
||||
|
||||
# target_partition_size has a default value,
|
||||
@@ -1051,12 +888,7 @@ def test_create_index_method(mock_create_index, mem_db: DBConnection):
|
||||
num_bits=4,
|
||||
)
|
||||
mock_create_index.assert_called_with(
|
||||
"vector",
|
||||
replace=True,
|
||||
config=expected_config,
|
||||
wait_timeout=None,
|
||||
name=None,
|
||||
train=True,
|
||||
"vector", replace=True, config=expected_config, name=None, train=True
|
||||
)
|
||||
|
||||
table.create_index(
|
||||
@@ -1067,12 +899,7 @@ def test_create_index_method(mock_create_index, mem_db: DBConnection):
|
||||
)
|
||||
expected_config = HnswPq(distance_type="dot")
|
||||
mock_create_index.assert_called_with(
|
||||
"my_vector",
|
||||
replace=False,
|
||||
config=expected_config,
|
||||
wait_timeout=None,
|
||||
name=None,
|
||||
train=True,
|
||||
"my_vector", replace=False, config=expected_config, name=None, train=True
|
||||
)
|
||||
|
||||
table.create_index(
|
||||
@@ -1087,12 +914,7 @@ def test_create_index_method(mock_create_index, mem_db: DBConnection):
|
||||
distance_type="cosine", sample_rate=0.1, m=29, ef_construction=10
|
||||
)
|
||||
mock_create_index.assert_called_with(
|
||||
"my_vector",
|
||||
replace=True,
|
||||
config=expected_config,
|
||||
wait_timeout=None,
|
||||
name=None,
|
||||
train=True,
|
||||
"my_vector", replace=True, config=expected_config, name=None, train=True
|
||||
)
|
||||
|
||||
table.create_index(
|
||||
@@ -1107,12 +929,7 @@ def test_create_index_method(mock_create_index, mem_db: DBConnection):
|
||||
distance_type="cosine", sample_rate=0.1, m=29, ef_construction=10
|
||||
)
|
||||
mock_create_index.assert_called_with(
|
||||
"my_vector",
|
||||
replace=True,
|
||||
config=expected_config,
|
||||
wait_timeout=None,
|
||||
name=None,
|
||||
train=True,
|
||||
"my_vector", replace=True, config=expected_config, name=None, train=True
|
||||
)
|
||||
|
||||
|
||||
@@ -1136,7 +953,6 @@ def test_create_index_name_and_train_parameters(
|
||||
"vector",
|
||||
replace=True,
|
||||
config=expected_config,
|
||||
wait_timeout=None,
|
||||
name="my_custom_index",
|
||||
train=True,
|
||||
)
|
||||
@@ -1144,82 +960,13 @@ def test_create_index_name_and_train_parameters(
|
||||
# Test with train=False
|
||||
table.create_index(vector_column_name="vector", train=False)
|
||||
mock_create_index.assert_called_with(
|
||||
"vector",
|
||||
replace=True,
|
||||
config=expected_config,
|
||||
wait_timeout=None,
|
||||
name=None,
|
||||
train=False,
|
||||
"vector", replace=True, config=expected_config, name=None, train=False
|
||||
)
|
||||
|
||||
# Test with both name and train
|
||||
table.create_index(vector_column_name="vector", name="my_index_name", train=True)
|
||||
mock_create_index.assert_called_with(
|
||||
"vector",
|
||||
replace=True,
|
||||
config=expected_config,
|
||||
wait_timeout=None,
|
||||
name="my_index_name",
|
||||
train=True,
|
||||
)
|
||||
|
||||
|
||||
@patch("lancedb.table.AsyncTable.create_index")
|
||||
def test_create_index_legacy_emits_deprecation_warning(
|
||||
mock_create_index, mem_db: DBConnection
|
||||
):
|
||||
table = mem_db.create_table(
|
||||
"test",
|
||||
data=[{"vector": [3.1, 4.1]}, {"vector": [5.9, 26.5]}],
|
||||
)
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="create_index"):
|
||||
table.create_index(metric="l2", num_partitions=8, vector_column_name="vector")
|
||||
|
||||
|
||||
@patch("lancedb.table.AsyncTable.create_index")
|
||||
def test_create_index_new_api(mock_create_index, mem_db: DBConnection):
|
||||
table = mem_db.create_table(
|
||||
"test",
|
||||
data=[
|
||||
{"vector": [3.1, 4.1], "category": "a", "text": "hello world"},
|
||||
{"vector": [5.9, 26.5], "category": "b", "text": "goodbye"},
|
||||
],
|
||||
)
|
||||
|
||||
# Vector index via new API should not warn
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error", DeprecationWarning)
|
||||
table.create_index("vector", config=IvfPq(distance_type="l2"))
|
||||
mock_create_index.assert_called_with(
|
||||
"vector",
|
||||
replace=True,
|
||||
config=IvfPq(distance_type="l2"),
|
||||
wait_timeout=None,
|
||||
name=None,
|
||||
train=True,
|
||||
)
|
||||
|
||||
# Scalar index via new API
|
||||
table.create_index("category", config=BTree())
|
||||
mock_create_index.assert_called_with(
|
||||
"category",
|
||||
replace=True,
|
||||
config=BTree(),
|
||||
wait_timeout=None,
|
||||
name=None,
|
||||
train=True,
|
||||
)
|
||||
|
||||
# FTS index via new API
|
||||
table.create_index("text", config=FTS(with_position=True))
|
||||
mock_create_index.assert_called_with(
|
||||
"text",
|
||||
replace=True,
|
||||
config=FTS(with_position=True),
|
||||
wait_timeout=None,
|
||||
name=None,
|
||||
train=True,
|
||||
"vector", replace=True, config=expected_config, name="my_index_name", train=True
|
||||
)
|
||||
|
||||
|
||||
@@ -2035,9 +1782,8 @@ def test_create_scalar_index(mem_db: DBConnection):
|
||||
"my_table",
|
||||
data=test_data,
|
||||
)
|
||||
# Test with default name; confirm DeprecationWarning fires
|
||||
with pytest.warns(DeprecationWarning, match="create_scalar_index"):
|
||||
table.create_scalar_index("x")
|
||||
# Test with default name
|
||||
table.create_scalar_index("x")
|
||||
indices = table.list_indices()
|
||||
assert len(indices) == 1
|
||||
scalar_index = indices[0]
|
||||
@@ -2065,59 +1811,6 @@ def test_create_scalar_index(mem_db: DBConnection):
|
||||
assert scalar_index.name == "custom_y_index"
|
||||
|
||||
|
||||
def test_create_index_nested_field_paths(mem_db: DBConnection):
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("metadata", pa.struct([pa.field("user_id", pa.int32())])),
|
||||
pa.field(
|
||||
"image",
|
||||
pa.struct([pa.field("embedding", pa.list_(pa.float32(), 2))]),
|
||||
),
|
||||
]
|
||||
)
|
||||
data = pa.Table.from_pylist(
|
||||
[
|
||||
{
|
||||
"metadata": {"user_id": i},
|
||||
"image": {"embedding": [float(i), float(i + 1)]},
|
||||
}
|
||||
for i in range(256)
|
||||
],
|
||||
schema=schema,
|
||||
)
|
||||
table = mem_db.create_table("nested_index_paths", data=data)
|
||||
|
||||
table.create_scalar_index("metadata.user_id", name="metadata_user_id_idx")
|
||||
table.create_index(
|
||||
vector_column_name="image.embedding",
|
||||
num_partitions=1,
|
||||
num_sub_vectors=1,
|
||||
name="image_embedding_idx",
|
||||
)
|
||||
|
||||
indices = sorted(table.list_indices(), key=lambda idx: idx.name)
|
||||
assert [(idx.name, idx.index_type, idx.columns) for idx in indices] == [
|
||||
("image_embedding_idx", "IvfPq", ["image.embedding"]),
|
||||
("metadata_user_id_idx", "BTree", ["metadata.user_id"]),
|
||||
]
|
||||
|
||||
vector_results = (
|
||||
table.search([0.0, 1.0], vector_column_name="image.embedding")
|
||||
.limit(1)
|
||||
.to_list()
|
||||
)
|
||||
assert len(vector_results) == 1
|
||||
assert vector_results[0]["metadata"]["user_id"] == 0
|
||||
|
||||
default_vector_results = table.search([0.0, 1.0]).limit(1).to_list()
|
||||
assert len(default_vector_results) == 1
|
||||
assert default_vector_results[0]["metadata"]["user_id"] == 0
|
||||
|
||||
filtered_results = table.search().where("metadata.user_id = 42").limit(1).to_list()
|
||||
assert len(filtered_results) == 1
|
||||
assert filtered_results[0]["metadata"]["user_id"] == 42
|
||||
|
||||
|
||||
def test_empty_query(mem_db: DBConnection):
|
||||
table = mem_db.create_table(
|
||||
"my_table",
|
||||
@@ -2192,74 +1885,6 @@ def test_search_with_schema_inf_multiple_vector(mem_db: DBConnection):
|
||||
table.search(q).limit(1).to_arrow()
|
||||
|
||||
|
||||
def test_search_infers_single_nested_vector(mem_db: DBConnection):
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int32()),
|
||||
pa.field(
|
||||
"image",
|
||||
pa.struct([pa.field("embedding", pa.list_(pa.float32(), 2))]),
|
||||
),
|
||||
]
|
||||
)
|
||||
data = pa.Table.from_pylist(
|
||||
[
|
||||
{"id": 0, "image": {"embedding": [0.0, 1.0]}},
|
||||
{"id": 1, "image": {"embedding": [10.0, 11.0]}},
|
||||
],
|
||||
schema=schema,
|
||||
)
|
||||
table = mem_db.create_table("nested_vector_default_search", data=data)
|
||||
|
||||
result = table.search([0.0, 1.0]).limit(1).to_list()
|
||||
assert result[0]["id"] == 0
|
||||
|
||||
|
||||
def test_search_nested_vector_multiple_candidates(mem_db: DBConnection):
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field(
|
||||
"image",
|
||||
pa.struct([pa.field("embedding", pa.list_(pa.float32(), 2))]),
|
||||
),
|
||||
pa.field(
|
||||
"text",
|
||||
pa.struct([pa.field("embedding", pa.list_(pa.float32(), 2))]),
|
||||
),
|
||||
]
|
||||
)
|
||||
data = pa.Table.from_pylist(
|
||||
[
|
||||
{
|
||||
"image": {"embedding": [0.0, 1.0]},
|
||||
"text": {"embedding": [2.0, 3.0]},
|
||||
}
|
||||
],
|
||||
schema=schema,
|
||||
)
|
||||
table = mem_db.create_table("nested_vector_multiple_candidates", data=data)
|
||||
|
||||
with pytest.raises(ValueError, match="image.embedding.*text.embedding"):
|
||||
table.search([0.0, 1.0]).limit(1).to_arrow()
|
||||
|
||||
|
||||
def test_search_nested_vector_no_candidates(mem_db: DBConnection):
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int32()),
|
||||
pa.field("metadata", pa.struct([pa.field("label", pa.string())])),
|
||||
]
|
||||
)
|
||||
data = pa.Table.from_pylist(
|
||||
[{"id": 0, "metadata": {"label": "cat"}}],
|
||||
schema=schema,
|
||||
)
|
||||
table = mem_db.create_table("nested_vector_no_candidates", data=data)
|
||||
|
||||
with pytest.raises(ValueError, match="no vector column"):
|
||||
table.search([0.0, 1.0]).limit(1).to_arrow()
|
||||
|
||||
|
||||
def test_compact_cleanup(tmp_db: DBConnection):
|
||||
pytest.importorskip("lance")
|
||||
table = tmp_db.create_table(
|
||||
@@ -2545,30 +2170,6 @@ def test_alter_columns(mem_db: DBConnection):
|
||||
assert table.to_arrow().column_names == ["new_id"]
|
||||
|
||||
|
||||
def test_update_field_metadata(mem_db: DBConnection):
|
||||
data = pa.table({"id": [0, 1], "category": ["a", "b"]})
|
||||
table = mem_db.create_table("my_table", data=data)
|
||||
|
||||
res = table.update_field_metadata(
|
||||
{"path": "category", "metadata": {"unit": "label", "pii": "false"}}
|
||||
)
|
||||
assert res.version == 2
|
||||
# Arrow field metadata is bytes-keyed
|
||||
assert table.schema.field("category").metadata == {
|
||||
b"unit": b"label",
|
||||
b"pii": b"false",
|
||||
}
|
||||
|
||||
# merge: add a key, delete one via None, keep the rest
|
||||
table.update_field_metadata(
|
||||
{"path": "category", "metadata": {"source": "import", "pii": None}}
|
||||
)
|
||||
assert table.schema.field("category").metadata == {
|
||||
b"unit": b"label",
|
||||
b"source": b"import",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alter_columns_async(mem_db_async: AsyncConnection):
|
||||
data = pa.table({"id": [0, 1]})
|
||||
|
||||
@@ -1,15 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import http.server
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import pickle
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
|
||||
import lancedb
|
||||
import pyarrow as pa
|
||||
@@ -20,107 +15,6 @@ from lancedb.util import tbl_to_tensor
|
||||
torch = pytest.importorskip("torch")
|
||||
|
||||
|
||||
REMOTE_ROWS = list(range(100))
|
||||
|
||||
|
||||
def _make_mock_http_handler(handler):
|
||||
class MockLanceDBHandler(http.server.BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
handler(self)
|
||||
|
||||
def do_POST(self):
|
||||
handler(self)
|
||||
|
||||
return MockLanceDBHandler
|
||||
|
||||
|
||||
def _remote_schema_payload():
|
||||
return {
|
||||
"version": 1,
|
||||
"schema": {
|
||||
"fields": [
|
||||
{"name": "a", "type": {"type": "int64"}, "nullable": False},
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _offsets_from_filter(filter_sql: str | None) -> list[int]:
|
||||
if filter_sql is None:
|
||||
return REMOTE_ROWS
|
||||
match = re.search(r"_rowoffset in \((.*?)\)", filter_sql)
|
||||
if match is None:
|
||||
return REMOTE_ROWS
|
||||
raw_offsets = match.group(1).strip()
|
||||
if raw_offsets == "":
|
||||
return []
|
||||
return [int(offset.strip()) for offset in raw_offsets.split(",")]
|
||||
|
||||
|
||||
def _remote_dataset_handler(request):
|
||||
request.close_connection = True
|
||||
if request.path == "/v1/table/test/describe/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(json.dumps(_remote_schema_payload()).encode())
|
||||
elif request.path == "/v1/table/test/count_rows/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(str(len(REMOTE_ROWS)).encode())
|
||||
elif request.path == "/v1/table/test/query/":
|
||||
content_len = int(request.headers.get("Content-Length"))
|
||||
body = json.loads(request.rfile.read(content_len))
|
||||
offsets = _offsets_from_filter(body.get("filter"))
|
||||
requested_columns = body.get("columns") or ["a"]
|
||||
if isinstance(requested_columns, dict):
|
||||
requested_columns = list(requested_columns)
|
||||
|
||||
data = {}
|
||||
for column in requested_columns:
|
||||
if column == "a":
|
||||
data[column] = [REMOTE_ROWS[offset] for offset in offsets]
|
||||
elif column == "_rowoffset":
|
||||
data[column] = offsets
|
||||
elif column == "_rowid":
|
||||
data[column] = offsets
|
||||
|
||||
table = pa.table(data)
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/vnd.apache.arrow.file")
|
||||
request.end_headers()
|
||||
with pa.ipc.new_file(request.wfile, schema=table.schema) as writer:
|
||||
writer.write_table(table)
|
||||
else:
|
||||
request.send_response(404)
|
||||
request.end_headers()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _remote_dataset_table():
|
||||
with http.server.ThreadingHTTPServer(
|
||||
("localhost", 0), _make_mock_http_handler(_remote_dataset_handler)
|
||||
) as server:
|
||||
port = server.server_address[1]
|
||||
handle = threading.Thread(target=server.serve_forever)
|
||||
handle.start()
|
||||
try:
|
||||
db = lancedb.connect(
|
||||
"db://dev",
|
||||
api_key="fake",
|
||||
host_override=f"http://localhost:{port}",
|
||||
client_config={
|
||||
"retry_config": {"retries": 0},
|
||||
"timeout_config": {"connect_timeout": 2, "read_timeout": 2},
|
||||
},
|
||||
)
|
||||
yield db.open_table("test")
|
||||
finally:
|
||||
server.shutdown()
|
||||
handle.join()
|
||||
|
||||
|
||||
def _open_native_table(uri: str, table_name: str):
|
||||
"""Top-level connection factory used by the explicit-factory pickle test.
|
||||
|
||||
@@ -213,39 +107,6 @@ def test_permutation_dataloader_multiprocessing(tmp_db):
|
||||
assert seen == 1000
|
||||
|
||||
|
||||
def test_remote_table_dataloader_multiprocessing():
|
||||
with _remote_dataset_table() as table:
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
table,
|
||||
collate_fn=tbl_to_tensor,
|
||||
batch_size=10,
|
||||
num_workers=2,
|
||||
multiprocessing_context="spawn",
|
||||
)
|
||||
seen = 0
|
||||
for batch in dataloader:
|
||||
assert batch.size(0) == 1
|
||||
assert batch.size(1) == 10
|
||||
seen += batch.size(1)
|
||||
assert seen == len(REMOTE_ROWS)
|
||||
|
||||
|
||||
def test_remote_permutation_dataloader_multiprocessing():
|
||||
with _remote_dataset_table() as table:
|
||||
permutation = Permutation.identity(table)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
permutation,
|
||||
batch_size=10,
|
||||
num_workers=2,
|
||||
multiprocessing_context="spawn",
|
||||
)
|
||||
seen = 0
|
||||
for batch in dataloader:
|
||||
assert batch["a"].size(0) == 10
|
||||
seen += batch["a"].size(0)
|
||||
assert seen == len(REMOTE_ROWS)
|
||||
|
||||
|
||||
def test_permutation_pickle_with_connection_factory(tmp_path):
|
||||
"""When the user provides a connection_factory, pickling should round-trip
|
||||
through that factory rather than introspecting the connection URI. Useful
|
||||
@@ -310,35 +171,6 @@ def _multiworker_dataloader_target(db_uri: str, result_queue):
|
||||
result_queue.put(count)
|
||||
|
||||
|
||||
def _remote_multiworker_dataloader_target(port: int, result_queue):
|
||||
import lancedb
|
||||
from lancedb.permutation import Permutation
|
||||
|
||||
db = lancedb.connect(
|
||||
"db://dev",
|
||||
api_key="fake",
|
||||
host_override=f"http://localhost:{port}",
|
||||
client_config={
|
||||
"retry_config": {"retries": 0},
|
||||
"timeout_config": {"connect_timeout": 2, "read_timeout": 2},
|
||||
},
|
||||
)
|
||||
table = db.open_table("test")
|
||||
permutation = Permutation.identity(table)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
permutation,
|
||||
batch_size=10,
|
||||
num_workers=2,
|
||||
multiprocessing_context="fork",
|
||||
)
|
||||
count = 0
|
||||
for batch in dataloader:
|
||||
assert batch["a"].size(0) == 10
|
||||
count += 1
|
||||
result_queue.put(count)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform != "linux",
|
||||
reason=(
|
||||
@@ -376,46 +208,3 @@ def test_permutation_dataloader_fork_workers(tmp_path):
|
||||
assert proc.exitcode == 0, f"child exited with code {proc.exitcode}"
|
||||
assert not queue.empty(), "child produced no batches"
|
||||
assert queue.get() == 100
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform != "linux",
|
||||
reason=(
|
||||
"fork() is unavailable on Windows and unsafe on macOS "
|
||||
"(Apple frameworks/TLS are not fork-safe)"
|
||||
),
|
||||
)
|
||||
def test_remote_permutation_dataloader_fork_workers():
|
||||
with http.server.ThreadingHTTPServer(
|
||||
("localhost", 0), _make_mock_http_handler(_remote_dataset_handler)
|
||||
) as server:
|
||||
port = server.server_address[1]
|
||||
handle = threading.Thread(target=server.serve_forever)
|
||||
handle.start()
|
||||
try:
|
||||
ctx = mp.get_context("spawn")
|
||||
queue = ctx.Queue()
|
||||
proc = ctx.Process(
|
||||
target=_remote_multiworker_dataloader_target,
|
||||
args=(port, queue),
|
||||
)
|
||||
proc.start()
|
||||
proc.join(timeout=30)
|
||||
|
||||
if proc.is_alive():
|
||||
proc.terminate()
|
||||
proc.join(timeout=5)
|
||||
if proc.is_alive():
|
||||
proc.kill()
|
||||
proc.join()
|
||||
pytest.fail(
|
||||
"Remote permutation hung when iterated in a fork-based "
|
||||
"DataLoader worker"
|
||||
)
|
||||
|
||||
assert proc.exitcode == 0, f"child exited with code {proc.exitcode}"
|
||||
assert not queue.empty(), "child produced no batches"
|
||||
assert queue.get() == 10
|
||||
finally:
|
||||
server.shutdown()
|
||||
handle.join()
|
||||
|
||||
@@ -539,7 +539,7 @@ impl Connection {
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (uri, api_key=None, region=None, host_override=None, read_consistency_interval=None, client_config=None, storage_options=None, session=None, manifest_enabled=false, namespace_client_properties=None))]
|
||||
#[pyo3(signature = (uri, api_key=None, region=None, host_override=None, read_consistency_interval=None, client_config=None, storage_options=None, session=None, manifest_enabled=false, namespace_client_properties=None, oauth_config=None))]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn connect(
|
||||
py: Python<'_>,
|
||||
@@ -553,6 +553,7 @@ pub fn connect(
|
||||
session: Option<crate::session::Session>,
|
||||
manifest_enabled: bool,
|
||||
namespace_client_properties: Option<HashMap<String, String>>,
|
||||
oauth_config: Option<crate::oauth::PyOAuthConfig>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
future_into_py(py, async move {
|
||||
let mut builder = lancedb::connect(&uri);
|
||||
@@ -582,6 +583,10 @@ pub fn connect(
|
||||
if let Some(client_config) = client_config {
|
||||
builder = builder.client_config(client_config.into());
|
||||
}
|
||||
if let Some(oauth_config) = oauth_config {
|
||||
let config: lancedb::remote::oauth::OAuthConfig = oauth_config.into();
|
||||
builder = builder.oauth_config(config);
|
||||
}
|
||||
if let Some(session) = session {
|
||||
builder = builder.session(session.inner.clone());
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ use query::{FTSQuery, HybridQuery, Query, VectorQuery};
|
||||
use session::Session;
|
||||
use table::{
|
||||
AddColumnsResult, AddResult, AlterColumnsResult, DeleteResult, DropColumnsResult, LsmWriteSpec,
|
||||
MergeResult, Table, UpdateFieldMetadataResult, UpdateResult,
|
||||
MergeResult, Table, UpdateResult,
|
||||
};
|
||||
|
||||
pub mod arrow;
|
||||
@@ -26,6 +26,7 @@ pub mod expr;
|
||||
pub mod header;
|
||||
pub mod index;
|
||||
pub mod namespace;
|
||||
pub mod oauth;
|
||||
pub mod permutation;
|
||||
pub mod query;
|
||||
pub mod runtime;
|
||||
@@ -50,7 +51,6 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<RecordBatchStream>()?;
|
||||
m.add_class::<AddColumnsResult>()?;
|
||||
m.add_class::<AlterColumnsResult>()?;
|
||||
m.add_class::<UpdateFieldMetadataResult>()?;
|
||||
m.add_class::<AddResult>()?;
|
||||
m.add_class::<MergeResult>()?;
|
||||
m.add_class::<LsmWriteSpec>()?;
|
||||
|
||||
53
python/src/oauth.rs
Normal file
53
python/src/oauth.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use pyo3::FromPyObject;
|
||||
|
||||
use lancedb::remote::oauth::{OAuthConfig, OAuthFlow};
|
||||
|
||||
/// Python-side OAuth configuration, extracted via FromPyObject.
|
||||
/// Maps to `lancedb.remote.oauth.OAuthConfig` Python dataclass.
|
||||
#[derive(FromPyObject)]
|
||||
pub struct PyOAuthConfig {
|
||||
pub issuer_url: String,
|
||||
pub client_id: String,
|
||||
pub scopes: Vec<String>,
|
||||
pub flow: String,
|
||||
pub client_secret: Option<String>,
|
||||
pub redirect_uri: Option<String>,
|
||||
pub callback_port: Option<u16>,
|
||||
pub managed_identity_client_id: Option<String>,
|
||||
pub token_file: Option<String>,
|
||||
pub refresh_buffer_secs: Option<u64>,
|
||||
}
|
||||
|
||||
impl From<PyOAuthConfig> for OAuthConfig {
|
||||
fn from(py: PyOAuthConfig) -> Self {
|
||||
let flow = match py.flow.as_str() {
|
||||
"client_credentials" => OAuthFlow::ClientCredentials,
|
||||
"authorization_code_pkce" => OAuthFlow::AuthorizationCodePKCE {
|
||||
redirect_uri: py.redirect_uri,
|
||||
callback_port: py.callback_port,
|
||||
},
|
||||
"device_code" => OAuthFlow::DeviceCode,
|
||||
"azure_managed_identity" => OAuthFlow::AzureManagedIdentity {
|
||||
client_id: py.managed_identity_client_id,
|
||||
},
|
||||
"workload_identity" => OAuthFlow::WorkloadIdentity {
|
||||
token_file: py
|
||||
.token_file
|
||||
.expect("token_file is required for workload_identity flow"),
|
||||
},
|
||||
other => panic!("Unknown OAuth flow type: {other}"),
|
||||
};
|
||||
|
||||
OAuthConfig {
|
||||
issuer_url: py.issuer_url,
|
||||
client_id: py.client_id,
|
||||
client_secret: py.client_secret,
|
||||
scopes: py.scopes,
|
||||
flow,
|
||||
refresh_buffer_secs: py.refresh_buffer_secs,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -16,12 +16,12 @@ use arrow::{
|
||||
pyarrow::{FromPyArrow, PyArrowType, ToPyArrow},
|
||||
};
|
||||
use lancedb::table::{
|
||||
AddDataMode, ColumnAlteration, Duration, FieldMetadataUpdate, NewColumnTransform,
|
||||
OptimizeAction, OptimizeOptions, Ref, Table as LanceDbTable,
|
||||
AddDataMode, ColumnAlteration, Duration, NewColumnTransform, OptimizeAction, OptimizeOptions,
|
||||
Table as LanceDbTable,
|
||||
};
|
||||
use pyo3::{
|
||||
Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
|
||||
exceptions::{PyRuntimeError, PyValueError},
|
||||
exceptions::{PyKeyError, PyRuntimeError, PyValueError},
|
||||
pyclass, pymethods,
|
||||
types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods},
|
||||
};
|
||||
@@ -143,20 +143,18 @@ pub struct MergeResult {
|
||||
pub num_inserted_rows: u64,
|
||||
pub num_deleted_rows: u64,
|
||||
pub num_attempts: u32,
|
||||
pub num_rows: u64,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl MergeResult {
|
||||
pub fn __repr__(&self) -> String {
|
||||
format!(
|
||||
"MergeResult(version={}, num_updated_rows={}, num_inserted_rows={}, num_deleted_rows={}, num_attempts={}, num_rows={})",
|
||||
"MergeResult(version={}, num_updated_rows={}, num_inserted_rows={}, num_deleted_rows={}, num_attempts={})",
|
||||
self.version,
|
||||
self.num_updated_rows,
|
||||
self.num_inserted_rows,
|
||||
self.num_deleted_rows,
|
||||
self.num_attempts,
|
||||
self.num_rows
|
||||
self.num_attempts
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -169,7 +167,6 @@ impl From<lancedb::table::MergeResult> for MergeResult {
|
||||
num_inserted_rows: result.num_inserted_rows,
|
||||
num_deleted_rows: result.num_deleted_rows,
|
||||
num_attempts: result.num_attempts,
|
||||
num_rows: result.num_rows,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -197,12 +194,6 @@ impl LsmWriteSpec {
|
||||
}
|
||||
|
||||
/// Identity sharding — shard by the raw value of `column`.
|
||||
///
|
||||
/// `column` must be a deterministic function of the unenforced primary
|
||||
/// key: every row with a given primary key must always produce the same
|
||||
/// `column` value, or upserts of that key can land in different shards
|
||||
/// and a stale version can win. Typically `column` is the primary key
|
||||
/// itself or a stable attribute of it.
|
||||
#[staticmethod]
|
||||
pub fn identity(column: String) -> Self {
|
||||
Self {
|
||||
@@ -357,27 +348,6 @@ impl From<lancedb::table::AlterColumnsResult> for AlterColumnsResult {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(get_all, from_py_object)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct UpdateFieldMetadataResult {
|
||||
pub version: u64,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl UpdateFieldMetadataResult {
|
||||
pub fn __repr__(&self) -> String {
|
||||
format!("UpdateFieldMetadataResult(version={})", self.version)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<lancedb::table::UpdateFieldMetadataResult> for UpdateFieldMetadataResult {
|
||||
fn from(result: lancedb::table::UpdateFieldMetadataResult) -> Self {
|
||||
Self {
|
||||
version: result.version,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(get_all, from_py_object)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct DropColumnsResult {
|
||||
@@ -864,11 +834,6 @@ impl Table {
|
||||
Ok(Tags::new(self.inner_ref()?.clone()))
|
||||
}
|
||||
|
||||
#[getter]
|
||||
pub fn branches(&self) -> PyResult<Branches> {
|
||||
Ok(Branches::new(self.inner_ref()?.clone()))
|
||||
}
|
||||
|
||||
#[pyo3(signature = (offsets))]
|
||||
pub fn take_offsets(self_: PyRef<'_, Self>, offsets: Vec<u64>) -> PyResult<TakeQuery> {
|
||||
Ok(TakeQuery::new(
|
||||
@@ -968,12 +933,6 @@ impl Table {
|
||||
if let Some(use_index) = parameters.use_index {
|
||||
builder.use_index(use_index);
|
||||
}
|
||||
if let Some(use_lsm_write) = parameters.use_lsm_write {
|
||||
builder.use_lsm_write(use_lsm_write);
|
||||
}
|
||||
if let Some(validate_single_shard) = parameters.validate_single_shard {
|
||||
builder.validate_single_shard(validate_single_shard);
|
||||
}
|
||||
|
||||
future_into_py(self_.py(), async move {
|
||||
let res = builder.execute(Box::new(batches)).await.infer_error()?;
|
||||
@@ -1012,13 +971,6 @@ impl Table {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn close_lsm_writers(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner.close_lsm_writers().await.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn uses_v2_manifest_paths(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
@@ -1128,57 +1080,31 @@ impl Table {
|
||||
field_name: String,
|
||||
metadata: &Bound<'_, PyDict>,
|
||||
) -> PyResult<Bound<'a, PyAny>> {
|
||||
// Deprecated: forwards to the update_field_metadata path (replace mode).
|
||||
let mut update = FieldMetadataUpdate::new(field_name).replace();
|
||||
for (key, value) in metadata.into_iter() {
|
||||
update = update.set(key.extract::<String>()?, value.extract::<String>()?);
|
||||
let mut new_metadata = HashMap::<String, String>::new();
|
||||
for (column_name, value) in metadata.into_iter() {
|
||||
let key: String = column_name.extract()?;
|
||||
let value: String = value.extract()?;
|
||||
new_metadata.insert(key, value);
|
||||
}
|
||||
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner.update_field_metadata(&[update]).await.infer_error()?;
|
||||
let native_tbl = inner
|
||||
.as_native()
|
||||
.ok_or_else(|| PyValueError::new_err("This cannot be run on a remote table"))?;
|
||||
let schema = native_tbl.manifest().await.infer_error()?.schema;
|
||||
let field = schema
|
||||
.field(&field_name)
|
||||
.ok_or_else(|| PyKeyError::new_err(format!("Field {} not found", field_name)))?;
|
||||
|
||||
native_tbl
|
||||
.replace_field_metadata(vec![(field.id as u32, new_metadata)])
|
||||
.await
|
||||
.infer_error()?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn update_field_metadata<'a>(
|
||||
self_: PyRef<'a, Self>,
|
||||
updates: Vec<Bound<PyDict>>,
|
||||
) -> PyResult<Bound<'a, PyAny>> {
|
||||
let updates = updates
|
||||
.iter()
|
||||
.map(|update| {
|
||||
let path: String = update
|
||||
.get_item("path")?
|
||||
.ok_or_else(|| PyValueError::new_err("Missing path"))?
|
||||
.extract()?;
|
||||
let mut field_update = FieldMetadataUpdate::new(path);
|
||||
if let Some(metadata) = update.get_item("metadata")? {
|
||||
let metadata_dict = metadata.cast::<PyDict>()?;
|
||||
for (key, value) in metadata_dict.iter() {
|
||||
let key: String = key.extract()?;
|
||||
if value.is_none() {
|
||||
field_update = field_update.remove(key);
|
||||
} else {
|
||||
field_update = field_update.set(key, value.extract::<String>()?);
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(replace) = update.get_item("replace")?
|
||||
&& replace.extract::<bool>()?
|
||||
{
|
||||
field_update = field_update.replace();
|
||||
}
|
||||
Ok(field_update)
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()?;
|
||||
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let result = inner.update_field_metadata(&updates).await.infer_error()?;
|
||||
Ok(UpdateFieldMetadataResult::from(result))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(FromPyObject)]
|
||||
@@ -1198,8 +1124,6 @@ pub struct MergeInsertParams {
|
||||
when_not_matched_by_source_condition: Option<String>,
|
||||
timeout: Option<std::time::Duration>,
|
||||
use_index: Option<bool>,
|
||||
use_lsm_write: Option<bool>,
|
||||
validate_single_shard: Option<bool>,
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
@@ -1270,66 +1194,3 @@ impl Tags {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
pub struct Branches {
|
||||
inner: LanceDbTable,
|
||||
}
|
||||
|
||||
impl Branches {
|
||||
pub fn new(table: LanceDbTable) -> Self {
|
||||
Self { inner: table }
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl Branches {
|
||||
pub fn list(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let res = inner.list_branches().await.infer_error()?;
|
||||
Python::attach(|py| {
|
||||
let py_dict = PyDict::new(py);
|
||||
for (name, contents) in res {
|
||||
let value = PyDict::new(py);
|
||||
value.set_item("parent_branch", contents.parent_branch)?;
|
||||
value.set_item("parent_version", contents.parent_version)?;
|
||||
value.set_item("manifest_size", contents.manifest_size)?;
|
||||
py_dict.set_item(name, value)?;
|
||||
}
|
||||
Ok(py_dict.unbind())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (name, from_ref=None, from_version=None))]
|
||||
pub fn create(
|
||||
self_: PyRef<'_, Self>,
|
||||
name: String,
|
||||
from_ref: Option<String>,
|
||||
from_version: Option<u64>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let from = Ref::Version(from_ref, from_version);
|
||||
let table = inner.create_branch(&name, from).await.infer_error()?;
|
||||
Ok(Table::new(table))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn checkout(self_: PyRef<'_, Self>, name: String) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let table = inner.checkout_branch(&name).await.infer_error()?;
|
||||
Ok(Table::new(table))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn delete(self_: PyRef<'_, Self>, name: String) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner.delete_branch(&name).await.infer_error()?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
[toolchain]
|
||||
channel = "1.95.0"
|
||||
channel = "1.94.0"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb"
|
||||
version = "0.30.1-beta.0"
|
||||
version = "0.29.1-beta.0"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
@@ -75,7 +75,12 @@ reqwest = { version = "0.12.0", default-features = false, features = [
|
||||
"stream",
|
||||
], optional = true }
|
||||
http = { version = "1", optional = true } # Matching what is in reqwest
|
||||
uuid = { version = "1.7.0", features = ["v4", "v5"] }
|
||||
# OAuth dependencies (used by remote feature)
|
||||
sha2 = { version = "0.10", optional = true }
|
||||
base64 = { version = "0.22", optional = true }
|
||||
urlencoding = { version = "2", optional = true }
|
||||
open = { version = "5", optional = true }
|
||||
uuid = { version = "1.7.0", features = ["v4"] }
|
||||
polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
|
||||
polars = { version = ">=0.37,<0.40.0", optional = true }
|
||||
hf-hub = { version = "0.4.1", optional = true, default-features = false, features = [
|
||||
@@ -104,7 +109,6 @@ datafusion.workspace = true
|
||||
http-body = "1" # Matching reqwest
|
||||
rstest = "0.23.0"
|
||||
test-log = "0.2"
|
||||
serial_test = "3"
|
||||
|
||||
|
||||
[features]
|
||||
@@ -129,7 +133,7 @@ huggingface = [
|
||||
"lance-namespace-impls/dir-huggingface",
|
||||
]
|
||||
dynamodb = ["lance/dynamodb", "aws"]
|
||||
remote = ["dep:reqwest", "dep:http", "lance-namespace-impls/rest", "lance-namespace-impls/rest-adapter"]
|
||||
remote = ["dep:reqwest", "dep:http", "dep:sha2", "dep:base64", "dep:urlencoding", "dep:open", "lance-namespace-impls/rest", "lance-namespace-impls/rest-adapter"]
|
||||
fp16kernels = ["lance-linalg/fp16kernels"]
|
||||
s3-test = []
|
||||
bedrock = ["dep:aws-sdk-bedrockruntime"]
|
||||
|
||||
@@ -622,6 +622,8 @@ pub struct ConnectRequest {
|
||||
pub struct ConnectBuilder {
|
||||
request: ConnectRequest,
|
||||
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
|
||||
#[cfg(feature = "remote")]
|
||||
oauth_config: Option<crate::remote::oauth::OAuthConfig>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
@@ -643,6 +645,8 @@ impl ConnectBuilder {
|
||||
session: None,
|
||||
},
|
||||
embedding_registry: None,
|
||||
#[cfg(feature = "remote")]
|
||||
oauth_config: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -731,6 +735,19 @@ impl ConnectBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
/// Configure OAuth authentication for LanceDB Cloud/Enterprise.
|
||||
///
|
||||
/// This creates an [`OAuthHeaderProvider`](crate::remote::OAuthHeaderProvider)
|
||||
/// from the given config and sets it as the header provider, replacing any
|
||||
/// previously configured header provider or API key.
|
||||
///
|
||||
/// Token acquisition and refresh are handled entirely in Rust.
|
||||
#[cfg(feature = "remote")]
|
||||
pub fn oauth_config(mut self, config: crate::remote::oauth::OAuthConfig) -> Self {
|
||||
self.oauth_config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
/// Provide a custom [`EmbeddingRegistry`] to use for this connection.
|
||||
pub fn embedding_registry(mut self, registry: Arc<dyn EmbeddingRegistry>) -> Self {
|
||||
self.embedding_registry = Some(registry);
|
||||
@@ -812,7 +829,8 @@ impl ConnectBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
/// The interval at which to check for updates from other processes.
|
||||
/// The interval at which to check for updates from other processes. This
|
||||
/// only affects LanceDB OSS.
|
||||
///
|
||||
/// If left unset, consistency is not checked. For maximum read
|
||||
/// performance, this is the default. For strong consistency, set this to
|
||||
@@ -824,11 +842,8 @@ impl ConnectBuilder {
|
||||
/// This only affects read operations. Write operations are always
|
||||
/// consistent.
|
||||
///
|
||||
/// # Cost
|
||||
///
|
||||
/// Stronger consistency is not free. The smaller the interval, the more
|
||||
/// often each read pays the cost of checking for updates against object
|
||||
/// storage, raising per-read latency and cost.
|
||||
/// LanceDB Cloud uses eventual consistency under the hood, and is not
|
||||
/// currently configurable.
|
||||
pub fn read_consistency_interval(
|
||||
mut self,
|
||||
read_consistency_interval: std::time::Duration,
|
||||
@@ -876,9 +891,29 @@ impl ConnectBuilder {
|
||||
let region = options.region.ok_or_else(|| Error::InvalidInput {
|
||||
message: "A region is required when connecting to LanceDb Cloud".to_string(),
|
||||
})?;
|
||||
let api_key = options.api_key.ok_or_else(|| Error::InvalidInput {
|
||||
message: "An api_key is required when connecting to LanceDb Cloud".to_string(),
|
||||
})?;
|
||||
|
||||
// When OAuth is configured, api_key is not required
|
||||
let api_key = match (&self.oauth_config, &options.api_key) {
|
||||
(Some(_), None) => String::new(),
|
||||
(Some(_), Some(key)) => key.clone(),
|
||||
(None, Some(key)) => key.clone(),
|
||||
(None, None) => {
|
||||
return Err(Error::InvalidInput {
|
||||
message:
|
||||
"An api_key or oauth_config is required when connecting to LanceDb Cloud"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let mut client_config = self.request.client_config;
|
||||
|
||||
// Apply OAuth header provider if configured
|
||||
if let Some(oauth_config) = self.oauth_config {
|
||||
let provider = crate::remote::oauth::OAuthHeaderProvider::new(oauth_config)?;
|
||||
client_config.header_provider =
|
||||
Some(Arc::new(provider) as Arc<dyn crate::remote::client::HeaderProvider>);
|
||||
}
|
||||
|
||||
let storage_options = StorageOptions(options.storage_options.clone());
|
||||
let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new(
|
||||
@@ -886,9 +921,8 @@ impl ConnectBuilder {
|
||||
&api_key,
|
||||
®ion,
|
||||
options.host_override,
|
||||
self.request.client_config,
|
||||
client_config,
|
||||
storage_options.into(),
|
||||
self.request.read_consistency_interval,
|
||||
)?);
|
||||
Ok(Connection {
|
||||
internal,
|
||||
|
||||
@@ -271,26 +271,15 @@ impl Scannable for WithEmbeddingsScannable {
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Task panicked during embedding computation: {}", e),
|
||||
})??;
|
||||
// Look up columns by name (not position) so the result matches
|
||||
// the output schema even when columns appear in a different
|
||||
// order — e.g. `add_columns` placed a new column after the
|
||||
// embedding column, but the computed batch appends embeddings
|
||||
// at the end. Cast per-column because field metadata (e.g.
|
||||
// nested nullability) may also differ between the embedding
|
||||
// function output and the table.
|
||||
let columns: Vec<ArrayRef> = output_schema
|
||||
.fields()
|
||||
// Cast columns to match the declared output schema. The data is
|
||||
// identical but field metadata (e.g. nested nullability) may
|
||||
// differ between the embedding function output and the table.
|
||||
let columns: Vec<ArrayRef> = result
|
||||
.columns()
|
||||
.iter()
|
||||
.map(|field| {
|
||||
let col = result.column_by_name(field.name()).ok_or_else(|| {
|
||||
Error::InvalidInput {
|
||||
message: format!(
|
||||
"Column '{}' required by the table schema was not present in the input batch",
|
||||
field.name()
|
||||
),
|
||||
}
|
||||
})?;
|
||||
let target_type = field.data_type();
|
||||
.enumerate()
|
||||
.map(|(i, col)| {
|
||||
let target_type = output_schema.field(i).data_type();
|
||||
if col.data_type() == target_type {
|
||||
Ok(col.clone())
|
||||
} else {
|
||||
@@ -975,118 +964,5 @@ mod tests {
|
||||
"Expected EmbeddingFunctionNotFound"
|
||||
);
|
||||
}
|
||||
|
||||
/// Regression test for https://github.com/lancedb/lancedb/issues/3136.
|
||||
///
|
||||
/// When a column is added to the table after the embedding column via
|
||||
/// schema evolution, the table schema becomes
|
||||
/// `[..., embedding, extra]`. The input batch (without the embedding)
|
||||
/// is `[..., extra]`, and `compute_embeddings_for_batch` appends the
|
||||
/// embedding at the end giving `[..., extra, embedding]`. A positional
|
||||
/// cast to the output schema would map `extra` onto `embedding` and
|
||||
/// fail with a CastError. Columns must be matched by name.
|
||||
#[tokio::test]
|
||||
async fn test_with_embeddings_scannable_column_added_after_embedding() {
|
||||
let input_schema = Arc::new(Schema::new(vec![
|
||||
Field::new("text", DataType::Utf8, false),
|
||||
Field::new("score", DataType::Float64, true),
|
||||
]));
|
||||
let batch = RecordBatch::try_new(
|
||||
input_schema.clone(),
|
||||
vec![
|
||||
Arc::new(StringArray::from(vec!["hello", "world"])) as ArrayRef,
|
||||
Arc::new(arrow_array::Float64Array::from(vec![1.0, 2.0])) as ArrayRef,
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let mock_embedding: Arc<dyn EmbeddingFunction> = Arc::new(MockEmbed::new("mock", 4));
|
||||
let embedding_def = EmbeddingDefinition::new("text", "mock", Some("text_vec"));
|
||||
|
||||
// Table schema: embedding column is BEFORE `score`, as would
|
||||
// happen if `score` was added via `add_columns` after creating
|
||||
// the table with an embedding on `text`.
|
||||
let output_schema = Arc::new(Schema::new(vec![
|
||||
Field::new("text", DataType::Utf8, false),
|
||||
Field::new(
|
||||
"text_vec",
|
||||
DataType::FixedSizeList(
|
||||
Arc::new(Field::new("item", DataType::Float32, true)),
|
||||
4,
|
||||
),
|
||||
false,
|
||||
),
|
||||
Field::new("score", DataType::Float64, true),
|
||||
]));
|
||||
|
||||
let mut scannable = WithEmbeddingsScannable::with_schema(
|
||||
Box::new(batch),
|
||||
vec![(embedding_def, mock_embedding)],
|
||||
output_schema.clone(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let stream = scannable.scan_as_stream();
|
||||
let results: Vec<RecordBatch> = stream.try_collect().await.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
|
||||
let result_batch = &results[0];
|
||||
assert_eq!(result_batch.schema(), output_schema);
|
||||
assert_eq!(result_batch.num_rows(), 2);
|
||||
// Position 1 must actually hold the FixedSizeList embedding —
|
||||
// not the score column reinterpreted by a permissive cast.
|
||||
let embedding = result_batch
|
||||
.column(1)
|
||||
.as_any()
|
||||
.downcast_ref::<arrow_array::FixedSizeListArray>()
|
||||
.expect("position 1 should be a FixedSizeList embedding");
|
||||
assert_eq!(embedding.value_length(), 4);
|
||||
assert_eq!(embedding.null_count(), 0);
|
||||
}
|
||||
|
||||
/// If the input batch is missing a non-embedding column required by
|
||||
/// the table schema, we should return a clear error rather than
|
||||
/// silently producing a malformed batch.
|
||||
#[tokio::test]
|
||||
async fn test_with_embeddings_scannable_missing_required_column() {
|
||||
let input_schema =
|
||||
Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)]));
|
||||
let batch = RecordBatch::try_new(
|
||||
input_schema,
|
||||
vec![Arc::new(StringArray::from(vec!["hello", "world"])) as ArrayRef],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let mock_embedding: Arc<dyn EmbeddingFunction> = Arc::new(MockEmbed::new("mock", 4));
|
||||
let embedding_def = EmbeddingDefinition::new("text", "mock", Some("text_vec"));
|
||||
|
||||
let output_schema = Arc::new(Schema::new(vec![
|
||||
Field::new("text", DataType::Utf8, false),
|
||||
Field::new(
|
||||
"text_vec",
|
||||
DataType::FixedSizeList(
|
||||
Arc::new(Field::new("item", DataType::Float32, true)),
|
||||
4,
|
||||
),
|
||||
false,
|
||||
),
|
||||
Field::new("score", DataType::Float64, true),
|
||||
]));
|
||||
|
||||
let mut scannable = WithEmbeddingsScannable::with_schema(
|
||||
Box::new(batch),
|
||||
vec![(embedding_def, mock_embedding)],
|
||||
output_schema,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let stream = scannable.scan_as_stream();
|
||||
let results: Result<Vec<RecordBatch>> = stream.try_collect().await;
|
||||
let err = results.expect_err("expected an error");
|
||||
assert!(
|
||||
matches!(&err, Error::InvalidInput { message } if message.contains("score")),
|
||||
"expected InvalidInput about missing 'score' column, got: {err:?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -464,9 +464,11 @@ mod tests {
|
||||
let mut iter = ids.into_iter().map(|o| o.unwrap());
|
||||
while let Some(first) = iter.next() {
|
||||
let rows_left_in_clump = if first == 4470 { 19 } else { 29 };
|
||||
for expected_next in (first + 1)..=(first + rows_left_in_clump) {
|
||||
let mut expected_next = first + 1;
|
||||
for _ in 0..rows_left_in_clump {
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next, expected_next);
|
||||
expected_next += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,12 +23,17 @@ impl VectorIndex {
|
||||
.fields
|
||||
.iter()
|
||||
.map(|field_id| {
|
||||
manifest.schema.field_path(*field_id).unwrap_or_else(|_| {
|
||||
panic!(
|
||||
"field {field_id} of index {} must exist in schema",
|
||||
index.name
|
||||
)
|
||||
})
|
||||
manifest
|
||||
.schema
|
||||
.field_by_id(*field_id)
|
||||
.unwrap_or_else(|| {
|
||||
panic!(
|
||||
"field {field_id} of index {} must exist in schema",
|
||||
index.name
|
||||
)
|
||||
})
|
||||
.name
|
||||
.clone()
|
||||
})
|
||||
.collect();
|
||||
Self {
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
pub(crate) mod client;
|
||||
pub(crate) mod db;
|
||||
pub mod oauth;
|
||||
mod retry;
|
||||
pub(crate) mod table;
|
||||
pub(crate) mod util;
|
||||
@@ -20,3 +21,4 @@ const JSON_CONTENT_TYPE: &str = "application/json";
|
||||
|
||||
pub use client::{ClientConfig, HeaderProvider, RetryConfig, TimeoutConfig, TlsConfig};
|
||||
pub use db::{RemoteDatabaseOptions, RemoteDatabaseOptionsBuilder};
|
||||
pub use oauth::{OAuthConfig, OAuthFlow, OAuthHeaderProvider};
|
||||
|
||||
@@ -245,9 +245,6 @@ pub struct RestfulLanceDbClient<S: HttpSend = Sender> {
|
||||
pub(crate) sender: S,
|
||||
pub(crate) id_delimiter: String,
|
||||
pub(crate) header_provider: Option<Arc<dyn HeaderProvider>>,
|
||||
/// Connection-level read consistency interval. Drives the
|
||||
/// `x-lancedb-min-timestamp` freshness header sent on read requests.
|
||||
pub(crate) read_consistency_interval: Option<Duration>,
|
||||
}
|
||||
|
||||
impl<S: HttpSend> std::fmt::Debug for RestfulLanceDbClient<S> {
|
||||
@@ -341,7 +338,6 @@ impl RestfulLanceDbClient<Sender> {
|
||||
host_override: Option<String>,
|
||||
default_headers: HeaderMap,
|
||||
client_config: ClientConfig,
|
||||
read_consistency_interval: Option<Duration>,
|
||||
) -> Result<Self> {
|
||||
// Get the timeouts
|
||||
let timeout =
|
||||
@@ -439,7 +435,6 @@ impl RestfulLanceDbClient<Sender> {
|
||||
.clone()
|
||||
.unwrap_or("$".to_string()),
|
||||
header_provider: client_config.header_provider,
|
||||
read_consistency_interval,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -845,16 +840,6 @@ pub mod test_utils {
|
||||
pub fn client_with_handler<T>(
|
||||
handler: impl Fn(reqwest::Request) -> http::response::Response<T> + Send + Sync + 'static,
|
||||
) -> RestfulLanceDbClient<MockSender>
|
||||
where
|
||||
T: Into<reqwest::Body>,
|
||||
{
|
||||
client_with_handler_and_interval(handler, None)
|
||||
}
|
||||
|
||||
pub fn client_with_handler_and_interval<T>(
|
||||
handler: impl Fn(reqwest::Request) -> http::response::Response<T> + Send + Sync + 'static,
|
||||
read_consistency_interval: Option<Duration>,
|
||||
) -> RestfulLanceDbClient<MockSender>
|
||||
where
|
||||
T: Into<reqwest::Body>,
|
||||
{
|
||||
@@ -872,7 +857,6 @@ pub mod test_utils {
|
||||
},
|
||||
id_delimiter: "$".to_string(),
|
||||
header_provider: None,
|
||||
read_consistency_interval,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -897,7 +881,6 @@ pub mod test_utils {
|
||||
},
|
||||
id_delimiter: config.id_delimiter.unwrap_or_else(|| "$".to_string()),
|
||||
header_provider: config.header_provider,
|
||||
read_consistency_interval: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -905,18 +888,8 @@ pub mod test_utils {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serial_test::serial;
|
||||
use std::time::Duration;
|
||||
|
||||
// Serializes the env-var-mutating tests below: cargo test runs tests in
|
||||
// parallel, but several of these tests read and write the same process-
|
||||
// global env vars (`LANCEDB_USER_ID*`), so they would race without this.
|
||||
static ENV_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(());
|
||||
|
||||
fn lock_env() -> std::sync::MutexGuard<'static, ()> {
|
||||
ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timeout_config_default() {
|
||||
let config = TimeoutConfig::default();
|
||||
@@ -1073,7 +1046,6 @@ mod tests {
|
||||
sender: Sender,
|
||||
id_delimiter: "+".to_string(),
|
||||
header_provider: Some(Arc::new(provider) as Arc<dyn HeaderProvider>),
|
||||
read_consistency_interval: None,
|
||||
};
|
||||
|
||||
// Apply dynamic headers
|
||||
@@ -1109,7 +1081,6 @@ mod tests {
|
||||
sender: Sender,
|
||||
id_delimiter: "+".to_string(),
|
||||
header_provider: Some(Arc::new(provider) as Arc<dyn HeaderProvider>),
|
||||
read_consistency_interval: None,
|
||||
};
|
||||
|
||||
// Apply dynamic headers
|
||||
@@ -1147,7 +1118,6 @@ mod tests {
|
||||
sender: Sender,
|
||||
id_delimiter: "+".to_string(),
|
||||
header_provider: Some(Arc::new(provider) as Arc<dyn HeaderProvider>),
|
||||
read_consistency_interval: None,
|
||||
};
|
||||
|
||||
// Header provider errors should fail the request
|
||||
@@ -1173,9 +1143,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial(user_id_env)]
|
||||
fn test_resolve_user_id_none() {
|
||||
let _guard = lock_env();
|
||||
let config = ClientConfig::default();
|
||||
// Clear env vars that might be set from other tests
|
||||
// SAFETY: This is only called in tests
|
||||
@@ -1187,9 +1155,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial(user_id_env)]
|
||||
fn test_resolve_user_id_from_env() {
|
||||
let _guard = lock_env();
|
||||
// SAFETY: This is only called in tests
|
||||
unsafe {
|
||||
std::env::set_var("LANCEDB_USER_ID", "env-user-id");
|
||||
@@ -1203,9 +1169,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial(user_id_env)]
|
||||
fn test_resolve_user_id_from_env_key() {
|
||||
let _guard = lock_env();
|
||||
// SAFETY: This is only called in tests
|
||||
unsafe {
|
||||
std::env::remove_var("LANCEDB_USER_ID");
|
||||
@@ -1225,9 +1189,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial(user_id_env)]
|
||||
fn test_resolve_user_id_direct_takes_precedence() {
|
||||
let _guard = lock_env();
|
||||
// SAFETY: This is only called in tests
|
||||
unsafe {
|
||||
std::env::set_var("LANCEDB_USER_ID", "env-user-id");
|
||||
@@ -1244,9 +1206,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial(user_id_env)]
|
||||
fn test_resolve_user_id_empty_env_ignored() {
|
||||
let _guard = lock_env();
|
||||
// SAFETY: This is only called in tests
|
||||
unsafe {
|
||||
std::env::set_var("LANCEDB_USER_ID", "");
|
||||
|
||||
@@ -206,7 +206,6 @@ impl RemoteDatabase {
|
||||
host_override: Option<String>,
|
||||
client_config: ClientConfig,
|
||||
options: RemoteOptions,
|
||||
read_consistency_interval: Option<std::time::Duration>,
|
||||
) -> Result<Self> {
|
||||
let parsed = super::client::parse_db_url(uri)?;
|
||||
let header_map = RestfulLanceDbClient::<Sender>::default_headers(
|
||||
@@ -234,7 +233,6 @@ impl RemoteDatabase {
|
||||
host_override,
|
||||
header_map,
|
||||
client_config.clone(),
|
||||
read_consistency_interval,
|
||||
)?;
|
||||
|
||||
let table_cache = Cache::builder()
|
||||
|
||||
906
rust/lancedb/src/remote/oauth.rs
Normal file
906
rust/lancedb/src/remote/oauth.rs
Normal file
@@ -0,0 +1,906 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use log::{debug, info, warn};
|
||||
use reqwest::Client;
|
||||
use serde::Deserialize;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::remote::client::HeaderProvider;
|
||||
|
||||
const DEFAULT_REFRESH_BUFFER_SECS: u64 = 300;
|
||||
const DEFAULT_CALLBACK_PORT: u16 = 8400;
|
||||
const AZURE_IMDS_ENDPOINT: &str = "http://169.254.169.254/metadata/identity/oauth2/token";
|
||||
const AZURE_IMDS_API_VERSION: &str = "2018-02-01";
|
||||
|
||||
/// OAuth authentication flow configuration.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum OAuthFlow {
|
||||
/// Client Credentials grant (service-to-service / M2M).
|
||||
/// Requires `client_secret` in [`OAuthConfig`].
|
||||
ClientCredentials,
|
||||
|
||||
/// Authorization Code with PKCE (interactive browser-based auth).
|
||||
AuthorizationCodePKCE {
|
||||
/// Redirect URI (default: `http://localhost:{callback_port}/callback`)
|
||||
redirect_uri: Option<String>,
|
||||
/// Port for the local HTTP callback server (default: 8400)
|
||||
callback_port: Option<u16>,
|
||||
},
|
||||
|
||||
/// Device Code grant (CLI / headless environments).
|
||||
/// Displays a verification URL and user code for out-of-band authentication.
|
||||
DeviceCode,
|
||||
|
||||
/// Azure Managed Identity via IMDS.
|
||||
/// Works on Azure VMs, AKS, App Service, and Azure Functions.
|
||||
AzureManagedIdentity {
|
||||
/// Client ID for user-assigned managed identity.
|
||||
/// Omit for system-assigned managed identity.
|
||||
client_id: Option<String>,
|
||||
},
|
||||
|
||||
/// Workload Identity Federation.
|
||||
/// Exchanges a platform token (K8s service account, GitHub OIDC) for an IdP token.
|
||||
WorkloadIdentity {
|
||||
/// Path to the federated token file
|
||||
/// (e.g. `AZURE_FEDERATED_TOKEN_FILE`).
|
||||
token_file: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// OAuth configuration for LanceDB authentication.
|
||||
///
|
||||
/// All token acquisition and refresh is handled in the Rust layer.
|
||||
/// Python and TypeScript bindings expose this as a plain config object.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OAuthConfig {
|
||||
/// OIDC issuer URL or OAuth authority URL.
|
||||
/// For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
|
||||
pub issuer_url: String,
|
||||
|
||||
/// Application / Client ID.
|
||||
pub client_id: String,
|
||||
|
||||
/// Client secret (required for `ClientCredentials`, optional for others).
|
||||
pub client_secret: Option<String>,
|
||||
|
||||
/// OAuth scopes to request.
|
||||
/// For Azure: `["api://{app_id}/.default"]`
|
||||
pub scopes: Vec<String>,
|
||||
|
||||
/// Authentication flow to use.
|
||||
pub flow: OAuthFlow,
|
||||
|
||||
/// Seconds before token expiry to trigger proactive refresh (default: 300).
|
||||
pub refresh_buffer_secs: Option<u64>,
|
||||
}
|
||||
|
||||
// -- OIDC Discovery --
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OidcDiscovery {
|
||||
token_endpoint: String,
|
||||
authorization_endpoint: Option<String>,
|
||||
device_authorization_endpoint: Option<String>,
|
||||
}
|
||||
|
||||
// -- Token Response --
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TokenResponse {
|
||||
access_token: String,
|
||||
#[serde(default)]
|
||||
refresh_token: Option<String>,
|
||||
/// Token lifetime in seconds.
|
||||
/// Some providers (Azure IMDS) return this as a string, so we accept both.
|
||||
#[serde(default, deserialize_with = "deserialize_optional_u64_or_string")]
|
||||
expires_in: Option<u64>,
|
||||
#[serde(default)]
|
||||
#[allow(dead_code)]
|
||||
token_type: Option<String>,
|
||||
}
|
||||
|
||||
fn deserialize_optional_u64_or_string<'de, D>(
|
||||
deserializer: D,
|
||||
) -> std::result::Result<Option<u64>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
use serde::de;
|
||||
|
||||
struct U64OrString;
|
||||
impl<'de> de::Visitor<'de> for U64OrString {
|
||||
type Value = Option<u64>;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("a u64, a numeric string, or null")
|
||||
}
|
||||
|
||||
fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
|
||||
Ok(Some(v))
|
||||
}
|
||||
|
||||
fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
|
||||
Ok(Some(v as u64))
|
||||
}
|
||||
|
||||
fn visit_str<E: de::Error>(self, v: &str) -> std::result::Result<Self::Value, E> {
|
||||
v.parse::<u64>().map(Some).map_err(de::Error::custom)
|
||||
}
|
||||
|
||||
fn visit_none<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn visit_unit<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_any(U64OrString)
|
||||
}
|
||||
|
||||
// -- Device Code Response --
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DeviceCodeResponse {
|
||||
device_code: String,
|
||||
user_code: String,
|
||||
verification_uri: String,
|
||||
#[serde(default)]
|
||||
verification_uri_complete: Option<String>,
|
||||
expires_in: u64,
|
||||
interval: Option<u64>,
|
||||
}
|
||||
|
||||
// -- Internal Token State --
|
||||
|
||||
struct TokenState {
|
||||
access_token: Option<String>,
|
||||
refresh_token: Option<String>,
|
||||
expires_at: Option<Instant>,
|
||||
}
|
||||
|
||||
impl TokenState {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
access_token: None,
|
||||
refresh_token: None,
|
||||
expires_at: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_expired(&self, buffer: Duration) -> bool {
|
||||
match (self.access_token.as_ref(), self.expires_at) {
|
||||
(Some(_), Some(expires_at)) => Instant::now() + buffer >= expires_at,
|
||||
(None, _) => true,
|
||||
(Some(_), None) => false, // no expiry info, assume valid
|
||||
}
|
||||
}
|
||||
|
||||
fn update(&mut self, resp: &TokenResponse) {
|
||||
self.access_token = Some(resp.access_token.clone());
|
||||
if resp.refresh_token.is_some() {
|
||||
self.refresh_token = resp.refresh_token.clone();
|
||||
}
|
||||
self.expires_at = resp
|
||||
.expires_in
|
||||
.map(|secs| Instant::now() + Duration::from_secs(secs));
|
||||
}
|
||||
}
|
||||
|
||||
/// OAuth header provider that manages the full token lifecycle.
|
||||
///
|
||||
/// Implements [`HeaderProvider`] to inject `Authorization: Bearer <token>`
|
||||
/// headers into every LanceDB request, with automatic token refresh.
|
||||
pub struct OAuthHeaderProvider {
|
||||
config: OAuthConfig,
|
||||
http_client: Client,
|
||||
token_state: Arc<RwLock<TokenState>>,
|
||||
/// Cached OIDC discovery document
|
||||
discovery: Arc<RwLock<Option<OidcDiscovery>>>,
|
||||
refresh_buffer: Duration,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for OAuthHeaderProvider {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("OAuthHeaderProvider")
|
||||
.field("issuer_url", &self.config.issuer_url)
|
||||
.field("client_id", &self.config.client_id)
|
||||
.field("flow", &self.config.flow)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl OAuthHeaderProvider {
|
||||
/// Create a new OAuth header provider from configuration.
|
||||
pub fn new(config: OAuthConfig) -> Result<Self> {
|
||||
// Validate config upfront
|
||||
if matches!(config.flow, OAuthFlow::ClientCredentials) && config.client_secret.is_none() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "client_secret is required for ClientCredentials flow".to_string(),
|
||||
});
|
||||
}
|
||||
if config.scopes.is_empty() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "At least one OAuth scope is required".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let http_client = Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to create HTTP client for OAuth: {e}"),
|
||||
})?;
|
||||
|
||||
let refresh_buffer = Duration::from_secs(
|
||||
config
|
||||
.refresh_buffer_secs
|
||||
.unwrap_or(DEFAULT_REFRESH_BUFFER_SECS),
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
http_client,
|
||||
token_state: Arc::new(RwLock::new(TokenState::new())),
|
||||
discovery: Arc::new(RwLock::new(None)),
|
||||
refresh_buffer,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get a valid access token, refreshing if necessary.
|
||||
async fn get_valid_token(&self) -> Result<String> {
|
||||
// Fast path: check if current token is still valid
|
||||
{
|
||||
let state = self.token_state.read().await;
|
||||
if !state.is_expired(self.refresh_buffer)
|
||||
&& let Some(ref token) = state.access_token
|
||||
{
|
||||
return Ok(token.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Slow path: acquire or refresh token
|
||||
let mut state = self.token_state.write().await;
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if !state.is_expired(self.refresh_buffer)
|
||||
&& let Some(ref token) = state.access_token
|
||||
{
|
||||
return Ok(token.clone());
|
||||
}
|
||||
|
||||
let uses_refresh_token = !matches!(
|
||||
self.config.flow,
|
||||
OAuthFlow::ClientCredentials
|
||||
| OAuthFlow::AzureManagedIdentity { .. }
|
||||
| OAuthFlow::WorkloadIdentity { .. }
|
||||
);
|
||||
|
||||
let resp = if let Some(ref refresh_token) = state.refresh_token
|
||||
&& uses_refresh_token
|
||||
{
|
||||
debug!("Refreshing OAuth token using refresh_token");
|
||||
self.refresh_with_token(refresh_token).await?
|
||||
} else {
|
||||
debug!("Acquiring new OAuth token via {:?} flow", self.config.flow);
|
||||
self.acquire_token().await?
|
||||
};
|
||||
|
||||
state.update(&resp);
|
||||
Ok(resp.access_token)
|
||||
}
|
||||
|
||||
/// Acquire a new token using the configured flow.
|
||||
async fn acquire_token(&self) -> Result<TokenResponse> {
|
||||
match &self.config.flow {
|
||||
OAuthFlow::ClientCredentials => self.acquire_client_credentials().await,
|
||||
OAuthFlow::AuthorizationCodePKCE {
|
||||
redirect_uri,
|
||||
callback_port,
|
||||
} => {
|
||||
self.acquire_authorization_code_pkce(
|
||||
redirect_uri.as_deref(),
|
||||
callback_port.unwrap_or(DEFAULT_CALLBACK_PORT),
|
||||
)
|
||||
.await
|
||||
}
|
||||
OAuthFlow::DeviceCode => self.acquire_device_code().await,
|
||||
OAuthFlow::AzureManagedIdentity { client_id } => {
|
||||
self.acquire_managed_identity(client_id.as_deref()).await
|
||||
}
|
||||
OAuthFlow::WorkloadIdentity { token_file } => {
|
||||
self.acquire_workload_identity(token_file).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -- OIDC Discovery --
|
||||
|
||||
async fn get_discovery(&self) -> Result<OidcDiscovery> {
|
||||
{
|
||||
let cached = self.discovery.read().await;
|
||||
if let Some(ref disc) = *cached {
|
||||
return Ok(OidcDiscovery {
|
||||
token_endpoint: disc.token_endpoint.clone(),
|
||||
authorization_endpoint: disc.authorization_endpoint.clone(),
|
||||
device_authorization_endpoint: disc.device_authorization_endpoint.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let mut cache = self.discovery.write().await;
|
||||
// Double-check
|
||||
if let Some(ref disc) = *cache {
|
||||
return Ok(OidcDiscovery {
|
||||
token_endpoint: disc.token_endpoint.clone(),
|
||||
authorization_endpoint: disc.authorization_endpoint.clone(),
|
||||
device_authorization_endpoint: disc.device_authorization_endpoint.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
let discovery_url = format!(
|
||||
"{}/.well-known/openid-configuration",
|
||||
self.config.issuer_url.trim_end_matches('/')
|
||||
);
|
||||
|
||||
debug!("Fetching OIDC discovery from {}", discovery_url);
|
||||
|
||||
let resp = self
|
||||
.http_client
|
||||
.get(&discovery_url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to fetch OIDC discovery document: {e}"),
|
||||
})?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(Error::Runtime {
|
||||
message: format!(
|
||||
"OIDC discovery failed with status {}: {}",
|
||||
resp.status(),
|
||||
resp.text().await.unwrap_or_default()
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
let disc: OidcDiscovery = resp.json().await.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to parse OIDC discovery document: {e}"),
|
||||
})?;
|
||||
|
||||
let result = OidcDiscovery {
|
||||
token_endpoint: disc.token_endpoint.clone(),
|
||||
authorization_endpoint: disc.authorization_endpoint.clone(),
|
||||
device_authorization_endpoint: disc.device_authorization_endpoint.clone(),
|
||||
};
|
||||
|
||||
*cache = Some(disc);
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn get_token_endpoint_from_issuer(&self) -> String {
|
||||
// Derive Azure v2.0 token endpoint from issuer URL
|
||||
// issuer: https://login.microsoftonline.com/{tenant}/v2.0
|
||||
// token: https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token
|
||||
let base = self.config.issuer_url.trim_end_matches("/v2.0");
|
||||
format!("{base}/oauth2/v2.0/token")
|
||||
}
|
||||
|
||||
async fn get_token_endpoint(&self) -> Result<String> {
|
||||
match self.get_discovery().await {
|
||||
Ok(disc) => Ok(disc.token_endpoint),
|
||||
Err(e) => {
|
||||
warn!("OIDC discovery failed, falling back to derived endpoint: {e}");
|
||||
Ok(self.get_token_endpoint_from_issuer())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn scopes_string(&self) -> String {
|
||||
self.config.scopes.join(" ")
|
||||
}
|
||||
|
||||
// -- Client Credentials Flow --
|
||||
|
||||
async fn acquire_client_credentials(&self) -> Result<TokenResponse> {
|
||||
let client_secret = self
|
||||
.config
|
||||
.client_secret
|
||||
.as_ref()
|
||||
.ok_or(Error::InvalidInput {
|
||||
message: "client_secret is required for ClientCredentials flow".to_string(),
|
||||
})?;
|
||||
|
||||
let token_endpoint = self.get_token_endpoint().await?;
|
||||
|
||||
let params = [
|
||||
("grant_type", "client_credentials"),
|
||||
("client_id", &self.config.client_id),
|
||||
("client_secret", client_secret),
|
||||
("scope", &self.scopes_string()),
|
||||
];
|
||||
|
||||
self.post_token_request(&token_endpoint, ¶ms).await
|
||||
}
|
||||
|
||||
// -- Authorization Code + PKCE Flow --
|
||||
|
||||
async fn acquire_authorization_code_pkce(
|
||||
&self,
|
||||
redirect_uri: Option<&str>,
|
||||
callback_port: u16,
|
||||
) -> Result<TokenResponse> {
|
||||
use rand::Rng;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
let discovery = self.get_discovery().await?;
|
||||
let auth_endpoint = discovery.authorization_endpoint.ok_or(Error::Runtime {
|
||||
message: "OIDC discovery did not provide authorization_endpoint".to_string(),
|
||||
})?;
|
||||
|
||||
let default_redirect = format!("http://localhost:{callback_port}/callback");
|
||||
let redirect = redirect_uri.unwrap_or(&default_redirect);
|
||||
|
||||
// Generate PKCE code verifier and challenge (S256)
|
||||
const PKCE_CHARSET: &[u8] =
|
||||
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~";
|
||||
let code_verifier: String = {
|
||||
let mut rng = rand::rng();
|
||||
(0..128)
|
||||
.map(|_| {
|
||||
let idx = rng.random_range(0..PKCE_CHARSET.len());
|
||||
PKCE_CHARSET[idx] as char
|
||||
})
|
||||
.collect()
|
||||
};
|
||||
let code_challenge = {
|
||||
use sha2::{Digest, Sha256};
|
||||
let hash = Sha256::digest(code_verifier.as_bytes());
|
||||
base64_url_encode(&hash)
|
||||
};
|
||||
|
||||
let state: String = {
|
||||
let mut rng = rand::rng();
|
||||
(0..32)
|
||||
.map(|_| {
|
||||
let idx = rng.random_range(0..16u8);
|
||||
b"0123456789abcdef"[idx as usize] as char
|
||||
})
|
||||
.collect()
|
||||
};
|
||||
|
||||
// Build authorization URL
|
||||
let auth_url = format!(
|
||||
"{auth_endpoint}?response_type=code&client_id={}&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method=S256&state={state}",
|
||||
urlencoding::encode(&self.config.client_id),
|
||||
urlencoding::encode(redirect),
|
||||
urlencoding::encode(&self.scopes_string()),
|
||||
urlencoding::encode(&code_challenge),
|
||||
);
|
||||
|
||||
info!("Opening browser for OAuth login...");
|
||||
info!("If the browser doesn't open, visit: {auth_url}");
|
||||
|
||||
// Try to open browser
|
||||
let _ = open::that(&auth_url);
|
||||
|
||||
// Start local callback server
|
||||
let listener = TcpListener::bind(format!("127.0.0.1:{callback_port}"))
|
||||
.await
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to bind callback server on port {callback_port}: {e}"),
|
||||
})?;
|
||||
|
||||
info!("Waiting for OAuth callback on port {callback_port}...");
|
||||
|
||||
let (mut stream, _) = listener.accept().await.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to accept callback connection: {e}"),
|
||||
})?;
|
||||
|
||||
// Read the HTTP request
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let n = tokio::io::AsyncReadExt::read(&mut stream, &mut buf)
|
||||
.await
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to read callback request: {e}"),
|
||||
})?;
|
||||
let request_str = String::from_utf8_lossy(&buf[..n]);
|
||||
|
||||
// Extract authorization code from query params
|
||||
let code = extract_query_param(&request_str, "code").ok_or(Error::Runtime {
|
||||
message: "No authorization code in callback".to_string(),
|
||||
})?;
|
||||
|
||||
let returned_state = extract_query_param(&request_str, "state");
|
||||
if returned_state.as_deref() != Some(&state) {
|
||||
return Err(Error::Runtime {
|
||||
message: "OAuth state mismatch — possible CSRF attack".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Send success response to browser
|
||||
let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n<html><body><h2>Authentication successful!</h2><p>You can close this window.</p></body></html>";
|
||||
let _ = stream.write_all(response.as_bytes()).await;
|
||||
|
||||
// Exchange code for tokens
|
||||
let token_endpoint = self.get_token_endpoint().await?;
|
||||
let mut params = vec![
|
||||
("grant_type", "authorization_code"),
|
||||
("client_id", self.config.client_id.as_str()),
|
||||
("code", &code),
|
||||
("redirect_uri", redirect),
|
||||
("code_verifier", &code_verifier),
|
||||
];
|
||||
if let Some(ref secret) = self.config.client_secret {
|
||||
params.push(("client_secret", secret));
|
||||
}
|
||||
|
||||
self.post_token_request(&token_endpoint, ¶ms).await
|
||||
}
|
||||
|
||||
// -- Device Code Flow --
|
||||
|
||||
async fn acquire_device_code(&self) -> Result<TokenResponse> {
|
||||
let discovery = self.get_discovery().await?;
|
||||
let device_endpoint = discovery
|
||||
.device_authorization_endpoint
|
||||
.ok_or(Error::Runtime {
|
||||
message: "OIDC discovery did not provide device_authorization_endpoint".to_string(),
|
||||
})?;
|
||||
|
||||
let params = [
|
||||
("client_id", self.config.client_id.as_str()),
|
||||
("scope", &self.scopes_string()),
|
||||
];
|
||||
|
||||
let resp = self
|
||||
.http_client
|
||||
.post(&device_endpoint)
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Device code request failed: {e}"),
|
||||
})?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(Error::Runtime {
|
||||
message: format!(
|
||||
"Device code request failed with status {}: {}",
|
||||
resp.status(),
|
||||
resp.text().await.unwrap_or_default()
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
let device_resp: DeviceCodeResponse = resp.json().await.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to parse device code response: {e}"),
|
||||
})?;
|
||||
|
||||
// Display instructions to user
|
||||
info!(
|
||||
"To sign in, visit {} and enter code: {}",
|
||||
device_resp.verification_uri, device_resp.user_code
|
||||
);
|
||||
if let Some(ref uri) = device_resp.verification_uri_complete {
|
||||
info!("Or visit: {uri}");
|
||||
}
|
||||
|
||||
// Poll token endpoint
|
||||
let token_endpoint = self.get_token_endpoint().await?;
|
||||
let poll_interval = Duration::from_secs(device_resp.interval.unwrap_or(5));
|
||||
let deadline = Instant::now() + Duration::from_secs(device_resp.expires_in);
|
||||
|
||||
loop {
|
||||
if Instant::now() >= deadline {
|
||||
return Err(Error::Runtime {
|
||||
message: "Device code flow timed out waiting for user authentication"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
tokio::time::sleep(poll_interval).await;
|
||||
|
||||
let poll_params = [
|
||||
("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
|
||||
("client_id", self.config.client_id.as_str()),
|
||||
("device_code", &device_resp.device_code),
|
||||
];
|
||||
|
||||
let poll_resp = self
|
||||
.http_client
|
||||
.post(&token_endpoint)
|
||||
.form(&poll_params)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Device code poll failed: {e}"),
|
||||
})?;
|
||||
|
||||
if poll_resp.status().is_success() {
|
||||
return poll_resp.json().await.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to parse token response: {e}"),
|
||||
});
|
||||
}
|
||||
|
||||
// Check for pending / slow_down errors
|
||||
let body = poll_resp.text().await.unwrap_or_default();
|
||||
if body.contains("authorization_pending") {
|
||||
continue;
|
||||
}
|
||||
if body.contains("slow_down") {
|
||||
tokio::time::sleep(Duration::from_secs(5)).await;
|
||||
continue;
|
||||
}
|
||||
|
||||
return Err(Error::Runtime {
|
||||
message: format!("Device code poll failed: {body}"),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// -- Azure Managed Identity Flow --
|
||||
|
||||
async fn acquire_managed_identity(&self, mi_client_id: Option<&str>) -> Result<TokenResponse> {
|
||||
let resource = self.scopes_string().replace("/.default", "");
|
||||
|
||||
let mut url = format!(
|
||||
"{AZURE_IMDS_ENDPOINT}?api-version={AZURE_IMDS_API_VERSION}&resource={}",
|
||||
urlencoding::encode(&resource),
|
||||
);
|
||||
if let Some(cid) = mi_client_id {
|
||||
url.push_str(&format!("&client_id={}", urlencoding::encode(cid)));
|
||||
}
|
||||
|
||||
let resp = self
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Metadata", "true")
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Azure IMDS request failed: {e}"),
|
||||
})?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(Error::Runtime {
|
||||
message: format!(
|
||||
"Azure IMDS returned status {}: {}",
|
||||
resp.status(),
|
||||
resp.text().await.unwrap_or_default()
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
resp.json().await.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to parse IMDS token response: {e}"),
|
||||
})
|
||||
}
|
||||
|
||||
// -- Workload Identity Federation Flow --
|
||||
|
||||
async fn acquire_workload_identity(&self, token_file: &str) -> Result<TokenResponse> {
|
||||
let federated_token =
|
||||
tokio::fs::read_to_string(token_file)
|
||||
.await
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to read federated token file '{token_file}': {e}"),
|
||||
})?;
|
||||
|
||||
let token_endpoint = self.get_token_endpoint().await?;
|
||||
|
||||
let params = [
|
||||
("grant_type", "client_credentials"),
|
||||
("client_id", self.config.client_id.as_str()),
|
||||
(
|
||||
"client_assertion_type",
|
||||
"urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
),
|
||||
("client_assertion", federated_token.trim()),
|
||||
("scope", &self.scopes_string()),
|
||||
];
|
||||
|
||||
self.post_token_request(&token_endpoint, ¶ms).await
|
||||
}
|
||||
|
||||
// -- Refresh Token Flow --
|
||||
|
||||
async fn refresh_with_token(&self, refresh_token: &str) -> Result<TokenResponse> {
|
||||
let token_endpoint = self.get_token_endpoint().await?;
|
||||
|
||||
let mut params = vec![
|
||||
("grant_type", "refresh_token"),
|
||||
("client_id", self.config.client_id.as_str()),
|
||||
("refresh_token", refresh_token),
|
||||
];
|
||||
if let Some(ref secret) = self.config.client_secret {
|
||||
params.push(("client_secret", secret.as_str()));
|
||||
}
|
||||
|
||||
self.post_token_request(&token_endpoint, ¶ms).await
|
||||
}
|
||||
|
||||
// -- Shared Helpers --
|
||||
|
||||
async fn post_token_request(
|
||||
&self,
|
||||
endpoint: &str,
|
||||
params: &[(&str, &str)],
|
||||
) -> Result<TokenResponse> {
|
||||
let resp = self
|
||||
.http_client
|
||||
.post(endpoint)
|
||||
.form(params)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Token request to {endpoint} failed: {e}"),
|
||||
})?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(Error::Runtime {
|
||||
message: format!(
|
||||
"Token request failed with status {}: {}",
|
||||
resp.status(),
|
||||
resp.text().await.unwrap_or_default()
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
resp.json().await.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to parse token response: {e}"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl HeaderProvider for OAuthHeaderProvider {
|
||||
async fn get_headers(&self) -> Result<HashMap<String, String>> {
|
||||
let token = self.get_valid_token().await?;
|
||||
Ok(HashMap::from([(
|
||||
"authorization".to_string(),
|
||||
format!("Bearer {token}"),
|
||||
)]))
|
||||
}
|
||||
}
|
||||
|
||||
// -- Utility functions --
|
||||
|
||||
fn base64_url_encode(input: &[u8]) -> String {
|
||||
use base64::Engine;
|
||||
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(input)
|
||||
}
|
||||
|
||||
/// Extract a query parameter value from a raw HTTP GET request line.
|
||||
fn extract_query_param(request: &str, param: &str) -> Option<String> {
|
||||
let first_line = request.lines().next()?;
|
||||
let path = first_line.split_whitespace().nth(1)?;
|
||||
let query = path.split('?').nth(1)?;
|
||||
for pair in query.split('&') {
|
||||
let mut kv = pair.splitn(2, '=');
|
||||
if let (Some(key), Some(value)) = (kv.next(), kv.next())
|
||||
&& key == param
|
||||
{
|
||||
return Some(urlencoding::decode(value).ok()?.into_owned());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_extract_query_param() {
|
||||
let request = "GET /callback?code=abc123&state=xyz HTTP/1.1\r\nHost: localhost\r\n";
|
||||
assert_eq!(
|
||||
extract_query_param(request, "code"),
|
||||
Some("abc123".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
extract_query_param(request, "state"),
|
||||
Some("xyz".to_string())
|
||||
);
|
||||
assert_eq!(extract_query_param(request, "missing"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_query_param_encoded() {
|
||||
let request = "GET /callback?code=abc%20123&state=x%26y HTTP/1.1\r\n";
|
||||
assert_eq!(
|
||||
extract_query_param(request, "code"),
|
||||
Some("abc 123".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
extract_query_param(request, "state"),
|
||||
Some("x&y".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_state_expiry() {
|
||||
let mut state = TokenState::new();
|
||||
assert!(state.is_expired(Duration::from_secs(0)));
|
||||
|
||||
state.access_token = Some("tok".to_string());
|
||||
state.expires_at = Some(Instant::now() + Duration::from_secs(600));
|
||||
assert!(!state.is_expired(Duration::from_secs(300)));
|
||||
assert!(state.is_expired(Duration::from_secs(601)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_base64_url_encode() {
|
||||
let input = b"hello world";
|
||||
let encoded = base64_url_encode(input);
|
||||
assert!(!encoded.contains('+'));
|
||||
assert!(!encoded.contains('/'));
|
||||
assert!(!encoded.contains('='));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scopes_string() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
|
||||
client_id: "app-id".to_string(),
|
||||
client_secret: Some("secret".to_string()),
|
||||
scopes: vec!["scope1".to_string(), "scope2".to_string()],
|
||||
flow: OAuthFlow::ClientCredentials,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
let provider = OAuthHeaderProvider::new(config).unwrap();
|
||||
assert_eq!(provider.scopes_string(), "scope1 scope2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_endpoint_derivation() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "https://login.microsoftonline.com/my-tenant/v2.0".to_string(),
|
||||
client_id: "id".to_string(),
|
||||
client_secret: None,
|
||||
scopes: vec!["api://test/.default".to_string()],
|
||||
flow: OAuthFlow::DeviceCode,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
let provider = OAuthHeaderProvider::new(config).unwrap();
|
||||
assert_eq!(
|
||||
provider.get_token_endpoint_from_issuer(),
|
||||
"https://login.microsoftonline.com/my-tenant/oauth2/v2.0/token"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_client_credentials_requires_secret() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
|
||||
client_id: "app-id".to_string(),
|
||||
client_secret: None,
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: OAuthFlow::ClientCredentials,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
assert!(OAuthHeaderProvider::new(config).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_scopes_rejected() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
|
||||
client_id: "app-id".to_string(),
|
||||
client_secret: None,
|
||||
scopes: vec![],
|
||||
flow: OAuthFlow::DeviceCode,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
assert!(OAuthHeaderProvider::new(config).is_err());
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user