Compare commits

..

2 Commits

Author SHA1 Message Date
BubbleCal
929c683a6e Handle version number refs 2025-12-22 16:53:16 +08:00
lancedb automation
e2794d1a29 chore: update lance dependency to v2.0.0-beta.3 2025-12-19 22:04:10 +00:00
81 changed files with 549 additions and 2528 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.24.1"
current_version = "0.23.1-beta.1"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -3,7 +3,7 @@ name: build-linux-wheel
description: "Build a manylinux wheel for lance"
inputs:
python-minor-version:
description: "10, 11, 12, 13"
description: "8, 9, 10, 11, 12"
required: true
args:
description: "--release"

View File

@@ -3,7 +3,7 @@ name: build_wheel
description: "Build a lance wheel"
inputs:
python-minor-version:
description: "10, 11, 12, 13"
description: "8, 9, 10, 11"
required: true
args:
description: "--release"

View File

@@ -3,7 +3,7 @@ name: build_wheel
description: "Build a lance wheel"
inputs:
python-minor-version:
description: "10, 11, 12, 13, 14"
description: "8, 9, 10, 11"
required: true
args:
description: "--release"

View File

@@ -75,13 +75,6 @@ jobs:
VERSION="${VERSION#v}"
BRANCH_NAME="codex/update-lance-${VERSION//[^a-zA-Z0-9]/-}"
# Use "chore" for beta/rc versions, "feat" for stable releases
if [[ "${VERSION}" == *beta* ]] || [[ "${VERSION}" == *rc* ]]; then
COMMIT_TYPE="chore"
else
COMMIT_TYPE="feat"
fi
cat <<EOF >/tmp/codex-prompt.txt
You are running inside the lancedb repository on a GitHub Actions runner. Update the Lance dependency to version ${VERSION} and prepare a pull request for maintainers to review.
@@ -91,10 +84,10 @@ jobs:
3. After clippy succeeds, run "cargo fmt --all" to format the workspace.
4. Ensure the repository is clean except for intentional changes. Inspect "git status --short" and "git diff" to confirm the dependency update and any required fixes.
5. Create and switch to a new branch named "${BRANCH_NAME}" (replace any duplicated hyphens if necessary).
6. Stage all relevant files with "git add -A". Commit using the message "${COMMIT_TYPE}: update lance dependency to v${VERSION}".
6. Stage all relevant files with "git add -A". Commit using the message "chore: update lance dependency to v${VERSION}".
7. Push the branch to origin. If the branch already exists, force-push your changes.
8. env "GH_TOKEN" is available, use "gh" tools for github related operations like creating pull request.
9. Create a pull request targeting "main" with title "${COMMIT_TYPE}: update lance dependency to v${VERSION}". First, write the PR body to /tmp/pr-body.md using a heredoc (cat <<'EOF' > /tmp/pr-body.md). The body should summarize the dependency bump, clippy/fmt verification, and link the triggering tag (${TAG}). Then run "gh pr create --body-file /tmp/pr-body.md".
9. Create a pull request targeting "main" with title "chore: update lance dependency to v${VERSION}". In the body, summarize the dependency bump, clippy/fmt verification, and link the triggering tag (${TAG}).
10. After creating the PR, display the PR URL, "git status --short", and a concise summary of the commands run and their results.
Constraints:

View File

@@ -41,7 +41,7 @@ jobs:
sudo apt install -y protobuf-compiler libssl-dev
rustup update && rustup default
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v5
with:
python-version: "3.10"
cache: "pip"

View File

@@ -44,12 +44,12 @@ jobs:
fetch-depth: 0
lfs: true
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: 3.8
- uses: ./.github/workflows/build_linux_wheel
with:
python-minor-version: 10
python-minor-version: 8
args: "--release --strip ${{ matrix.config.extra_args }}"
arm-build: ${{ matrix.config.platform == 'aarch64' }}
manylinux: ${{ matrix.config.manylinux }}
@@ -74,12 +74,12 @@ jobs:
fetch-depth: 0
lfs: true
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: "3.13"
python-version: 3.12
- uses: ./.github/workflows/build_mac_wheel
with:
python-minor-version: 10
python-minor-version: 8
args: "--release --strip --target ${{ matrix.config.target }} --features fp16kernels"
- uses: ./.github/workflows/upload_wheel
if: startsWith(github.ref, 'refs/tags/python-v')
@@ -95,12 +95,12 @@ jobs:
fetch-depth: 0
lfs: true
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: "3.13"
python-version: 3.12
- uses: ./.github/workflows/build_windows_wheel
with:
python-minor-version: 10
python-minor-version: 8
args: "--release --strip"
vcpkg_token: ${{ secrets.VCPKG_GITHUB_PACKAGES }}
- uses: ./.github/workflows/upload_wheel

View File

@@ -36,9 +36,9 @@ jobs:
fetch-depth: 0
lfs: true
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v5
with:
python-version: "3.13"
python-version: "3.12"
- name: Install ruff
run: |
pip install ruff==0.9.9
@@ -61,9 +61,9 @@ jobs:
fetch-depth: 0
lfs: true
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v5
with:
python-version: "3.13"
python-version: "3.12"
- name: Install protobuf compiler
run: |
sudo apt update
@@ -90,9 +90,9 @@ jobs:
fetch-depth: 0
lfs: true
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v5
with:
python-version: "3.13"
python-version: "3.12"
cache: "pip"
- name: Install protobuf
run: |
@@ -110,7 +110,7 @@ jobs:
timeout-minutes: 30
strategy:
matrix:
python-minor-version: ["10", "13"]
python-minor-version: ["9", "12"]
runs-on: "ubuntu-24.04"
defaults:
run:
@@ -126,7 +126,7 @@ jobs:
sudo apt update
sudo apt install -y protobuf-compiler
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v5
with:
python-version: 3.${{ matrix.python-minor-version }}
- uses: ./.github/workflows/build_linux_wheel
@@ -156,9 +156,9 @@ jobs:
fetch-depth: 0
lfs: true
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v5
with:
python-version: "3.13"
python-version: "3.12"
- uses: ./.github/workflows/build_mac_wheel
with:
args: --profile ci
@@ -185,9 +185,9 @@ jobs:
fetch-depth: 0
lfs: true
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v5
with:
python-version: "3.13"
python-version: "3.12"
- uses: ./.github/workflows/build_windows_wheel
with:
args: --profile ci
@@ -212,9 +212,9 @@ jobs:
sudo apt update
sudo apt install -y protobuf-compiler
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v5
with:
python-version: "3.10"
python-version: 3.9
- name: Install lancedb
run: |
pip install "pydantic<2"

View File

@@ -48,8 +48,6 @@ jobs:
run: cargo fmt --all -- --check
- name: Run clippy
run: cargo clippy --profile ci --workspace --tests --all-features -- -D warnings
- name: Run clippy (without remote feature)
run: cargo clippy --profile ci --workspace --tests -- -D warnings
build-no-lock:
runs-on: ubuntu-24.04
@@ -169,13 +167,13 @@ jobs:
- name: Build
run: |
$env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT
cargo build --profile ci --features aws,remote --tests --locked --target ${{ matrix.target }}
cargo build --profile ci --features remote --tests --locked --target ${{ matrix.target }}
- name: Run tests
# Can only run tests when target matches host
if: ${{ matrix.target == 'x86_64-pc-windows-msvc' }}
run: |
$env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT
cargo test --profile ci --features aws,remote --locked
cargo test --profile ci --features remote --locked
msrv:
# Check the minimum supported Rust version
@@ -183,7 +181,7 @@ jobs:
runs-on: ubuntu-24.04
strategy:
matrix:
msrv: ["1.88.0"] # This should match up with rust-version in Cargo.toml
msrv: ["1.78.0"] # This should match up with rust-version in Cargo.toml
env:
# Need up-to-date compilers for kernels
CC: clang-18
@@ -214,6 +212,4 @@ jobs:
cargo update -p aws-sdk-sts --precise 1.51.0
cargo update -p home --precise 0.5.9
- name: cargo +${{ matrix.msrv }} check
env:
RUSTUP_TOOLCHAIN: ${{ matrix.msrv }}
run: cargo check --profile ci --workspace --tests --benches --all-features

94
Cargo.lock generated
View File

@@ -1,6 +1,6 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 4
version = 3
[[package]]
name = "adler2"
@@ -3141,8 +3141,8 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]]
name = "fsst"
version = "1.0.4"
source = "git+https://github.com/lance-format/lance.git?tag=v1.0.4#a93eaad1f6909a843cf8aa00d5530359012a7aaa"
version = "2.0.0-beta.3"
source = "git+https://github.com/lance-format/lance.git?tag=v2.0.0-beta.3#e6233665e377926ed2a8ceca667bc7a2f23341ae"
dependencies = [
"arrow-array",
"rand 0.9.2",
@@ -4261,7 +4261,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b0f83760fb341a774ed326568e19f5a863af4a952def8c39f9ab92fd95b88e5"
dependencies = [
"equivalent",
"hashbrown 0.16.0",
"hashbrown 0.15.5",
"serde",
"serde_core",
]
@@ -4478,8 +4478,8 @@ dependencies = [
[[package]]
name = "lance"
version = "1.0.4"
source = "git+https://github.com/lance-format/lance.git?tag=v1.0.4#a93eaad1f6909a843cf8aa00d5530359012a7aaa"
version = "2.0.0-beta.3"
source = "git+https://github.com/lance-format/lance.git?tag=v2.0.0-beta.3#e6233665e377926ed2a8ceca667bc7a2f23341ae"
dependencies = [
"arrow",
"arrow-arith",
@@ -4544,13 +4544,14 @@ dependencies = [
[[package]]
name = "lance-arrow"
version = "1.0.4"
source = "git+https://github.com/lance-format/lance.git?tag=v1.0.4#a93eaad1f6909a843cf8aa00d5530359012a7aaa"
version = "2.0.0-beta.3"
source = "git+https://github.com/lance-format/lance.git?tag=v2.0.0-beta.3#e6233665e377926ed2a8ceca667bc7a2f23341ae"
dependencies = [
"arrow-array",
"arrow-buffer",
"arrow-cast",
"arrow-data",
"arrow-ord",
"arrow-schema",
"arrow-select",
"bytes",
@@ -4563,8 +4564,8 @@ dependencies = [
[[package]]
name = "lance-bitpacking"
version = "1.0.4"
source = "git+https://github.com/lance-format/lance.git?tag=v1.0.4#a93eaad1f6909a843cf8aa00d5530359012a7aaa"
version = "2.0.0-beta.3"
source = "git+https://github.com/lance-format/lance.git?tag=v2.0.0-beta.3#e6233665e377926ed2a8ceca667bc7a2f23341ae"
dependencies = [
"arrayref",
"paste",
@@ -4573,8 +4574,8 @@ dependencies = [
[[package]]
name = "lance-core"
version = "1.0.4"
source = "git+https://github.com/lance-format/lance.git?tag=v1.0.4#a93eaad1f6909a843cf8aa00d5530359012a7aaa"
version = "2.0.0-beta.3"
source = "git+https://github.com/lance-format/lance.git?tag=v2.0.0-beta.3#e6233665e377926ed2a8ceca667bc7a2f23341ae"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -4587,6 +4588,7 @@ dependencies = [
"datafusion-sql",
"deepsize",
"futures",
"itertools 0.13.0",
"lance-arrow",
"libc",
"log",
@@ -4610,8 +4612,8 @@ dependencies = [
[[package]]
name = "lance-datafusion"
version = "1.0.4"
source = "git+https://github.com/lance-format/lance.git?tag=v1.0.4#a93eaad1f6909a843cf8aa00d5530359012a7aaa"
version = "2.0.0-beta.3"
source = "git+https://github.com/lance-format/lance.git?tag=v2.0.0-beta.3#e6233665e377926ed2a8ceca667bc7a2f23341ae"
dependencies = [
"arrow",
"arrow-array",
@@ -4641,8 +4643,8 @@ dependencies = [
[[package]]
name = "lance-datagen"
version = "1.0.4"
source = "git+https://github.com/lance-format/lance.git?tag=v1.0.4#a93eaad1f6909a843cf8aa00d5530359012a7aaa"
version = "2.0.0-beta.3"
source = "git+https://github.com/lance-format/lance.git?tag=v2.0.0-beta.3#e6233665e377926ed2a8ceca667bc7a2f23341ae"
dependencies = [
"arrow",
"arrow-array",
@@ -4659,8 +4661,8 @@ dependencies = [
[[package]]
name = "lance-encoding"
version = "1.0.4"
source = "git+https://github.com/lance-format/lance.git?tag=v1.0.4#a93eaad1f6909a843cf8aa00d5530359012a7aaa"
version = "2.0.0-beta.3"
source = "git+https://github.com/lance-format/lance.git?tag=v2.0.0-beta.3#e6233665e377926ed2a8ceca667bc7a2f23341ae"
dependencies = [
"arrow-arith",
"arrow-array",
@@ -4697,8 +4699,8 @@ dependencies = [
[[package]]
name = "lance-file"
version = "1.0.4"
source = "git+https://github.com/lance-format/lance.git?tag=v1.0.4#a93eaad1f6909a843cf8aa00d5530359012a7aaa"
version = "2.0.0-beta.3"
source = "git+https://github.com/lance-format/lance.git?tag=v2.0.0-beta.3#e6233665e377926ed2a8ceca667bc7a2f23341ae"
dependencies = [
"arrow-arith",
"arrow-array",
@@ -4730,8 +4732,8 @@ dependencies = [
[[package]]
name = "lance-geo"
version = "1.0.4"
source = "git+https://github.com/lance-format/lance.git?tag=v1.0.4#a93eaad1f6909a843cf8aa00d5530359012a7aaa"
version = "2.0.0-beta.3"
source = "git+https://github.com/lance-format/lance.git?tag=v2.0.0-beta.3#e6233665e377926ed2a8ceca667bc7a2f23341ae"
dependencies = [
"datafusion",
"geo-types",
@@ -4742,8 +4744,8 @@ dependencies = [
[[package]]
name = "lance-index"
version = "1.0.4"
source = "git+https://github.com/lance-format/lance.git?tag=v1.0.4#a93eaad1f6909a843cf8aa00d5530359012a7aaa"
version = "2.0.0-beta.3"
source = "git+https://github.com/lance-format/lance.git?tag=v2.0.0-beta.3#e6233665e377926ed2a8ceca667bc7a2f23341ae"
dependencies = [
"arrow",
"arrow-arith",
@@ -4789,6 +4791,7 @@ dependencies = [
"prost-types",
"rand 0.9.2",
"rand_distr 0.5.1",
"rangemap",
"rayon",
"roaring",
"serde",
@@ -4804,8 +4807,8 @@ dependencies = [
[[package]]
name = "lance-io"
version = "1.0.4"
source = "git+https://github.com/lance-format/lance.git?tag=v1.0.4#a93eaad1f6909a843cf8aa00d5530359012a7aaa"
version = "2.0.0-beta.3"
source = "git+https://github.com/lance-format/lance.git?tag=v2.0.0-beta.3#e6233665e377926ed2a8ceca667bc7a2f23341ae"
dependencies = [
"arrow",
"arrow-arith",
@@ -4845,8 +4848,8 @@ dependencies = [
[[package]]
name = "lance-linalg"
version = "1.0.4"
source = "git+https://github.com/lance-format/lance.git?tag=v1.0.4#a93eaad1f6909a843cf8aa00d5530359012a7aaa"
version = "2.0.0-beta.3"
source = "git+https://github.com/lance-format/lance.git?tag=v2.0.0-beta.3#e6233665e377926ed2a8ceca667bc7a2f23341ae"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -4862,8 +4865,8 @@ dependencies = [
[[package]]
name = "lance-namespace"
version = "1.0.4"
source = "git+https://github.com/lance-format/lance.git?tag=v1.0.4#a93eaad1f6909a843cf8aa00d5530359012a7aaa"
version = "2.0.0-beta.3"
source = "git+https://github.com/lance-format/lance.git?tag=v2.0.0-beta.3#e6233665e377926ed2a8ceca667bc7a2f23341ae"
dependencies = [
"arrow",
"async-trait",
@@ -4875,8 +4878,8 @@ dependencies = [
[[package]]
name = "lance-namespace-impls"
version = "1.0.4"
source = "git+https://github.com/lance-format/lance.git?tag=v1.0.4#a93eaad1f6909a843cf8aa00d5530359012a7aaa"
version = "2.0.0-beta.3"
source = "git+https://github.com/lance-format/lance.git?tag=v2.0.0-beta.3#e6233665e377926ed2a8ceca667bc7a2f23341ae"
dependencies = [
"arrow",
"arrow-ipc",
@@ -4884,7 +4887,6 @@ dependencies = [
"async-trait",
"axum",
"bytes",
"chrono",
"futures",
"lance",
"lance-core",
@@ -4906,9 +4908,9 @@ dependencies = [
[[package]]
name = "lance-namespace-reqwest-client"
version = "0.4.5"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2acdba67f84190067532fce07b51a435dd390d7cdc1129a05003e5cb3274cf0"
checksum = "00a21b43fe2a373896727b97927adedd2683d2907683f294f62cf8815fbf6a01"
dependencies = [
"reqwest",
"serde",
@@ -4919,8 +4921,8 @@ dependencies = [
[[package]]
name = "lance-table"
version = "1.0.4"
source = "git+https://github.com/lance-format/lance.git?tag=v1.0.4#a93eaad1f6909a843cf8aa00d5530359012a7aaa"
version = "2.0.0-beta.3"
source = "git+https://github.com/lance-format/lance.git?tag=v2.0.0-beta.3#e6233665e377926ed2a8ceca667bc7a2f23341ae"
dependencies = [
"arrow",
"arrow-array",
@@ -4959,8 +4961,8 @@ dependencies = [
[[package]]
name = "lance-testing"
version = "1.0.4"
source = "git+https://github.com/lance-format/lance.git?tag=v1.0.4#a93eaad1f6909a843cf8aa00d5530359012a7aaa"
version = "2.0.0-beta.3"
source = "git+https://github.com/lance-format/lance.git?tag=v2.0.0-beta.3#e6233665e377926ed2a8ceca667bc7a2f23341ae"
dependencies = [
"arrow-array",
"arrow-schema",
@@ -4971,7 +4973,7 @@ dependencies = [
[[package]]
name = "lancedb"
version = "0.24.1"
version = "0.23.1-beta.1"
dependencies = [
"ahash",
"anyhow",
@@ -5050,7 +5052,7 @@ dependencies = [
[[package]]
name = "lancedb-nodejs"
version = "0.24.1"
version = "0.23.1-beta.1"
dependencies = [
"arrow-array",
"arrow-ipc",
@@ -5070,7 +5072,7 @@ dependencies = [
[[package]]
name = "lancedb-python"
version = "0.27.1"
version = "0.26.1-beta.1"
dependencies = [
"arrow",
"async-trait",
@@ -6726,8 +6728,8 @@ version = "0.13.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf"
dependencies = [
"heck 0.5.0",
"itertools 0.14.0",
"heck 0.4.1",
"itertools 0.12.1",
"log",
"multimap",
"once_cell",
@@ -6747,7 +6749,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d"
dependencies = [
"anyhow",
"itertools 0.14.0",
"itertools 0.12.1",
"proc-macro2",
"quote",
"syn 2.0.106",
@@ -8077,7 +8079,7 @@ version = "0.8.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1c97747dbf44bb1ca44a561ece23508e99cb592e862f22222dcf42f51d1e451"
dependencies = [
"heck 0.5.0",
"heck 0.4.1",
"proc-macro2",
"quote",
"syn 2.0.106",

View File

@@ -12,23 +12,23 @@ repository = "https://github.com/lancedb/lancedb"
description = "Serverless, low-latency vector database for AI applications"
keywords = ["lancedb", "lance", "database", "vector", "search"]
categories = ["database-implementations"]
rust-version = "1.88.0"
rust-version = "1.78.0"
[workspace.dependencies]
lance = { "version" = "=1.0.4", default-features = false, "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-core = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-datagen = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-file = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-io = { "version" = "=1.0.4", default-features = false, "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-index = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-linalg = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace-impls = { "version" = "=1.0.4", default-features = false, "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-table = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-testing = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-datafusion = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-encoding = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance-arrow = { "version" = "=1.0.4", "tag" = "v1.0.4", "git" = "https://github.com/lance-format/lance.git" }
lance = { "version" = "=2.0.0-beta.3", default-features = false, "tag" = "v2.0.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-core = { "version" = "=2.0.0-beta.3", "tag" = "v2.0.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-datagen = { "version" = "=2.0.0-beta.3", "tag" = "v2.0.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-file = { "version" = "=2.0.0-beta.3", "tag" = "v2.0.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-io = { "version" = "=2.0.0-beta.3", default-features = false, "tag" = "v2.0.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-index = { "version" = "=2.0.0-beta.3", "tag" = "v2.0.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-linalg = { "version" = "=2.0.0-beta.3", "tag" = "v2.0.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace = { "version" = "=2.0.0-beta.3", "tag" = "v2.0.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace-impls = { "version" = "=2.0.0-beta.3", default-features = false, "tag" = "v2.0.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-table = { "version" = "=2.0.0-beta.3", "tag" = "v2.0.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-testing = { "version" = "=2.0.0-beta.3", "tag" = "v2.0.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-datafusion = { "version" = "=2.0.0-beta.3", "tag" = "v2.0.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-encoding = { "version" = "=2.0.0-beta.3", "tag" = "v2.0.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
lance-arrow = { "version" = "=2.0.0-beta.3", "tag" = "v2.0.0-beta.3", "git" = "https://github.com/lance-format/lance.git" }
ahash = "0.8"
# Note that this one does not include pyarrow
arrow = { version = "56.2", optional = false }

View File

@@ -66,7 +66,7 @@ Follow the [Quickstart](https://lancedb.com/docs/quickstart/) doc to set up Lanc
| Python SDK | https://lancedb.github.io/lancedb/python/python/ |
| Typescript SDK | https://lancedb.github.io/lancedb/js/globals/ |
| Rust SDK | https://docs.rs/lancedb/latest/lancedb/index.html |
| REST API | https://docs.lancedb.com/api-reference/rest |
| REST API | https://docs.lancedb.com/api-reference/introduction |
## **Join Us and Contribute**

View File

@@ -16,7 +16,7 @@ check_command_exists() {
}
if [[ ! -e ./lancedb ]]; then
if [[ x${SOPHON_READ_TOKEN} != "x" ]]; then
if [[ -v SOPHON_READ_TOKEN ]]; then
INPUT="lancedb-linux-x64"
gh release \
--repo lancedb/lancedb \

View File

@@ -11,7 +11,7 @@ watch:
theme:
name: "material"
logo: assets/logo.png
favicon: assets/favicon.ico
favicon: assets/logo.png
palette:
# Palette toggle for light mode
- scheme: lancedb
@@ -32,6 +32,8 @@ theme:
- content.tooltips
- toc.follow
- navigation.top
- navigation.tabs
- navigation.tabs.sticky
- navigation.footer
- navigation.tracking
- navigation.instant
@@ -113,13 +115,12 @@ markdown_extensions:
emoji_index: !!python/name:material.extensions.emoji.twemoji
emoji_generator: !!python/name:material.extensions.emoji.to_svg
- markdown.extensions.toc:
toc_depth: 3
permalink: true
permalink_title: Anchor link to this section
baselevel: 1
permalink: ""
nav:
- Documentation:
- SDK Reference: index.md
- API reference:
- Overview: index.md
- Python: python/python.md
- Javascript/TypeScript: js/globals.md
- Java: java/java.md

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

View File

@@ -1,111 +0,0 @@
# VoyageAI Embeddings : Multimodal
VoyageAI embeddings can also be used to embed both text and image data, only some of the models support image data and you can check the list
under [https://docs.voyageai.com/docs/multimodal-embeddings](https://docs.voyageai.com/docs/multimodal-embeddings)
Supported multimodal models:
- `voyage-multimodal-3` - 1024 dimensions (text + images)
- `voyage-multimodal-3.5` - Flexible dimensions (256, 512, 1024 default, 2048). Supports text, images, and video.
### Video Support (voyage-multimodal-3.5)
The `voyage-multimodal-3.5` model supports video input through:
- Video URLs (`.mp4`, `.webm`, `.mov`, `.avi`, `.mkv`, `.m4v`, `.gif`)
- Video file paths
Constraints: Max 20MB video size.
Supported parameters (to be passed in `create` method) are:
| Parameter | Type | Default Value | Description |
|---|---|-------------------------|-------------------------------------------|
| `name` | `str` | `"voyage-multimodal-3"` | The model ID of the VoyageAI model to use |
| `output_dimension` | `int` | `None` | Output dimension for voyage-multimodal-3.5. Valid: 256, 512, 1024, 2048 |
Usage Example:
```python
import base64
import os
from io import BytesIO
import requests
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import get_registry
import pandas as pd
os.environ['VOYAGE_API_KEY'] = 'YOUR_VOYAGE_API_KEY'
db = lancedb.connect(".lancedb")
func = get_registry().get("voyageai").create(name="voyage-multimodal-3")
def image_to_base64(image_bytes: bytes):
buffered = BytesIO(image_bytes)
img_str = base64.b64encode(buffered.getvalue())
return img_str.decode("utf-8")
class Images(LanceModel):
label: str
image_uri: str = func.SourceField() # image uri as the source
image_bytes: str = func.SourceField() # image bytes base64 encoded as the source
vector: Vector(func.ndims()) = func.VectorField() # vector column
vec_from_bytes: Vector(func.ndims()) = func.VectorField() # Another vector column
if "images" in db.table_names():
db.drop_table("images")
table = db.create_table("images", schema=Images)
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
]
# get each uri as bytes
images_bytes = [image_to_base64(requests.get(uri).content) for uri in uris]
table.add(
pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": images_bytes})
)
```
Now we can search using text from both the default vector column and the custom vector column
```python
# text search
actual = table.search("man's best friend", "vec_from_bytes").limit(1).to_pydantic(Images)[0]
print(actual.label) # prints "dog"
frombytes = (
table.search("man's best friend", vector_column_name="vec_from_bytes")
.limit(1)
.to_pydantic(Images)[0]
)
print(frombytes.label)
```
Because we're using a multi-modal embedding function, we can also search using images
```python
# image search
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
image_bytes = requests.get(query_image_uri).content
query_image = Image.open(BytesIO(image_bytes))
actual = table.search(query_image, "vec_from_bytes").limit(1).to_pydantic(Images)[0]
print(actual.label == "dog")
# image search using a custom vector column
other = (
table.search(query_image, vector_column_name="vec_from_bytes")
.limit(1)
.to_pydantic(Images)[0]
)
print(actual.label)
```

View File

@@ -1,62 +0,0 @@
# VoyageAI Embeddings
Voyage AI provides cutting-edge embedding and rerankers.
Using voyageai API requires voyageai package, which can be installed using `pip install voyageai`. Voyage AI embeddings are used to generate embeddings for text data. The embeddings can be used for various tasks like semantic search, clustering, and classification.
You also need to set the `VOYAGE_API_KEY` environment variable to use the VoyageAI API.
Supported models are:
**Voyage-4 Series (Latest)**
- voyage-4 (1024 dims, general-purpose and multilingual retrieval, 320K batch tokens)
- voyage-4-lite (1024 dims, optimized for latency and cost, 1M batch tokens)
- voyage-4-large (1024 dims, best retrieval quality, 120K batch tokens)
**Voyage-3 Series**
- voyage-3
- voyage-3-lite
**Domain-Specific Models**
- voyage-finance-2
- voyage-multilingual-2
- voyage-law-2
- voyage-code-2
Supported parameters (to be passed in `create` method) are:
| Parameter | Type | Default Value | Description |
|---|---|--------|---------|
| `name` | `str` | `None` | The model ID of the model to use. Supported base models for Text Embeddings: voyage-4, voyage-4-lite, voyage-4-large, voyage-3, voyage-3-lite, voyage-finance-2, voyage-multilingual-2, voyage-law-2, voyage-code-2 |
| `input_type` | `str` | `None` | Type of the input text. Default to None. Other options: query, document. |
| `truncation` | `bool` | `True` | Whether to truncate the input texts to fit within the context length. |
Usage Example:
```python
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import EmbeddingFunctionRegistry
voyageai = EmbeddingFunctionRegistry
.get_instance()
.get("voyageai")
.create(name="voyage-3")
class TextModel(LanceModel):
text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
data = [ { "text": "hello world" },
{ "text": "goodbye world" }]
db = lancedb.connect("~/.lancedb")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(data)
```

View File

@@ -1,12 +1,8 @@
# SDK Reference
# API Reference
This site contains the API reference for the client SDKs supported by [LanceDB](https://lancedb.com).
This page contains the API reference for the SDKs supported by the LanceDB team.
- [Python](python/python.md)
- [JavaScript/TypeScript](js/globals.md)
- [Java](java/java.md)
- [Rust](https://docs.rs/lancedb/latest/lancedb/index.html)
!!! info "LanceDB Documentation"
If you're looking for the full documentation of LanceDB, visit [docs.lancedb.com](https://docs.lancedb.com).
- [Rust](https://docs.rs/lancedb/latest/lancedb/index.html)

View File

@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
<dependency>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-core</artifactId>
<version>0.24.1</version>
<version>0.23.1-beta.1</version>
</dependency>
```

View File

@@ -85,26 +85,17 @@
/* Header gradient (only header area) */
.md-header {
background: linear-gradient(90deg, #e4d8f8 0%, #F0B7C1 45%, #E55A2B 100%);
background: linear-gradient(90deg, #3B2E58 0%, #F0B7C1 45%, #E55A2B 100%);
box-shadow: inset 0 1px 0 rgba(255,255,255,0.08), 0 1px 0 rgba(0,0,0,0.08);
}
/* Improve brand title contrast on the lavender side */
.md-header__title,
.md-header__topic,
.md-header__title .md-ellipsis,
.md-header__topic .md-ellipsis {
color: #2b1b3a;
text-shadow: 0 1px 0 rgba(255, 255, 255, 0.25);
}
/* Same colors as header for tabs (that hold the text) */
.md-tabs {
background: linear-gradient(90deg, #e4d8f8 0%, #F0B7C1 45%, #E55A2B 100%);
background: linear-gradient(90deg, #3B2E58 0%, #F0B7C1 45%, #E55A2B 100%);
}
/* Dark scheme variant */
[data-md-color-scheme="slate"] .md-header,
[data-md-color-scheme="slate"] .md-tabs {
background: linear-gradient(90deg, #e4d8f8 0%, #F0B7C1 45%, #E55A2B 100%);
background: linear-gradient(90deg, #3B2E58 0%, #F0B7C1 45%, #E55A2B 100%);
}

View File

@@ -8,7 +8,7 @@
<parent>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.24.1-final.0</version>
<version>0.23.1-beta.1</version>
<relativePath>../pom.xml</relativePath>
</parent>

View File

@@ -6,7 +6,7 @@
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.24.1-final.0</version>
<version>0.23.1-beta.1</version>
<packaging>pom</packaging>
<name>${project.artifactId}</name>
<description>LanceDB Java SDK Parent POM</description>
@@ -28,7 +28,7 @@
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<arrow.version>15.0.0</arrow.version>
<lance-core.version>1.0.4</lance-core.version>
<lance-core.version>1.0.0-rc.2</lance-core.version>
<spotless.skip>false</spotless.skip>
<spotless.version>2.30.0</spotless.version>
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>

View File

@@ -1,7 +1,7 @@
[package]
name = "lancedb-nodejs"
edition.workspace = true
version = "0.24.1"
version = "0.23.1-beta.1"
license.workspace = true
description.workspace = true
repository.workspace = true
@@ -36,6 +36,6 @@ aws-lc-rs = "=1.13.0"
napi-build = "2.1"
[features]
default = ["remote", "lancedb/aws", "lancedb/gcs", "lancedb/azure", "lancedb/dynamodb", "lancedb/oss", "lancedb/huggingface"]
default = ["remote", "lancedb/default"]
fp16kernels = ["lancedb/fp16kernels"]
remote = ["lancedb/remote"]

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-arm64",
"version": "0.24.1",
"version": "0.23.1-beta.1",
"os": ["darwin"],
"cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-x64",
"version": "0.24.1",
"version": "0.23.1-beta.1",
"os": ["darwin"],
"cpu": ["x64"],
"main": "lancedb.darwin-x64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.24.1",
"version": "0.23.1-beta.1",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-musl",
"version": "0.24.1",
"version": "0.23.1-beta.1",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.24.1",
"version": "0.23.1-beta.1",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-musl",
"version": "0.24.1",
"version": "0.23.1-beta.1",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-arm64-msvc",
"version": "0.24.1",
"version": "0.23.1-beta.1",
"os": [
"win32"
],

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.24.1",
"version": "0.23.1-beta.1",
"os": ["win32"],
"cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node",

View File

@@ -1,12 +1,12 @@
{
"name": "@lancedb/lancedb",
"version": "0.24.1",
"version": "0.23.1-beta.1",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "@lancedb/lancedb",
"version": "0.24.1",
"version": "0.23.1-beta.1",
"cpu": [
"x64",
"arm64"

View File

@@ -11,7 +11,7 @@
"ann"
],
"private": false,
"version": "0.24.1",
"version": "0.23.1-beta.1",
"main": "dist/index.js",
"exports": {
".": "./dist/index.js",

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.28.0-beta.0"
current_version = "0.26.1-beta.1"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -16,7 +16,7 @@ The Python package is a wrapper around the Rust library, `lancedb`. We use
To set up your development environment, you will need to install the following:
1. Python 3.10 or later
1. Python 3.9 or later
2. Cargo (Rust's package manager). Use [rustup](https://rustup.rs/) to install.
3. [protoc](https://grpc.io/docs/protoc-installation/) (Protocol Buffers compiler)

View File

@@ -1,13 +1,13 @@
[package]
name = "lancedb-python"
version = "0.28.0-beta.0"
version = "0.26.1-beta.1"
edition.workspace = true
description = "Python bindings for LanceDB"
license.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
rust-version = "1.88.0"
rust-version = "1.75.0"
[lib]
name = "_lancedb"
@@ -21,7 +21,7 @@ lance-core.workspace = true
lance-namespace.workspace = true
lance-io.workspace = true
env_logger.workspace = true
pyo3 = { version = "0.25", features = ["extension-module", "abi3-py310"] }
pyo3 = { version = "0.25", features = ["extension-module", "abi3-py39"] }
pyo3-async-runtimes = { version = "0.25", features = [
"attributes",
"tokio-runtime",
@@ -34,10 +34,10 @@ tokio = { version = "1.40", features = ["sync"] }
[build-dependencies]
pyo3-build-config = { version = "0.25", features = [
"extension-module",
"abi3-py310",
"abi3-py39",
] }
[features]
default = ["remote", "lancedb/aws", "lancedb/gcs", "lancedb/azure", "lancedb/dynamodb", "lancedb/oss", "lancedb/huggingface"]
default = ["remote", "lancedb/default"]
fp16kernels = ["lancedb/fp16kernels"]
remote = ["lancedb/remote"]

View File

@@ -16,7 +16,7 @@ description = "lancedb"
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]
license = { file = "LICENSE" }
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.9"
keywords = [
"data-format",
"data-science",
@@ -33,10 +33,10 @@ classifiers = [
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Scientific/Engineering",
]
@@ -137,4 +137,4 @@ include = [
"python/lancedb/_lancedb.pyi",
]
exclude = ["python/tests/"]
pythonVersion = "3.13"
pythonVersion = "3.12"

View File

@@ -13,7 +13,6 @@ __version__ = importlib.metadata.version("lancedb")
from ._lancedb import connect as lancedb_connect
from .common import URI, sanitize_uri
from urllib.parse import urlparse
from .db import AsyncConnection, DBConnection, LanceDBConnection
from .io import StorageOptionsProvider
from .remote import ClientConfig
@@ -29,39 +28,6 @@ from .namespace import (
)
def _check_s3_bucket_with_dots(
uri: str, storage_options: Optional[Dict[str, str]]
) -> None:
"""
Check if an S3 URI has a bucket name containing dots and warn if no region
is specified. S3 buckets with dots cannot use virtual-hosted-style URLs,
which breaks automatic region detection.
See: https://github.com/lancedb/lancedb/issues/1898
"""
if not isinstance(uri, str) or not uri.startswith("s3://"):
return
parsed = urlparse(uri)
bucket = parsed.netloc
if "." not in bucket:
return
# Check if region is provided in storage_options
region_keys = {"region", "aws_region"}
has_region = storage_options and any(k in storage_options for k in region_keys)
if not has_region:
raise ValueError(
f"S3 bucket name '{bucket}' contains dots, which prevents automatic "
f"region detection. Please specify the region explicitly via "
f"storage_options={{'region': '<your-region>'}} or "
f"storage_options={{'aws_region': '<your-region>'}}. "
f"See https://github.com/lancedb/lancedb/issues/1898 for details."
)
def connect(
uri: URI,
*,
@@ -155,11 +121,9 @@ def connect(
storage_options=storage_options,
**kwargs,
)
_check_s3_bucket_with_dots(str(uri), storage_options)
if kwargs:
raise ValueError(f"Unknown keyword arguments: {kwargs}")
return LanceDBConnection(
uri,
read_consistency_interval=read_consistency_interval,
@@ -247,8 +211,6 @@ async def connect_async(
if isinstance(client_config, dict):
client_config = ClientConfig(**client_config)
_check_s3_bucket_with_dots(str(uri), storage_options)
return AsyncConnection(
await lancedb_connect(
sanitize_uri(uri),

View File

@@ -179,7 +179,6 @@ class Table:
cleanup_since_ms: Optional[int] = None,
delete_unverified: Optional[bool] = None,
) -> OptimizeStats: ...
async def uri(self) -> str: ...
@property
def tags(self) -> Tags: ...
def query(self) -> Query: ...

View File

@@ -22,12 +22,7 @@ class BackgroundEventLoop:
self.thread.start()
def run(self, future):
concurrent_future = asyncio.run_coroutine_threadsafe(future, self.loop)
try:
return concurrent_future.result()
except BaseException:
concurrent_future.cancel()
raise
return asyncio.run_coroutine_threadsafe(future, self.loop).result()
LOOP = BackgroundEventLoop()

View File

@@ -210,8 +210,10 @@ class DBConnection(EnforceOverrides):
page_token: str, optional
The token to use for pagination. If not present, start from the beginning.
Typically, this token is last table name from the previous page.
Only supported by LanceDb Cloud.
limit: int, default 10
The size of the page to return.
Only supported by LanceDb Cloud.
Returns
-------

View File

@@ -275,7 +275,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
"""
Convert image inputs to PIL Images.
"""
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
PIL = attempt_import_or_raise("PIL", "pillow")
requests = attempt_import_or_raise("requests", "requests")
images = self.sanitize_input(images)
pil_images = []
@@ -285,12 +285,12 @@ class ColPaliEmbeddings(EmbeddingFunction):
if image.startswith(("http://", "https://")):
response = requests.get(image, timeout=10)
response.raise_for_status()
pil_images.append(PIL_Image.open(io.BytesIO(response.content)))
pil_images.append(PIL.Image.open(io.BytesIO(response.content)))
else:
with PIL_Image.open(image) as im:
with PIL.Image.open(image) as im:
pil_images.append(im.copy())
elif isinstance(image, bytes):
pil_images.append(PIL_Image.open(io.BytesIO(image)))
pil_images.append(PIL.Image.open(io.BytesIO(image)))
else:
# Assume it's a PIL Image; will raise if invalid
pil_images.append(image)

View File

@@ -77,8 +77,8 @@ class JinaEmbeddings(EmbeddingFunction):
if isinstance(inputs, list):
inputs = inputs
else:
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(inputs, PIL_Image.Image):
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(inputs, PIL.Image.Image):
inputs = [inputs]
return inputs
@@ -89,13 +89,13 @@ class JinaEmbeddings(EmbeddingFunction):
elif isinstance(image, (str, Path)):
parsed = urlparse.urlparse(image)
# TODO handle drive letter on windows.
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
PIL = attempt_import_or_raise("PIL", "pillow")
if parsed.scheme == "file":
pil_image = PIL_Image.open(parsed.path)
pil_image = PIL.Image.open(parsed.path)
elif parsed.scheme == "":
pil_image = PIL_Image.open(image if os.name == "nt" else parsed.path)
pil_image = PIL.Image.open(image if os.name == "nt" else parsed.path)
elif parsed.scheme.startswith("http"):
pil_image = PIL_Image.open(io.BytesIO(url_retrieve(image)))
pil_image = PIL.Image.open(io.BytesIO(url_retrieve(image)))
else:
raise NotImplementedError("Only local and http(s) urls are supported")
buffered = io.BytesIO()
@@ -103,9 +103,9 @@ class JinaEmbeddings(EmbeddingFunction):
image_bytes = buffered.getvalue()
image_dict = {"image": base64.b64encode(image_bytes).decode("utf-8")}
else:
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(image, PIL_Image.Image):
if isinstance(image, PIL.Image.Image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
image_bytes = buffered.getvalue()
@@ -136,9 +136,9 @@ class JinaEmbeddings(EmbeddingFunction):
elif isinstance(query, (Path, bytes)):
return [self.generate_image_embedding(query)]
else:
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(query, PIL_Image.Image):
if isinstance(query, PIL.Image.Image):
return [self.generate_image_embedding(query)]
else:
raise TypeError(

View File

@@ -71,8 +71,8 @@ class OpenClipEmbeddings(EmbeddingFunction):
if isinstance(query, str):
return [self.generate_text_embeddings(query)]
else:
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(query, PIL_Image.Image):
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(query, PIL.Image.Image):
return [self.generate_image_embedding(query)]
else:
raise TypeError("OpenClip supports str or PIL Image as query")
@@ -145,20 +145,20 @@ class OpenClipEmbeddings(EmbeddingFunction):
return self._encode_and_normalize_image(image)
def _to_pil(self, image: Union[str, bytes]):
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(image, bytes):
return PIL_Image.open(io.BytesIO(image))
if isinstance(image, PIL_Image.Image):
return PIL.Image.open(io.BytesIO(image))
if isinstance(image, PIL.Image.Image):
return image
elif isinstance(image, str):
parsed = urlparse.urlparse(image)
# TODO handle drive letter on windows.
if parsed.scheme == "file":
return PIL_Image.open(parsed.path)
return PIL.Image.open(parsed.path)
elif parsed.scheme == "":
return PIL_Image.open(image if os.name == "nt" else parsed.path)
return PIL.Image.open(image if os.name == "nt" else parsed.path)
elif parsed.scheme.startswith("http"):
return PIL_Image.open(io.BytesIO(url_retrieve(image)))
return PIL.Image.open(io.BytesIO(url_retrieve(image)))
else:
raise NotImplementedError("Only local and http(s) urls are supported")

View File

@@ -56,8 +56,8 @@ class SigLipEmbeddings(EmbeddingFunction):
if isinstance(query, str):
return [self.generate_text_embeddings(query)]
else:
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(query, PIL_Image.Image):
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(query, PIL.Image.Image):
return [self.generate_image_embedding(query)]
else:
raise TypeError("SigLIP supports str or PIL Image as query")
@@ -127,21 +127,21 @@ class SigLipEmbeddings(EmbeddingFunction):
return image_features.cpu().detach().numpy().squeeze()
def _to_pil(self, image: Union[str, bytes, "PIL.Image.Image"]):
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(image, PIL_Image.Image):
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(image, PIL.Image.Image):
return image.convert("RGB") if image.mode != "RGB" else image
elif isinstance(image, bytes):
return PIL_Image.open(io.BytesIO(image)).convert("RGB")
return PIL.Image.open(io.BytesIO(image)).convert("RGB")
elif isinstance(image, str):
parsed = urlparse.urlparse(image)
if parsed.scheme == "file":
return PIL_Image.open(parsed.path).convert("RGB")
return PIL.Image.open(parsed.path).convert("RGB")
elif parsed.scheme == "":
path = image if os.name == "nt" else parsed.path
return PIL_Image.open(path).convert("RGB")
return PIL.Image.open(path).convert("RGB")
elif parsed.scheme.startswith("http"):
image_bytes = url_retrieve(image)
return PIL_Image.open(io.BytesIO(image_bytes)).convert("RGB")
return PIL.Image.open(io.BytesIO(image_bytes)).convert("RGB")
else:
raise NotImplementedError("Only local and http(s) urls are supported")
else:

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import base64
import os
from typing import ClassVar, TYPE_CHECKING, List, Union, Any, Generator, Optional
from typing import ClassVar, TYPE_CHECKING, List, Union, Any, Generator
from pathlib import Path
from urllib.parse import urlparse
@@ -21,9 +21,6 @@ if TYPE_CHECKING:
# Token limits for different VoyageAI models
VOYAGE_TOTAL_TOKEN_LIMITS = {
"voyage-4": 320_000,
"voyage-4-lite": 1_000_000,
"voyage-4-large": 120_000,
"voyage-context-3": 32_000,
"voyage-3.5-lite": 1_000_000,
"voyage-3.5": 320_000,
@@ -48,32 +45,14 @@ def is_valid_url(text):
return False
VIDEO_EXTENSIONS = {".mp4", ".webm", ".mov", ".avi", ".mkv", ".m4v", ".gif"}
def is_video_url(url: str) -> bool:
"""Check if URL points to a video file based on extension."""
parsed = urlparse(url)
path = parsed.path.lower()
return any(path.endswith(ext) for ext in VIDEO_EXTENSIONS)
def is_video_path(path: Path) -> bool:
"""Check if file path is a video file based on extension."""
return path.suffix.lower() in VIDEO_EXTENSIONS
def transform_input(input_data: Union[str, bytes, Path]):
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(input_data, str):
if is_valid_url(input_data):
if is_video_url(input_data):
content = {"type": "video_url", "video_url": input_data}
else:
content = {"type": "image_url", "image_url": input_data}
content = {"type": "image_url", "image_url": input_data}
else:
content = {"type": "text", "text": input_data}
elif isinstance(input_data, PIL_Image.Image):
elif isinstance(input_data, PIL.Image.Image):
buffered = BytesIO()
input_data.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
@@ -82,7 +61,7 @@ def transform_input(input_data: Union[str, bytes, Path]):
"image_base64": "data:image/jpeg;base64," + img_str,
}
elif isinstance(input_data, bytes):
img = PIL_Image.open(BytesIO(input_data))
img = PIL.Image.open(BytesIO(input_data))
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
@@ -91,24 +70,14 @@ def transform_input(input_data: Union[str, bytes, Path]):
"image_base64": "data:image/jpeg;base64," + img_str,
}
elif isinstance(input_data, Path):
if is_video_path(input_data):
# Read video file and encode as base64
with open(input_data, "rb") as f:
video_bytes = f.read()
video_str = base64.b64encode(video_bytes).decode("utf-8")
content = {
"type": "video_base64",
"video_base64": video_str,
}
else:
img = PIL_Image.open(input_data)
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
content = {
"type": "image_base64",
"image_base64": "data:image/jpeg;base64," + img_str,
}
img = PIL.Image.open(input_data)
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
content = {
"type": "image_base64",
"image_base64": "data:image/jpeg;base64," + img_str,
}
else:
raise ValueError("Each input should be either str, bytes, Path or Image.")
@@ -119,11 +88,9 @@ def sanitize_multimodal_input(inputs: Union[TEXT, IMAGES]) -> List[Any]:
"""
Sanitize the input to the embedding function.
"""
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
if isinstance(inputs, (str, bytes, Path, PIL_Image.Image)):
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(inputs, (str, bytes, Path, PIL.Image.Image)):
inputs = [inputs]
elif isinstance(inputs, list):
pass # Already a list, use as-is
elif isinstance(inputs, pa.Array):
inputs = inputs.to_pylist()
elif isinstance(inputs, pa.ChunkedArray):
@@ -133,7 +100,7 @@ def sanitize_multimodal_input(inputs: Union[TEXT, IMAGES]) -> List[Any]:
f"Input type {type(inputs)} not allowed with multimodal model."
)
if not all(isinstance(x, (str, bytes, Path, PIL_Image.Image)) for x in inputs):
if not all(isinstance(x, (str, bytes, Path, PIL.Image.Image)) for x in inputs):
raise ValueError("Each input should be either str, bytes, Path or Image.")
return [transform_input(i) for i in inputs]
@@ -170,25 +137,17 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
name: str
The name of the model to use. List of acceptable models:
* voyage-4 (1024 dims, general-purpose and multilingual retrieval)
* voyage-4-lite (1024 dims, optimized for latency and cost)
* voyage-4-large (1024 dims, best retrieval quality)
* voyage-context-3
* voyage-3.5
* voyage-3.5-lite
* voyage-3
* voyage-3-lite
* voyage-multimodal-3
* voyage-multimodal-3.5
* voyage-finance-2
* voyage-multilingual-2
* voyage-law-2
* voyage-code-2
output_dimension: int, optional
The output dimension for models that support flexible dimensions.
Currently only voyage-multimodal-3.5 supports this feature.
Valid options: 256, 512, 1024 (default), 2048.
Examples
--------
@@ -216,14 +175,8 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
"""
name: str
output_dimension: Optional[int] = None
client: ClassVar = None
_FLEXIBLE_DIM_MODELS: ClassVar[list] = ["voyage-multimodal-3.5"]
_VALID_DIMENSIONS: ClassVar[list] = [256, 512, 1024, 2048]
text_embedding_models: list = [
"voyage-4",
"voyage-4-lite",
"voyage-4-large",
"voyage-3.5",
"voyage-3.5-lite",
"voyage-3",
@@ -233,7 +186,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
"voyage-law-2",
"voyage-code-2",
]
multimodal_embedding_models: list = ["voyage-multimodal-3", "voyage-multimodal-3.5"]
multimodal_embedding_models: list = ["voyage-multimodal-3"]
contextual_embedding_models: list = ["voyage-context-3"]
def _is_multimodal_model(self, model_name: str):
@@ -245,25 +198,11 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
return model_name in self.contextual_embedding_models or "context" in model_name
def ndims(self):
# Handle flexible dimension models
if self.name in self._FLEXIBLE_DIM_MODELS:
if self.output_dimension is not None:
if self.output_dimension not in self._VALID_DIMENSIONS:
raise ValueError(
f"Invalid output_dimension {self.output_dimension} "
f"for {self.name}. Valid options: {self._VALID_DIMENSIONS}"
)
return self.output_dimension
return 1024 # default dimension
if self.name == "voyage-3-lite":
return 512
elif self.name == "voyage-code-2":
return 1536
elif self.name in [
"voyage-4",
"voyage-4-lite",
"voyage-4-large",
"voyage-context-3",
"voyage-3.5",
"voyage-3.5-lite",
@@ -272,17 +211,12 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
"voyage-finance-2",
"voyage-multilingual-2",
"voyage-law-2",
"voyage-multimodal-3",
]:
return 1024
else:
raise ValueError(f"Model {self.name} not supported")
def _get_multimodal_kwargs(self, **kwargs):
"""Get kwargs for multimodal embed call, including output_dimension if set."""
if self.name in self._FLEXIBLE_DIM_MODELS and self.output_dimension is not None:
kwargs["output_dimension"] = self.output_dimension
return kwargs
def compute_query_embeddings(
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
) -> List[np.ndarray]:
@@ -300,7 +234,6 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
"""
client = VoyageAIEmbeddingFunction._get_client()
if self._is_multimodal_model(self.name):
kwargs = self._get_multimodal_kwargs(**kwargs)
result = client.multimodal_embed(
inputs=[[query]], model=self.name, input_type="query", **kwargs
)
@@ -342,7 +275,6 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
)
if has_images:
# Use non-batched API for images
kwargs = self._get_multimodal_kwargs(**kwargs)
result = client.multimodal_embed(
inputs=sanitized, model=self.name, input_type="document", **kwargs
)
@@ -425,7 +357,6 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
callable: A function that takes a batch of texts and returns embeddings.
"""
if self._is_multimodal_model(self.name):
multimodal_kwargs = self._get_multimodal_kwargs(**kwargs)
def embed_batch(batch: List[str]) -> List[np.array]:
batch_inputs = sanitize_multimodal_input(batch)
@@ -433,7 +364,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
inputs=batch_inputs,
model=self.name,
input_type=input_type,
**multimodal_kwargs,
**kwargs,
)
return result.embeddings

View File

@@ -275,7 +275,7 @@ def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
return pa.timestamp("us", tz=tz)
elif getattr(py_type, "__origin__", None) in (list, tuple):
child = py_type.__args__[0]
return _pydantic_list_child_to_arrow(child, field)
return pa.list_(_py_type_to_arrow_type(child, field))
raise TypeError(
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}."
)
@@ -298,18 +298,12 @@ else:
def _pydantic_type_to_arrow_type(tp: Any, field: FieldInfo) -> pa.DataType:
def _safe_issubclass(candidate: Any, base: type) -> bool:
try:
return issubclass(candidate, base)
except TypeError:
return False
if inspect.isclass(tp):
if _safe_issubclass(tp, pydantic.BaseModel):
if issubclass(tp, pydantic.BaseModel):
# Struct
fields = _pydantic_model_to_fields(tp)
return pa.struct(fields)
if _safe_issubclass(tp, FixedSizeListMixin):
if issubclass(tp, FixedSizeListMixin):
if getattr(tp, "is_multi_vector", lambda: False)():
return pa.list_(pa.list_(tp.value_arrow_type(), tp.dim()))
# For regular Vector
@@ -317,67 +311,45 @@ def _pydantic_type_to_arrow_type(tp: Any, field: FieldInfo) -> pa.DataType:
return _py_type_to_arrow_type(tp, field)
def _pydantic_list_child_to_arrow(child: Any, field: FieldInfo) -> pa.DataType:
unwrapped = _unwrap_optional_annotation(child)
if unwrapped is not None:
return pa.list_(
pa.field("item", _pydantic_type_to_arrow_type(unwrapped, field), True)
)
return pa.list_(_pydantic_type_to_arrow_type(child, field))
def _unwrap_optional_annotation(annotation: Any) -> Any | None:
if isinstance(annotation, (_GenericAlias, GenericAlias)):
origin = annotation.__origin__
args = annotation.__args__
if origin == Union:
non_none = [arg for arg in args if arg is not type(None)]
if len(non_none) == 1 and len(non_none) != len(args):
return non_none[0]
elif sys.version_info >= (3, 10) and isinstance(annotation, types.UnionType):
args = annotation.__args__
non_none = [arg for arg in args if arg is not type(None)]
if len(non_none) == 1 and len(non_none) != len(args):
return non_none[0]
return None
def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
"""Convert a Pydantic FieldInfo to Arrow DataType"""
unwrapped = _unwrap_optional_annotation(field.annotation)
if unwrapped is not None:
return _pydantic_type_to_arrow_type(unwrapped, field)
if isinstance(field.annotation, (_GenericAlias, GenericAlias)):
origin = field.annotation.__origin__
args = field.annotation.__args__
if origin is list:
child = args[0]
return _pydantic_list_child_to_arrow(child, field)
return pa.list_(_py_type_to_arrow_type(child, field))
elif origin == Union:
if len(args) == 2 and args[1] is type(None):
return _pydantic_type_to_arrow_type(args[0], field)
elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType):
args = field.annotation.__args__
if len(args) == 2:
for typ in args:
if typ is type(None):
continue
return _py_type_to_arrow_type(typ, field)
return _pydantic_type_to_arrow_type(field.annotation, field)
def is_nullable(field: FieldInfo) -> bool:
"""Check if a Pydantic FieldInfo is nullable."""
if _unwrap_optional_annotation(field.annotation) is not None:
return True
if isinstance(field.annotation, (_GenericAlias, GenericAlias)):
origin = field.annotation.__origin__
args = field.annotation.__args__
if origin == Union:
if any(typ is type(None) for typ in args):
if len(args) == 2 and args[1] is type(None):
return True
elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType):
args = field.annotation.__args__
for typ in args:
if typ is type(None):
return True
elif inspect.isclass(field.annotation):
try:
if issubclass(field.annotation, FixedSizeListMixin):
return field.annotation.nullable()
except TypeError:
return False
elif inspect.isclass(field.annotation) and issubclass(
field.annotation, FixedSizeListMixin
):
return field.annotation.nullable()
return False

View File

@@ -384,7 +384,6 @@ class RemoteDBConnection(DBConnection):
on_bad_vectors: str = "error",
fill_value: float = 0.0,
mode: Optional[str] = None,
exist_ok: bool = False,
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
*,
namespace: Optional[List[str]] = None,
@@ -413,12 +412,6 @@ class RemoteDBConnection(DBConnection):
- pyarrow.Schema
- [LanceModel][lancedb.pydantic.LanceModel]
mode: str, default "create"
The mode to use when creating the table.
Can be either "create", "overwrite", or "exist_ok".
exist_ok: bool, default False
If exist_ok is True, and mode is None or "create", mode will be changed
to "exist_ok".
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "error", "drop", "fill".
@@ -490,11 +483,6 @@ class RemoteDBConnection(DBConnection):
LanceTable(table4)
"""
if exist_ok:
if mode == "create":
mode = "exist_ok"
elif not mode:
mode = "exist_ok"
if namespace is None:
namespace = []
validate_table_name(name)

View File

@@ -18,17 +18,7 @@ from lancedb._lancedb import (
UpdateResult,
)
from lancedb.embeddings.base import EmbeddingFunctionConfig
from lancedb.index import (
FTS,
BTree,
Bitmap,
HnswSq,
IvfFlat,
IvfPq,
IvfRq,
IvfSq,
LabelList,
)
from lancedb.index import FTS, BTree, Bitmap, HnswSq, IvfFlat, IvfPq, IvfSq, LabelList
from lancedb.remote.db import LOOP
import pyarrow as pa
@@ -275,12 +265,6 @@ class RemoteTable(Table):
num_sub_vectors=num_sub_vectors,
num_bits=num_bits,
)
elif index_type == "IVF_RQ":
config = IvfRq(
distance_type=metric,
num_partitions=num_partitions,
num_bits=num_bits,
)
elif index_type == "IVF_SQ":
config = IvfSq(distance_type=metric, num_partitions=num_partitions)
elif index_type == "IVF_HNSW_PQ":
@@ -295,8 +279,7 @@ class RemoteTable(Table):
else:
raise ValueError(
f"Unknown vector index type: {index_type}. Valid options are"
" 'IVF_FLAT', 'IVF_PQ', 'IVF_RQ', 'IVF_SQ',"
" 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
" 'IVF_FLAT', 'IVF_SQ', 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
)
LOOP.run(
@@ -655,14 +638,6 @@ class RemoteTable(Table):
def stats(self):
return LOOP.run(self._table.stats())
@property
def uri(self) -> str:
"""The table URI (storage location).
For remote tables, this fetches the location from the server via describe.
"""
return LOOP.run(self._table.uri())
def take_offsets(self, offsets: list[int]) -> LanceTakeQueryBuilder:
return LanceTakeQueryBuilder(self._table.take_offsets(offsets))

View File

@@ -2218,10 +2218,6 @@ class LanceTable(Table):
def stats(self) -> TableStatistics:
return LOOP.run(self._table.stats())
@property
def uri(self) -> str:
return LOOP.run(self._table.uri())
def create_scalar_index(
self,
column: str,
@@ -3610,20 +3606,6 @@ class AsyncTable:
"""
return await self._inner.stats()
async def uri(self) -> str:
"""
Get the table URI (storage location).
For remote tables, this fetches the location from the server via describe.
For local tables, this returns the dataset URI.
Returns
-------
str
The full storage location of the table (e.g., S3/GCS path).
"""
return await self._inner.uri()
async def add(
self,
data: DATA,

View File

@@ -2,27 +2,12 @@
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from datetime import timedelta
from lancedb.db import AsyncConnection, DBConnection
import lancedb
import pytest
import pytest_asyncio
def pandas_string_type():
"""Return the PyArrow string type that pandas uses for string columns.
pandas 3.0+ uses large_string for string columns, pandas 2.x uses string.
"""
import pandas as pd
import pyarrow as pa
version = tuple(int(x) for x in pd.__version__.split(".")[:2])
if version >= (3, 0):
return pa.large_utf8()
return pa.utf8()
# Use an in-memory database for most tests.
@pytest.fixture
def mem_db() -> DBConnection:

View File

@@ -268,8 +268,6 @@ async def test_create_table_from_iterator_async(mem_db_async: lancedb.AsyncConne
def test_create_exist_ok(tmp_db: lancedb.DBConnection):
from conftest import pandas_string_type
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
@@ -288,11 +286,10 @@ def test_create_exist_ok(tmp_db: lancedb.DBConnection):
assert tbl.schema == tbl2.schema
assert len(tbl) == len(tbl2)
# pandas 3.0+ uses large_string, pandas 2.x uses string
schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), list_size=2)),
pa.field("item", pandas_string_type()),
pa.field("item", pa.utf8()),
pa.field("price", pa.float64()),
]
)
@@ -302,7 +299,7 @@ def test_create_exist_ok(tmp_db: lancedb.DBConnection):
bad_schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), list_size=2)),
pa.field("item", pandas_string_type()),
pa.field("item", pa.utf8()),
pa.field("price", pa.float64()),
pa.field("extra", pa.float32()),
]
@@ -368,8 +365,6 @@ async def test_create_mode_async(tmp_db_async: lancedb.AsyncConnection):
@pytest.mark.asyncio
async def test_create_exist_ok_async(tmp_db_async: lancedb.AsyncConnection):
from conftest import pandas_string_type
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
@@ -387,11 +382,10 @@ async def test_create_exist_ok_async(tmp_db_async: lancedb.AsyncConnection):
assert tbl.name == tbl2.name
assert await tbl.schema() == await tbl2.schema()
# pandas 3.0+ uses large_string, pandas 2.x uses string
schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), list_size=2)),
pa.field("item", pandas_string_type()),
pa.field("item", pa.utf8()),
pa.field("price", pa.float64()),
]
)
@@ -601,8 +595,6 @@ def test_open_table_sync(tmp_db: lancedb.DBConnection):
@pytest.mark.asyncio
async def test_open_table(tmp_path):
from conftest import pandas_string_type
db = await lancedb.connect_async(tmp_path)
data = pd.DataFrame(
{
@@ -622,11 +614,10 @@ async def test_open_table(tmp_path):
)
is not None
)
# pandas 3.0+ uses large_string, pandas 2.x uses string
assert await tbl.schema() == pa.schema(
{
"vector": pa.list_(pa.float32(), list_size=2),
"item": pandas_string_type(),
"item": pa.utf8(),
"price": pa.float64(),
}
)

View File

@@ -517,36 +517,19 @@ def test_ollama_embedding(tmp_path):
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
@pytest.mark.parametrize(
"model_name,expected_dims",
[
("voyage-3", 1024),
("voyage-4", 1024),
("voyage-4-lite", 1024),
("voyage-4-large", 1024),
],
)
def test_voyageai_embedding_function(model_name, expected_dims, tmp_path):
"""Integration test for VoyageAI text embedding models with real API calls."""
voyageai = get_registry().get("voyageai").create(name=model_name, max_retries=0)
def test_voyageai_embedding_function():
voyageai = get_registry().get("voyageai").create(name="voyage-3", max_retries=0)
class TextModel(LanceModel):
text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect(tmp_path)
db = lancedb.connect("~/lancedb")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
assert voyageai.ndims() == expected_dims, (
f"{model_name} should have {expected_dims} dimensions"
)
# Test search functionality
result = tbl.search("hello").limit(1).to_pandas()
assert result["text"][0] == "hello world"
@pytest.mark.slow
@@ -630,133 +613,6 @@ def test_voyageai_multimodal_embedding_text_function():
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
def test_voyageai_multimodal_35_embedding_function():
"""Test voyage-multimodal-3.5 model with text input."""
voyageai = (
get_registry()
.get("voyageai")
.create(name="voyage-multimodal-3.5", max_retries=0)
)
class TextModel(LanceModel):
text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect("~/lancedb")
tbl = db.create_table("test_multimodal_35", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
assert voyageai.ndims() == 1024
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
def test_voyageai_multimodal_35_flexible_dimensions():
"""Test voyage-multimodal-3.5 model with custom output dimension."""
voyageai = (
get_registry()
.get("voyageai")
.create(name="voyage-multimodal-3.5", output_dimension=512, max_retries=0)
)
class TextModel(LanceModel):
text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
assert voyageai.ndims() == 512
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect("~/lancedb")
tbl = db.create_table("test_multimodal_35_dim", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == 512
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
def test_voyageai_multimodal_35_image_embedding():
"""Test voyage-multimodal-3.5 model with image input."""
voyageai = (
get_registry()
.get("voyageai")
.create(name="voyage-multimodal-3.5", max_retries=0)
)
class Images(LanceModel):
label: str
image_uri: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
db = lancedb.connect("~/lancedb")
table = db.create_table(
"test_multimodal_35_images", schema=Images, mode="overwrite"
)
labels = ["cat", "dog"]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
]
table.add(pd.DataFrame({"label": labels, "image_uri": uris}))
assert len(table.to_pandas()["vector"][0]) == voyageai.ndims()
assert voyageai.ndims() == 1024
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
@pytest.mark.parametrize("dimension", [256, 512, 1024, 2048])
def test_voyageai_multimodal_35_all_dimensions(dimension):
"""Test voyage-multimodal-3.5 model with all valid output dimensions."""
voyageai = (
get_registry()
.get("voyageai")
.create(name="voyage-multimodal-3.5", output_dimension=dimension, max_retries=0)
)
assert voyageai.ndims() == dimension
class TextModel(LanceModel):
text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
df = pd.DataFrame({"text": ["hello world"]})
db = lancedb.connect("~/lancedb")
tbl = db.create_table(
f"test_multimodal_35_dim_{dimension}", schema=TextModel, mode="overwrite"
)
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == dimension
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
def test_voyageai_multimodal_35_invalid_dimension():
"""Test voyage-multimodal-3.5 model raises error for invalid output dimension."""
with pytest.raises(ValueError, match="Invalid output_dimension"):
voyageai = (
get_registry()
.get("voyageai")
.create(name="voyage-multimodal-3.5", output_dimension=999, max_retries=0)
)
# ndims() is where the validation happens
voyageai.ndims()
@pytest.mark.slow
@pytest.mark.skipif(
importlib.util.find_spec("colpali_engine") is None,

View File

@@ -26,8 +26,6 @@ import pytest
from lance_namespace import (
CreateEmptyTableRequest,
CreateEmptyTableResponse,
DeclareTableRequest,
DeclareTableResponse,
DescribeTableRequest,
DescribeTableResponse,
LanceNamespace,
@@ -162,19 +160,6 @@ class TrackingNamespace(LanceNamespace):
return modified
def declare_table(self, request: DeclareTableRequest) -> DeclareTableResponse:
"""Track declare_table calls and inject rotating credentials."""
with self.lock:
self.create_call_count += 1
count = self.create_call_count
response = self.inner.declare_table(request)
response.storage_options = self._modify_storage_options(
response.storage_options, count
)
return response
def create_empty_table(
self, request: CreateEmptyTableRequest
) -> CreateEmptyTableResponse:

View File

@@ -438,15 +438,11 @@ def test_filter_with_splits(mem_db):
row_count = permutation_tbl.count_rows()
assert row_count == 67
# Verify the permutation table only contains row_id and split_id
assert set(permutation_tbl.schema.names) == {"row_id", "split_id"}
row_ids = permutation_tbl.search(None).to_arrow().to_pydict()["row_id"]
data = tbl.take_row_ids(row_ids).to_arrow().to_pydict()
data = permutation_tbl.search(None).to_arrow().to_pydict()
categories = data["category"]
# All categories should be A or B
assert all(cat in ("A", "B") for cat in categories)
assert all(cat in ["A", "B"] for cat in categories)
def test_filter_with_shuffle(mem_db):

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import json
import sys
from datetime import date, datetime
from typing import List, Optional, Tuple
@@ -19,6 +20,10 @@ from pydantic import BaseModel
from pydantic import Field
@pytest.mark.skipif(
sys.version_info < (3, 9),
reason="using native type alias requires python3.9 or higher",
)
def test_pydantic_to_arrow():
class StructModel(pydantic.BaseModel):
a: str
@@ -78,6 +83,10 @@ def test_pydantic_to_arrow():
assert schema == expect_schema
@pytest.mark.skipif(
sys.version_info < (3, 10),
reason="using | type syntax requires python3.10 or higher",
)
def test_optional_types_py310():
class TestModel(pydantic.BaseModel):
a: str | None
@@ -96,233 +105,10 @@ def test_optional_types_py310():
assert schema == expect_schema
def test_optional_structs():
class SplitInfo(pydantic.BaseModel):
start_frame: int
end_frame: int
class TestModel(pydantic.BaseModel):
id: str
split: SplitInfo | None = None
schema = pydantic_to_schema(TestModel)
expect_schema = pa.schema(
[
pa.field("id", pa.utf8(), False),
pa.field(
"split",
pa.struct(
[
pa.field("start_frame", pa.int64(), False),
pa.field("end_frame", pa.int64(), False),
]
),
True,
),
]
)
assert schema == expect_schema
def test_optional_struct_list_py310():
class SplitInfo(pydantic.BaseModel):
start_frame: int
end_frame: int
class TestModel(pydantic.BaseModel):
id: str
splits: list[SplitInfo] | None = None
schema = pydantic_to_schema(TestModel)
expect_schema = pa.schema(
[
pa.field("id", pa.utf8(), False),
pa.field(
"splits",
pa.list_(
pa.struct(
[
pa.field("start_frame", pa.int64(), False),
pa.field("end_frame", pa.int64(), False),
]
)
),
True,
),
]
)
assert schema == expect_schema
def test_nested_struct_list():
class SplitInfo(pydantic.BaseModel):
start_frame: int
end_frame: int
class TestModel(pydantic.BaseModel):
id: str
splits: list[SplitInfo]
schema = pydantic_to_schema(TestModel)
expect_schema = pa.schema(
[
pa.field("id", pa.utf8(), False),
pa.field(
"splits",
pa.list_(
pa.struct(
[
pa.field("start_frame", pa.int64(), False),
pa.field("end_frame", pa.int64(), False),
]
)
),
False,
),
]
)
assert schema == expect_schema
def test_nested_struct_list_optional():
class SplitInfo(pydantic.BaseModel):
start_frame: int
end_frame: int
class TestModel(pydantic.BaseModel):
id: str
splits: Optional[list[SplitInfo]] = None
schema = pydantic_to_schema(TestModel)
expect_schema = pa.schema(
[
pa.field("id", pa.utf8(), False),
pa.field(
"splits",
pa.list_(
pa.struct(
[
pa.field("start_frame", pa.int64(), False),
pa.field("end_frame", pa.int64(), False),
]
)
),
True,
),
]
)
assert schema == expect_schema
def test_nested_struct_list_optional_items():
class SplitInfo(pydantic.BaseModel):
start_frame: int
end_frame: int
class TestModel(pydantic.BaseModel):
id: str
splits: list[Optional[SplitInfo]]
schema = pydantic_to_schema(TestModel)
expect_schema = pa.schema(
[
pa.field("id", pa.utf8(), False),
pa.field(
"splits",
pa.list_(
pa.field(
"item",
pa.struct(
[
pa.field("start_frame", pa.int64(), False),
pa.field("end_frame", pa.int64(), False),
]
),
True,
)
),
False,
),
]
)
assert schema == expect_schema
def test_nested_struct_list_optional_container_and_items():
class SplitInfo(pydantic.BaseModel):
start_frame: int
end_frame: int
class TestModel(pydantic.BaseModel):
id: str
splits: Optional[list[Optional[SplitInfo]]] = None
schema = pydantic_to_schema(TestModel)
expect_schema = pa.schema(
[
pa.field("id", pa.utf8(), False),
pa.field(
"splits",
pa.list_(
pa.field(
"item",
pa.struct(
[
pa.field("start_frame", pa.int64(), False),
pa.field("end_frame", pa.int64(), False),
]
),
True,
)
),
True,
),
]
)
assert schema == expect_schema
def test_nested_struct_list_optional_items_pep604():
class SplitInfo(pydantic.BaseModel):
start_frame: int
end_frame: int
class TestModel(pydantic.BaseModel):
id: str
splits: list[SplitInfo | None]
schema = pydantic_to_schema(TestModel)
expect_schema = pa.schema(
[
pa.field("id", pa.utf8(), False),
pa.field(
"splits",
pa.list_(
pa.field(
"item",
pa.struct(
[
pa.field("start_frame", pa.int64(), False),
pa.field("end_frame", pa.int64(), False),
]
),
True,
)
),
False,
),
]
)
assert schema == expect_schema
@pytest.mark.skipif(
sys.version_info > (3, 8),
reason="using native type alias requires python3.9 or higher",
)
def test_pydantic_to_arrow_py38():
class StructModel(pydantic.BaseModel):
a: str

View File

@@ -8,7 +8,7 @@ import http.server
import json
import threading
import time
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import uuid
from packaging.version import Version
@@ -168,42 +168,6 @@ def test_table_len_sync():
assert len(table) == 1
def test_create_table_exist_ok():
def handler(request):
if request.path == "/v1/table/test/create/?mode=exist_ok":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(b"{}")
else:
request.send_response(404)
request.end_headers()
with mock_lancedb_connection(handler) as db:
table = db.create_table("test", [{"id": 1}], exist_ok=True)
assert table is not None
with mock_lancedb_connection(handler) as db:
table = db.create_table("test", [{"id": 1}], mode="create", exist_ok=True)
assert table is not None
def test_create_table_exist_ok_with_mode_overwrite():
def handler(request):
if request.path == "/v1/table/test/create/?mode=overwrite":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(b"{}")
else:
request.send_response(404)
request.end_headers()
with mock_lancedb_connection(handler) as db:
table = db.create_table("test", [{"id": 1}], mode="overwrite", exist_ok=True)
assert table is not None
@pytest.mark.asyncio
async def test_http_error():
request_id_holder = {"request_id": None}
@@ -1203,22 +1167,3 @@ async def test_header_provider_overrides_static_headers():
extra_headers={"X-API-Key": "static-key", "X-Extra": "extra-value"},
) as db:
await db.table_names()
@pytest.mark.parametrize("exception", [KeyboardInterrupt, SystemExit, GeneratorExit])
def test_background_loop_cancellation(exception):
"""Test that BackgroundEventLoop.run() cancels the future on interrupt."""
from lancedb.background_loop import BackgroundEventLoop
mock_future = MagicMock()
mock_future.result.side_effect = exception()
with (
patch.object(BackgroundEventLoop, "__init__", return_value=None),
patch("asyncio.run_coroutine_threadsafe", return_value=mock_future),
):
loop = BackgroundEventLoop()
loop.loop = MagicMock()
with pytest.raises(exception):
loop.run(None)
mock_future.cancel.assert_called_once()

View File

@@ -1,68 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
"""
Tests for S3 bucket names containing dots.
Related issue: https://github.com/lancedb/lancedb/issues/1898
These tests validate the early error checking for S3 bucket names with dots.
No actual S3 connection is made - validation happens before connection.
"""
import pytest
import lancedb
# Test URIs
BUCKET_WITH_DOTS = "s3://my.bucket.name/path"
BUCKET_WITH_DOTS_AND_REGION = ("s3://my.bucket.name", {"region": "us-east-1"})
BUCKET_WITH_DOTS_AND_AWS_REGION = ("s3://my.bucket.name", {"aws_region": "us-east-1"})
BUCKET_WITHOUT_DOTS = "s3://my-bucket/path"
class TestS3BucketWithDotsSync:
"""Tests for connect()."""
def test_bucket_with_dots_requires_region(self):
with pytest.raises(ValueError, match="contains dots"):
lancedb.connect(BUCKET_WITH_DOTS)
def test_bucket_with_dots_and_region_passes(self):
uri, opts = BUCKET_WITH_DOTS_AND_REGION
db = lancedb.connect(uri, storage_options=opts)
assert db is not None
def test_bucket_with_dots_and_aws_region_passes(self):
uri, opts = BUCKET_WITH_DOTS_AND_AWS_REGION
db = lancedb.connect(uri, storage_options=opts)
assert db is not None
def test_bucket_without_dots_passes(self):
db = lancedb.connect(BUCKET_WITHOUT_DOTS)
assert db is not None
class TestS3BucketWithDotsAsync:
"""Tests for connect_async()."""
@pytest.mark.asyncio
async def test_bucket_with_dots_requires_region(self):
with pytest.raises(ValueError, match="contains dots"):
await lancedb.connect_async(BUCKET_WITH_DOTS)
@pytest.mark.asyncio
async def test_bucket_with_dots_and_region_passes(self):
uri, opts = BUCKET_WITH_DOTS_AND_REGION
db = await lancedb.connect_async(uri, storage_options=opts)
assert db is not None
@pytest.mark.asyncio
async def test_bucket_with_dots_and_aws_region_passes(self):
uri, opts = BUCKET_WITH_DOTS_AND_AWS_REGION
db = await lancedb.connect_async(uri, storage_options=opts)
assert db is not None
@pytest.mark.asyncio
async def test_bucket_without_dots_passes(self):
db = await lancedb.connect_async(BUCKET_WITHOUT_DOTS)
assert db is not None

View File

@@ -1967,9 +1967,3 @@ def test_add_table_with_empty_embeddings(tmp_path):
on_bad_vectors="drop",
)
assert table.count_rows() == 1
def test_table_uri(tmp_path):
db = lancedb.connect(tmp_path)
table = db.create_table("my_table", data=[{"x": 0}])
assert table.uri == str(tmp_path / "my_table.lance")

View File

@@ -528,19 +528,12 @@ def test_sanitize_data(
else:
expected_schema = schema
else:
from conftest import pandas_string_type
# polars uses large_string, pandas 3.0+ uses large_string, others use string
if isinstance(data, pl.DataFrame):
text_type = pa.large_utf8()
elif isinstance(data, pd.DataFrame):
text_type = pandas_string_type()
else:
text_type = pa.string()
expected_schema = pa.schema(
{
"id": pa.int64(),
"text": text_type,
"text": pa.large_utf8()
if isinstance(data, pl.DataFrame)
else pa.string(),
"vector": pa.list_(pa.float32(), 10),
}
)

View File

@@ -1,108 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
"""Unit tests for VoyageAI embedding function.
These tests verify model registration and configuration without requiring API calls.
"""
import pytest
from unittest.mock import MagicMock, patch
from lancedb.embeddings import get_registry
@pytest.fixture(autouse=True)
def reset_voyageai_client():
"""Reset VoyageAI client before and after each test to avoid state pollution."""
from lancedb.embeddings.voyageai import VoyageAIEmbeddingFunction
VoyageAIEmbeddingFunction.client = None
yield
VoyageAIEmbeddingFunction.client = None
class TestVoyageAIModelRegistration:
"""Tests for VoyageAI model registration and configuration."""
@pytest.fixture
def mock_voyageai_client(self):
"""Mock VoyageAI client to avoid API calls."""
with patch.dict("os.environ", {"VOYAGE_API_KEY": "test-key"}):
with patch("lancedb.embeddings.voyageai.attempt_import_or_raise") as mock:
mock_client = MagicMock()
mock_voyageai = MagicMock()
mock_voyageai.Client.return_value = mock_client
mock.return_value = mock_voyageai
yield mock_client
def test_voyageai_registered(self):
"""Test that VoyageAI is registered in the embedding function registry."""
registry = get_registry()
assert registry.get("voyageai") is not None
@pytest.mark.parametrize(
"model_name,expected_dims",
[
# Voyage-4 series (all 1024 dims)
("voyage-4", 1024),
("voyage-4-lite", 1024),
("voyage-4-large", 1024),
# Voyage-3 series
("voyage-3", 1024),
("voyage-3-lite", 512),
# Domain-specific models
("voyage-finance-2", 1024),
("voyage-multilingual-2", 1024),
("voyage-law-2", 1024),
("voyage-code-2", 1536),
# Multimodal
("voyage-multimodal-3", 1024),
],
)
def test_model_dimensions(self, model_name, expected_dims, mock_voyageai_client):
"""Test that each model returns the correct dimensions."""
registry = get_registry()
func = registry.get("voyageai").create(name=model_name)
assert func.ndims() == expected_dims, (
f"Model {model_name} should have {expected_dims} dimensions"
)
def test_unsupported_model_raises_error(self, mock_voyageai_client):
"""Test that unsupported models raise ValueError."""
registry = get_registry()
func = registry.get("voyageai").create(name="unsupported-model")
with pytest.raises(ValueError, match="not supported"):
func.ndims()
@pytest.mark.parametrize(
"model_name",
[
"voyage-4",
"voyage-4-lite",
"voyage-4-large",
],
)
def test_voyage4_models_are_text_models(self, model_name, mock_voyageai_client):
"""Test that voyage-4 models are classified as text models (not multimodal)."""
registry = get_registry()
func = registry.get("voyageai").create(name=model_name)
assert not func._is_multimodal_model(model_name), (
f"{model_name} should be a text model, not multimodal"
)
def test_voyage4_models_in_text_embedding_list(self, mock_voyageai_client):
"""Test that voyage-4 models are in the text_embedding_models list."""
registry = get_registry()
func = registry.get("voyageai").create(name="voyage-4")
assert "voyage-4" in func.text_embedding_models
assert "voyage-4-lite" in func.text_embedding_models
assert "voyage-4-large" in func.text_embedding_models
def test_voyage4_models_not_in_multimodal_list(self, mock_voyageai_client):
"""Test that voyage-4 models are NOT in the multimodal_embedding_models list."""
registry = get_registry()
func = registry.get("voyageai").create(name="voyage-4")
assert "voyage-4" not in func.multimodal_embedding_models
assert "voyage-4-lite" not in func.multimodal_embedding_models
assert "voyage-4-large" not in func.multimodal_embedding_models

View File

@@ -304,7 +304,6 @@ impl Connection {
},
page_token,
limit: limit.map(|l| l as i32),
..Default::default()
};
let response = inner.list_namespaces(request).await.infer_error()?;
Python::with_gil(|py| -> PyResult<Py<PyDict>> {
@@ -327,12 +326,11 @@ impl Connection {
let py = self_.py();
future_into_py(py, async move {
use lance_namespace::models::CreateNamespaceRequest;
// Mode is now a string field
let mode_str = mode.and_then(|m| match m.to_lowercase().as_str() {
"create" => Some("Create".to_string()),
"exist_ok" => Some("ExistOk".to_string()),
"overwrite" => Some("Overwrite".to_string()),
_ => None,
let mode_value = mode.map(|m| match m.to_lowercase().as_str() {
"create" => "Create".to_string(),
"exist_ok" => "ExistOk".to_string(),
"overwrite" => "Overwrite".to_string(),
_ => m,
});
let request = CreateNamespaceRequest {
id: if namespace.is_empty() {
@@ -340,9 +338,8 @@ impl Connection {
} else {
Some(namespace)
},
mode: mode_str,
mode: mode_value,
properties,
..Default::default()
};
let response = inner.create_namespace(request).await.infer_error()?;
Python::with_gil(|py| -> PyResult<Py<PyDict>> {
@@ -364,16 +361,15 @@ impl Connection {
let py = self_.py();
future_into_py(py, async move {
use lance_namespace::models::DropNamespaceRequest;
// Mode and Behavior are now string fields
let mode_str = mode.and_then(|m| match m.to_uppercase().as_str() {
"SKIP" => Some("Skip".to_string()),
"FAIL" => Some("Fail".to_string()),
_ => None,
let mode_value = mode.map(|m| match m.to_uppercase().as_str() {
"SKIP" => "Skip".to_string(),
"FAIL" => "Fail".to_string(),
_ => m,
});
let behavior_str = behavior.and_then(|b| match b.to_uppercase().as_str() {
"RESTRICT" => Some("Restrict".to_string()),
"CASCADE" => Some("Cascade".to_string()),
_ => None,
let behavior_value = behavior.map(|b| match b.to_uppercase().as_str() {
"RESTRICT" => "Restrict".to_string(),
"CASCADE" => "Cascade".to_string(),
_ => b,
});
let request = DropNamespaceRequest {
id: if namespace.is_empty() {
@@ -381,9 +377,8 @@ impl Connection {
} else {
Some(namespace)
},
mode: mode_str,
behavior: behavior_str,
..Default::default()
mode: mode_value,
behavior: behavior_value,
};
let response = inner.drop_namespace(request).await.infer_error()?;
Python::with_gil(|py| -> PyResult<Py<PyDict>> {
@@ -410,7 +405,6 @@ impl Connection {
} else {
Some(namespace)
},
..Default::default()
};
let response = inner.describe_namespace(request).await.infer_error()?;
Python::with_gil(|py| -> PyResult<Py<PyDict>> {
@@ -440,7 +434,6 @@ impl Connection {
},
page_token,
limit: limit.map(|l| l as i32),
..Default::default()
};
let response = inner.list_tables(request).await.infer_error()?;
Python::with_gil(|py| -> PyResult<Py<PyDict>> {

View File

@@ -497,11 +497,6 @@ impl Table {
})
}
pub fn uri(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move { inner.uri().await.infer_error() })
}
pub fn __repr__(&self) -> String {
match &self.inner {
None => format!("ClosedTable({})", self.name),

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb"
version = "0.24.1"
version = "0.23.1-beta.1"
edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true
@@ -104,16 +104,11 @@ test-log = "0.2"
[features]
default = []
default = ["aws", "gcs", "azure", "dynamodb", "oss"]
aws = ["lance/aws", "lance-io/aws", "lance-namespace-impls/dir-aws"]
oss = ["lance/oss", "lance-io/oss", "lance-namespace-impls/dir-oss"]
gcs = ["lance/gcp", "lance-io/gcp", "lance-namespace-impls/dir-gcp"]
azure = ["lance/azure", "lance-io/azure", "lance-namespace-impls/dir-azure"]
huggingface = [
"lance/huggingface",
"lance-io/huggingface",
"lance-namespace-impls/dir-huggingface",
]
dynamodb = ["lance/dynamodb", "aws"]
remote = ["dep:reqwest", "dep:http", "lance-namespace-impls/rest", "lance-namespace-impls/rest-adapter"]
fp16kernels = ["lance-linalg/fp16kernels"]
@@ -153,6 +148,3 @@ name = "ivf_pq"
[[example]]
name = "hybrid_search"
required-features = ["sentence-transformers"]
[package.metadata.docs.rs]
all-features = true

View File

@@ -36,42 +36,10 @@ use crate::remote::{
};
use crate::table::{TableDefinition, WriteOptions};
use crate::Table;
use lance::io::ObjectStoreParams;
pub use lance_encoding::version::LanceFileVersion;
#[cfg(feature = "remote")]
use lance_io::object_store::StorageOptions;
use lance_io::object_store::{StorageOptionsAccessor, StorageOptionsProvider};
fn merge_storage_options(
store_params: &mut ObjectStoreParams,
pairs: impl IntoIterator<Item = (String, String)>,
) {
let mut options = store_params.storage_options().cloned().unwrap_or_default();
for (key, value) in pairs {
options.insert(key, value);
}
let provider = store_params
.storage_options_accessor
.as_ref()
.and_then(|accessor| accessor.provider().cloned());
let accessor = if let Some(provider) = provider {
StorageOptionsAccessor::with_initial_and_provider(options, provider)
} else {
StorageOptionsAccessor::with_static_options(options)
};
store_params.storage_options_accessor = Some(Arc::new(accessor));
}
fn set_storage_options_provider(
store_params: &mut ObjectStoreParams,
provider: Arc<dyn StorageOptionsProvider>,
) {
let accessor = match store_params.storage_options().cloned() {
Some(options) => StorageOptionsAccessor::with_initial_and_provider(options, provider),
None => StorageOptionsAccessor::with_provider(provider),
};
store_params.storage_options_accessor = Some(Arc::new(accessor));
}
use lance_io::object_store::StorageOptionsProvider;
/// A builder for configuring a [`Connection::table_names`] operation
pub struct TableNamesBuilder {
@@ -251,36 +219,8 @@ impl CreateTableBuilder<false> {
/// Execute the create table operation
pub async fn execute(self) -> Result<Table> {
let parent = self.parent.clone();
let embedding_registry = self.embedding_registry.clone();
let request = self.into_request()?;
Ok(Table::new_with_embedding_registry(
parent.create_table(request).await?,
parent,
embedding_registry,
))
}
fn into_request(self) -> Result<CreateTableRequest> {
if self.embeddings.is_empty() {
return Ok(self.request);
}
let CreateTableData::Empty(table_def) = self.request.data else {
unreachable!("CreateTableBuilder<false> should always have Empty data")
};
let schema = table_def.schema.clone();
let empty_batch = arrow_array::RecordBatch::new_empty(schema.clone());
let reader = Box::new(std::iter::once(Ok(empty_batch)).collect::<Vec<_>>());
let reader = arrow_array::RecordBatchIterator::new(reader.into_iter(), schema);
let with_embeddings = WithEmbeddings::new(reader, self.embeddings);
let table_definition = with_embeddings.table_definition()?;
Ok(CreateTableRequest {
data: CreateTableData::Empty(table_definition),
..self.request
})
let table = parent.create_table(self.request).await?;
Ok(Table::new(table, parent))
}
}
@@ -306,14 +246,16 @@ impl<const HAS_DATA: bool> CreateTableBuilder<HAS_DATA> {
///
/// See available options at <https://lancedb.com/docs/storage/>
pub fn storage_option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
let store_params = self
let store_options = self
.request
.write_options
.lance_write_params
.get_or_insert(Default::default())
.store_params
.get_or_insert(Default::default())
.storage_options
.get_or_insert(Default::default());
merge_storage_options(store_params, [(key.into(), value.into())]);
store_options.insert(key.into(), value.into());
self
}
@@ -327,17 +269,19 @@ impl<const HAS_DATA: bool> CreateTableBuilder<HAS_DATA> {
mut self,
pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
) -> Self {
let store_params = self
let store_options = self
.request
.write_options
.lance_write_params
.get_or_insert(Default::default())
.store_params
.get_or_insert(Default::default())
.storage_options
.get_or_insert(Default::default());
let updates = pairs
.into_iter()
.map(|(key, value)| (key.into(), value.into()));
merge_storage_options(store_params, updates);
for (key, value) in pairs {
store_options.insert(key.into(), value.into());
}
self
}
@@ -374,21 +318,23 @@ impl<const HAS_DATA: bool> CreateTableBuilder<HAS_DATA> {
/// This has no effect in LanceDB Cloud.
#[deprecated(since = "0.15.1", note = "Use `database_options` instead")]
pub fn enable_v2_manifest_paths(mut self, use_v2_manifest_paths: bool) -> Self {
let store_params = self
let storage_options = self
.request
.write_options
.lance_write_params
.get_or_insert_with(Default::default)
.store_params
.get_or_insert_with(Default::default)
.storage_options
.get_or_insert_with(Default::default);
let value = if use_v2_manifest_paths {
"true".to_string()
} else {
"false".to_string()
};
merge_storage_options(
store_params,
[(OPT_NEW_TABLE_V2_MANIFEST_PATHS.to_string(), value)],
storage_options.insert(
OPT_NEW_TABLE_V2_MANIFEST_PATHS.to_string(),
if use_v2_manifest_paths {
"true".to_string()
} else {
"false".to_string()
},
);
self
}
@@ -398,19 +344,19 @@ impl<const HAS_DATA: bool> CreateTableBuilder<HAS_DATA> {
/// The default is `LanceFileVersion::Stable`.
#[deprecated(since = "0.15.1", note = "Use `database_options` instead")]
pub fn data_storage_version(mut self, data_storage_version: LanceFileVersion) -> Self {
let store_params = self
let storage_options = self
.request
.write_options
.lance_write_params
.get_or_insert_with(Default::default)
.store_params
.get_or_insert_with(Default::default)
.storage_options
.get_or_insert_with(Default::default);
merge_storage_options(
store_params,
[(
OPT_NEW_TABLE_STORAGE_VERSION.to_string(),
data_storage_version.to_string(),
)],
storage_options.insert(
OPT_NEW_TABLE_STORAGE_VERSION.to_string(),
data_storage_version.to_string(),
);
self
}
@@ -435,14 +381,13 @@ impl<const HAS_DATA: bool> CreateTableBuilder<HAS_DATA> {
/// This allows tables to automatically refresh cloud storage credentials
/// when they expire, enabling long-running operations on remote storage.
pub fn storage_options_provider(mut self, provider: Arc<dyn StorageOptionsProvider>) -> Self {
let store_params = self
.request
self.request
.write_options
.lance_write_params
.get_or_insert(Default::default())
.store_params
.get_or_insert(Default::default());
set_storage_options_provider(store_params, provider);
.get_or_insert(Default::default())
.storage_options_provider = Some(provider);
self
}
}
@@ -505,13 +450,15 @@ impl OpenTableBuilder {
///
/// See available options at <https://lancedb.com/docs/storage/>
pub fn storage_option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
let store_params = self
let storage_options = self
.request
.lance_read_params
.get_or_insert(Default::default())
.store_options
.get_or_insert(Default::default())
.storage_options
.get_or_insert(Default::default());
merge_storage_options(store_params, [(key.into(), value.into())]);
storage_options.insert(key.into(), value.into());
self
}
@@ -525,16 +472,18 @@ impl OpenTableBuilder {
mut self,
pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
) -> Self {
let store_params = self
let storage_options = self
.request
.lance_read_params
.get_or_insert(Default::default())
.store_options
.get_or_insert(Default::default())
.storage_options
.get_or_insert(Default::default());
let updates = pairs
.into_iter()
.map(|(key, value)| (key.into(), value.into()));
merge_storage_options(store_params, updates);
for (key, value) in pairs {
storage_options.insert(key.into(), value.into());
}
self
}
@@ -558,13 +507,12 @@ impl OpenTableBuilder {
/// This allows tables to automatically refresh cloud storage credentials
/// when they expire, enabling long-running operations on remote storage.
pub fn storage_options_provider(mut self, provider: Arc<dyn StorageOptionsProvider>) -> Self {
let store_params = self
.request
self.request
.lance_read_params
.get_or_insert(Default::default())
.store_options
.get_or_insert(Default::default());
set_storage_options_provider(store_params, provider);
.get_or_insert(Default::default())
.storage_options_provider = Some(provider);
self
}
@@ -920,10 +868,6 @@ pub struct ConnectBuilder {
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
}
#[cfg(feature = "remote")]
const ENV_VARS_TO_STORAGE_OPTS: [(&str, &str); 1] =
[("AZURE_STORAGE_ACCOUNT_NAME", "azure_storage_account_name")];
impl ConnectBuilder {
/// Create a new [`ConnectOptions`] with the given database URI.
pub fn new(uri: &str) -> Self {
@@ -1107,27 +1051,11 @@ impl ConnectBuilder {
self
}
#[cfg(feature = "remote")]
fn apply_env_defaults(
env_var_to_remote_storage_option: &[(&str, &str)],
options: &mut HashMap<String, String>,
) {
for (env_key, opt_key) in env_var_to_remote_storage_option {
if let Ok(env_value) = std::env::var(env_key) {
if !options.contains_key(*opt_key) {
options.insert((*opt_key).to_string(), env_value);
}
}
}
}
#[cfg(feature = "remote")]
fn execute_remote(self) -> Result<Connection> {
use crate::remote::db::RemoteDatabaseOptions;
let mut merged_options = self.request.options.clone();
Self::apply_env_defaults(&ENV_VARS_TO_STORAGE_OPTS, &mut merged_options);
let options = RemoteDatabaseOptions::parse_from_map(&merged_options)?;
let options = RemoteDatabaseOptions::parse_from_map(&self.request.options)?;
let region = options.region.ok_or_else(|| Error::InvalidInput {
message: "A region is required when connecting to LanceDb Cloud".to_string(),
@@ -1349,6 +1277,8 @@ mod test_utils {
#[cfg(test)]
mod tests {
use std::fs::create_dir_all;
use crate::database::listing::{ListingDatabaseOptions, NewTableConfig};
use crate::query::QueryBase;
use crate::query::{ExecutableQuery, QueryExecutionOptions};
@@ -1372,23 +1302,6 @@ mod tests {
assert_eq!(tc.connection.uri(), tc.uri);
}
#[cfg(feature = "remote")]
#[test]
fn test_apply_env_defaults() {
let env_key = "TEST_APPLY_ENV_DEFAULTS_ENVIRONMENT_VARIABLE_ENV_KEY";
let env_val = "TEST_APPLY_ENV_DEFAULTS_ENVIRONMENT_VARIABLE_ENV_VAL";
let opts_key = "test_apply_env_defaults_environment_variable_opts_key";
std::env::set_var(env_key, env_val);
let mut options = HashMap::new();
ConnectBuilder::apply_env_defaults(&[(env_key, opts_key)], &mut options);
assert_eq!(Some(&env_val.to_string()), options.get(opts_key));
options.insert(opts_key.to_string(), "EXPLICIT-VALUE".to_string());
ConnectBuilder::apply_env_defaults(&[(env_key, opts_key)], &mut options);
assert_eq!(Some(&"EXPLICIT-VALUE".to_string()), options.get(opts_key));
}
#[cfg(not(windows))]
#[tokio::test]
async fn test_connect_relative() {
@@ -1412,27 +1325,25 @@ mod tests {
#[tokio::test]
async fn test_table_names() {
let tc = new_test_connection().await.unwrap();
let db = tc.connection;
let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)]));
let tmp_dir = tempdir().unwrap();
let mut names = Vec::with_capacity(100);
for _ in 0..100 {
let name = uuid::Uuid::new_v4().to_string();
let mut name = uuid::Uuid::new_v4().to_string();
names.push(name.clone());
db.create_empty_table(name, schema.clone())
.execute()
.await
.unwrap();
name.push_str(".lance");
create_dir_all(tmp_dir.path().join(&name)).unwrap();
}
names.sort();
let tables = db.table_names().limit(100).execute().await.unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = connect(uri).execute().await.unwrap();
let tables = db.table_names().execute().await.unwrap();
assert_eq!(tables, names);
let tables = db
.table_names()
.start_after(&names[30])
.limit(100)
.execute()
.await
.unwrap();
@@ -1613,27 +1524,18 @@ mod tests {
#[tokio::test]
async fn drop_table() {
let tc = new_test_connection().await.unwrap();
let db = tc.connection;
let tmp_dir = tempdir().unwrap();
if tc.is_remote {
// All the typical endpoints such as s3:///, file-object-store:///, etc. treat drop_table
// as idempotent.
assert!(db.drop_table("invalid_table", &[]).await.is_ok());
} else {
// The behavior of drop_table when using a file:/// endpoint differs from all other
// object providers, in that it returns an error when deleting a non-existent table.
assert!(matches!(
db.drop_table("invalid_table", &[]).await,
Err(crate::Error::TableNotFound { .. }),
));
}
let uri = tmp_dir.path().to_str().unwrap();
let db = connect(uri).execute().await.unwrap();
let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)]));
db.create_empty_table("table1", schema.clone())
.execute()
.await
.unwrap();
// drop non-exist table
assert!(matches!(
db.drop_table("invalid_table", &[]).await,
Err(crate::Error::TableNotFound { .. }),
));
create_dir_all(tmp_dir.path().join("table1.lance")).unwrap();
db.drop_table("table1", &[]).await.unwrap();
let tables = db.table_names().execute().await.unwrap();
@@ -1720,128 +1622,4 @@ mod tests {
let cloned_count = cloned_table.count_rows(None).await.unwrap();
assert_eq!(source_count, cloned_count);
}
#[tokio::test]
async fn test_create_empty_table_with_embeddings() {
use crate::embeddings::{EmbeddingDefinition, EmbeddingFunction};
use arrow_array::{
Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, StringArray,
};
use std::borrow::Cow;
#[derive(Debug, Clone)]
struct MockEmbedding {
dim: usize,
}
impl EmbeddingFunction for MockEmbedding {
fn name(&self) -> &str {
"test_embedding"
}
fn source_type(&self) -> Result<Cow<'_, DataType>> {
Ok(Cow::Owned(DataType::Utf8))
}
fn dest_type(&self) -> Result<Cow<'_, DataType>> {
Ok(Cow::Owned(DataType::new_fixed_size_list(
DataType::Float32,
self.dim as i32,
true,
)))
}
fn compute_source_embeddings(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
let len = source.len();
let values = vec![1.0f32; len * self.dim];
let values = Arc::new(Float32Array::from(values));
let field = Arc::new(Field::new("item", DataType::Float32, true));
Ok(Arc::new(FixedSizeListArray::new(
field,
self.dim as i32,
values,
None,
)))
}
fn compute_query_embeddings(&self, _input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
unimplemented!()
}
}
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = connect(uri).execute().await.unwrap();
let embed_func = Arc::new(MockEmbedding { dim: 128 });
db.embedding_registry()
.register("test_embedding", embed_func.clone())
.unwrap();
let schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)]));
let ed = EmbeddingDefinition {
source_column: "name".to_owned(),
dest_column: Some("name_embedding".to_owned()),
embedding_name: "test_embedding".to_owned(),
};
let table = db
.create_empty_table("test", schema)
.mode(CreateTableMode::Overwrite)
.add_embedding(ed)
.unwrap()
.execute()
.await
.unwrap();
let table_schema = table.schema().await.unwrap();
assert!(table_schema.column_with_name("name").is_some());
assert!(table_schema.column_with_name("name_embedding").is_some());
let embedding_field = table_schema.field_with_name("name_embedding").unwrap();
assert_eq!(
embedding_field.data_type(),
&DataType::new_fixed_size_list(DataType::Float32, 128, true)
);
let input_schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)]));
let input_batch = RecordBatch::try_new(
input_schema.clone(),
vec![Arc::new(StringArray::from(vec![
Some("Alice"),
Some("Bob"),
Some("Charlie"),
]))],
)
.unwrap();
let input_reader = Box::new(RecordBatchIterator::new(
vec![Ok(input_batch)].into_iter(),
input_schema,
));
table.add(input_reader).execute().await.unwrap();
let results = table
.query()
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(results.len(), 1);
let batch = &results[0];
assert_eq!(batch.num_rows(), 3);
assert!(batch.column_by_name("name_embedding").is_some());
let embedding_col = batch
.column_by_name("name_embedding")
.unwrap()
.as_any()
.downcast_ref::<FixedSizeListArray>()
.unwrap();
assert_eq!(embedding_col.len(), 3);
}
}

View File

@@ -12,7 +12,7 @@ use lance::dataset::{builder::DatasetBuilder, ReadParams, WriteMode};
use lance::io::{ObjectStore, ObjectStoreParams, WrappingObjectStore};
use lance_datafusion::utils::StreamingWriteSource;
use lance_encoding::version::LanceFileVersion;
use lance_io::object_store::{StorageOptionsAccessor, StorageOptionsProvider};
use lance_io::object_store::StorageOptionsProvider;
use lance_table::io::commit::commit_handler_from_url;
use object_store::local::LocalFileSystem;
use snafu::ResultExt;
@@ -356,13 +356,7 @@ impl ListingDatabase {
.clone()
.unwrap_or_else(|| Arc::new(lance::session::Session::default()));
let os_params = ObjectStoreParams {
storage_options_accessor: if options.storage_options.is_empty() {
None
} else {
Some(Arc::new(StorageOptionsAccessor::with_static_options(
options.storage_options.clone(),
)))
},
storage_options: Some(options.storage_options.clone()),
..Default::default()
};
let (object_store, base_path) = ObjectStore::from_uri_and_params(
@@ -469,20 +463,9 @@ impl ListingDatabase {
validate_table_name(name)?;
let mut uri = self.uri.clone();
// If the URI does not end with a path separator, add one
// Use forward slash for URIs (http://, s3://, gs://, file://, etc.)
// Use platform-specific separator for local paths without scheme
let has_scheme = uri.contains("://");
let ends_with_separator = uri.ends_with('/') || uri.ends_with('\\');
if !ends_with_separator {
if has_scheme {
// URIs always use forward slash
uri.push('/');
} else {
// Local path without scheme - use platform separator
uri.push(std::path::MAIN_SEPARATOR);
}
// If the URI does not end with a slash, add one
if !uri.ends_with('/') {
uri.push('/');
}
// Append the table name with the lance file extension
uri.push_str(&format!("{}.{}", name, LANCE_FILE_EXTENSION));
@@ -498,13 +481,7 @@ impl ListingDatabase {
async fn drop_tables(&self, names: Vec<String>) -> Result<()> {
let object_store_params = ObjectStoreParams {
storage_options_accessor: if self.storage_options.is_empty() {
None
} else {
Some(Arc::new(StorageOptionsAccessor::with_static_options(
self.storage_options.clone(),
)))
},
storage_options: Some(self.storage_options.clone()),
..Default::default()
};
let mut uri = self.uri.clone();
@@ -553,7 +530,7 @@ impl ListingDatabase {
.lance_write_params
.as_ref()
.and_then(|p| p.store_params.as_ref())
.and_then(|sp| sp.storage_options());
.and_then(|sp| sp.storage_options.as_ref());
let storage_version_override = storage_options
.and_then(|opts| opts.get(OPT_NEW_TABLE_STORAGE_VERSION))
@@ -604,20 +581,21 @@ impl ListingDatabase {
// will cause a new connection to be created, and that connection will
// be dropped from the cache when python GCs the table object, which
// confounds reuse across tables.
if !self.storage_options.is_empty() || self.storage_options_provider.is_some() {
let store_params = write_params
if !self.storage_options.is_empty() {
let storage_options = write_params
.store_params
.get_or_insert_with(Default::default)
.storage_options
.get_or_insert_with(Default::default);
let mut storage_options = store_params.storage_options().cloned().unwrap_or_default();
if !self.storage_options.is_empty() {
self.inherit_storage_options(&mut storage_options);
}
let accessor = if let Some(ref provider) = self.storage_options_provider {
StorageOptionsAccessor::with_initial_and_provider(storage_options, provider.clone())
} else {
StorageOptionsAccessor::with_static_options(storage_options)
};
store_params.storage_options_accessor = Some(Arc::new(accessor));
self.inherit_storage_options(storage_options);
}
// Set storage options provider if available
if self.storage_options_provider.is_some() {
write_params
.store_params
.get_or_insert_with(Default::default)
.storage_options_provider = self.storage_options_provider.clone();
}
write_params.data_storage_version = self
@@ -903,13 +881,7 @@ impl Database for ListingDatabase {
validate_table_name(&request.target_table_name)?;
let storage_params = ObjectStoreParams {
storage_options_accessor: if self.storage_options.is_empty() {
None
} else {
Some(Arc::new(StorageOptionsAccessor::with_static_options(
self.storage_options.clone(),
)))
},
storage_options: Some(self.storage_options.clone()),
..Default::default()
};
let read_params = ReadParams {
@@ -973,28 +945,25 @@ impl Database for ListingDatabase {
// will cause a new connection to be created, and that connection will
// be dropped from the cache when python GCs the table object, which
// confounds reuse across tables.
if !self.storage_options.is_empty() || self.storage_options_provider.is_some() {
let store_params = request
if !self.storage_options.is_empty() {
let storage_options = request
.lance_read_params
.get_or_insert_with(Default::default)
.store_options
.get_or_insert_with(Default::default)
.storage_options
.get_or_insert_with(Default::default);
let mut storage_options = store_params.storage_options().cloned().unwrap_or_default();
if !self.storage_options.is_empty() {
self.inherit_storage_options(&mut storage_options);
}
// Preserve request-level provider if no connection-level provider exists
let request_provider = store_params
.storage_options_accessor
.as_ref()
.and_then(|a| a.provider().cloned());
let provider = self.storage_options_provider.clone().or(request_provider);
let accessor = if let Some(provider) = provider {
StorageOptionsAccessor::with_initial_and_provider(storage_options, provider)
} else {
StorageOptionsAccessor::with_static_options(storage_options)
};
store_params.storage_options_accessor = Some(Arc::new(accessor));
self.inherit_storage_options(storage_options);
}
// Set storage options provider if available
if self.storage_options_provider.is_some() {
request
.lance_read_params
.get_or_insert_with(Default::default)
.store_options
.get_or_insert_with(Default::default)
.storage_options_provider = self.storage_options_provider.clone();
}
// Some ReadParams are exposed in the OpenTableBuilder, but we also
@@ -1102,7 +1071,6 @@ mod tests {
use crate::table::{Table, TableDefinition};
use arrow_array::{Int32Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema};
use std::path::PathBuf;
use tempfile::tempdir;
async fn setup_database() -> (tempfile::TempDir, ListingDatabase) {
@@ -1901,9 +1869,7 @@ mod tests {
let write_options = WriteOptions {
lance_write_params: Some(lance::dataset::WriteParams {
store_params: Some(lance::io::ObjectStoreParams {
storage_options_accessor: Some(Arc::new(
StorageOptionsAccessor::with_static_options(storage_options),
)),
storage_options: Some(storage_options),
..Default::default()
}),
..Default::default()
@@ -1977,9 +1943,7 @@ mod tests {
let write_options = WriteOptions {
lance_write_params: Some(lance::dataset::WriteParams {
store_params: Some(lance::io::ObjectStoreParams {
storage_options_accessor: Some(Arc::new(
StorageOptionsAccessor::with_static_options(storage_options),
)),
storage_options: Some(storage_options),
..Default::default()
}),
..Default::default()
@@ -2082,19 +2046,6 @@ mod tests {
assert_eq!(db_options.new_table_config.enable_stable_row_ids, None);
}
#[tokio::test]
async fn test_table_uri() {
let (_tempdir, db) = setup_database().await;
let mut pb = PathBuf::new();
pb.push(db.uri.clone());
pb.push("test.lance");
let expected = pb.to_str().unwrap();
let uri = db.table_uri("test").ok().unwrap();
assert_eq!(uri, expected);
}
#[tokio::test]
async fn test_namespace_client() {
let (_tempdir, db) = setup_database().await;

View File

@@ -10,14 +10,13 @@ use async_trait::async_trait;
use lance_namespace::{
models::{
CreateEmptyTableRequest, CreateNamespaceRequest, CreateNamespaceResponse,
DeclareTableRequest, DescribeNamespaceRequest, DescribeNamespaceResponse,
DescribeTableRequest, DropNamespaceRequest, DropNamespaceResponse, DropTableRequest,
ListNamespacesRequest, ListNamespacesResponse, ListTablesRequest, ListTablesResponse,
DescribeNamespaceRequest, DescribeNamespaceResponse, DescribeTableRequest,
DropNamespaceRequest, DropNamespaceResponse, DropTableRequest, ListNamespacesRequest,
ListNamespacesResponse, ListTablesRequest, ListTablesResponse,
},
LanceNamespace,
};
use lance_namespace_impls::ConnectBuilder;
use log::warn;
use crate::database::ReadConsistency;
use crate::error::{Error, Result};
@@ -138,7 +137,6 @@ impl Database for LanceNamespaceDatabase {
id: Some(request.namespace),
page_token: request.start_after,
limit: request.limit.map(|l| l as i32),
..Default::default()
};
let response = self.namespace.list_tables(ns_request).await?;
@@ -155,7 +153,8 @@ impl Database for LanceNamespaceDatabase {
table_id.push(request.name.clone());
let describe_request = DescribeTableRequest {
id: Some(table_id.clone()),
..Default::default()
version: None,
with_table_uri: None,
};
let describe_result = self.namespace.describe_table(describe_request).await;
@@ -173,7 +172,6 @@ impl Database for LanceNamespaceDatabase {
// Drop the existing table - must succeed
let drop_request = DropTableRequest {
id: Some(table_id.clone()),
..Default::default()
};
self.namespace
.drop_table(drop_request)
@@ -205,53 +203,29 @@ impl Database for LanceNamespaceDatabase {
let mut table_id = request.namespace.clone();
table_id.push(request.name.clone());
// Try declare_table first, falling back to create_empty_table for backwards
// compatibility with older namespace clients that don't support declare_table
let declare_request = DeclareTableRequest {
let create_empty_request = CreateEmptyTableRequest {
id: Some(table_id.clone()),
..Default::default()
location: None,
properties: if self.storage_options.is_empty() {
None
} else {
Some(self.storage_options.clone())
},
};
let location = match self.namespace.declare_table(declare_request).await {
Ok(response) => response.location.ok_or_else(|| Error::Runtime {
message: "Table location is missing from declare_table response".to_string(),
})?,
Err(e) => {
// Check if the error is "not supported" and try create_empty_table as fallback
let err_str = e.to_string().to_lowercase();
if err_str.contains("not supported") || err_str.contains("not implemented") {
warn!(
"declare_table is not supported by the namespace client, \
falling back to deprecated create_empty_table. \
create_empty_table is deprecated and will be removed in Lance 3.0.0. \
Please upgrade your namespace client to support declare_table."
);
#[allow(deprecated)]
let create_empty_request = CreateEmptyTableRequest {
id: Some(table_id.clone()),
..Default::default()
};
let create_empty_response = self
.namespace
.create_empty_table(create_empty_request)
.await
.map_err(|e| Error::Runtime {
message: format!("Failed to create empty table: {}", e),
})?;
#[allow(deprecated)]
let create_response = self
.namespace
.create_empty_table(create_empty_request)
.await
.map_err(|e| Error::Runtime {
message: format!("Failed to create empty table: {}", e),
})?;
create_response.location.ok_or_else(|| Error::Runtime {
message: "Table location is missing from create_empty_table response"
.to_string(),
})?
} else {
return Err(Error::Runtime {
message: format!("Failed to declare table: {}", e),
});
}
}
};
let location = create_empty_response
.location
.ok_or_else(|| Error::Runtime {
message: "Table location is missing from create_empty_table response".to_string(),
})?;
let native_table = NativeTable::create_from_namespace(
self.namespace.clone(),
@@ -308,10 +282,7 @@ impl Database for LanceNamespaceDatabase {
let mut table_id = namespace.to_vec();
table_id.push(name.to_string());
let drop_request = DropTableRequest {
id: Some(table_id),
..Default::default()
};
let drop_request = DropTableRequest { id: Some(table_id) };
self.namespace
.drop_table(drop_request)
.await
@@ -466,7 +437,8 @@ mod tests {
// Create a child namespace first
conn.create_namespace(CreateNamespaceRequest {
id: Some(vec!["test_ns".into()]),
..Default::default()
mode: None,
properties: None,
})
.await
.expect("Failed to create namespace");
@@ -526,7 +498,8 @@ mod tests {
// Create a child namespace first
conn.create_namespace(CreateNamespaceRequest {
id: Some(vec!["test_ns".into()]),
..Default::default()
mode: None,
properties: None,
})
.await
.expect("Failed to create namespace");
@@ -589,7 +562,8 @@ mod tests {
// Create a child namespace first
conn.create_namespace(CreateNamespaceRequest {
id: Some(vec!["test_ns".into()]),
..Default::default()
mode: None,
properties: None,
})
.await
.expect("Failed to create namespace");
@@ -672,7 +646,8 @@ mod tests {
// Create a child namespace first
conn.create_namespace(CreateNamespaceRequest {
id: Some(vec!["test_ns".into()]),
..Default::default()
mode: None,
properties: None,
})
.await
.expect("Failed to create namespace");
@@ -727,7 +702,8 @@ mod tests {
// Create a child namespace first
conn.create_namespace(CreateNamespaceRequest {
id: Some(vec!["test_ns".into()]),
..Default::default()
mode: None,
properties: None,
})
.await
.expect("Failed to create namespace");
@@ -807,7 +783,8 @@ mod tests {
// Create a child namespace first
conn.create_namespace(CreateNamespaceRequest {
id: Some(vec!["test_ns".into()]),
..Default::default()
mode: None,
properties: None,
})
.await
.expect("Failed to create namespace");
@@ -840,7 +817,8 @@ mod tests {
// Create a child namespace first
conn.create_namespace(CreateNamespaceRequest {
id: Some(vec!["test_ns".into()]),
..Default::default()
mode: None,
properties: None,
})
.await
.expect("Failed to create namespace");

View File

@@ -19,7 +19,7 @@ use crate::{
split::{SplitStrategy, Splitter, SPLIT_ID_COLUMN},
util::{rename_column, TemporaryDirectory},
},
query::{ExecutableQuery, QueryBase, Select},
query::{ExecutableQuery, QueryBase},
Error, Result, Table,
};
@@ -27,8 +27,6 @@ pub const SRC_ROW_ID_COL: &str = "row_id";
pub const SPLIT_NAMES_CONFIG_KEY: &str = "split_names";
pub const DEFAULT_MEMORY_LIMIT: usize = 100 * 1024 * 1024;
/// Where to store the permutation table
#[derive(Debug, Clone, Default)]
enum PermutationDestination {
@@ -169,20 +167,10 @@ impl PermutationBuilder {
&self,
data: SendableRecordBatchStream,
) -> Result<SendableRecordBatchStream> {
let memory_limit = std::env::var("LANCEDB_PERM_BUILDER_MEMORY_LIMIT")
.unwrap_or_else(|_| DEFAULT_MEMORY_LIMIT.to_string())
.parse::<usize>()
.unwrap_or_else(|_| {
log::error!(
"Failed to parse LANCEDB_PERM_BUILDER_MEMORY_LIMIT, using default: {}",
DEFAULT_MEMORY_LIMIT
);
DEFAULT_MEMORY_LIMIT
});
let ctx = SessionContext::new_with_config_rt(
SessionConfig::default(),
RuntimeEnvBuilder::new()
.with_memory_limit(memory_limit, 1.0)
.with_memory_limit(100 * 1024 * 1024, 1.0)
.with_disk_manager_builder(
DiskManagerBuilder::default()
.with_mode(self.config.temp_dir.to_disk_manager_mode()),
@@ -244,7 +232,7 @@ impl PermutationBuilder {
/// Builds the permutation table and stores it in the given database.
pub async fn build(self) -> Result<Table> {
// First pass, apply filter and load row ids
let mut rows = self.base_table.query().select(Select::columns(&[ROW_ID]));
let mut rows = self.base_table.query().with_row_id();
if let Some(filter) = &self.config.filter {
rows = rows.only_if(filter);
@@ -333,47 +321,6 @@ mod tests {
use super::*;
#[tokio::test]
async fn test_permutation_table_only_stores_row_id_and_split_id() {
let temp_dir = tempfile::tempdir().unwrap();
let db = connect(temp_dir.path().to_str().unwrap())
.execute()
.await
.unwrap();
let initial_data = lance_datagen::gen_batch()
.col("col_a", lance_datagen::array::step::<Int32Type>())
.col("col_b", lance_datagen::array::step::<Int32Type>())
.into_ldb_stream(RowCount::from(100), BatchCount::from(10));
let data_table = db
.create_table_streaming("base_tbl", initial_data)
.execute()
.await
.unwrap();
let permutation_table = PermutationBuilder::new(data_table.clone())
.with_split_strategy(
SplitStrategy::Sequential {
sizes: SplitSizes::Percentages(vec![0.5, 0.5]),
},
None,
)
.with_filter("col_a > 57".to_string())
.build()
.await
.unwrap();
let schema = permutation_table.schema().await.unwrap();
let field_names: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect();
assert_eq!(
field_names,
vec!["row_id", "split_id"],
"Permutation table should only contain row_id and split_id columns, but found: {:?}",
field_names,
);
}
#[tokio::test]
async fn test_permutation_builder() {
let temp_dir = tempfile::tempdir().unwrap();
@@ -405,6 +352,8 @@ mod tests {
.await
.unwrap();
println!("permutation_table: {:?}", permutation_table);
// Potentially brittle seed-dependent values below
assert_eq!(permutation_table.count_rows(None).await.unwrap(), 330);
assert_eq!(

View File

@@ -171,7 +171,7 @@ impl Shuffler {
// This is kind of an annoying limitation but if we allow runt clumps from batches then
// clumps will get unaligned and we will mess up the clumps when we do the in-memory
// shuffle step. If this is a problem we can probably figure out a better way to do this.
if !is_last && !(batch.num_rows() as u64).is_multiple_of(clump_size) {
if !is_last && batch.num_rows() as u64 % clump_size != 0 {
return Err(Error::Runtime {
message: format!(
"Expected batch size ({}) to be divisible by clump size ({})",

View File

@@ -1,9 +1,12 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::{
atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
Arc,
use std::{
iter,
sync::{
atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
Arc,
},
};
use arrow_array::{Array, BooleanArray, RecordBatch, UInt64Array};
@@ -12,8 +15,6 @@ use datafusion_common::hash_utils::create_hashes;
use futures::{StreamExt, TryStreamExt};
use lance_arrow::SchemaExt;
use lance_core::ROW_ID;
use crate::{
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
dataloader::{
@@ -157,7 +158,7 @@ impl Splitter {
remaining_in_split
};
split_ids.extend(std::iter::repeat_n(split_id as u64, rows_to_add as usize));
split_ids.extend(iter::repeat(split_id as u64).take(rows_to_add as usize));
if done {
// Quit early if we've run out of splits
break;
@@ -362,15 +363,11 @@ impl Splitter {
pub fn project(&self, query: Query) -> Query {
match &self.strategy {
SplitStrategy::Calculated { calculation } => query.select(Select::Dynamic(vec![
(SPLIT_ID_COLUMN.to_string(), calculation.clone()),
(ROW_ID.to_string(), ROW_ID.to_string()),
])),
SplitStrategy::Hash { columns, .. } => {
let mut cols = columns.clone();
cols.push(ROW_ID.to_string());
query.select(Select::Columns(cols))
}
SplitStrategy::Calculated { calculation } => query.select(Select::Dynamic(vec![(
SPLIT_ID_COLUMN.to_string(),
calculation.clone(),
)])),
SplitStrategy::Hash { columns, .. } => query.select(Select::Columns(columns.clone())),
_ => query,
}
}
@@ -665,7 +662,7 @@ mod tests {
assert_eq!(split_batch.num_rows(), total_split_sizes as usize);
let mut expected = Vec::with_capacity(total_split_sizes as usize);
for (i, size) in expected_split_sizes.iter().enumerate() {
expected.extend(std::iter::repeat_n(i as u64, *size as usize));
expected.extend(iter::repeat(i as u64).take(*size as usize));
}
let expected = Arc::new(UInt64Array::from(expected)) as Arc<dyn Array>;

View File

@@ -120,13 +120,8 @@ impl MemoryRegistry {
}
/// A record batch reader that has embeddings applied to it
///
/// This is a wrapper around another record batch reader that applies embedding functions
/// when reading from the record batch.
///
/// When multiple embedding functions are defined, they are computed in parallel using
/// scoped threads to improve performance. For a single embedding function, computation
/// is done inline without threading overhead.
/// This is a wrapper around another record batch reader that applies an embedding function
/// when reading from the record batch
pub struct WithEmbeddings<R: RecordBatchReader> {
inner: R,
embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
@@ -240,48 +235,6 @@ impl<R: RecordBatchReader> WithEmbeddings<R> {
column_definitions,
})
}
fn compute_embeddings_parallel(&self, batch: &RecordBatch) -> Result<Vec<Arc<dyn Array>>> {
if self.embeddings.len() == 1 {
let (fld, func) = &self.embeddings[0];
let src_column =
batch
.column_by_name(&fld.source_column)
.ok_or_else(|| Error::InvalidInput {
message: format!("Source column '{}' not found", fld.source_column),
})?;
return Ok(vec![func.compute_source_embeddings(src_column.clone())?]);
}
// Parallel path: multiple embeddings
std::thread::scope(|s| {
let handles: Vec<_> = self
.embeddings
.iter()
.map(|(fld, func)| {
let src_column = batch.column_by_name(&fld.source_column).ok_or_else(|| {
Error::InvalidInput {
message: format!("Source column '{}' not found", fld.source_column),
}
})?;
let handle =
s.spawn(move || func.compute_source_embeddings(src_column.clone()));
Ok(handle)
})
.collect::<Result<_>>()?;
handles
.into_iter()
.map(|h| {
h.join().map_err(|e| Error::Runtime {
message: format!("Thread panicked during embedding computation: {:?}", e),
})?
})
.collect()
})
}
}
impl<R: RecordBatchReader> Iterator for MaybeEmbedded<R> {
@@ -309,19 +262,19 @@ impl<R: RecordBatchReader> Iterator for WithEmbeddings<R> {
fn next(&mut self) -> Option<Self::Item> {
let batch = self.inner.next()?;
match batch {
Ok(batch) => {
let embeddings = match self.compute_embeddings_parallel(&batch) {
Ok(emb) => emb,
Err(e) => {
return Some(Err(arrow_schema::ArrowError::ComputeError(format!(
"Error computing embedding: {}",
e
))))
}
};
let mut batch = batch;
for ((fld, _), embedding) in self.embeddings.iter().zip(embeddings.iter()) {
Ok(mut batch) => {
// todo: parallelize this
for (fld, func) in self.embeddings.iter() {
let src_column = batch.column_by_name(&fld.source_column).unwrap();
let embedding = match func.compute_source_embeddings(src_column.clone()) {
Ok(embedding) => embedding,
Err(e) => {
return Some(Err(arrow_schema::ArrowError::ComputeError(format!(
"Error computing embedding: {}",
e
))))
}
};
let dst_field_name = fld
.dest_column
.clone()
@@ -333,7 +286,7 @@ impl<R: RecordBatchReader> Iterator for WithEmbeddings<R> {
embedding.nulls().is_some(),
);
match batch.try_with_column(dst_field.clone(), embedding.clone()) {
match batch.try_with_column(dst_field.clone(), embedding) {
Ok(b) => batch = b,
Err(e) => return Some(Err(e)),
};

View File

@@ -297,10 +297,10 @@ impl IvfPqIndexBuilder {
}
pub(crate) fn suggested_num_sub_vectors(dim: u32) -> u32 {
if dim.is_multiple_of(16) {
if dim % 16 == 0 {
// Should be more aggressive than this default.
dim / 16
} else if dim.is_multiple_of(8) {
} else if dim % 8 == 0 {
dim / 8
} else {
log::warn!(

View File

@@ -25,14 +25,13 @@
//!
//! ## Crate Features
//!
//! - `aws` - Enable AWS S3 object store support.
//! - `dynamodb` - Enable DynamoDB manifest store support.
//! - `azure` - Enable Azure Blob Storage object store support.
//! - `gcs` - Enable Google Cloud Storage object store support.
//! - `oss` - Enable Alibaba Cloud OSS object store support.
//! - `remote` - Enable remote client to connect to LanceDB cloud.
//! - `huggingface` - Enable HuggingFace Hub integration for loading datasets from the Hub.
//! - `fp16kernels` - Enable FP16 kernels for faster vector search on CPU.
//! ### Experimental Features
//!
//! These features are not enabled by default. They are experimental or in-development features that
//! are not yet ready to be released.
//!
//! - `remote` - Enable remote client to connect to LanceDB cloud. This is not yet fully implemented
//! and should not be enabled.
//!
//! ### Quick Start
//!
@@ -51,15 +50,17 @@
//! - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud object store
//! - `db://dbname` - Lance Cloud
//!
//! You can also use [`ConnectBuilder`] to configure the connection to the database.
//! You can also use [`ConnectOptions`] to configure the connection to the database.
//!
//! ```rust
//! use object_store::aws::AwsCredential;
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
//! let db = lancedb::connect("data/sample-lancedb")
//! .storage_options([
//! ("aws_access_key_id", "some_key"),
//! ("aws_secret_access_key", "some_secret"),
//! ])
//! .aws_creds(AwsCredential {
//! key_id: "some_key".to_string(),
//! secret_key: "some_secret".to_string(),
//! token: None,
//! })
//! .execute()
//! .await
//! .unwrap();

View File

@@ -1718,7 +1718,8 @@ mod tests {
let namespace = vec!["test_ns".to_string()];
conn.create_namespace(CreateNamespaceRequest {
id: Some(namespace.clone()),
..Default::default()
mode: None,
properties: None,
})
.await
.expect("Failed to create namespace");
@@ -1743,7 +1744,8 @@ mod tests {
let list_response = conn
.list_tables(ListTablesRequest {
id: Some(namespace.clone()),
..Default::default()
page_token: None,
limit: None,
})
.await
.expect("Failed to list tables");
@@ -1754,7 +1756,8 @@ mod tests {
let list_response = namespace_client
.list_tables(ListTablesRequest {
id: Some(namespace.clone()),
..Default::default()
page_token: None,
limit: None,
})
.await
.unwrap();
@@ -1794,7 +1797,8 @@ mod tests {
let namespace = vec!["multi_table_ns".to_string()];
conn.create_namespace(CreateNamespaceRequest {
id: Some(namespace.clone()),
..Default::default()
mode: None,
properties: None,
})
.await
.expect("Failed to create namespace");
@@ -1819,7 +1823,8 @@ mod tests {
let list_response = conn
.list_tables(ListTablesRequest {
id: Some(namespace.clone()),
..Default::default()
page_token: None,
limit: None,
})
.await
.unwrap();

View File

@@ -204,7 +204,6 @@ pub struct RemoteTable<S: HttpSend = Sender> {
server_version: ServerVersion,
version: RwLock<Option<u64>>,
location: RwLock<Option<String>>,
}
impl<S: HttpSend> RemoteTable<S> {
@@ -222,7 +221,6 @@ impl<S: HttpSend> RemoteTable<S> {
identifier,
server_version,
version: RwLock::new(None),
location: RwLock::new(None),
}
}
@@ -641,7 +639,6 @@ impl<S: HttpSend> RemoteTable<S> {
struct TableDescription {
version: u64,
schema: JsonSchema,
location: Option<String>,
}
impl<S: HttpSend> std::fmt::Display for RemoteTable<S> {
@@ -670,7 +667,6 @@ mod test_utils {
identifier: name,
server_version: version.map(ServerVersion).unwrap_or_default(),
version: RwLock::new(None),
location: RwLock::new(None),
}
}
}
@@ -1092,17 +1088,6 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
body["num_partitions"] = serde_json::Value::Number(num_partitions.into());
}
}
Index::IvfRq(index) => {
body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_RQ".to_string());
body[METRIC_TYPE_KEY] =
serde_json::Value::String(index.distance_type.to_string().to_lowercase());
if let Some(num_partitions) = index.num_partitions {
body["num_partitions"] = serde_json::Value::Number(num_partitions.into());
}
if let Some(num_bits) = index.num_bits {
body["num_bits"] = serde_json::Value::Number(num_bits.into());
}
}
Index::BTree(_) => {
body[INDEX_TYPE_KEY] = serde_json::Value::String("BTREE".to_string());
}
@@ -1465,28 +1450,8 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
message: "table_definition is not supported on LanceDB cloud.".into(),
})
}
async fn uri(&self) -> Result<String> {
// Check if we already have the location cached
{
let location = self.location.read().await;
if let Some(ref loc) = *location {
return Ok(loc.clone());
}
}
// Fetch from server via describe
let description = self.describe().await?;
let location = description.location.ok_or_else(|| Error::NotSupported {
message: "Table URI not supported by the server".into(),
})?;
// Cache the location for future use
{
let mut cached_location = self.location.write().await;
*cached_location = Some(location.clone());
}
Ok(location)
fn dataset_uri(&self) -> &str {
"NOT_SUPPORTED"
}
async fn storage_options(&self) -> Option<HashMap<String, String>> {
@@ -3356,69 +3321,4 @@ mod tests {
let result = table.drop_columns(&["old_col1", "old_col2"]).await.unwrap();
assert_eq!(result.version, 5);
}
#[tokio::test]
async fn test_uri() {
let table = Table::new_with_handler("my_table", |request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/v1/table/my_table/describe/");
http::Response::builder()
.status(200)
.body(r#"{"version": 1, "schema": {"fields": []}, "location": "s3://bucket/path/to/table"}"#)
.unwrap()
});
let uri = table.uri().await.unwrap();
assert_eq!(uri, "s3://bucket/path/to/table");
}
#[tokio::test]
async fn test_uri_missing_location() {
let table = Table::new_with_handler("my_table", |request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/v1/table/my_table/describe/");
// Server returns response without location field
http::Response::builder()
.status(200)
.body(r#"{"version": 1, "schema": {"fields": []}}"#)
.unwrap()
});
let result = table.uri().await;
assert!(result.is_err());
assert!(matches!(&result, Err(Error::NotSupported { .. })));
}
#[tokio::test]
async fn test_uri_caching() {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = call_count.clone();
let table = Table::new_with_handler("my_table", move |request| {
assert_eq!(request.url().path(), "/v1/table/my_table/describe/");
call_count_clone.fetch_add(1, Ordering::SeqCst);
http::Response::builder()
.status(200)
.body(
r#"{"version": 1, "schema": {"fields": []}, "location": "gs://bucket/table"}"#,
)
.unwrap()
});
// First call should fetch from server
let uri1 = table.uri().await.unwrap();
assert_eq!(uri1, "gs://bucket/table");
assert_eq!(call_count.load(Ordering::SeqCst), 1);
// Second call should use cached value
let uri2 = table.uri().await.unwrap();
assert_eq!(uri2, "gs://bucket/table");
assert_eq!(call_count.load(Ordering::SeqCst), 1); // Still 1, no new call
}
}

View File

@@ -40,7 +40,7 @@ use lance_index::vector::pq::PQBuildParams;
use lance_index::vector::sq::builder::SQBuildParams;
use lance_index::DatasetIndexExt;
use lance_index::IndexType;
use lance_io::object_store::{LanceNamespaceStorageOptionsProvider, StorageOptionsAccessor};
use lance_io::object_store::LanceNamespaceStorageOptionsProvider;
use lance_namespace::models::{
QueryTableRequest as NsQueryTableRequest, QueryTableRequestColumns,
QueryTableRequestFullTextQuery, QueryTableRequestVector, StringFtsQuery,
@@ -79,11 +79,10 @@ use self::merge::MergeInsertBuilder;
pub mod datafusion;
pub(crate) mod dataset;
pub mod delete;
pub mod merge;
use crate::index::waiter::wait_for_index;
pub use chrono::Duration;
pub use delete::DeleteResult;
use futures::future::{join_all, Either};
pub use lance::dataset::optimize::CompactionOptions;
pub use lance::dataset::refs::{TagContents, Tags as LanceTags};
@@ -447,6 +446,15 @@ pub struct AddResult {
pub version: u64,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct DeleteResult {
// The commit version associated with the operation.
// A version of `0` indicates compatibility with legacy servers that do not return
/// a commit version.
#[serde(default)]
pub version: u64,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct MergeResult {
// The commit version associated with the operation.
@@ -600,8 +608,8 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
async fn list_versions(&self) -> Result<Vec<Version>>;
/// Get the table definition.
async fn table_definition(&self) -> Result<TableDefinition>;
/// Get the table URI (storage location)
async fn uri(&self) -> Result<String>;
/// Get the table URI
fn dataset_uri(&self) -> &str;
/// Get the storage options used when opening this table, if any.
async fn storage_options(&self) -> Option<HashMap<String, String>>;
/// Poll until the columns are fully indexed. Will return Error::Timeout if the columns
@@ -1309,12 +1317,11 @@ impl Table {
self.inner.list_indices().await
}
/// Get the table URI (storage location)
/// Get the underlying dataset URI
///
/// Returns the full storage location of the table (e.g., S3/GCS path).
/// For remote tables, this fetches the location from the server via describe.
pub async fn uri(&self) -> Result<String> {
self.inner.uri().await
/// Warning: This is an internal API and the return value is subject to change.
pub fn dataset_uri(&self) -> &str {
self.inner.dataset_uri()
}
/// Get the storage options used when opening this table, if any.
@@ -1658,14 +1665,18 @@ impl NativeTable {
// Use DatasetBuilder::from_namespace which automatically fetches location
// and storage options from the namespace
let builder = DatasetBuilder::from_namespace(namespace_client.clone(), table_id)
.await
.map_err(|e| match e {
lance::Error::Namespace { source, .. } => Error::Runtime {
message: format!("Failed to get table info from namespace: {:?}", source),
},
source => Error::Lance { source },
})?;
let builder = DatasetBuilder::from_namespace(
namespace_client.clone(),
table_id,
false, // Don't ignore namespace storage options
)
.await
.map_err(|e| match e {
lance::Error::Namespace { source, .. } => Error::Runtime {
message: format!("Failed to get table info from namespace: {:?}", source),
},
source => Error::Lance { source },
})?;
let dataset = builder
.with_read_params(params)
@@ -1869,13 +1880,7 @@ impl NativeTable {
let store_params = params
.store_params
.get_or_insert_with(ObjectStoreParams::default);
let accessor = match store_params.storage_options().cloned() {
Some(options) => {
StorageOptionsAccessor::with_initial_and_provider(options, storage_options_provider)
}
None => StorageOptionsAccessor::with_provider(storage_options_provider),
};
store_params.storage_options_accessor = Some(Arc::new(accessor));
store_params.storage_options_provider = Some(storage_options_provider);
// Patch the params if we have a write store wrapper
let params = match write_store_wrapper.clone() {
@@ -2051,7 +2056,7 @@ impl NativeTable {
return provided;
}
let suggested = suggested_num_sub_vectors(dim);
if num_bits.is_some_and(|num_bits| num_bits == 4) && !suggested.is_multiple_of(2) {
if num_bits.is_some_and(|num_bits| num_bits == 4) && suggested % 2 != 0 {
// num_sub_vectors must be even when 4 bits are used
suggested + 1
} else {
@@ -2399,7 +2404,7 @@ impl NativeTable {
with_row_id: Some(vq.base.with_row_id),
bypass_vector_index: Some(!vq.use_index),
full_text_query,
..Default::default()
version: None,
})
}
AnyQuery::Query(q) => {
@@ -2461,11 +2466,18 @@ impl NativeTable {
columns,
prefilter: Some(q.prefilter),
offset: q.offset.map(|o| o as i32),
ef: None,
refine_factor: None,
distance_type: None,
nprobes: None,
vector_column: None, // No vector column for plain queries
with_row_id: Some(q.with_row_id),
bypass_vector_index: Some(true), // No vector index for plain queries
full_text_query,
..Default::default()
version: None,
fast_search: None,
lower_bound: None,
upper_bound: None,
})
}
}
@@ -3070,8 +3082,11 @@ impl BaseTable for NativeTable {
/// Delete rows from the table
async fn delete(&self, predicate: &str) -> Result<DeleteResult> {
// Delegate to the submodule implementation
delete::execute_delete(self, predicate).await
let mut dataset = self.dataset.get_mut().await?;
dataset.delete(predicate).await?;
Ok(DeleteResult {
version: dataset.version().version,
})
}
async fn tags(&self) -> Result<Box<dyn Tags + '_>> {
@@ -3215,8 +3230,8 @@ impl BaseTable for NativeTable {
Ok(results.into_iter().flatten().collect())
}
async fn uri(&self) -> Result<String> {
Ok(self.uri.clone())
fn dataset_uri(&self) -> &str {
self.uri.as_str()
}
async fn storage_options(&self) -> Option<HashMap<String, String>> {
@@ -3224,7 +3239,7 @@ impl BaseTable for NativeTable {
.get()
.await
.ok()
.and_then(|dataset| dataset.initial_storage_options().cloned())
.and_then(|dataset| dataset.storage_options().cloned())
}
async fn index_stats(&self, index_name: &str) -> Result<Option<IndexStatistics>> {
@@ -3389,6 +3404,7 @@ pub struct FragmentSummaryStats {
#[cfg(test)]
#[allow(deprecated)]
mod tests {
use std::iter;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
@@ -4005,7 +4021,7 @@ mod tests {
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(offset..(offset + 10))),
Arc::new(Int32Array::from_iter_values(std::iter::repeat_n(age, 10))),
Arc::new(Int32Array::from_iter_values(iter::repeat(age).take(10))),
],
)],
schema,
@@ -5140,8 +5156,8 @@ mod tests {
ns_request
.columns
.as_ref()
.and_then(|c| c.column_names.as_ref()),
Some(&vec!["id".to_string()])
.and_then(|c| c.column_names.clone()),
Some(vec!["id".to_string()])
);
assert_eq!(ns_request.vector_column, Some("vector".to_string()));
assert_eq!(ns_request.distance_type, Some("l2".to_string()));
@@ -5187,8 +5203,8 @@ mod tests {
ns_request
.columns
.as_ref()
.and_then(|c| c.column_names.as_ref()),
Some(&vec!["id".to_string()])
.and_then(|c| c.column_names.clone()),
Some(vec!["id".to_string()])
);
assert_eq!(ns_request.with_row_id, Some(true));
assert_eq!(ns_request.bypass_vector_index, Some(true));

View File

@@ -100,7 +100,8 @@ impl DatasetRef {
let should_checkout = match &target_ref {
refs::Ref::Version(_, Some(target_ver)) => version != target_ver,
refs::Ref::Version(_, None) => true, // No specific version, always checkout
refs::Ref::Tag(_) => true, // Always checkout for tags
refs::Ref::VersionNumber(target_ver) => version != target_ver,
refs::Ref::Tag(_) => true, // Always checkout for tags
};
if should_checkout {

View File

@@ -1,161 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use serde::{Deserialize, Serialize};
use super::NativeTable;
use crate::Result;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct DeleteResult {
// The commit version associated with the operation.
// A version of `0` indicates compatibility with legacy servers that do not return
/// a commit version.
#[serde(default)]
pub version: u64,
}
/// Internal implementation of the delete logic
///
/// This logic was moved from NativeTable::delete to keep table.rs clean.
pub(crate) async fn execute_delete(table: &NativeTable, predicate: &str) -> Result<DeleteResult> {
// We access the dataset from the table. Since this is in the same module hierarchy (super),
// and 'dataset' is pub(crate), we can access it.
let mut dataset = table.dataset.get_mut().await?;
// Perform the actual delete on the Lance dataset
dataset.delete(predicate).await?;
// Return the result with the new version
Ok(DeleteResult {
version: dataset.version().version,
})
}
#[cfg(test)]
mod tests {
use crate::connect;
use arrow_array::{record_batch, Int32Array, RecordBatch, RecordBatchIterator};
use arrow_schema::{DataType, Field, Schema};
use std::sync::Arc;
use crate::query::ExecutableQuery;
use futures::TryStreamExt;
#[tokio::test]
async fn test_delete_simple() {
let conn = connect("memory://").execute().await.unwrap();
// 1. Create a table with values 0 to 9
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(0..10))],
)
.unwrap();
let table = conn
.create_table(
"test_delete",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// 2. Verify initial state
assert_eq!(table.count_rows(None).await.unwrap(), 10);
// 3. Execute Delete (removes values > 5)
table.delete("i > 5").await.unwrap();
// 4. Verify results
assert_eq!(table.count_rows(None).await.unwrap(), 6); // 0, 1, 2, 3, 4, 5 remain
// 5. Verify specific data consistency
let batches = table
.query()
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch = &batches[0];
let array = batch
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
// Ensure no value > 5 exists
for val in array.iter() {
assert!(val.unwrap() <= 5);
}
}
#[tokio::test]
async fn rows_removed_schema_same() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(
("id", Int32, [1, 2, 3, 4, 5]),
("name", Utf8, ["a", "b", "c", "d", "e"])
)
.unwrap();
let original_schema = batch.schema();
let table = conn
.create_table(
"test_delete_all",
RecordBatchIterator::new(vec![Ok(batch)], original_schema.clone()),
)
.execute()
.await
.unwrap();
table.delete("true").await.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 0);
let current_schema = table.schema().await.unwrap();
//check if the original schema is the same as current
assert_eq!(current_schema, original_schema);
}
#[tokio::test]
async fn test_delete_false_increments_version() {
let conn = connect("memory://").execute().await.unwrap();
// Create a table with 5 rows
let batch = record_batch!(("id", Int32, [1, 2, 3, 4, 5])).unwrap();
let schema = batch.schema();
let table = conn
.create_table(
"test_delete_noop",
RecordBatchIterator::new(vec![Ok(batch)], schema),
)
.execute()
.await
.unwrap();
// Capture the initial state (Rows = 5, Version = 1)
let initial_rows = table.count_rows(None).await.unwrap();
let initial_version = table.version().await.unwrap();
assert_eq!(initial_rows, 5);
table.delete("false").await.unwrap();
// Rows should still be 5
let current_rows = table.count_rows(None).await.unwrap();
assert_eq!(
current_rows, initial_rows,
"Data should not change when predicate is false"
);
// version check
let current_version = table.version().await.unwrap();
assert!(
current_version > initial_version,
"Table version must increment after delete operation"
);
}
}

View File

@@ -5,19 +5,16 @@
use regex::Regex;
use std::env;
use std::process::Stdio;
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::{Child, ChildStdout, Command};
use tokio::sync::mpsc;
use std::io::{BufRead, BufReader};
use std::process::{Child, ChildStdout, Command, Stdio};
use crate::{connect, Connection};
use anyhow::{anyhow, bail, Result};
use anyhow::{bail, Result};
use tempfile::{tempdir, TempDir};
pub struct TestConnection {
pub uri: String,
pub connection: Connection,
pub is_remote: bool,
_temp_dir: Option<TempDir>,
_process: Option<TestProcess>,
}
@@ -40,56 +37,6 @@ pub async fn new_test_connection() -> Result<TestConnection> {
}
}
async fn spawn_stdout_reader(
mut stdout: BufReader<ChildStdout>,
port_sender: mpsc::Sender<anyhow::Result<String>>,
) -> tokio::task::JoinHandle<()> {
let print_stdout = env::var("PRINT_LANCEDB_TEST_CONNECTION_SCRIPT_OUTPUT").is_ok();
tokio::spawn(async move {
let mut line = String::new();
let re = Regex::new(r"Query node now listening on 0.0.0.0:(.*)").unwrap();
loop {
line.clear();
let result = stdout.read_line(&mut line).await;
if let Err(err) = result {
port_sender
.send(Err(anyhow!(
"error while reading from process output: {}",
err
)))
.await
.unwrap();
return;
} else if result.unwrap() == 0 {
port_sender
.send(Err(anyhow!(
" hit EOF before reading port from process output."
)))
.await
.unwrap();
return;
}
if re.is_match(&line) {
let caps = re.captures(&line).unwrap();
port_sender.send(Ok(caps[1].to_string())).await.unwrap();
break;
}
}
loop {
line.clear();
match stdout.read_line(&mut line).await {
Err(_) => return,
Ok(0) => return,
Ok(_size) => {
if print_stdout {
print!("{}", line);
}
}
}
}
})
}
async fn new_remote_connection(script_path: &str) -> Result<TestConnection> {
let temp_dir = tempdir()?;
let data_path = temp_dir.path().to_str().unwrap().to_string();
@@ -110,25 +57,38 @@ async fn new_remote_connection(script_path: &str) -> Result<TestConnection> {
child: child_result.unwrap(),
};
let stdout = BufReader::new(process.child.stdout.take().unwrap());
let (port_sender, mut port_receiver) = mpsc::channel(5);
let _reader = spawn_stdout_reader(stdout, port_sender).await;
let port = match port_receiver.recv().await {
None => bail!("Unable to determine the port number used by the phalanx process we spawned, because the reader thread was closed too soon."),
Some(Err(err)) => bail!("Unable to determine the port number used by the phalanx process we spawned, because of an error, {}", err),
Some(Ok(port)) => port,
};
let port = read_process_port(stdout)?;
let uri = "db://test";
let host_override = format!("http://localhost:{}", port);
let connection = create_new_connection(uri, &host_override).await?;
Ok(TestConnection {
uri: uri.to_string(),
connection,
is_remote: true,
_temp_dir: Some(temp_dir),
_process: Some(process),
})
}
fn read_process_port(mut stdout: BufReader<ChildStdout>) -> Result<String> {
let mut line = String::new();
let re = Regex::new(r"Query node now listening on 0.0.0.0:(.*)").unwrap();
loop {
let result = stdout.read_line(&mut line);
if let Err(err) = result {
bail!(format!(
"read_process_port: error while reading from process output: {}",
err
));
} else if result.unwrap() == 0 {
bail!("read_process_port: hit EOF before reading port from process output.");
}
if re.is_match(&line) {
let caps = re.captures(&line).unwrap();
return Ok(caps[1].to_string());
}
}
}
#[cfg(feature = "remote")]
async fn create_new_connection(uri: &str, host_override: &str) -> crate::error::Result<Connection> {
connect(uri)
@@ -154,7 +114,6 @@ async fn new_local_connection() -> Result<TestConnection> {
Ok(TestConnection {
uri: uri.to_string(),
connection,
is_remote: false,
_temp_dir: Some(temp_dir),
_process: None,
})

View File

@@ -4,6 +4,7 @@
use std::{
borrow::Cow,
collections::{HashMap, HashSet},
iter::repeat,
sync::Arc,
};
@@ -267,10 +268,9 @@ fn create_some_records() -> Result<impl IntoArrow> {
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)),
Arc::new(StringArray::from_iter(std::iter::repeat_n(
Some("hello world".to_string()),
TOTAL,
))),
Arc::new(StringArray::from_iter(
repeat(Some("hello world".to_string())).take(TOTAL),
)),
],
)
.unwrap()]

View File

@@ -1,253 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::{
borrow::Cow,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
time::Duration,
};
use arrow::buffer::NullBuffer;
use arrow_array::{
Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, StringArray,
};
use arrow_schema::{DataType, Field, Schema};
use lancedb::{
embeddings::{EmbeddingDefinition, EmbeddingFunction, MaybeEmbedded, WithEmbeddings},
Error, Result,
};
#[derive(Debug)]
struct SlowMockEmbed {
name: String,
dim: usize,
delay_ms: u64,
call_count: Arc<AtomicUsize>,
}
impl SlowMockEmbed {
pub fn new(name: String, dim: usize, delay_ms: u64) -> Self {
Self {
name,
dim,
delay_ms,
call_count: Arc::new(AtomicUsize::new(0)),
}
}
pub fn get_call_count(&self) -> usize {
self.call_count.load(Ordering::SeqCst)
}
}
impl EmbeddingFunction for SlowMockEmbed {
fn name(&self) -> &str {
&self.name
}
fn source_type(&self) -> Result<Cow<'_, DataType>> {
Ok(Cow::Owned(DataType::Utf8))
}
fn dest_type(&self) -> Result<Cow<'_, DataType>> {
Ok(Cow::Owned(DataType::new_fixed_size_list(
DataType::Float32,
self.dim as _,
true,
)))
}
fn compute_source_embeddings(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
// Simulate slow embedding computation
std::thread::sleep(Duration::from_millis(self.delay_ms));
self.call_count.fetch_add(1, Ordering::SeqCst);
let len = source.len();
let inner = Arc::new(Float32Array::from(vec![Some(1.0); len * self.dim]));
let field = Field::new("item", inner.data_type().clone(), false);
let arr = FixedSizeListArray::new(
Arc::new(field),
self.dim as _,
inner,
Some(NullBuffer::new_valid(len)),
);
Ok(Arc::new(arr))
}
fn compute_query_embeddings(&self, _input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
unimplemented!()
}
}
fn create_test_batch() -> Result<RecordBatch> {
let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)]));
let text = StringArray::from(vec!["hello", "world"]);
RecordBatch::try_new(schema, vec![Arc::new(text)]).map_err(|e| Error::Runtime {
message: format!("Failed to create test batch: {}", e),
})
}
#[test]
fn test_single_embedding_fast_path() {
// Single embedding should execute without spawning threads
let batch = create_test_batch().unwrap();
let schema = batch.schema();
let embed = Arc::new(SlowMockEmbed::new("test".to_string(), 2, 10));
let embedding_def = EmbeddingDefinition::new("text", "test", Some("embedding"));
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
let embeddings = vec![(embedding_def, embed.clone() as Arc<dyn EmbeddingFunction>)];
let mut with_embeddings = WithEmbeddings::new(reader, embeddings);
let result = with_embeddings.next().unwrap().unwrap();
assert!(result.column_by_name("embedding").is_some());
assert_eq!(embed.get_call_count(), 1);
}
#[test]
fn test_multiple_embeddings_parallel() {
// Multiple embeddings should execute in parallel
let batch = create_test_batch().unwrap();
let schema = batch.schema();
let embed1 = Arc::new(SlowMockEmbed::new("embed1".to_string(), 2, 100));
let embed2 = Arc::new(SlowMockEmbed::new("embed2".to_string(), 3, 100));
let embed3 = Arc::new(SlowMockEmbed::new("embed3".to_string(), 4, 100));
let def1 = EmbeddingDefinition::new("text", "embed1", Some("emb1"));
let def2 = EmbeddingDefinition::new("text", "embed2", Some("emb2"));
let def3 = EmbeddingDefinition::new("text", "embed3", Some("emb3"));
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
let embeddings = vec![
(def1, embed1.clone() as Arc<dyn EmbeddingFunction>),
(def2, embed2.clone() as Arc<dyn EmbeddingFunction>),
(def3, embed3.clone() as Arc<dyn EmbeddingFunction>),
];
let mut with_embeddings = WithEmbeddings::new(reader, embeddings);
let result = with_embeddings.next().unwrap().unwrap();
// Verify all embedding columns are present
assert!(result.column_by_name("emb1").is_some());
assert!(result.column_by_name("emb2").is_some());
assert!(result.column_by_name("emb3").is_some());
// Verify all embeddings were computed
assert_eq!(embed1.get_call_count(), 1);
assert_eq!(embed2.get_call_count(), 1);
assert_eq!(embed3.get_call_count(), 1);
}
#[test]
fn test_embedding_column_order_preserved() {
// Verify that embedding columns are added in the same order as definitions
let batch = create_test_batch().unwrap();
let schema = batch.schema();
let embed1 = Arc::new(SlowMockEmbed::new("embed1".to_string(), 2, 10));
let embed2 = Arc::new(SlowMockEmbed::new("embed2".to_string(), 3, 10));
let embed3 = Arc::new(SlowMockEmbed::new("embed3".to_string(), 4, 10));
let def1 = EmbeddingDefinition::new("text", "embed1", Some("first"));
let def2 = EmbeddingDefinition::new("text", "embed2", Some("second"));
let def3 = EmbeddingDefinition::new("text", "embed3", Some("third"));
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
let embeddings = vec![
(def1, embed1 as Arc<dyn EmbeddingFunction>),
(def2, embed2 as Arc<dyn EmbeddingFunction>),
(def3, embed3 as Arc<dyn EmbeddingFunction>),
];
let mut with_embeddings = WithEmbeddings::new(reader, embeddings);
let result = with_embeddings.next().unwrap().unwrap();
let result_schema = result.schema();
// Original column is first
assert_eq!(result_schema.field(0).name(), "text");
// Embedding columns follow in order
assert_eq!(result_schema.field(1).name(), "first");
assert_eq!(result_schema.field(2).name(), "second");
assert_eq!(result_schema.field(3).name(), "third");
}
#[test]
fn test_embedding_error_propagation() {
// Test that errors from embedding computation are properly propagated
#[derive(Debug)]
struct FailingEmbed {
name: String,
}
impl EmbeddingFunction for FailingEmbed {
fn name(&self) -> &str {
&self.name
}
fn source_type(&self) -> Result<Cow<'_, DataType>> {
Ok(Cow::Owned(DataType::Utf8))
}
fn dest_type(&self) -> Result<Cow<'_, DataType>> {
Ok(Cow::Owned(DataType::new_fixed_size_list(
DataType::Float32,
2,
true,
)))
}
fn compute_source_embeddings(&self, _source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
Err(Error::Runtime {
message: "Intentional failure".to_string(),
})
}
fn compute_query_embeddings(&self, _input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
unimplemented!()
}
}
let batch = create_test_batch().unwrap();
let schema = batch.schema();
let embed = Arc::new(FailingEmbed {
name: "failing".to_string(),
});
let def = EmbeddingDefinition::new("text", "failing", Some("emb"));
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
let embeddings = vec![(def, embed as Arc<dyn EmbeddingFunction>)];
let mut with_embeddings = WithEmbeddings::new(reader, embeddings);
let result = with_embeddings.next().unwrap();
assert!(result.is_err());
let err_msg = format!("{}", result.err().unwrap());
assert!(err_msg.contains("Intentional failure"));
}
#[test]
fn test_maybe_embedded_with_no_embeddings() {
// Test that MaybeEmbedded::No variant works correctly
let batch = create_test_batch().unwrap();
let schema = batch.schema();
let reader = RecordBatchIterator::new(vec![Ok(batch.clone())], schema.clone());
let table_def = lancedb::table::TableDefinition {
schema: schema.clone(),
column_definitions: vec![lancedb::table::ColumnDefinition {
kind: lancedb::table::ColumnKind::Physical,
}],
};
let mut maybe_embedded = MaybeEmbedded::try_new(reader, table_def, None).unwrap();
let result = maybe_embedded.next().unwrap().unwrap();
assert_eq!(result.num_columns(), 1);
assert_eq!(result.column(0).as_ref(), batch.column(0).as_ref());
}