diff --git a/.bumpversion.toml b/.bumpversion.toml index d87d069f9..a862d28c0 100644 --- a/.bumpversion.toml +++ b/.bumpversion.toml @@ -1,5 +1,5 @@ [tool.bumpversion] -current_version = "0.30.0-beta.0" +current_version = "0.30.0-beta.1" parse = """(?x) (?P0|[1-9]\\d*)\\. (?P0|[1-9]\\d*)\\. diff --git a/.github/workflows/build_windows_wheel/action.yml b/.github/workflows/build_windows_wheel/action.yml index 742b187b2..18f39818f 100644 --- a/.github/workflows/build_windows_wheel/action.yml +++ b/.github/workflows/build_windows_wheel/action.yml @@ -29,7 +29,3 @@ runs: args: ${{ inputs.args }} docker-options: "-e PIP_EXTRA_INDEX_URL='https://pypi.fury.io/lance-format/ https://pypi.fury.io/lancedb/'" working-directory: python - - uses: actions/upload-artifact@v4 - with: - name: windows-wheels - path: python\target\wheels diff --git a/.github/workflows/pypi-publish.yml b/.github/workflows/pypi-publish.yml index ca6e3219b..b15f2e6b2 100644 --- a/.github/workflows/pypi-publish.yml +++ b/.github/workflows/pypi-publish.yml @@ -8,6 +8,9 @@ on: # This should trigger a dry run (we skip the final publish step) paths: - .github/workflows/pypi-publish.yml + - .github/workflows/build_linux_wheel/action.yml + - .github/workflows/build_mac_wheel/action.yml + - .github/workflows/build_windows_wheel/action.yml - Cargo.toml # Change in dependency frequently breaks builds - Cargo.lock @@ -21,32 +24,21 @@ jobs: linux: name: Python ${{ matrix.config.platform }} manylinux${{ matrix.config.manylinux }} timeout-minutes: 60 - permissions: - id-token: write - contents: read strategy: matrix: config: - - platform: x86_64 - manylinux: "2_17" - extra_args: "" - runner: ubuntu-22.04 - platform: x86_64 manylinux: "2_28" extra_args: "--features fp16kernels" runner: ubuntu-22.04 - - platform: aarch64 - manylinux: "2_17" - extra_args: "" - # For successful fat LTO builds, we need a large runner to avoid OOM errors. - runner: ubuntu-2404-8x-arm64 + # For successful fat LTO builds, we need a large runner to avoid OOM errors. - platform: aarch64 manylinux: "2_28" extra_args: "--features fp16kernels" runner: ubuntu-2404-8x-arm64 runs-on: ${{ matrix.config.runner }} steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: fetch-depth: 0 lfs: true @@ -60,15 +52,14 @@ jobs: args: "--release --strip ${{ matrix.config.extra_args }}" arm-build: ${{ matrix.config.platform == 'aarch64' }} manylinux: ${{ matrix.config.manylinux }} - - uses: ./.github/workflows/upload_wheel + - uses: actions/upload-artifact@v7 if: startsWith(github.ref, 'refs/tags/python-v') with: - fury_token: ${{ secrets.FURY_TOKEN }} + name: wheels-linux-${{ matrix.config.platform }}-${{ matrix.config.manylinux }} + path: target/wheels/lancedb-*.whl + if-no-files-found: error mac: timeout-minutes: 90 - permissions: - id-token: write - contents: read runs-on: ${{ matrix.config.runner }} strategy: matrix: @@ -78,7 +69,7 @@ jobs: env: MACOSX_DEPLOYMENT_TARGET: 10.15 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: fetch-depth: 0 lfs: true @@ -90,18 +81,21 @@ jobs: with: python-minor-version: 10 args: "--release --strip --target ${{ matrix.config.target }} --features fp16kernels" - - uses: ./.github/workflows/upload_wheel + - uses: actions/upload-artifact@v7 if: startsWith(github.ref, 'refs/tags/python-v') with: - fury_token: ${{ secrets.FURY_TOKEN }} + name: wheels-mac-${{ matrix.config.target }} + path: target/wheels/lancedb-*.whl + if-no-files-found: error windows: - timeout-minutes: 60 - permissions: - id-token: write - contents: read + timeout-minutes: 90 runs-on: windows-latest + env: + # link.exe is single-threaded and the long pole on Windows builds. Use + # rustc's bundled lld-link instead. + CARGO_TARGET_X86_64_PC_WINDOWS_MSVC_LINKER: rust-lld steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: fetch-depth: 0 lfs: true @@ -113,18 +107,70 @@ jobs: with: python-minor-version: 10 args: "--release --strip" - vcpkg_token: ${{ secrets.VCPKG_GITHUB_PACKAGES }} - - uses: ./.github/workflows/upload_wheel + - uses: actions/upload-artifact@v7 if: startsWith(github.ref, 'refs/tags/python-v') with: - fury_token: ${{ secrets.FURY_TOKEN }} + name: wheels-windows + path: target/wheels/lancedb-*.whl + if-no-files-found: error + publish: + name: Publish wheels + if: startsWith(github.ref, 'refs/tags/python-v') + needs: [linux, mac, windows] + runs-on: ubuntu-latest + permissions: + id-token: write + contents: read + steps: + - uses: actions/checkout@v6 + - name: Download wheel artifacts + uses: actions/download-artifact@v8 + with: + pattern: wheels-* + path: target/wheels + merge-multiple: true + - name: List wheels + run: ls -la target/wheels + - name: Choose repo + id: choose_repo + run: | + if [[ ${{ github.ref }} == *beta* ]]; then + echo "repo=fury" >> $GITHUB_OUTPUT + else + echo "repo=pypi" >> $GITHUB_OUTPUT + fi + - name: Publish to Fury + if: steps.choose_repo.outputs.repo == 'fury' + env: + FURY_TOKEN: ${{ secrets.FURY_TOKEN }} + run: | + shopt -s nullglob + WHEELS=(target/wheels/lancedb-*.whl) + if [[ ${#WHEELS[@]} -eq 0 ]]; then + echo "No wheels found in target/wheels/" >&2 + exit 1 + fi + for WHEEL in "${WHEELS[@]}"; do + echo "Uploading $WHEEL to Fury" + curl -f -F package=@"$WHEEL" "https://$FURY_TOKEN@push.fury.io/lancedb/" + done + # NOTE: pypa/gh-action-pypi-publish must be invoked directly from a + # workflow file, not from inside a composite action. When called from a + # composite, `github.action_repository` is empty (actions/runner#2473) + # and the action falls back to `github.repository`, producing a bogus + # `docker://ghcr.io/:` image reference that GHA tries to pull. + - name: Publish to PyPI + if: steps.choose_repo.outputs.repo == 'pypi' + uses: pypa/gh-action-pypi-publish@release/v1 + with: + packages-dir: target/wheels/ gh-release: if: startsWith(github.ref, 'refs/tags/python-v') runs-on: ubuntu-latest permissions: contents: write steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: fetch-depth: 0 lfs: true @@ -187,13 +233,13 @@ jobs: report-failure: name: Report Workflow Failure runs-on: ubuntu-latest - needs: [linux, mac, windows] + needs: [linux, mac, windows, publish] permissions: contents: read issues: write if: always() && failure() && startsWith(github.ref, 'refs/tags/python-v') steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - uses: ./.github/actions/create-failure-issue with: job-results: ${{ toJSON(needs) }} diff --git a/.github/workflows/upload_wheel/action.yml b/.github/workflows/upload_wheel/action.yml deleted file mode 100644 index 8bcdb7a88..000000000 --- a/.github/workflows/upload_wheel/action.yml +++ /dev/null @@ -1,34 +0,0 @@ -name: upload-wheel - -description: "Upload wheels to Pypi" -inputs: - fury_token: - required: true - description: "release token for the fury repo" - -runs: - using: "composite" - steps: - - name: Choose repo - shell: bash - id: choose_repo - run: | - if [[ ${{ github.ref }} == *beta* ]]; then - echo "repo=fury" >> $GITHUB_OUTPUT - else - echo "repo=pypi" >> $GITHUB_OUTPUT - fi - - name: Publish to Fury - if: steps.choose_repo.outputs.repo == 'fury' - shell: bash - env: - FURY_TOKEN: ${{ inputs.fury_token }} - run: | - WHEEL=$(ls target/wheels/lancedb-*.whl 2> /dev/null | head -n 1) - echo "Uploading $WHEEL to Fury" - curl -f -F package=@$WHEEL https://$FURY_TOKEN@push.fury.io/lancedb/ - - name: Publish to PyPI - if: steps.choose_repo.outputs.repo == 'pypi' - uses: pypa/gh-action-pypi-publish@release/v1 - with: - packages-dir: target/wheels/ diff --git a/AGENTS.md b/AGENTS.md index 79e4de58b..14ec35441 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -37,10 +37,13 @@ Before committing changes, run formatting for every language you touched. At min and run targeted tests through `cd python && uv run ...`. * TypeScript changes: run the relevant `npm`/`pnpm` lint, format, build, and docs commands in `nodejs`. -Before creating a PR, make sure the PR title follows Conventional Commits, such as -`fix: support nested field paths in native index creation` or -`feat(python): add dataset multiprocessing support`. The semantic-release check uses the -PR title and body as the merge commit message, so a non-conventional PR title will fail CI. +Before creating a PR, the exact value passed to `gh pr create --title` must follow +Conventional Commits, such as `fix: support nested field paths in native index creation` +or `feat(python): add dataset multiprocessing support`. Do not use a plain natural +language summary like `Support nested field paths in native index creation` as the PR +title. The semantic-release check uses the PR title and body as the merge commit message, +so a non-conventional PR title will fail CI. After creating a PR, read the remote PR title +back and fix it immediately if it is not conventional. ## Coding tips diff --git a/Cargo.lock b/Cargo.lock index adbd99b13..171c8731b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -568,7 +568,7 @@ dependencies = [ "bytes", "fastrand", "hex", - "http 1.4.0", + "http 1.4.1", "sha1 0.10.6", "time", "tokio", @@ -631,7 +631,7 @@ dependencies = [ "bytes-utils", "fastrand", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "http-body 0.4.6", "http-body 1.0.1", "percent-encoding", @@ -661,7 +661,7 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "http-body-util", "regex-lite", "tracing", @@ -686,7 +686,7 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "regex-lite", "tracing", ] @@ -710,7 +710,7 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "regex-lite", "tracing", ] @@ -740,7 +740,7 @@ dependencies = [ "hex", "hmac 0.13.0", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "lru", "percent-encoding", @@ -769,7 +769,7 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "regex-lite", "tracing", ] @@ -793,7 +793,7 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "regex-lite", "tracing", ] @@ -818,7 +818,7 @@ dependencies = [ "aws-types", "fastrand", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "regex-lite", "tracing", ] @@ -840,7 +840,7 @@ dependencies = [ "hex", "hmac 0.13.0", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "p256", "percent-encoding", "ring", @@ -873,7 +873,7 @@ dependencies = [ "bytes", "crc-fast", "hex", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "md-5 0.11.0", @@ -907,7 +907,7 @@ dependencies = [ "bytes-utils", "futures-core", "futures-util", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "percent-encoding", @@ -928,7 +928,7 @@ dependencies = [ "h2 0.3.27", "h2 0.4.14", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "http-body 0.4.6", "hyper 0.14.32", "hyper 1.9.0", @@ -976,20 +976,21 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.11.1" +version = "1.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0504b1ab12debb5959e5165ee5fe97dd387e7aa7ea6a477bfd7635dfe769a4f5" +checksum = "b8e6f5caf6fea86f8c2206541ab5857cfcda9013426cdbe8fa0098b9e2d32182" dependencies = [ "aws-smithy-async", "aws-smithy-http", "aws-smithy-http-client", "aws-smithy-observability", "aws-smithy-runtime-api", + "aws-smithy-schema", "aws-smithy-types", "bytes", "fastrand", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "http-body 0.4.6", "http-body 1.0.1", "http-body-util", @@ -1001,16 +1002,16 @@ dependencies = [ [[package]] name = "aws-smithy-runtime-api" -version = "1.12.0" +version = "1.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b71a13df6ada0aafbf21a73bdfcdf9324cfa9df77d96b8446045be3cde61b42e" +checksum = "dc117c179ecf39a62a0a3f49f600e9ac26a7ad7dd172177999f83933af776c32" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api-macros", "aws-smithy-types", "bytes", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "pin-project-lite", "tokio", "tracing", @@ -1029,17 +1030,28 @@ dependencies = [ ] [[package]] -name = "aws-smithy-types" -version = "1.4.7" +name = "aws-smithy-schema" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d73dbfbaa8e4bc57b9045137680b958d274823509a360abfd8e1d514d40c95c" +checksum = "7442cb268338f0eb8278140a107c046756aa01093d8ef5e99628d34ae09c94f5" +dependencies = [ + "aws-smithy-runtime-api", + "aws-smithy-types", + "http 1.4.1", +] + +[[package]] +name = "aws-smithy-types" +version = "1.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "056b66dbce2f81cc0c1e2b05bb402eb58f8a3530479d650efadd5bbae9a4050b" dependencies = [ "base64-simd", "bytes", "bytes-utils", "futures-core", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "http-body 0.4.6", "http-body 1.0.1", "http-body-util", @@ -1087,7 +1099,7 @@ dependencies = [ "axum-core", "bytes", "futures-util", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "hyper 1.9.0", @@ -1120,7 +1132,7 @@ dependencies = [ "async-trait", "bytes", "futures-util", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "mime", @@ -1399,6 +1411,12 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "bytecount" +version = "0.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "175812e0be2bccb6abe50bb8d566126198344f707e304f45c648fd8f2cc0365e" + [[package]] name = "bytemuck" version = "1.25.0" @@ -1522,9 +1540,9 @@ dependencies = [ [[package]] name = "cedarwood" -version = "0.4.6" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d910bedd62c24733263d0bed247460853c9d22e8956bd4cd964302095e04e90" +checksum = "c0524a528a6a0288df1863c3c20fe92c301875b4941e7b6c4b394ab08c5a4c55" dependencies = [ "smallvec", ] @@ -3284,8 +3302,8 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" [[package]] name = "fsst" -version = "7.0.0-beta.13" -source = "git+https://github.com/lance-format/lance.git?tag=v7.0.0-beta.13#929166e3ff51ed61b1fa42de2c63feaf51967ea1" +version = "7.2.0-beta.1" +source = "git+https://github.com/lance-format/lance.git?tag=v7.2.0-beta.1#b9995aba6115e8e4bc43179a45cbd0f9a170f305" dependencies = [ "arrow-array", "rand 0.9.4", @@ -3675,7 +3693,7 @@ dependencies = [ "fnv", "futures-core", "futures-sink", - "http 1.4.0", + "http 1.4.1", "indexmap 2.14.0", "slab", "tokio", @@ -3781,7 +3799,7 @@ checksum = "629d8f3bbeda9d148036d6b0de0a3ab947abd08ce90626327fc3547a49d59d97" dependencies = [ "dirs", "futures", - "http 1.4.0", + "http 1.4.1", "indicatif", "libc", "log", @@ -3804,7 +3822,7 @@ checksum = "430b33fa84f92796d4d263070b6c0d3ca219df7b9a0e1853ee431029b1612bcd" dependencies = [ "async-trait", "bytes", - "http 1.4.0", + "http 1.4.1", "more-asserts", "serde", "thiserror 2.0.18", @@ -3858,9 +3876,9 @@ dependencies = [ [[package]] name = "http" -version = "1.4.0" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +checksum = "8be7462df143984c4598a256ef469b251d7d7f9e271135073e78fc535414f3d0" dependencies = [ "bytes", "itoa", @@ -3884,7 +3902,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http 1.4.0", + "http 1.4.1", ] [[package]] @@ -3895,7 +3913,7 @@ checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" dependencies = [ "bytes", "futures-core", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "pin-project-lite", ] @@ -3962,7 +3980,7 @@ dependencies = [ "futures-channel", "futures-core", "h2 0.4.14", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "httparse", "httpdate", @@ -3994,7 +4012,7 @@ version = "0.27.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ca68d021ef39cf6463ab54c1d0f5daf03377b70561305bb89a8f83aab66e0f" dependencies = [ - "http 1.4.0", + "http 1.4.1", "hyper 1.9.0", "hyper-util", "rustls 0.23.40", @@ -4015,7 +4033,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "hyper 1.9.0", "ipnet", @@ -4077,6 +4095,21 @@ dependencies = [ "zerovec", ] +[[package]] +name = "icu_locale" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5a396343c7208121dc86e35623d3dfe19814a7613cfd14964994cdc9c9a2e26" +dependencies = [ + "icu_collections", + "icu_locale_core", + "icu_locale_data", + "icu_provider", + "potential_utf", + "tinystr", + "zerovec", +] + [[package]] name = "icu_locale_core" version = "2.2.0" @@ -4085,11 +4118,18 @@ checksum = "92219b62b3e2b4d88ac5119f8904c10f8f61bf7e95b640d25ba3075e6cac2c29" dependencies = [ "displaydoc", "litemap", + "serde", "tinystr", "writeable", "zerovec", ] +[[package]] +name = "icu_locale_data" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5fdcc9ac77c6d74ff5cf6e65ef3181d6af32003b16fce3a77fb451d2f695993" + [[package]] name = "icu_normalizer" version = "2.2.0" @@ -4138,6 +4178,8 @@ checksum = "139c4cf31c8b5f33d7e199446eff9c1e02decfc2f0eec2c8d71f65befa45b421" dependencies = [ "displaydoc", "icu_locale_core", + "serde", + "stable_deref_trait", "writeable", "yoke", "zerofrom", @@ -4145,6 +4187,27 @@ dependencies = [ "zerovec", ] +[[package]] +name = "icu_segmenter" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c0794db0b1a86193ac9c48768d0e6c52c54448e0870ad87907d456ee0dac964" +dependencies = [ + "icu_collections", + "icu_locale", + "icu_provider", + "icu_segmenter_data", + "potential_utf", + "utf8_iter", + "zerovec", +] + +[[package]] +name = "icu_segmenter_data" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4a2c462a4d927d512f5f882a033ddd62f33a05bb9f230d98f736ac3dc85938f" + [[package]] name = "id-arena" version = "2.3.0" @@ -4306,19 +4369,20 @@ checksum = "9028f49264629065d057f340a86acb84867925865f73bbf8d47b4d149a7e88b8" [[package]] name = "jieba-macros" -version = "0.9.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a29cfc5dcd898604c6f80363411fa6b6b08e27d1d253d6225b9cb6702ea02fc0" +checksum = "46adade69b634535a8f495cf87710ed893cff53e1dbc9dd750c2ab81c5defb82" dependencies = [ "phf_codegen 0.13.1", ] [[package]] name = "jieba-rs" -version = "0.9.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3245d6e9d1d5facbd6a23848d6b67e3439738ccbb4fa5a3d65da315ba1a910a2" +checksum = "11b53580aaa8ec8b713da271da434f8947409242c537a9ab3f7b76bdbb19e8a9" dependencies = [ + "bytecount", "cedarwood", "jieba-macros", "phf 0.13.1", @@ -4506,8 +4570,8 @@ checksum = "e037a2e1d8d5fdbd49b16a4ea09d5d6401c1f29eca5ff29d03d3824dba16256a" [[package]] name = "lance" -version = "7.0.0-beta.13" -source = "git+https://github.com/lance-format/lance.git?tag=v7.0.0-beta.13#929166e3ff51ed61b1fa42de2c63feaf51967ea1" +version = "7.2.0-beta.1" +source = "git+https://github.com/lance-format/lance.git?tag=v7.2.0-beta.1#b9995aba6115e8e4bc43179a45cbd0f9a170f305" dependencies = [ "arc-swap", "arrow", @@ -4525,6 +4589,7 @@ dependencies = [ "async_cell", "aws-credential-types", "aws-sdk-dynamodb", + "bitpacking", "byteorder", "bytes", "chrono", @@ -4551,9 +4616,11 @@ dependencies = [ "lance-io", "lance-linalg", "lance-namespace", + "lance-select", "lance-table", "lance-tokenizer", "log", + "moka", "object_store", "permutation", "pin-project", @@ -4577,8 +4644,8 @@ dependencies = [ [[package]] name = "lance-arrow" -version = "7.0.0-beta.13" -source = "git+https://github.com/lance-format/lance.git?tag=v7.0.0-beta.13#929166e3ff51ed61b1fa42de2c63feaf51967ea1" +version = "7.2.0-beta.1" +source = "git+https://github.com/lance-format/lance.git?tag=v7.2.0-beta.1#b9995aba6115e8e4bc43179a45cbd0f9a170f305" dependencies = [ "arrow-array", "arrow-buffer", @@ -4596,10 +4663,34 @@ dependencies = [ "rand 0.9.4", ] +[[package]] +name = "lance-arrow-scalar" +version = "58.0.0" +source = "git+https://github.com/lance-format/lance.git?tag=v7.2.0-beta.1#b9995aba6115e8e4bc43179a45cbd0f9a170f305" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-row", + "arrow-schema", + "half", +] + +[[package]] +name = "lance-arrow-stats" +version = "58.0.0" +source = "git+https://github.com/lance-format/lance.git?tag=v7.2.0-beta.1#b9995aba6115e8e4bc43179a45cbd0f9a170f305" +dependencies = [ + "arrow-array", + "arrow-schema", + "lance-arrow-scalar", +] + [[package]] name = "lance-bitpacking" -version = "7.0.0-beta.13" -source = "git+https://github.com/lance-format/lance.git?tag=v7.0.0-beta.13#929166e3ff51ed61b1fa42de2c63feaf51967ea1" +version = "7.2.0-beta.1" +source = "git+https://github.com/lance-format/lance.git?tag=v7.2.0-beta.1#b9995aba6115e8e4bc43179a45cbd0f9a170f305" dependencies = [ "arrayref", "paste", @@ -4608,8 +4699,8 @@ dependencies = [ [[package]] name = "lance-core" -version = "7.0.0-beta.13" -source = "git+https://github.com/lance-format/lance.git?tag=v7.0.0-beta.13#929166e3ff51ed61b1fa42de2c63feaf51967ea1" +version = "7.2.0-beta.1" +source = "git+https://github.com/lance-format/lance.git?tag=v7.2.0-beta.1#b9995aba6115e8e4bc43179a45cbd0f9a170f305" dependencies = [ "arrow-array", "arrow-buffer", @@ -4644,8 +4735,8 @@ dependencies = [ [[package]] name = "lance-datafusion" -version = "7.0.0-beta.13" -source = "git+https://github.com/lance-format/lance.git?tag=v7.0.0-beta.13#929166e3ff51ed61b1fa42de2c63feaf51967ea1" +version = "7.2.0-beta.1" +source = "git+https://github.com/lance-format/lance.git?tag=v7.2.0-beta.1#b9995aba6115e8e4bc43179a45cbd0f9a170f305" dependencies = [ "arrow", "arrow-array", @@ -4675,8 +4766,8 @@ dependencies = [ [[package]] name = "lance-datagen" -version = "7.0.0-beta.13" -source = "git+https://github.com/lance-format/lance.git?tag=v7.0.0-beta.13#929166e3ff51ed61b1fa42de2c63feaf51967ea1" +version = "7.2.0-beta.1" +source = "git+https://github.com/lance-format/lance.git?tag=v7.2.0-beta.1#b9995aba6115e8e4bc43179a45cbd0f9a170f305" dependencies = [ "arrow", "arrow-array", @@ -4694,8 +4785,8 @@ dependencies = [ [[package]] name = "lance-encoding" -version = "7.0.0-beta.13" -source = "git+https://github.com/lance-format/lance.git?tag=v7.0.0-beta.13#929166e3ff51ed61b1fa42de2c63feaf51967ea1" +version = "7.2.0-beta.1" +source = "git+https://github.com/lance-format/lance.git?tag=v7.2.0-beta.1#b9995aba6115e8e4bc43179a45cbd0f9a170f305" dependencies = [ "arrow-arith", "arrow-array", @@ -4730,8 +4821,8 @@ dependencies = [ [[package]] name = "lance-file" -version = "7.0.0-beta.13" -source = "git+https://github.com/lance-format/lance.git?tag=v7.0.0-beta.13#929166e3ff51ed61b1fa42de2c63feaf51967ea1" +version = "7.2.0-beta.1" +source = "git+https://github.com/lance-format/lance.git?tag=v7.2.0-beta.1#b9995aba6115e8e4bc43179a45cbd0f9a170f305" dependencies = [ "arrow-arith", "arrow-array", @@ -4762,8 +4853,8 @@ dependencies = [ [[package]] name = "lance-index" -version = "7.0.0-beta.13" -source = "git+https://github.com/lance-format/lance.git?tag=v7.0.0-beta.13#929166e3ff51ed61b1fa42de2c63feaf51967ea1" +version = "7.2.0-beta.1" +source = "git+https://github.com/lance-format/lance.git?tag=v7.2.0-beta.1#b9995aba6115e8e4bc43179a45cbd0f9a170f305" dependencies = [ "arc-swap", "arrow", @@ -4793,6 +4884,7 @@ dependencies = [ "jieba-rs", "jsonb", "lance-arrow", + "lance-arrow-stats", "lance-core", "lance-datafusion", "lance-datagen", @@ -4800,6 +4892,7 @@ dependencies = [ "lance-file", "lance-io", "lance-linalg", + "lance-select", "lance-table", "lance-tokenizer", "libm", @@ -4827,8 +4920,8 @@ dependencies = [ [[package]] name = "lance-io" -version = "7.0.0-beta.13" -source = "git+https://github.com/lance-format/lance.git?tag=v7.0.0-beta.13#929166e3ff51ed61b1fa42de2c63feaf51967ea1" +version = "7.2.0-beta.1" +source = "git+https://github.com/lance-format/lance.git?tag=v7.2.0-beta.1#b9995aba6115e8e4bc43179a45cbd0f9a170f305" dependencies = [ "arrow", "arrow-arith", @@ -4847,7 +4940,7 @@ dependencies = [ "chrono", "deepsize", "futures", - "http 1.4.0", + "http 1.4.1", "io-uring", "lance-arrow", "lance-core", @@ -4870,8 +4963,8 @@ dependencies = [ [[package]] name = "lance-linalg" -version = "7.0.0-beta.13" -source = "git+https://github.com/lance-format/lance.git?tag=v7.0.0-beta.13#929166e3ff51ed61b1fa42de2c63feaf51967ea1" +version = "7.2.0-beta.1" +source = "git+https://github.com/lance-format/lance.git?tag=v7.2.0-beta.1#b9995aba6115e8e4bc43179a45cbd0f9a170f305" dependencies = [ "arrow-array", "arrow-buffer", @@ -4887,8 +4980,8 @@ dependencies = [ [[package]] name = "lance-namespace" -version = "7.0.0-beta.13" -source = "git+https://github.com/lance-format/lance.git?tag=v7.0.0-beta.13#929166e3ff51ed61b1fa42de2c63feaf51967ea1" +version = "7.2.0-beta.1" +source = "git+https://github.com/lance-format/lance.git?tag=v7.2.0-beta.1#b9995aba6115e8e4bc43179a45cbd0f9a170f305" dependencies = [ "arrow", "async-trait", @@ -4900,8 +4993,8 @@ dependencies = [ [[package]] name = "lance-namespace-impls" -version = "7.0.0-beta.13" -source = "git+https://github.com/lance-format/lance.git?tag=v7.0.0-beta.13#929166e3ff51ed61b1fa42de2c63feaf51967ea1" +version = "7.2.0-beta.1" +source = "git+https://github.com/lance-format/lance.git?tag=v7.2.0-beta.1#b9995aba6115e8e4bc43179a45cbd0f9a170f305" dependencies = [ "arrow", "arrow-ipc", @@ -4936,9 +5029,9 @@ dependencies = [ [[package]] name = "lance-namespace-reqwest-client" -version = "0.7.6" +version = "0.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f65e31bdaa13e01dab6e7cf566da31df243c34a542f0d915d3601ec0e01e61d2" +checksum = "6369eee4682fb11edf538388b43c61ce288b8302fe89bb40944d7daa7faaae99" dependencies = [ "reqwest 0.12.28", "serde", @@ -4948,10 +5041,25 @@ dependencies = [ "url", ] +[[package]] +name = "lance-select" +version = "7.2.0-beta.1" +source = "git+https://github.com/lance-format/lance.git?tag=v7.2.0-beta.1#b9995aba6115e8e4bc43179a45cbd0f9a170f305" +dependencies = [ + "arrow-array", + "arrow-buffer", + "byteorder", + "bytes", + "deepsize", + "itertools 0.13.0", + "lance-core", + "roaring", +] + [[package]] name = "lance-table" -version = "7.0.0-beta.13" -source = "git+https://github.com/lance-format/lance.git?tag=v7.0.0-beta.13#929166e3ff51ed61b1fa42de2c63feaf51967ea1" +version = "7.2.0-beta.1" +source = "git+https://github.com/lance-format/lance.git?tag=v7.2.0-beta.1#b9995aba6115e8e4bc43179a45cbd0f9a170f305" dependencies = [ "arrow", "arrow-array", @@ -4970,6 +5078,7 @@ dependencies = [ "lance-core", "lance-file", "lance-io", + "lance-select", "log", "object_store", "prost", @@ -4990,8 +5099,8 @@ dependencies = [ [[package]] name = "lance-testing" -version = "7.0.0-beta.13" -source = "git+https://github.com/lance-format/lance.git?tag=v7.0.0-beta.13#929166e3ff51ed61b1fa42de2c63feaf51967ea1" +version = "7.2.0-beta.1" +source = "git+https://github.com/lance-format/lance.git?tag=v7.2.0-beta.1#b9995aba6115e8e4bc43179a45cbd0f9a170f305" dependencies = [ "arrow-array", "arrow-schema", @@ -5002,9 +5111,10 @@ dependencies = [ [[package]] name = "lance-tokenizer" -version = "7.0.0-beta.13" -source = "git+https://github.com/lance-format/lance.git?tag=v7.0.0-beta.13#929166e3ff51ed61b1fa42de2c63feaf51967ea1" +version = "7.2.0-beta.1" +source = "git+https://github.com/lance-format/lance.git?tag=v7.2.0-beta.1#b9995aba6115e8e4bc43179a45cbd0f9a170f305" dependencies = [ + "icu_segmenter", "jieba-rs", "lindera", "rust-stemmers", @@ -5014,7 +5124,7 @@ dependencies = [ [[package]] name = "lancedb" -version = "0.30.0-beta.0" +version = "0.30.0-beta.1" dependencies = [ "ahash", "anyhow", @@ -5051,7 +5161,7 @@ dependencies = [ "futures", "half", "hf-hub", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "lance", "lance-arrow", @@ -5084,6 +5194,7 @@ dependencies = [ "serde", "serde_json", "serde_with", + "serial_test", "snafu 0.8.9", "tempfile", "test-log", @@ -5096,7 +5207,7 @@ dependencies = [ [[package]] name = "lancedb-nodejs" -version = "0.30.0-beta.0" +version = "0.30.0-beta.1" dependencies = [ "arrow-array", "arrow-buffer", @@ -5119,7 +5230,7 @@ dependencies = [ [[package]] name = "lancedb-python" -version = "0.33.0-beta.0" +version = "0.33.0-beta.1" dependencies = [ "arrow", "async-trait", @@ -5330,9 +5441,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.29" +version = "0.4.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +checksum = "616ec5685824bcc94416c6d4a7a446eea774a31efd7062c8480ba6fd06d7a6e5" [[package]] name = "loom" @@ -5902,7 +6013,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "http 1.4.0", + "http 1.4.1", "http-body-util", "httparse", "humantime", @@ -6015,7 +6126,7 @@ dependencies = [ "base64 0.22.1", "bytes", "futures", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "jiff", "log", @@ -6040,7 +6151,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "048b1b29c503263bdd80a9afe46a68cd02ea9bd361185b1feab4b151078998e9" dependencies = [ "futures", - "http 1.4.0", + "http 1.4.1", "mea", "opendal-core", ] @@ -6084,7 +6195,7 @@ checksum = "7452bf3ec61cfd81ac9ad9ada17825931e9e371d44a045c6bfab9596c0a2ac3b" dependencies = [ "base64 0.22.1", "bytes", - "http 1.4.0", + "http 1.4.1", "log", "opendal-core", "opendal-service-azure-common", @@ -6104,7 +6215,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f9884c2d8cf8ba2bb077d79c877dac5863ba3bab9e2c9c1e41a2e0491404772" dependencies = [ "bytes", - "http 1.4.0", + "http 1.4.1", "log", "opendal-core", "opendal-service-azure-common", @@ -6122,7 +6233,7 @@ version = "0.56.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ffb0e45d6c8dcf66ce2da20e241bcb80e6e540e109a4ff20f318f6c9b4c54e0c" dependencies = [ - "http 1.4.0", + "http 1.4.1", "opendal-core", ] @@ -6134,7 +6245,7 @@ checksum = "70a49477a10163431896d106136117f5670717f9c9e49cf6f710528800c6633a" dependencies = [ "async-trait", "bytes", - "http 1.4.0", + "http 1.4.1", "log", "opendal-core", "percent-encoding", @@ -6155,7 +6266,7 @@ checksum = "7b2ab7a2a8a11dfe257ef4db5c0de798acbcd0d6429c37382dad2154bc06a388" dependencies = [ "bytes", "hf-xet", - "http 1.4.0", + "http 1.4.1", "log", "opendal-core", "percent-encoding", @@ -6171,7 +6282,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "29c8a917829ad06d21b639558532cb0101fe49b040d946d673a73018683fac05" dependencies = [ "bytes", - "http 1.4.0", + "http 1.4.1", "log", "opendal-core", "quick-xml 0.38.4", @@ -6190,7 +6301,7 @@ dependencies = [ "base64 0.22.1", "bytes", "crc32c", - "http 1.4.0", + "http 1.4.1", "log", "md-5 0.10.6", "opendal-core", @@ -6934,6 +7045,8 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0103b1cef7ec0cf76490e969665504990193874ea05c85ff9bab8b911d0a0564" dependencies = [ + "serde_core", + "writeable", "zerovec", ] @@ -7597,7 +7710,7 @@ checksum = "57ac2757f3140aa2e213b554148ae0b52733e624fc6723f0cc6bb3d440176c95" dependencies = [ "anyhow", "form_urlencoded", - "http 1.4.0", + "http 1.4.1", "log", "percent-encoding", "reqsign-core", @@ -7615,7 +7728,7 @@ dependencies = [ "anyhow", "bytes", "form_urlencoded", - "http 1.4.0", + "http 1.4.1", "log", "percent-encoding", "quick-xml 0.39.4", @@ -7637,7 +7750,7 @@ dependencies = [ "base64 0.22.1", "bytes", "form_urlencoded", - "http 1.4.0", + "http 1.4.1", "jsonwebtoken", "log", "pem", @@ -7662,7 +7775,7 @@ dependencies = [ "futures", "hex", "hmac 0.12.1", - "http 1.4.0", + "http 1.4.1", "jiff", "log", "percent-encoding", @@ -7689,7 +7802,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "35cc609b49c69e76ecaceb775a03f792d1ed3e7755ab3548d4534fd801e3242e" dependencies = [ "form_urlencoded", - "http 1.4.0", + "http 1.4.1", "jsonwebtoken", "log", "percent-encoding", @@ -7714,7 +7827,7 @@ dependencies = [ "futures-core", "futures-util", "h2 0.4.14", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "hyper 1.9.0", @@ -7758,7 +7871,7 @@ dependencies = [ "bytes", "futures-core", "futures-util", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "hyper 1.9.0", @@ -7812,7 +7925,7 @@ checksum = "199dda04a536b532d0cc04d7979e39b1c763ea749bf91507017069c00b96056f" dependencies = [ "anyhow", "async-trait", - "http 1.4.0", + "http 1.4.1", "reqwest 0.13.3", "thiserror 2.0.18", "tower-service", @@ -8128,6 +8241,15 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "scc" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46e6f046b7fef48e2660c57ed794263155d713de679057f2d0c169bfc6e756cc" +dependencies = [ + "sdd", +] + [[package]] name = "schannel" version = "0.1.29" @@ -8194,6 +8316,12 @@ dependencies = [ "untrusted 0.9.0", ] +[[package]] +name = "sdd" +version = "3.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490dcfcbfef26be6800d11870ff2df8774fa6e86d047e3e8c8a76b25655e41ca" + [[package]] name = "sec1" version = "0.3.0" @@ -8285,9 +8413,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.149" +version = "1.0.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +checksum = "e8014e44b4736ed0538adeecded0fce2a272f22dc9578a7eb6b2d9993c74cfb9" dependencies = [ "itoa", "memchr", @@ -8384,6 +8512,32 @@ dependencies = [ "unsafe-libyaml", ] +[[package]] +name = "serial_test" +version = "3.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "911bd979bf1070a3f3aa7b691a3b3e9968f339ceeec89e08c280a8a22207a32f" +dependencies = [ + "futures-executor", + "futures-util", + "log", + "once_cell", + "parking_lot", + "scc", + "serial_test_derive", +] + +[[package]] +name = "serial_test_derive" +version = "3.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a7d91949b85b0d2fb687445e448b40d322b6b3e4af6b44a29b21d9a5f33e6d9" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "sha1" version = "0.10.6" @@ -8406,6 +8560,12 @@ dependencies = [ "digest 0.11.3", ] +[[package]] +name = "sha1_smol" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbfa15b3dddfee50a0fff136974b3e1bde555604ba463834a7eb7deb6417705d" + [[package]] name = "sha2" version = "0.10.9" @@ -9125,6 +9285,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8323304221c2a851516f22236c5722a72eaa19749016521d6dff0824447d96d" dependencies = [ "displaydoc", + "serde_core", "zerovec", ] @@ -9313,7 +9474,7 @@ checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5" dependencies = [ "bitflags 2.11.1", "bytes", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "pin-project-lite", @@ -9333,7 +9494,7 @@ dependencies = [ "bytes", "futures-core", "futures-util", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "pin-project-lite", @@ -9622,13 +9783,14 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.23.1" +version = "1.23.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddd74a9687298c6858e9b88ec8935ec45d22e8fd5e6394fa1bd4e99a87789c76" +checksum = "d258b83ceec21034727ecee8c382cfa6c3e133699b0742c64571814fb420c9f7" dependencies = [ "getrandom 0.4.2", "js-sys", "serde_core", + "sha1_smol", "wasm-bindgen", ] @@ -10353,7 +10515,7 @@ dependencies = [ "clap", "crc32fast", "futures", - "http 1.4.0", + "http 1.4.1", "hyper 1.9.0", "lazy_static", "more-asserts", @@ -10427,7 +10589,7 @@ dependencies = [ "chrono", "clap", "gearhash", - "http 1.4.0", + "http 1.4.1", "itertools 0.14.0", "lazy_static", "more-asserts", @@ -10592,6 +10754,7 @@ dependencies = [ "displaydoc", "yoke", "zerofrom", + "zerovec", ] [[package]] @@ -10600,6 +10763,7 @@ version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "90f911cbc359ab6af17377d242225f4d75119aec87ea711a880987b18cd7b239" dependencies = [ + "serde", "yoke", "zerofrom", "zerovec-derive", diff --git a/Cargo.toml b/Cargo.toml index 1c87a4f25..7ff946895 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,20 +13,20 @@ categories = ["database-implementations"] rust-version = "1.91.0" [workspace.dependencies] -lance = { "version" = "=7.0.0-beta.13", default-features = false, "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" } -lance-core = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" } -lance-datagen = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" } -lance-file = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" } -lance-io = { "version" = "=7.0.0-beta.13", default-features = false, "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" } -lance-index = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" } -lance-linalg = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" } -lance-namespace = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" } -lance-namespace-impls = { "version" = "=7.0.0-beta.13", default-features = false, "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" } -lance-table = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" } -lance-testing = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" } -lance-datafusion = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" } -lance-encoding = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" } -lance-arrow = { "version" = "=7.0.0-beta.13", "tag" = "v7.0.0-beta.13", "git" = "https://github.com/lance-format/lance.git" } +lance = { "version" = "=7.2.0-beta.1", default-features = false, "tag" = "v7.2.0-beta.1", "git" = "https://github.com/lance-format/lance.git" } +lance-core = { "version" = "=7.2.0-beta.1", "tag" = "v7.2.0-beta.1", "git" = "https://github.com/lance-format/lance.git" } +lance-datagen = { "version" = "=7.2.0-beta.1", "tag" = "v7.2.0-beta.1", "git" = "https://github.com/lance-format/lance.git" } +lance-file = { "version" = "=7.2.0-beta.1", "tag" = "v7.2.0-beta.1", "git" = "https://github.com/lance-format/lance.git" } +lance-io = { "version" = "=7.2.0-beta.1", default-features = false, "tag" = "v7.2.0-beta.1", "git" = "https://github.com/lance-format/lance.git" } +lance-index = { "version" = "=7.2.0-beta.1", "tag" = "v7.2.0-beta.1", "git" = "https://github.com/lance-format/lance.git" } +lance-linalg = { "version" = "=7.2.0-beta.1", "tag" = "v7.2.0-beta.1", "git" = "https://github.com/lance-format/lance.git" } +lance-namespace = { "version" = "=7.2.0-beta.1", "tag" = "v7.2.0-beta.1", "git" = "https://github.com/lance-format/lance.git" } +lance-namespace-impls = { "version" = "=7.2.0-beta.1", default-features = false, "tag" = "v7.2.0-beta.1", "git" = "https://github.com/lance-format/lance.git" } +lance-table = { "version" = "=7.2.0-beta.1", "tag" = "v7.2.0-beta.1", "git" = "https://github.com/lance-format/lance.git" } +lance-testing = { "version" = "=7.2.0-beta.1", "tag" = "v7.2.0-beta.1", "git" = "https://github.com/lance-format/lance.git" } +lance-datafusion = { "version" = "=7.2.0-beta.1", "tag" = "v7.2.0-beta.1", "git" = "https://github.com/lance-format/lance.git" } +lance-encoding = { "version" = "=7.2.0-beta.1", "tag" = "v7.2.0-beta.1", "git" = "https://github.com/lance-format/lance.git" } +lance-arrow = { "version" = "=7.2.0-beta.1", "tag" = "v7.2.0-beta.1", "git" = "https://github.com/lance-format/lance.git" } ahash = "0.8" # Note that this one does not include pyarrow arrow = { version = "58.0.0", optional = false } diff --git a/ci/check_lance_release.py b/ci/check_lance_release.py index e906dd489..47f1cdbde 100755 --- a/ci/check_lance_release.py +++ b/ci/check_lance_release.py @@ -112,25 +112,25 @@ def fetch_remote_tags() -> List[TagInfo]: "api", "-X", "GET", - f"repos/{LANCE_REPO}/git/refs/tags", - "--paginate", + f"repos/{LANCE_REPO}/releases", "--jq", - ".[].ref", + ".[].tag_name", + "-F", + "per_page=20", ] ) tags: List[TagInfo] = [] for line in output.splitlines(): - ref = line.strip() - if not ref.startswith("refs/tags/v"): + tag = line.strip() + if not tag.startswith("v"): continue - tag = ref.split("refs/tags/")[-1] version = tag.lstrip("v") try: tags.append(TagInfo(tag=tag, version=version, semver=parse_semver(version))) except ValueError: continue if not tags: - raise RuntimeError("No Lance tags could be parsed from GitHub API output") + raise RuntimeError("No Lance releases could be parsed from GitHub API output") return tags diff --git a/docs/src/java/java.md b/docs/src/java/java.md index f263a7dee..14581bee0 100644 --- a/docs/src/java/java.md +++ b/docs/src/java/java.md @@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`: com.lancedb lancedb-core - 0.30.0-beta.0 + 0.30.0-beta.1 ``` diff --git a/docs/src/js/classes/MergeInsertBuilder.md b/docs/src/js/classes/MergeInsertBuilder.md index ae601c9e2..ac0493bad 100644 --- a/docs/src/js/classes/MergeInsertBuilder.md +++ b/docs/src/js/classes/MergeInsertBuilder.md @@ -76,6 +76,57 @@ the query optimizer chooses a suboptimal path. *** +### useLsmWrite() + +```ts +useLsmWrite(useLsmWrite): MergeInsertBuilder +``` + +Controls whether the merge uses the MemWAL LSM write path. + +By default (unset), a `mergeInsert` on a table with an LSM write spec is +routed through Lance's MemWAL shard writer, and a table without one uses +the standard path. Pass `false` to force the standard path even when a +spec is set. Pass `true` to require a spec — `mergeInsert` rejects if none +is installed. + +#### Parameters + +* **useLsmWrite**: `boolean` + Whether to use the LSM write path. + +#### Returns + +[`MergeInsertBuilder`](MergeInsertBuilder.md) + +*** + +### validateSingleShard() + +```ts +validateSingleShard(validateSingleShard): MergeInsertBuilder +``` + +Controls how an LSM merge checks that its input targets a single shard. + +When a table has an LSM write spec, every row in a `mergeInsert` call must +route to the same shard. When `true` (the default), every row is inspected +to verify this. When `false`, only the first row is inspected and the +shard it routes to is used for the whole input — a faster path for callers +that have already pre-sharded their input. Has no effect on tables without +an LSM write spec. + +#### Parameters + +* **validateSingleShard**: `boolean` + Whether to check every row routes to one shard. Defaults to `true`. + +#### Returns + +[`MergeInsertBuilder`](MergeInsertBuilder.md) + +*** + ### whenMatchedUpdateAll() ```ts diff --git a/docs/src/js/classes/Table.md b/docs/src/js/classes/Table.md index 45fa13362..62b962daf 100644 --- a/docs/src/js/classes/Table.md +++ b/docs/src/js/classes/Table.md @@ -187,6 +187,25 @@ Any attempt to use the table after it is closed will result in an error. *** +### closeLsmWriters() + +```ts +abstract closeLsmWriters(): Promise +``` + +Drain and close any cached MemWAL shard writers held for this table. + +When an [LsmWriteSpec](../interfaces/LsmWriteSpec.md) is installed, `mergeInsert` opens MemWAL +shard writers and caches them for reuse across calls. This closes them, +flushing pending data; writers reopen lazily on the next `mergeInsert`. +It is a no-op when no writers are cached. + +#### Returns + +`Promise`<`void`> + +*** + ### countRows() ```ts diff --git a/docs/src/js/interfaces/ConnectionOptions.md b/docs/src/js/interfaces/ConnectionOptions.md index 1ad0e127a..de2083a9b 100644 --- a/docs/src/js/interfaces/ConnectionOptions.md +++ b/docs/src/js/interfaces/ConnectionOptions.md @@ -70,16 +70,20 @@ client used by manifest-enabled native connections. optional readConsistencyInterval: number; ``` -(For LanceDB OSS only): The interval, in seconds, at which to check for -updates to the table from other processes. If None, then consistency is not -checked. For performance reasons, this is the default. For strong -consistency, set this to zero seconds. Then every read will check for -updates from other processes. As a compromise, you can set this to a -non-zero value for eventual consistency. If more than that interval -has passed since the last check, then the table will be checked for updates. -Note: this consistency only applies to read operations. Write operations are +The interval, in seconds, at which to check for updates to the table +from other processes. If None, then consistency is not checked. For +performance reasons, this is the default. For strong consistency, set +this to zero seconds. Then every read will check for updates from other +processes. As a compromise, you can set this to a non-zero value for +eventual consistency. If more than that interval has passed since the +last check, then the table will be checked for updates. Note: this +consistency only applies to read operations. Write operations are always consistent. +Stronger consistency is not free. The smaller the interval, the more +often each read pays the cost of checking for updates against object +storage, raising per-read latency and cost. + *** ### region? diff --git a/docs/src/js/interfaces/LsmWriteSpec.md b/docs/src/js/interfaces/LsmWriteSpec.md index 017e819dc..8a588df6a 100644 --- a/docs/src/js/interfaces/LsmWriteSpec.md +++ b/docs/src/js/interfaces/LsmWriteSpec.md @@ -11,7 +11,10 @@ Specification selecting Lance's MemWAL LSM-style write path for `specType` is `"bucket"`, `"identity"`, or `"unsharded"`. For `"bucket"`, `column` and `numBuckets` are required; for `"identity"`, `column` is -required. +required and must be a deterministic function of the unenforced primary +key (every row with a given primary key must always produce the same +`column` value, or upserts of that key can land in different shards and a +stale version can win). ## Properties diff --git a/docs/src/js/interfaces/MergeResult.md b/docs/src/js/interfaces/MergeResult.md index d59049cb8..6114fabfa 100644 --- a/docs/src/js/interfaces/MergeResult.md +++ b/docs/src/js/interfaces/MergeResult.md @@ -32,6 +32,14 @@ numInsertedRows: number; *** +### numRows + +```ts +numRows: number; +``` + +*** + ### numUpdatedRows ```ts diff --git a/java/lancedb-core/pom.xml b/java/lancedb-core/pom.xml index efa414320..f46ae0c94 100644 --- a/java/lancedb-core/pom.xml +++ b/java/lancedb-core/pom.xml @@ -8,7 +8,7 @@ com.lancedb lancedb-parent - 0.30.0-beta.0 + 0.30.0-beta.1 ../pom.xml diff --git a/java/pom.xml b/java/pom.xml index 800183127..7b23cd52a 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -6,7 +6,7 @@ com.lancedb lancedb-parent - 0.30.0-beta.0 + 0.30.0-beta.1 pom ${project.artifactId} LanceDB Java SDK Parent POM @@ -28,7 +28,7 @@ UTF-8 15.0.0 - 7.0.0-beta.13 + 7.2.0-beta.1 false 2.30.0 1.7 diff --git a/nodejs/Cargo.toml b/nodejs/Cargo.toml index 29a255e9a..ce874d019 100644 --- a/nodejs/Cargo.toml +++ b/nodejs/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "lancedb-nodejs" edition.workspace = true -version = "0.30.0-beta.0" +version = "0.30.0-beta.1" publish = false license.workspace = true description.workspace = true diff --git a/nodejs/__test__/connection.test.ts b/nodejs/__test__/connection.test.ts index d195a0ca4..68180471a 100644 --- a/nodejs/__test__/connection.test.ts +++ b/nodejs/__test__/connection.test.ts @@ -171,18 +171,22 @@ describe("given a connection", () => { let manifestDir = tmpDir.name + "/test_manifest_paths_v2_empty.lance/_versions"; - readdirSync(manifestDir).forEach((file) => { - expect(file).toMatch(/^\d{20}\.manifest$/); - }); + readdirSync(manifestDir) + .filter((f) => f.endsWith(".manifest")) + .forEach((file) => { + expect(file).toMatch(/^\d{20}\.manifest$/); + }); table = (await db.createTable("test_manifest_paths_v2", [{ id: 1 }], { enableV2ManifestPaths: true, })) as LocalTable; expect(await table.usesV2ManifestPaths()).toBe(true); manifestDir = tmpDir.name + "/test_manifest_paths_v2.lance/_versions"; - readdirSync(manifestDir).forEach((file) => { - expect(file).toMatch(/^\d{20}\.manifest$/); - }); + readdirSync(manifestDir) + .filter((f) => f.endsWith(".manifest")) + .forEach((file) => { + expect(file).toMatch(/^\d{20}\.manifest$/); + }); }); it("should be able to migrate tables to the V2 manifest paths", async () => { @@ -199,16 +203,20 @@ describe("given a connection", () => { const manifestDir = tmpDir.name + "/test_manifest_path_migration.lance/_versions"; - readdirSync(manifestDir).forEach((file) => { - expect(file).toMatch(/^\d\.manifest$/); - }); + readdirSync(manifestDir) + .filter((f) => f.endsWith(".manifest")) + .forEach((file) => { + expect(file).toMatch(/^\d\.manifest$/); + }); await table.migrateManifestPathsV2(); expect(await table.usesV2ManifestPaths()).toBe(true); - readdirSync(manifestDir).forEach((file) => { - expect(file).toMatch(/^\d{20}\.manifest$/); - }); + readdirSync(manifestDir) + .filter((f) => f.endsWith(".manifest")) + .forEach((file) => { + expect(file).toMatch(/^\d{20}\.manifest$/); + }); }); }); diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index 7d43ca351..3be56d3c7 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -2625,3 +2625,97 @@ describe("setLsmWriteSpec / unsetLsmWriteSpec", () => { ).rejects.toThrow(); }); }); + +describe("LSM merge insert", () => { + let tmpDir: tmp.DirResult; + + beforeEach(() => { + tmpDir = tmp.dirSync({ unsafeCleanup: true }); + }); + afterEach(() => tmpDir.removeCallback()); + + async function bucketTable(conn: Connection): Promise { + // The primary key column must be non-nullable. + const table = await conn.createEmptyTable( + "t", + new arrow.Schema([ + new arrow.Field("id", new arrow.Utf8(), false), + new arrow.Field("value", new arrow.Float64(), true), + ]), + ); + await table.add([ + { id: "a", value: 1 }, + { id: "b", value: 2 }, + ]); + await table.setUnenforcedPrimaryKey("id"); + // numBuckets = 1: every row routes to the single bucket. + await table.setLsmWriteSpec({ + specType: "bucket", + column: "id", + numBuckets: 1, + }); + return table; + } + + it("routes merge_insert through the shard writer", async () => { + const conn = await connect(tmpDir.name); + const table = await bucketTable(conn); + + const res = await table + .mergeInsert("id") + .whenMatchedUpdateAll() + .whenNotMatchedInsertAll() + .execute([ + { id: "c", value: 3 }, + { id: "d", value: 4 }, + ]); + // LSM path: rows go to the MemWAL, so only numRows is populated. + expect(res.numRows).toBe(2); + expect(res.version).toBe(0); + expect(res.numInsertedRows).toBe(0); + + await table.closeLsmWriters(); + }); + + it("falls back to the standard path with useLsmWrite(false)", async () => { + const conn = await connect(tmpDir.name); + const table = await bucketTable(conn); + + const res = await table + .mergeInsert("id") + .whenNotMatchedInsertAll() + .useLsmWrite(false) + .execute([ + { id: "b", value: 9 }, + { id: "e", value: 5 }, + ]); + // Standard path commits: id="e" inserted ("b" already exists). + expect(res.numInsertedRows).toBe(1); + expect(await table.countRows()).toBe(3); + }); + + it("supports validateSingleShard(false)", async () => { + const conn = await connect(tmpDir.name); + const table = await bucketTable(conn); + + const res = await table + .mergeInsert("id") + .whenMatchedUpdateAll() + .whenNotMatchedInsertAll() + .validateSingleShard(false) + .execute([{ id: "f", value: 6 }]); + expect(res.numRows).toBe(1); + }); + + it("rejects a non-upsert merge under an LSM spec", async () => { + const conn = await connect(tmpDir.name); + const table = await bucketTable(conn); + + await expect( + table + .mergeInsert("id") + .whenNotMatchedInsertAll() + .execute([{ id: "g", value: 7 }]), + ).rejects.toThrow(); + }); +}); diff --git a/nodejs/lancedb/merge.ts b/nodejs/lancedb/merge.ts index dc9144fdf..08321427f 100644 --- a/nodejs/lancedb/merge.ts +++ b/nodejs/lancedb/merge.ts @@ -87,6 +87,41 @@ export class MergeInsertBuilder { this.#schema, ); } + /** + * Controls whether the merge uses the MemWAL LSM write path. + * + * By default (unset), a `mergeInsert` on a table with an LSM write spec is + * routed through Lance's MemWAL shard writer, and a table without one uses + * the standard path. Pass `false` to force the standard path even when a + * spec is set. Pass `true` to require a spec — `mergeInsert` rejects if none + * is installed. + * + * @param useLsmWrite - Whether to use the LSM write path. + */ + useLsmWrite(useLsmWrite: boolean): MergeInsertBuilder { + return new MergeInsertBuilder( + this.#native.useLsmWrite(useLsmWrite), + this.#schema, + ); + } + /** + * Controls how an LSM merge checks that its input targets a single shard. + * + * When a table has an LSM write spec, every row in a `mergeInsert` call must + * route to the same shard. When `true` (the default), every row is inspected + * to verify this. When `false`, only the first row is inspected and the + * shard it routes to is used for the whole input — a faster path for callers + * that have already pre-sharded their input. Has no effect on tables without + * an LSM write spec. + * + * @param validateSingleShard - Whether to check every row routes to one shard. Defaults to `true`. + */ + validateSingleShard(validateSingleShard: boolean): MergeInsertBuilder { + return new MergeInsertBuilder( + this.#native.validateSingleShard(validateSingleShard), + this.#schema, + ); + } /** * Executes the merge insert operation * diff --git a/nodejs/lancedb/table.ts b/nodejs/lancedb/table.ts index fe495392a..ae2e86995 100644 --- a/nodejs/lancedb/table.ts +++ b/nodejs/lancedb/table.ts @@ -161,7 +161,10 @@ export interface Version { * * `specType` is `"bucket"`, `"identity"`, or `"unsharded"`. For `"bucket"`, * `column` and `numBuckets` are required; for `"identity"`, `column` is - * required. + * required and must be a deterministic function of the unenforced primary + * key (every row with a given primary key must always produce the same + * `column` value, or upserts of that key can land in different shards and a + * stale version can win). */ export interface LsmWriteSpec { /** One of `"bucket"`, `"identity"`, or `"unsharded"`. */ @@ -567,6 +570,16 @@ export abstract class Table { * @returns {Promise} */ abstract unsetLsmWriteSpec(): Promise; + /** + * Drain and close any cached MemWAL shard writers held for this table. + * + * When an {@link LsmWriteSpec} is installed, `mergeInsert` opens MemWAL + * shard writers and caches them for reuse across calls. This closes them, + * flushing pending data; writers reopen lazily on the next `mergeInsert`. + * It is a no-op when no writers are cached. + * @returns {Promise} + */ + abstract closeLsmWriters(): Promise; /** Retrieve the version of the table */ abstract version(): Promise; @@ -1041,6 +1054,10 @@ export class LocalTable extends Table { return await this.inner.unsetLsmWriteSpec(); } + async closeLsmWriters(): Promise { + return await this.inner.closeLsmWriters(); + } + async version(): Promise { return await this.inner.version(); } diff --git a/nodejs/npm/darwin-arm64/package.json b/nodejs/npm/darwin-arm64/package.json index 0f7382778..f088fc28c 100644 --- a/nodejs/npm/darwin-arm64/package.json +++ b/nodejs/npm/darwin-arm64/package.json @@ -1,6 +1,6 @@ { "name": "@lancedb/lancedb-darwin-arm64", - "version": "0.30.0-beta.0", + "version": "0.30.0-beta.1", "os": ["darwin"], "cpu": ["arm64"], "main": "lancedb.darwin-arm64.node", diff --git a/nodejs/npm/linux-arm64-gnu/package.json b/nodejs/npm/linux-arm64-gnu/package.json index 32704daf4..0179aa5c1 100644 --- a/nodejs/npm/linux-arm64-gnu/package.json +++ b/nodejs/npm/linux-arm64-gnu/package.json @@ -1,6 +1,6 @@ { "name": "@lancedb/lancedb-linux-arm64-gnu", - "version": "0.30.0-beta.0", + "version": "0.30.0-beta.1", "os": ["linux"], "cpu": ["arm64"], "main": "lancedb.linux-arm64-gnu.node", diff --git a/nodejs/npm/linux-arm64-musl/package.json b/nodejs/npm/linux-arm64-musl/package.json index 51bdb4c4d..ac7a1eb5f 100644 --- a/nodejs/npm/linux-arm64-musl/package.json +++ b/nodejs/npm/linux-arm64-musl/package.json @@ -1,6 +1,6 @@ { "name": "@lancedb/lancedb-linux-arm64-musl", - "version": "0.30.0-beta.0", + "version": "0.30.0-beta.1", "os": ["linux"], "cpu": ["arm64"], "main": "lancedb.linux-arm64-musl.node", diff --git a/nodejs/npm/linux-x64-gnu/package.json b/nodejs/npm/linux-x64-gnu/package.json index 49521f1f0..9fbc99f9c 100644 --- a/nodejs/npm/linux-x64-gnu/package.json +++ b/nodejs/npm/linux-x64-gnu/package.json @@ -1,6 +1,6 @@ { "name": "@lancedb/lancedb-linux-x64-gnu", - "version": "0.30.0-beta.0", + "version": "0.30.0-beta.1", "os": ["linux"], "cpu": ["x64"], "main": "lancedb.linux-x64-gnu.node", diff --git a/nodejs/npm/linux-x64-musl/package.json b/nodejs/npm/linux-x64-musl/package.json index 1948a952c..bdb789c41 100644 --- a/nodejs/npm/linux-x64-musl/package.json +++ b/nodejs/npm/linux-x64-musl/package.json @@ -1,6 +1,6 @@ { "name": "@lancedb/lancedb-linux-x64-musl", - "version": "0.30.0-beta.0", + "version": "0.30.0-beta.1", "os": ["linux"], "cpu": ["x64"], "main": "lancedb.linux-x64-musl.node", diff --git a/nodejs/npm/win32-arm64-msvc/package.json b/nodejs/npm/win32-arm64-msvc/package.json index 823778741..0d45ccbae 100644 --- a/nodejs/npm/win32-arm64-msvc/package.json +++ b/nodejs/npm/win32-arm64-msvc/package.json @@ -1,6 +1,6 @@ { "name": "@lancedb/lancedb-win32-arm64-msvc", - "version": "0.30.0-beta.0", + "version": "0.30.0-beta.1", "os": [ "win32" ], diff --git a/nodejs/npm/win32-x64-msvc/package.json b/nodejs/npm/win32-x64-msvc/package.json index 45afb9172..8f08bdffe 100644 --- a/nodejs/npm/win32-x64-msvc/package.json +++ b/nodejs/npm/win32-x64-msvc/package.json @@ -1,6 +1,6 @@ { "name": "@lancedb/lancedb-win32-x64-msvc", - "version": "0.30.0-beta.0", + "version": "0.30.0-beta.1", "os": ["win32"], "cpu": ["x64"], "main": "lancedb.win32-x64-msvc.node", diff --git a/nodejs/package-lock.json b/nodejs/package-lock.json index b8f63d9ab..a309e4347 100644 --- a/nodejs/package-lock.json +++ b/nodejs/package-lock.json @@ -1,12 +1,12 @@ { "name": "@lancedb/lancedb", - "version": "0.30.0-beta.0", + "version": "0.30.0-beta.1", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@lancedb/lancedb", - "version": "0.30.0-beta.0", + "version": "0.30.0-beta.1", "cpu": [ "x64", "arm64" diff --git a/nodejs/package.json b/nodejs/package.json index 980464129..a74c14287 100644 --- a/nodejs/package.json +++ b/nodejs/package.json @@ -11,7 +11,7 @@ "ann" ], "private": false, - "version": "0.30.0-beta.0", + "version": "0.30.0-beta.1", "main": "dist/index.js", "exports": { ".": "./dist/index.js", diff --git a/nodejs/src/lib.rs b/nodejs/src/lib.rs index f241fb81f..53c630c93 100644 --- a/nodejs/src/lib.rs +++ b/nodejs/src/lib.rs @@ -24,15 +24,19 @@ mod util; #[napi(object)] #[derive(Debug)] pub struct ConnectionOptions { - /// (For LanceDB OSS only): The interval, in seconds, at which to check for - /// updates to the table from other processes. If None, then consistency is not - /// checked. For performance reasons, this is the default. For strong - /// consistency, set this to zero seconds. Then every read will check for - /// updates from other processes. As a compromise, you can set this to a - /// non-zero value for eventual consistency. If more than that interval - /// has passed since the last check, then the table will be checked for updates. - /// Note: this consistency only applies to read operations. Write operations are + /// The interval, in seconds, at which to check for updates to the table + /// from other processes. If None, then consistency is not checked. For + /// performance reasons, this is the default. For strong consistency, set + /// this to zero seconds. Then every read will check for updates from other + /// processes. As a compromise, you can set this to a non-zero value for + /// eventual consistency. If more than that interval has passed since the + /// last check, then the table will be checked for updates. Note: this + /// consistency only applies to read operations. Write operations are /// always consistent. + /// + /// Stronger consistency is not free. The smaller the interval, the more + /// often each read pays the cost of checking for updates against object + /// storage, raising per-read latency and cost. pub read_consistency_interval: Option, /// (For LanceDB OSS only): configuration for object storage. /// diff --git a/nodejs/src/merge.rs b/nodejs/src/merge.rs index 98d637fb3..5ba9846bc 100644 --- a/nodejs/src/merge.rs +++ b/nodejs/src/merge.rs @@ -50,6 +50,20 @@ impl NativeMergeInsertBuilder { this } + #[napi] + pub fn use_lsm_write(&self, use_lsm_write: bool) -> Self { + let mut this = self.clone(); + this.inner.use_lsm_write(use_lsm_write); + this + } + + #[napi] + pub fn validate_single_shard(&self, validate_single_shard: bool) -> Self { + let mut this = self.clone(); + this.inner.validate_single_shard(validate_single_shard); + this + } + #[napi(catch_unwind)] pub async fn execute(&self, buf: Buffer) -> napi::Result { let data = ipc_file_to_batches(buf.to_vec()) diff --git a/nodejs/src/table.rs b/nodejs/src/table.rs index 4c5424bc9..16cde35d8 100644 --- a/nodejs/src/table.rs +++ b/nodejs/src/table.rs @@ -391,6 +391,11 @@ impl Table { .default_error() } + #[napi(catch_unwind)] + pub async fn close_lsm_writers(&self) -> napi::Result<()> { + self.inner_ref()?.close_lsm_writers().await.default_error() + } + #[napi(catch_unwind)] pub async fn version(&self) -> napi::Result { self.inner_ref()? @@ -940,6 +945,7 @@ pub struct MergeResult { pub num_updated_rows: i64, pub num_deleted_rows: i64, pub num_attempts: i64, + pub num_rows: i64, } impl From for MergeResult { @@ -950,6 +956,7 @@ impl From for MergeResult { num_updated_rows: value.num_updated_rows as i64, num_deleted_rows: value.num_deleted_rows as i64, num_attempts: value.num_attempts as i64, + num_rows: value.num_rows as i64, } } } diff --git a/python/.bumpversion.toml b/python/.bumpversion.toml index 652ebecd4..8982a7eb9 100644 --- a/python/.bumpversion.toml +++ b/python/.bumpversion.toml @@ -1,5 +1,5 @@ [tool.bumpversion] -current_version = "0.33.0-beta.0" +current_version = "0.33.0-beta.1" parse = """(?x) (?P0|[1-9]\\d*)\\. (?P0|[1-9]\\d*)\\. diff --git a/python/Cargo.toml b/python/Cargo.toml index af944bc97..272c5a5c8 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lancedb-python" -version = "0.33.0-beta.0" +version = "0.33.0-beta.1" publish = false edition.workspace = true description = "Python bindings for LanceDB" diff --git a/python/python/lancedb/__init__.py b/python/python/lancedb/__init__.py index be7a2b0fd..e748e1402 100644 --- a/python/python/lancedb/__init__.py +++ b/python/python/lancedb/__init__.py @@ -94,7 +94,6 @@ def connect( host_override: str, optional The override url for LanceDB Cloud. read_consistency_interval: timedelta, default None - (For LanceDB OSS only) The interval at which to check for updates to the table from other processes. If None, then consistency is not checked. For performance reasons, this is the default. For strong consistency, set this to @@ -104,6 +103,10 @@ def connect( the last check, then the table will be checked for updates. Note: this consistency only applies to read operations. Write operations are always consistent. + + Stronger consistency is not free. The smaller the interval, the more + often each read pays the cost of checking for updates against object + storage, raising per-read latency and cost. client_config: ClientConfig or dict, optional Configuration options for the LanceDB Cloud HTTP client. If a dict, then the keys are the attributes of the ClientConfig class. If None, then the @@ -147,6 +150,13 @@ def connect( >>> db = lancedb.connect("s3://my-bucket/lancedb", ... storage_options={"aws_access_key_id": "***"}) + For tests and temporary data, use an in-memory database: + + >>> db = lancedb.connect("memory://") + + In-memory databases are not persisted. Tables are dropped when the last + connection or table handle referencing them is closed. + Connect to LanceDB cloud: >>> db = lancedb.connect("db://my_database", api_key="ldb_...", @@ -210,6 +220,7 @@ def connect( request_thread_pool=request_thread_pool, client_config=client_config, storage_options=storage_options, + read_consistency_interval=read_consistency_interval, **kwargs, ) _check_s3_bucket_with_dots(str(uri), storage_options) @@ -345,7 +356,6 @@ async def connect_async( host_override: str, optional The override url for LanceDB Cloud. read_consistency_interval: timedelta, default None - (For LanceDB OSS only) The interval at which to check for updates to the table from other processes. If None, then consistency is not checked. For performance reasons, this is the default. For strong consistency, set this to @@ -355,6 +365,10 @@ async def connect_async( the last check, then the table will be checked for updates. Note: this consistency only applies to read operations. Write operations are always consistent. + + Stronger consistency is not free. The smaller the interval, the more + often each read pays the cost of checking for updates against object + storage, raising per-read latency and cost. client_config: ClientConfig or dict, optional Configuration options for the LanceDB Cloud HTTP client. If a dict, then the keys are the attributes of the ClientConfig class. If None, then the @@ -387,6 +401,8 @@ async def connect_async( ... db = await lancedb.connect_async("s3://my-bucket/lancedb", ... storage_options={ ... "aws_access_key_id": "***"}) + ... # For tests and temporary data, use an in-memory database + ... db = await lancedb.connect_async("memory://") ... # Connect to LanceDB cloud ... db = await lancedb.connect_async("db://my_database", api_key="ldb_...", ... client_config={ diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index db28e0fc8..0148f6575 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -220,6 +220,7 @@ class Table: async def set_unenforced_primary_key(self, columns: List[str]) -> None: ... async def set_lsm_write_spec(self, spec: LsmWriteSpec) -> None: ... async def unset_lsm_write_spec(self) -> None: ... + async def close_lsm_writers(self) -> None: ... @property def tags(self) -> Tags: ... def query(self) -> Query: ... @@ -420,6 +421,7 @@ class MergeResult: num_inserted_rows: int num_deleted_rows: int num_attempts: int + num_rows: int class LsmWriteSpec: """Specification selecting Lance's MemWAL LSM-style write path for diff --git a/python/python/lancedb/index.py b/python/python/lancedb/index.py index f3f4d6a6e..67656d8a3 100644 --- a/python/python/lancedb/index.py +++ b/python/python/lancedb/index.py @@ -281,6 +281,9 @@ class HnswPq: m: int = 20 ef_construction: int = 300 target_partition_size: Optional[int] = None + # Name of the accelerator (e.g. "cuda") to use for IVF training. When set, + # create_index() dispatches to pylance to build the index on the accelerator. + accelerator: Optional[str] = None @dataclass @@ -386,6 +389,9 @@ class HnswSq: m: int = 20 ef_construction: int = 300 target_partition_size: Optional[int] = None + # Name of the accelerator (e.g. "cuda") to use for IVF training. When set, + # create_index() dispatches to pylance to build the index on the accelerator. + accelerator: Optional[str] = None @dataclass @@ -579,6 +585,9 @@ class IvfFlat: max_iterations: int = 50 sample_rate: int = 256 target_partition_size: Optional[int] = None + # Name of the accelerator (e.g. "cuda") to use for IVF training. When set, + # create_index() dispatches to pylance to build the index on the accelerator. + accelerator: Optional[str] = None @dataclass @@ -609,6 +618,9 @@ class IvfSq: max_iterations: int = 50 sample_rate: int = 256 target_partition_size: Optional[int] = None + # Name of the accelerator (e.g. "cuda") to use for IVF training. When set, + # create_index() dispatches to pylance to build the index on the accelerator. + accelerator: Optional[str] = None @dataclass @@ -739,6 +751,9 @@ class IvfPq: max_iterations: int = 50 sample_rate: int = 256 target_partition_size: Optional[int] = None + # Name of the accelerator (e.g. "cuda") to use for IVF training. When set, + # create_index() dispatches to pylance to build the index on the accelerator. + accelerator: Optional[str] = None @dataclass @@ -792,6 +807,9 @@ class IvfRq: max_iterations: int = 50 sample_rate: int = 256 target_partition_size: Optional[int] = None + # Name of the accelerator (e.g. "cuda") to use for IVF training. When set, + # create_index() dispatches to pylance to build the index on the accelerator. + accelerator: Optional[str] = None __all__ = [ diff --git a/python/python/lancedb/merge.py b/python/python/lancedb/merge.py index b2564740c..6085f5a06 100644 --- a/python/python/lancedb/merge.py +++ b/python/python/lancedb/merge.py @@ -34,6 +34,8 @@ class LanceMergeInsertBuilder(object): self._when_not_matched_by_source_condition = None self._timeout = None self._use_index = True + self._use_lsm_write = None + self._validate_single_shard = None def when_matched_update_all( self, *, where: Optional[str] = None @@ -96,6 +98,46 @@ class LanceMergeInsertBuilder(object): self._use_index = use_index return self + def use_lsm_write(self, use_lsm_write: bool) -> LanceMergeInsertBuilder: + """ + Controls whether the merge uses the MemWAL LSM write path. + + By default (unset), a `merge_insert` on a table with an LSM write spec + is routed through Lance's MemWAL shard writer, and a table without one + uses the standard path. Pass `False` to force the standard path even + when a spec is set. Pass `True` to require a spec — `merge_insert` + raises an error if none is installed. + + Parameters + ---------- + use_lsm_write: bool + Whether to use the LSM write path. + """ + self._use_lsm_write = use_lsm_write + return self + + def validate_single_shard( + self, validate_single_shard: bool + ) -> LanceMergeInsertBuilder: + """ + Controls how an LSM merge checks that its input targets a single shard. + + When a table has an LSM write spec, every row in a `merge_insert` call + must route to the same shard. When `True` (the default), every row is + inspected to verify this. When `False`, only the first row is inspected + and the shard it routes to is used for the whole input — a faster path + for callers that have already pre-sharded their input. + + Has no effect on tables without an LSM write spec. + + Parameters + ---------- + validate_single_shard: bool + Whether to check every row routes to one shard. Defaults to `True`. + """ + self._validate_single_shard = validate_single_shard + return self + def execute( self, new_data: DATA, diff --git a/python/python/lancedb/remote/db.py b/python/python/lancedb/remote/db.py index 8f1aeda66..4421b057c 100644 --- a/python/python/lancedb/remote/db.py +++ b/python/python/lancedb/remote/db.py @@ -109,6 +109,7 @@ class RemoteDBConnection(DBConnection): connection_timeout: Optional[float] = None, read_timeout: Optional[float] = None, storage_options: Optional[Dict[str, str]] = None, + read_consistency_interval: Optional[timedelta] = None, ): """Connect to a remote LanceDB database.""" if isinstance(client_config, dict): @@ -167,6 +168,7 @@ class RemoteDBConnection(DBConnection): host_override=host_override, client_config=client_config, storage_options=storage_options, + read_consistency_interval=read_consistency_interval, ) ) diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index 189042898..019f91044 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -2,12 +2,25 @@ # SPDX-FileCopyrightText: Copyright The LanceDB Authors from datetime import timedelta +import deprecation import logging from functools import cached_property import os -from typing import Any, Callable, Dict, Iterable, List, Optional, Union, Literal +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Union, + Literal, + overload, +) import warnings +from lancedb import __version__ + from lancedb._lancedb import ( AddColumnsResult, AddResult, @@ -33,6 +46,7 @@ from lancedb.index import ( LabelList, ) from lancedb.remote.db import LOOP +from lancedb.table import IndexConfigType, KNOWN_METRICS import pyarrow as pa from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME @@ -195,6 +209,11 @@ class RemoteTable(Table): """List all the stats of a specified index""" return LOOP.run(self._table.index_stats(index_uuid)) + @deprecation.deprecated( + deprecated_in="0.25.0", + current_version=__version__, + details="Use create_index() with config=BTree()/Bitmap()/LabelList() instead.", + ) def create_scalar_index( self, column: str, @@ -204,7 +223,12 @@ class RemoteTable(Table): wait_timeout: Optional[timedelta] = None, name: Optional[str] = None, ): - """Creates a scalar index + """Creates a scalar index. + + .. deprecated:: 0.25.0 + Use :meth:`create_index` with a BTree, Bitmap, or LabelList config instead. + Example: ``table.create_index("column", config=BTree())`` + Parameters ---------- column : str @@ -235,6 +259,11 @@ class RemoteTable(Table): ) ) + @deprecation.deprecated( + deprecated_in="0.25.0", + current_version=__version__, + details="Use create_index() with config=FTS() instead.", + ) def create_fts_index( self, column: str, @@ -255,6 +284,12 @@ class RemoteTable(Table): prefix_only: bool = False, name: Optional[str] = None, ): + """Create a full-text search index on a column. + + .. deprecated:: 0.25.0 + Use :meth:`create_index` with an FTS config instead. + Example: ``table.create_index("text_column", config=FTS())`` + """ config = FTS( with_position=with_position, base_tokenizer=base_tokenizer, @@ -278,9 +313,43 @@ class RemoteTable(Table): ) ) + # New unified API overload + @overload def create_index( self, - metric="l2", + column: str, + /, + *, + config: IndexConfigType, + wait_timeout: Optional[timedelta] = ..., + name: Optional[str] = ..., + train: bool = ..., + ) -> None: ... + + # Legacy API overload (deprecated) + @overload + def create_index( + self, + metric: Literal["l2", "cosine", "dot", "hamming"] = ..., + vector_column_name: str = ..., + index_cache_size: Optional[int] = ..., + num_partitions: Optional[int] = ..., + num_sub_vectors: Optional[int] = ..., + replace: Optional[bool] = ..., + accelerator: Optional[str] = ..., + index_type: Literal[ + "VECTOR", "IVF_FLAT", "IVF_SQ", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ" + ] = ..., + wait_timeout: Optional[timedelta] = ..., + *, + num_bits: int = ..., + name: Optional[str] = ..., + train: bool = ..., + ) -> None: ... + + def create_index( + self, + metric: str = "l2", vector_column_name: str = VECTOR_COLUMN_NAME, index_cache_size: Optional[int] = None, num_partitions: Optional[int] = None, @@ -291,89 +360,113 @@ class RemoteTable(Table): wait_timeout: Optional[timedelta] = None, *, num_bits: int = 8, + config: Optional[IndexConfigType] = None, name: Optional[str] = None, train: bool = True, ): - """Create an index on the table. + """Create an index on a column. - Parameters - ---------- - metric : str - The metric to use for the index. Default is "l2". - vector_column_name : str - The name of the vector column. Default is "vector". + This method supports both the new unified API and the legacy API + for backwards compatibility. The new API takes the column name as the + first positional argument and an index configuration object via + ``config``; the legacy API takes the distance metric as the first + argument plus separate ``vector_column_name`` / ``num_partitions`` / + etc. parameters, and emits a ``DeprecationWarning``. Examples -------- - >>> import lancedb - >>> import uuid - >>> from lancedb.schema import vector - >>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP - ... region="...") # doctest: +SKIP - >>> table_name = uuid.uuid4().hex - >>> schema = pa.schema( - ... [ - ... pa.field("id", pa.uint32(), False), - ... pa.field("vector", vector(128), False), - ... pa.field("s", pa.string(), False), - ... ] + New API (recommended): + + >>> table.create_index( # doctest: +SKIP + ... "vector", config=IvfPq(distance_type="l2") ... ) - >>> table = db.create_table( # doctest: +SKIP - ... table_name, # doctest: +SKIP - ... schema=schema, # doctest: +SKIP + >>> table.create_index("category", config=BTree()) # doctest: +SKIP + >>> table.create_index("content", config=FTS()) # doctest: +SKIP + + Legacy API (deprecated): + + >>> table.create_index( # doctest: +SKIP + ... "l2", vector_column_name="vector" ... ) - >>> table.create_index("l2", "vector") # doctest: +SKIP """ + # Detect whether this is a legacy API call + is_legacy = self._is_legacy_create_index_call( + metric, + config, + num_partitions, + num_sub_vectors, + vector_column_name, + accelerator, + index_cache_size, + replace, + ) - if accelerator is not None: - logging.warning( - "GPU accelerator is not yet supported on LanceDB cloud." - "If you have 100M+ vectors to index," - "please contact us at contact@lancedb.com" - ) - if replace is not None: - logging.warning( - "replace is not supported on LanceDB cloud." - "Existing indexes will always be replaced." + if is_legacy: + warnings.warn( + "The create_index() API with metric/num_partitions parameters is " + "deprecated and will be removed in a future version. " + "Please migrate to the new unified API:\n" + " # Old (deprecated):\n" + " table.create_index('l2', vector_column_name='my_vector')\n" + " # New (recommended):\n" + " table.create_index('my_vector', config=IvfPq(distance_type='l2'))", + DeprecationWarning, + stacklevel=2, ) - index_type = index_type.upper() - if index_type == "VECTOR" or index_type == "IVF_PQ": - config = IvfPq( - distance_type=metric, - num_partitions=num_partitions, - num_sub_vectors=num_sub_vectors, - num_bits=num_bits, - ) - elif index_type == "IVF_RQ": - config = IvfRq( - distance_type=metric, - num_partitions=num_partitions, - num_bits=num_bits, - ) - elif index_type == "IVF_SQ": - config = IvfSq(distance_type=metric, num_partitions=num_partitions) - elif index_type == "IVF_HNSW_PQ": - raise ValueError( - "IVF_HNSW_PQ is not supported on LanceDB cloud." - "Please use IVF_HNSW_SQ instead." - ) - elif index_type == "IVF_HNSW_SQ": - config = HnswSq(distance_type=metric, num_partitions=num_partitions) - elif index_type == "IVF_HNSW_FLAT": - config = HnswFlat(distance_type=metric, num_partitions=num_partitions) - elif index_type == "IVF_FLAT": - config = IvfFlat(distance_type=metric, num_partitions=num_partitions) + column = vector_column_name + + if accelerator is not None: + logging.warning( + "GPU accelerator is not yet supported on LanceDB cloud." + "If you have 100M+ vectors to index," + "please contact us at contact@lancedb.com" + ) + if replace is not None: + logging.warning( + "replace is not supported on LanceDB cloud." + "Existing indexes will always be replaced." + ) + + idx_type = index_type.upper() + if idx_type == "VECTOR" or idx_type == "IVF_PQ": + config = IvfPq( + distance_type=metric, + num_partitions=num_partitions, + num_sub_vectors=num_sub_vectors, + num_bits=num_bits, + ) + elif idx_type == "IVF_RQ": + config = IvfRq( + distance_type=metric, + num_partitions=num_partitions, + num_bits=num_bits, + ) + elif idx_type == "IVF_SQ": + config = IvfSq(distance_type=metric, num_partitions=num_partitions) + elif idx_type == "IVF_HNSW_PQ": + raise ValueError( + "IVF_HNSW_PQ is not supported on LanceDB cloud." + "Please use IVF_HNSW_SQ instead." + ) + elif idx_type == "IVF_HNSW_SQ": + config = HnswSq(distance_type=metric, num_partitions=num_partitions) + elif idx_type == "IVF_HNSW_FLAT": + config = HnswFlat(distance_type=metric, num_partitions=num_partitions) + elif idx_type == "IVF_FLAT": + config = IvfFlat(distance_type=metric, num_partitions=num_partitions) + else: + raise ValueError( + f"Unknown vector index type: {idx_type}. Valid options are" + " 'IVF_FLAT', 'IVF_PQ', 'IVF_RQ', 'IVF_SQ'," + " 'IVF_HNSW_PQ', 'IVF_HNSW_SQ', 'IVF_HNSW_FLAT'" + ) else: - raise ValueError( - f"Unknown vector index type: {index_type}. Valid options are" - " 'IVF_FLAT', 'IVF_PQ', 'IVF_RQ', 'IVF_SQ'," - " 'IVF_HNSW_PQ', 'IVF_HNSW_SQ', 'IVF_HNSW_FLAT'" - ) + column = metric LOOP.run( self._table.create_index( - vector_column_name, + column, config=config, wait_timeout=wait_timeout, name=name, @@ -381,6 +474,37 @@ class RemoteTable(Table): ) ) + def _is_legacy_create_index_call( + self, + first_arg: str, + config: Optional[IndexConfigType], + num_partitions: Optional[int], + num_sub_vectors: Optional[int], + vector_column_name: str, + accelerator: Optional[str], + index_cache_size: Optional[int], + replace: Optional[bool], + ) -> bool: + """Detect if this is a legacy create_index call.""" + if config is not None: + return False + if any( + x is not None + for x in ( + num_partitions, + num_sub_vectors, + accelerator, + index_cache_size, + replace, + ) + ): + return True + if vector_column_name != VECTOR_COLUMN_NAME: + return True + if first_arg.lower() in KNOWN_METRICS: + return True + return False + def add( self, data: DATA, @@ -741,6 +865,10 @@ class RemoteTable(Table): """Not supported on LanceDB Cloud.""" return LOOP.run(self._table.unset_lsm_write_spec()) + def close_lsm_writers(self) -> None: + """No-op on LanceDB Cloud (no local shard writers).""" + return LOOP.run(self._table.close_lsm_writers()) + def drop_index(self, index_name: str): return LOOP.run(self._table.drop_index(index_name)) diff --git a/python/python/lancedb/rerankers/linear_combination.py b/python/python/lancedb/rerankers/linear_combination.py index 9f1d645c9..74f23ea61 100644 --- a/python/python/lancedb/rerankers/linear_combination.py +++ b/python/python/lancedb/rerankers/linear_combination.py @@ -102,8 +102,15 @@ class LinearCombinationReranker(Reranker): combined_list = [] for row_id, result in results.items(): + # Convert vector distance to a relevance score in [0, 1] where + # higher is better. Missing vector entries are penalised with + # `_invert_score(fill)` = 1 - fill (= 0.0 for the default fill=1). vector_score = self._invert_score(result.get("_distance", fill)) - fts_score = result.get("_score", fill) + # FTS scores (BM25) are already in a "higher = more relevant" space. + # Missing FTS entries are penalised symmetrically: we use + # `1 - fill` so that the same `fill` value drives both missing-vector + # and missing-FTS penalties in the same direction. + fts_score = result.get("_score", 1 - fill) result["_relevance_score"] = self._combine_score(vector_score, fts_score) combined_list.append(result) @@ -123,8 +130,12 @@ class LinearCombinationReranker(Reranker): return tbl def _combine_score(self, vector_score, fts_score): - # these scores represent distance - return 1 - (self.weight * vector_score + (1 - self.weight) * fts_score) + # Both vector_score (inverted distance) and fts_score are in a + # "higher = more relevant" space. A straight weighted average gives + # higher _relevance_score to better matches, as expected. + # Previously this returned `1 - (...)` which inverted the final + # ranking so that the *least* relevant document ranked first. + return self.weight * vector_score + (1 - self.weight) * fts_score def _invert_score(self, dist: float): # Invert the score between relevance and distance diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 3a9ae0801..2de369419 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -174,6 +174,24 @@ if TYPE_CHECKING: DistanceType, ) +# Type alias for index configuration objects +IndexConfigType = Union[ + IvfFlat, + IvfPq, + IvfSq, + IvfRq, + HnswFlat, + HnswPq, + HnswSq, + BTree, + Bitmap, + LabelList, + FTS, +] + +# Known distance metrics for legacy API detection +KNOWN_METRICS = {"l2", "cosine", "dot", "hamming"} + def _into_pyarrow_reader( data, schema: Optional[pa.Schema] = None @@ -807,11 +825,49 @@ class Table(ABC): """ raise NotImplementedError + # New unified API overload + @overload def create_index( self, - metric="l2", - num_partitions=256, - num_sub_vectors=96, + column: str, + /, + *, + config: IndexConfigType, + replace: bool = ..., + wait_timeout: Optional[timedelta] = ..., + name: Optional[str] = ..., + train: bool = ..., + ) -> None: ... + + # Legacy API overload (deprecated) + @overload + def create_index( + self, + metric: Literal["l2", "cosine", "dot", "hamming"] = ..., + num_partitions: Optional[int] = ..., + num_sub_vectors: Optional[int] = ..., + vector_column_name: str = ..., + replace: bool = ..., + accelerator: Optional[str] = ..., + index_cache_size: Optional[int] = ..., + *, + index_type: VectorIndexType = ..., + wait_timeout: Optional[timedelta] = ..., + num_bits: int = ..., + max_iterations: int = ..., + sample_rate: int = ..., + m: int = ..., + ef_construction: int = ..., + name: Optional[str] = ..., + train: bool = ..., + target_partition_size: Optional[int] = ..., + ) -> None: ... + + def create_index( + self, + metric: DistanceType = "l2", + num_partitions: Optional[int] = None, + num_sub_vectors: Optional[int] = None, vector_column_name: str = VECTOR_COLUMN_NAME, replace: bool = True, accelerator: Optional[str] = None, @@ -824,46 +880,53 @@ class Table(ABC): sample_rate: int = 256, m: int = 20, ef_construction: int = 300, + config: Optional[IndexConfigType] = None, name: Optional[str] = None, train: bool = True, target_partition_size: Optional[int] = None, ): - """Create an index on the table. + """Create an index on a column. + + This method supports both the new unified API and the legacy API + for backwards compatibility. The new API takes the column name as the + first positional argument and an index configuration object via + ``config``; the legacy API takes the distance metric as the first + argument plus separate ``vector_column_name`` / ``num_partitions`` / + etc. parameters, and emits a ``DeprecationWarning``. Parameters ---------- - metric: str, default "l2" - The distance metric to use when creating the index. - Valid values are "l2", "cosine", "dot", or "hamming". - l2 is euclidean distance. - Hamming is available only for binary vectors. - num_partitions: int, default 256 - The number of IVF partitions to use when creating the index. - Default is 256. - num_sub_vectors: int, default 96 - The number of PQ sub-vectors to use when creating the index. - Default is 96. - vector_column_name: str, default "vector" - The vector column name to create the index. - replace: bool, default True - - If True, replace the existing index if it exists. + metric : str + For new API: the column name to index. + For legacy API: the distance metric ("l2", "cosine", "dot", "hamming"). + config : IndexConfigType, optional + The index configuration object. If provided, uses the new unified API. + Can be one of: IvfFlat, IvfPq, IvfSq, IvfRq, HnswPq, HnswSq, + BTree, Bitmap, LabelList, FTS. + replace : bool, default True + Whether to replace an existing index on this column. + wait_timeout : timedelta, optional + Timeout to wait for async indexing to complete. + name : str, optional + Custom name for the index. + train : bool, default True + Whether to train the index with existing data. - - If False, raise an error if duplicate index exists. - accelerator: str, default None - If set, use the given accelerator to create the index. - Only support "cuda" for now. - index_cache_size : int, optional - The size of the index cache in number of entries. Default value is 256. - num_bits: int - The number of bits to encode sub-vectors. Only used with the IVF_PQ index. - Only 4 and 8 are supported. - wait_timeout: timedelta, optional - The timeout to wait if indexing is asynchronous. - name: str, optional - The name of the index. If not provided, a default name will be generated. - train: bool, default True - Whether to train the index with existing data. Vector indices always train - with existing data. + Examples + -------- + New API (recommended): + + >>> table.create_index( # doctest: +SKIP + ... "vector", config=IvfPq(distance_type="l2") + ... ) + >>> table.create_index("category", config=BTree()) # doctest: +SKIP + >>> table.create_index("content", config=FTS()) # doctest: +SKIP + + Legacy API (deprecated): + + >>> table.create_index( # doctest: +SKIP + ... "l2", vector_column_name="vector" + ... ) """ raise NotImplementedError @@ -1188,7 +1251,7 @@ class Table(ABC): ... .when_not_matched_insert_all() \\ ... .execute(new_data) >>> res - MergeResult(version=2, num_updated_rows=2, num_inserted_rows=1, num_deleted_rows=0, num_attempts=1) + MergeResult(version=2, num_updated_rows=2, num_inserted_rows=1, num_deleted_rows=0, num_attempts=1, num_rows=3) >>> # The order of new rows is non-deterministic since we use >>> # a hash-join as part of this operation and so we sort here >>> table.to_arrow().sort_by("a").to_pandas() @@ -2250,11 +2313,51 @@ class LanceTable(Table): dataset, allow_pyarrow_filter=False, batch_size=batch_size ) + # New unified API overload + @overload def create_index( self, - metric: DistanceType = "l2", - num_partitions=None, - num_sub_vectors=None, + column: str, + /, + *, + config: IndexConfigType, + replace: bool = ..., + wait_timeout: Optional[timedelta] = ..., + name: Optional[str] = ..., + train: bool = ..., + ) -> None: ... + + # Legacy API overload (deprecated) + @overload + def create_index( + self, + metric: Literal["l2", "cosine", "dot", "hamming"] = ..., + num_partitions: Optional[int] = ..., + num_sub_vectors: Optional[int] = ..., + vector_column_name: str = ..., + replace: bool = ..., + accelerator: Optional[str] = ..., + index_cache_size: Optional[int] = ..., + num_bits: int = ..., + index_type: Literal[ + "IVF_FLAT", "IVF_SQ", "IVF_PQ", "IVF_RQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ" + ] = ..., + max_iterations: int = ..., + sample_rate: int = ..., + m: int = ..., + ef_construction: int = ..., + *, + wait_timeout: Optional[timedelta] = ..., + name: Optional[str] = ..., + train: bool = ..., + target_partition_size: Optional[int] = ..., + ) -> None: ... + + def create_index( + self, + metric: str = "l2", + num_partitions: Optional[int] = None, + num_sub_vectors: Optional[int] = None, vector_column_name: str = VECTOR_COLUMN_NAME, replace: bool = True, accelerator: Optional[str] = None, @@ -2274,47 +2377,232 @@ class LanceTable(Table): m: int = 20, ef_construction: int = 300, *, + config: Optional[IndexConfigType] = None, + wait_timeout: Optional[timedelta] = None, name: Optional[str] = None, train: bool = True, target_partition_size: Optional[int] = None, ): - """Create an index on the table.""" - if accelerator is not None: - # accelerator is only supported through pylance. - self.to_lance().create_index( - column=vector_column_name, - index_type=index_type, + """Create an index on a column. + + This method supports both the new unified API and the legacy API + for backwards compatibility. The new API takes the column name as the + first positional argument and an index configuration object via + ``config``; the legacy API takes the distance metric as the first + argument plus separate ``vector_column_name`` / ``num_partitions`` / + etc. parameters, and emits a ``DeprecationWarning``. + + Parameters + ---------- + metric : str + For new API: the column name to index. + For legacy API: the distance metric ("l2", "cosine", "dot", "hamming"). + config : IndexConfigType, optional + The index configuration object. If provided, uses the new unified API. + Can be one of: IvfFlat, IvfPq, IvfSq, IvfRq, HnswPq, HnswSq, + BTree, Bitmap, LabelList, FTS. + replace : bool, default True + Whether to replace an existing index on this column. + wait_timeout : timedelta, optional + Timeout to wait for async indexing to complete. + name : str, optional + Custom name for the index. + train : bool, default True + Whether to train the index with existing data. + + Examples + -------- + New API (recommended): + + >>> table.create_index( # doctest: +SKIP + ... "vector", config=IvfPq(distance_type="l2") + ... ) + >>> table.create_index("category", config=BTree()) # doctest: +SKIP + >>> table.create_index("content", config=FTS()) # doctest: +SKIP + + Legacy API (deprecated): + + >>> table.create_index( # doctest: +SKIP + ... "l2", vector_column_name="vector" + ... ) + """ + # Detect whether this is a legacy API call + is_legacy = self._is_legacy_create_index_call( + metric, + config, + num_partitions, + num_sub_vectors, + vector_column_name, + accelerator, + index_cache_size, + ) + + if is_legacy: + warnings.warn( + "The create_index() API with metric/num_partitions parameters is " + "deprecated and will be removed in a future version. " + "Please migrate to the new unified API:\n" + " # Old (deprecated):\n" + " table.create_index('l2', vector_column_name='my_vector')\n" + " # New (recommended):\n" + " table.create_index('my_vector', config=IvfPq(distance_type='l2'))", + DeprecationWarning, + stacklevel=2, + ) + + # Legacy API: first arg is the distance metric + column = vector_column_name + + # Build config from legacy parameters + config = self._build_vector_config_from_legacy_params( metric=metric, + index_type=index_type, num_partitions=num_partitions, num_sub_vectors=num_sub_vectors, - replace=replace, - accelerator=accelerator, - index_cache_size=index_cache_size, num_bits=num_bits, + max_iterations=max_iterations, + sample_rate=sample_rate, m=m, ef_construction=ef_construction, target_partition_size=target_partition_size, + accelerator=accelerator, ) - self.checkout_latest() - return - elif index_type == "IVF_FLAT": - config = IvfFlat( + + # Handle accelerator through pylance + if accelerator is not None: + self.to_lance().create_index( + column=column, + index_type=index_type, + metric=metric, + num_partitions=num_partitions, + num_sub_vectors=num_sub_vectors, + replace=replace, + accelerator=accelerator, + index_cache_size=index_cache_size, + num_bits=num_bits, + m=m, + ef_construction=ef_construction, + target_partition_size=target_partition_size, + ) + self.checkout_latest() + return + else: + # New API: metric is the column name + column = metric + + # Check if config has accelerator set and dispatch to pylance + if config is not None and hasattr(config, "accelerator"): + acc = getattr(config, "accelerator", None) + if acc is not None: + # Dispatch to pylance for GPU acceleration + index_type_map = { + "IvfFlat": "IVF_FLAT", + "IvfSq": "IVF_SQ", + "IvfPq": "IVF_PQ", + "IvfRq": "IVF_RQ", + "HnswPq": "IVF_HNSW_PQ", + "HnswSq": "IVF_HNSW_SQ", + } + cfg_type = type(config).__name__ + lance_index_type = index_type_map.get(cfg_type, "IVF_PQ") + + self.to_lance().create_index( + column=column, + index_type=lance_index_type, + metric=getattr(config, "distance_type", "l2"), + num_partitions=getattr(config, "num_partitions", None), + num_sub_vectors=getattr(config, "num_sub_vectors", None), + replace=replace, + accelerator=acc, + num_bits=getattr(config, "num_bits", 8), + m=getattr(config, "m", 20), + ef_construction=getattr(config, "ef_construction", 300), + target_partition_size=getattr( + config, "target_partition_size", None + ), + ) + self.checkout_latest() + return + + return LOOP.run( + self._table.create_index( + column, + replace=replace, + config=config, + wait_timeout=wait_timeout, + name=name, + train=train, + ) + ) + + def _is_legacy_create_index_call( + self, + first_arg: str, + config: Optional[IndexConfigType], + num_partitions: Optional[int], + num_sub_vectors: Optional[int], + vector_column_name: str, + accelerator: Optional[str], + index_cache_size: Optional[int], + ) -> bool: + """Detect if this is a legacy create_index call.""" + # If config is provided, it's definitely the new API + if config is not None: + return False + + # If old-style parameters were explicitly set, it's legacy + if any( + x is not None + for x in (num_partitions, num_sub_vectors, accelerator, index_cache_size) + ): + return True + + # If vector_column_name differs from default, it's legacy + if vector_column_name != VECTOR_COLUMN_NAME: + return True + + # If first arg is a known metric, assume legacy + if first_arg.lower() in KNOWN_METRICS: + return True + + # Otherwise assume new API + return False + + def _build_vector_config_from_legacy_params( + self, + metric: str, + index_type: str, + num_partitions: Optional[int], + num_sub_vectors: Optional[int], + num_bits: int, + max_iterations: int, + sample_rate: int, + m: int, + ef_construction: int, + target_partition_size: Optional[int], + accelerator: Optional[str], + ) -> IndexConfigType: + """Build an index config object from legacy parameters.""" + if index_type == "IVF_FLAT": + return IvfFlat( distance_type=metric, num_partitions=num_partitions, max_iterations=max_iterations, sample_rate=sample_rate, target_partition_size=target_partition_size, + accelerator=accelerator, ) elif index_type == "IVF_SQ": - config = IvfSq( + return IvfSq( distance_type=metric, num_partitions=num_partitions, max_iterations=max_iterations, sample_rate=sample_rate, target_partition_size=target_partition_size, + accelerator=accelerator, ) elif index_type == "IVF_PQ": - config = IvfPq( + return IvfPq( distance_type=metric, num_partitions=num_partitions, num_sub_vectors=num_sub_vectors, @@ -2322,18 +2610,20 @@ class LanceTable(Table): max_iterations=max_iterations, sample_rate=sample_rate, target_partition_size=target_partition_size, + accelerator=accelerator, ) elif index_type == "IVF_RQ": - config = IvfRq( + return IvfRq( distance_type=metric, num_partitions=num_partitions, num_bits=num_bits, max_iterations=max_iterations, sample_rate=sample_rate, target_partition_size=target_partition_size, + accelerator=accelerator, ) elif index_type == "IVF_HNSW_PQ": - config = HnswPq( + return HnswPq( distance_type=metric, num_partitions=num_partitions, num_sub_vectors=num_sub_vectors, @@ -2343,9 +2633,10 @@ class LanceTable(Table): m=m, ef_construction=ef_construction, target_partition_size=target_partition_size, + accelerator=accelerator, ) elif index_type == "IVF_HNSW_SQ": - config = HnswSq( + return HnswSq( distance_type=metric, num_partitions=num_partitions, max_iterations=max_iterations, @@ -2353,9 +2644,10 @@ class LanceTable(Table): m=m, ef_construction=ef_construction, target_partition_size=target_partition_size, + accelerator=accelerator, ) elif index_type == "IVF_HNSW_FLAT": - config = HnswFlat( + return HnswFlat( distance_type=metric, num_partitions=num_partitions, max_iterations=max_iterations, @@ -2367,16 +2659,6 @@ class LanceTable(Table): else: raise ValueError(f"Unknown index type {index_type}") - return LOOP.run( - self._table.create_index( - vector_column_name, - replace=replace, - config=config, - name=name, - train=train, - ) - ) - def drop_index(self, name: str) -> None: """ Drops an index from the table @@ -2476,6 +2758,11 @@ class LanceTable(Table): """ return LOOP.run(self._table.latest_storage_options()) + @deprecation.deprecated( + deprecated_in="0.25.0", + current_version=__version__, + details="Use create_index() with config=BTree()/Bitmap()/LabelList() instead.", + ) def create_scalar_index( self, column: str, @@ -2484,6 +2771,12 @@ class LanceTable(Table): index_type: ScalarIndexType = "BTREE", name: Optional[str] = None, ): + """Create a scalar index on a column. + + .. deprecated:: 0.25.0 + Use :meth:`create_index` with a BTree, Bitmap, or LabelList config instead. + Example: ``table.create_index("column", config=BTree())`` + """ if index_type == "BTREE": config = BTree() elif index_type == "BITMAP": @@ -2496,6 +2789,11 @@ class LanceTable(Table): self._table.create_index(column, replace=replace, config=config, name=name) ) + @deprecation.deprecated( + deprecated_in="0.25.0", + current_version=__version__, + details="Use create_index() with config=FTS() instead.", + ) def create_fts_index( self, field_names: Union[str, List[str]], @@ -2519,6 +2817,12 @@ class LanceTable(Table): prefix_only: bool = False, name: Optional[str] = None, ): + """Create a full-text search index on a column. + + .. deprecated:: 0.25.0 + Use :meth:`create_index` with an FTS config instead. + Example: ``table.create_index("text_column", config=FTS())`` + """ self._ensure_no_legacy_fts_index() if use_tantivy: @@ -3297,6 +3601,11 @@ class LanceTable(Table): [`AsyncTable.unset_lsm_write_spec`][lancedb.AsyncTable.unset_lsm_write_spec].""" return LOOP.run(self._table.unset_lsm_write_spec()) + def close_lsm_writers(self) -> None: + """Close cached MemWAL shard writers. See + [`AsyncTable.close_lsm_writers`][lancedb.AsyncTable.close_lsm_writers].""" + return LOOP.run(self._table.close_lsm_writers()) + def uses_v2_manifest_paths(self) -> bool: """ Check if the table is using the new v2 manifest paths. @@ -3905,6 +4214,16 @@ class AsyncTable: """ await self._inner.unset_lsm_write_spec() + async def close_lsm_writers(self) -> None: + """Drain and close any cached MemWAL shard writers for this table. + + When an LSM write spec is installed, `merge_insert` opens MemWAL shard + writers and caches them for reuse across calls. This closes them, + flushing pending data; writers reopen lazily on the next + `merge_insert`. It is a no-op when no writers are cached. + """ + await self._inner.close_lsm_writers() + @property def name(self) -> str: """The name of the table.""" @@ -4355,7 +4674,7 @@ class AsyncTable: ... .when_not_matched_insert_all() \\ ... .execute(new_data) >>> res - MergeResult(version=2, num_updated_rows=2, num_inserted_rows=1, num_deleted_rows=0, num_attempts=1) + MergeResult(version=2, num_updated_rows=2, num_inserted_rows=1, num_deleted_rows=0, num_attempts=1, num_rows=3) >>> # The order of new rows is non-deterministic since we use >>> # a hash-join as part of this operation and so we sort here >>> table.to_arrow().sort_by("a").to_pandas() @@ -4735,6 +5054,8 @@ class AsyncTable: when_not_matched_by_source_condition=merge._when_not_matched_by_source_condition, timeout=merge._timeout, use_index=merge._use_index, + use_lsm_write=merge._use_lsm_write, + validate_single_shard=merge._validate_single_shard, ), ) diff --git a/python/python/tests/docs/test_merge_insert.py b/python/python/tests/docs/test_merge_insert.py index 228faa31b..adf812219 100644 --- a/python/python/tests/docs/test_merge_insert.py +++ b/python/python/tests/docs/test_merge_insert.py @@ -57,7 +57,7 @@ async def test_upsert_async(mem_db_async): await table.count_rows() # 3 res # MergeResult(version=2, num_updated_rows=1, - # num_inserted_rows=1, num_deleted_rows=0) + # num_inserted_rows=1, num_deleted_rows=0, num_rows=2) # --8<-- [end:upsert_basic_async] assert await table.count_rows() == 3 assert res.version == 2 @@ -86,7 +86,7 @@ def test_insert_if_not_exists(mem_db): table.count_rows() # 3 res # MergeResult(version=2, num_updated_rows=0, - # num_inserted_rows=1, num_deleted_rows=0) + # num_inserted_rows=1, num_deleted_rows=0, num_rows=1) # --8<-- [end:insert_if_not_exists] assert table.count_rows() == 3 assert res.version == 2 @@ -116,7 +116,7 @@ async def test_insert_if_not_exists_async(mem_db_async): await table.count_rows() # 3 res # MergeResult(version=2, num_updated_rows=0, - # num_inserted_rows=1, num_deleted_rows=0) + # num_inserted_rows=1, num_deleted_rows=0, num_rows=1) # --8<-- [end:insert_if_not_exists] assert await table.count_rows() == 3 assert res.version == 2 @@ -150,7 +150,7 @@ def test_replace_range(mem_db): table.count_rows("doc_id = 1") # 1 res # MergeResult(version=2, num_updated_rows=1, - # num_inserted_rows=0, num_deleted_rows=1) + # num_inserted_rows=0, num_deleted_rows=1, num_rows=1) # --8<-- [end:insert_if_not_exists] assert table.count_rows("doc_id = 1") == 1 assert res.version == 2 @@ -185,7 +185,7 @@ async def test_replace_range_async(mem_db_async): await table.count_rows("doc_id = 1") # 1 res # MergeResult(version=2, num_updated_rows=1, - # num_inserted_rows=0, num_deleted_rows=1) + # num_inserted_rows=0, num_deleted_rows=1, num_rows=1) # --8<-- [end:insert_if_not_exists] assert await table.count_rows("doc_id = 1") == 1 assert res.version == 2 diff --git a/python/python/tests/test_db.py b/python/python/tests/test_db.py index d3db372de..9495fb330 100644 --- a/python/python/tests/test_db.py +++ b/python/python/tests/test_db.py @@ -466,7 +466,8 @@ async def test_create_table_v2_manifest_paths_async(tmp_path): assert await tbl.uses_v2_manifest_paths() manifests_dir = tmp_path / "test_v2_manifest_paths.lance" / "_versions" for manifest in os.listdir(manifests_dir): - assert re.match(r"\d{20}\.manifest", manifest) + if manifest.endswith(".manifest"): + assert re.match(r"\d{20}\.manifest", manifest) # Start a table in V1 mode then migrate tbl = await db_no_v2_paths.create_table( @@ -476,13 +477,15 @@ async def test_create_table_v2_manifest_paths_async(tmp_path): assert not await tbl.uses_v2_manifest_paths() manifests_dir = tmp_path / "test_v2_migration.lance" / "_versions" for manifest in os.listdir(manifests_dir): - assert re.match(r"\d\.manifest", manifest) + if manifest.endswith(".manifest"): + assert re.match(r"\d\.manifest", manifest) await tbl.migrate_manifest_paths_v2() assert await tbl.uses_v2_manifest_paths() for manifest in os.listdir(manifests_dir): - assert re.match(r"\d{20}\.manifest", manifest) + if manifest.endswith(".manifest"): + assert re.match(r"\d{20}\.manifest", manifest) @pytest.mark.asyncio diff --git a/python/python/tests/test_fts.py b/python/python/tests/test_fts.py index 62f8f93d3..db83cb678 100644 --- a/python/python/tests/test_fts.py +++ b/python/python/tests/test_fts.py @@ -215,11 +215,12 @@ def test_reject_legacy_tantivy_index(table): @pytest.mark.parametrize("with_position", [True, False]) def test_create_inverted_index(table, with_position): - table.create_fts_index( - "text", - with_position=with_position, - name="custom_fts_index", - ) + with pytest.warns(DeprecationWarning, match="create_fts_index"): + table.create_fts_index( + "text", + with_position=with_position, + name="custom_fts_index", + ) indices = table.list_indices() fts_indices = [i for i in indices if i.index_type == "FTS"] assert any(i.name == "custom_fts_index" for i in fts_indices) diff --git a/python/python/tests/test_index.py b/python/python/tests/test_index.py index 0c71fc87b..18b845b2b 100644 --- a/python/python/tests/test_index.py +++ b/python/python/tests/test_index.py @@ -162,12 +162,13 @@ async def test_create_bitmap_index(some_table: AsyncTable): await some_table.create_index("data", config=Bitmap()) indices = await some_table.list_indices() assert len(indices) == 3 + # list_indices returns indices in alphabetical order by name assert indices[0].index_type == "Bitmap" - assert indices[0].columns == ["id"] + assert indices[0].columns == ["data"] assert indices[1].index_type == "Bitmap" - assert indices[1].columns == ["is_active"] + assert indices[1].columns == ["id"] assert indices[2].index_type == "Bitmap" - assert indices[2].columns == ["data"] + assert indices[2].columns == ["is_active"] index_name = indices[0].name stats = await some_table.index_stats(index_name) diff --git a/python/python/tests/test_lsm_write_spec.py b/python/python/tests/test_lsm_write_spec.py index b81153994..d9d75d3a9 100644 --- a/python/python/tests/test_lsm_write_spec.py +++ b/python/python/tests/test_lsm_write_spec.py @@ -40,16 +40,6 @@ def _make_table(tmp_path): def test_set_lsm_write_spec_validates(tmp_path): _db, table = _make_table(tmp_path) - # No PK set yet. - with pytest.raises(Exception, match="primary key"): - table.set_lsm_write_spec(LsmWriteSpec.bucket("id", 4)) - - table.set_unenforced_primary_key("id") - - # Column mismatch. - with pytest.raises(Exception, match="match"): - table.set_lsm_write_spec(LsmWriteSpec.bucket("v", 4)) - # Out-of-range num_buckets. with pytest.raises(Exception, match="num_buckets"): table.set_lsm_write_spec(LsmWriteSpec.bucket("id", 0)) @@ -70,7 +60,6 @@ def test_unset_lsm_write_spec(tmp_path): table.unset_lsm_write_spec() # Install a spec, then remove it; afterwards a fresh spec can be set. - table.set_unenforced_primary_key("id") table.set_lsm_write_spec(LsmWriteSpec.bucket("id", 4)) table.unset_lsm_write_spec() # A second unset errors — there is no spec left to remove. diff --git a/python/python/tests/test_merge_insert_lsm.py b/python/python/tests/test_merge_insert_lsm.py new file mode 100644 index 000000000..abdfb306d --- /dev/null +++ b/python/python/tests/test_merge_insert_lsm.py @@ -0,0 +1,196 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors + +"""Tests for the MemWAL LSM ``merge_insert`` dispatch.""" + +from datetime import timedelta + +import lancedb +import pyarrow as pa +import pytest +from lancedb._lancedb import LsmWriteSpec + +SCHEMA = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + pa.field("value", pa.int64(), nullable=False), + ] +) + +REGION_SCHEMA = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + pa.field("region", pa.utf8(), nullable=False), + ] +) + + +def _reader(ids): + batch = pa.RecordBatch.from_arrays( + [ + pa.array(ids, type=pa.int64()), + pa.array(list(range(len(ids))), type=pa.int64()), + ], + schema=SCHEMA, + ) + return pa.RecordBatchReader.from_batches(SCHEMA, [batch]) + + +def _region_reader(rows): + batch = pa.RecordBatch.from_arrays( + [ + pa.array([row[0] for row in rows], type=pa.int64()), + pa.array([row[1] for row in rows], type=pa.utf8()), + ], + schema=REGION_SCHEMA, + ) + return pa.RecordBatchReader.from_batches(REGION_SCHEMA, [batch]) + + +def _bucket_table(tmp_path): + """A table with ``id`` as the primary key and a single-bucket LSM spec.""" + db = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0)) + table = db.create_table("t", _reader([1, 2, 3])) + table.set_unenforced_primary_key("id") + # num_buckets = 1: every row routes to the single bucket. + table.set_lsm_write_spec(LsmWriteSpec.bucket("id", 1)) + return table + + +def test_lsm_merge_insert_bucket(tmp_path): + table = _bucket_table(tmp_path) + # Empty `on` defaults to the primary key. + result = ( + table.merge_insert([]) + .when_matched_update_all() + .when_not_matched_insert_all() + .execute(_reader([3, 4, 5])) + ) + # LSM path: rows go to the MemWAL, so only num_rows is populated. + assert result.num_rows == 3 + assert result.version == 0 + assert result.num_inserted_rows == 0 + assert result.num_updated_rows == 0 + + +def test_lsm_merge_insert_unsharded(tmp_path): + db = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0)) + table = db.create_table("t", _reader([1, 2, 3])) + table.set_unenforced_primary_key("id") + table.set_lsm_write_spec(LsmWriteSpec.unsharded()) + result = ( + table.merge_insert("id") + .when_matched_update_all() + .when_not_matched_insert_all() + .execute(_reader([10, 11, 12, 13])) + ) + assert result.num_rows == 4 + + +def test_lsm_merge_insert_identity(tmp_path): + db = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0)) + table = db.create_table("t", _region_reader([(1, "us"), (2, "us")])) + table.set_unenforced_primary_key("id") + table.set_lsm_write_spec(LsmWriteSpec.identity("region")) + # All rows share one identity value, so they route to one shard. + result = ( + table.merge_insert([]) + .when_matched_update_all() + .when_not_matched_insert_all() + .execute(_region_reader([(3, "us"), (4, "us")])) + ) + assert result.num_rows == 2 + + +def test_lsm_merge_insert_use_lsm_write_false(tmp_path): + table = _bucket_table(tmp_path) # rows id = 1, 2, 3 + # use_lsm_write(False) opts out: the standard path runs and commits. + result = ( + table.merge_insert("id") + .when_not_matched_insert_all() + .use_lsm_write(False) + .execute(_reader([3, 4, 5])) + ) + assert result.num_inserted_rows == 2 + assert table.count_rows() == 5 + + +def test_lsm_merge_insert_validate_single_shard_off(tmp_path): + table = _bucket_table(tmp_path) + result = ( + table.merge_insert([]) + .when_matched_update_all() + .when_not_matched_insert_all() + .validate_single_shard(False) + .execute(_reader([6, 7, 8])) + ) + assert result.num_rows == 3 + + +def test_lsm_merge_insert_use_lsm_write_true_requires_spec(tmp_path): + # A table with a primary key but no LSM write spec installed. + db = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0)) + table = db.create_table("t", _reader([1, 2, 3])) + table.set_unenforced_primary_key("id") + with pytest.raises(Exception, match="use_lsm_write"): + ( + table.merge_insert("id") + .when_matched_update_all() + .when_not_matched_insert_all() + .use_lsm_write(True) + .execute(_reader([4])) + ) + + +def test_lsm_merge_insert_rejects_on_not_primary_key(tmp_path): + table = _bucket_table(tmp_path) + with pytest.raises(Exception, match="primary key"): + ( + table.merge_insert("value") + .when_matched_update_all() + .when_not_matched_insert_all() + .execute(_reader([1])) + ) + + +def test_lsm_merge_insert_rejects_non_upsert(tmp_path): + table = _bucket_table(tmp_path) + # Insert-only (no when_matched_update_all) is not the upsert shape. + with pytest.raises(Exception, match="upsert"): + table.merge_insert([]).when_not_matched_insert_all().execute(_reader([4])) + + +def test_lsm_close_writers(tmp_path): + table = _bucket_table(tmp_path) + ( + table.merge_insert([]) + .when_matched_update_all() + .when_not_matched_insert_all() + .execute(_reader([7, 8])) + ) + table.close_lsm_writers() + # The writer reopens lazily on the next merge_insert. + result = ( + table.merge_insert([]) + .when_matched_update_all() + .when_not_matched_insert_all() + .execute(_reader([9])) + ) + assert result.num_rows == 1 + + +@pytest.mark.asyncio +async def test_async_lsm_merge_insert(tmp_path): + db = await lancedb.connect_async( + tmp_path, read_consistency_interval=timedelta(seconds=0) + ) + table = await db.create_table("t", _reader([1, 2, 3])) + await table.set_unenforced_primary_key("id") + await table.set_lsm_write_spec(LsmWriteSpec.bucket("id", 1)) + + builder = ( + table.merge_insert([]).when_matched_update_all().when_not_matched_insert_all() + ) + result = await builder.execute(_reader([3, 4, 5])) + assert result.num_rows == 3 + await table.close_lsm_writers() diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index 279f93658..4cc184c77 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -586,22 +586,25 @@ def test_table_create_indices(): # This is a smoke-test. table = db.create_table("test", [{"id": 1}]) - # Test create_scalar_index with custom name - table.create_scalar_index( - "id", wait_timeout=timedelta(seconds=2), name="custom_scalar_idx" - ) + # Test create_scalar_index with custom name (legacy method) + with pytest.warns(DeprecationWarning, match="create_scalar_index"): + table.create_scalar_index( + "id", wait_timeout=timedelta(seconds=2), name="custom_scalar_idx" + ) - # Test create_fts_index with custom name - table.create_fts_index( - "text", wait_timeout=timedelta(seconds=2), name="custom_fts_idx" - ) + # Test create_fts_index with custom name (legacy method) + with pytest.warns(DeprecationWarning, match="create_fts_index"): + table.create_fts_index( + "text", wait_timeout=timedelta(seconds=2), name="custom_fts_idx" + ) - # Test create_index with custom name - table.create_index( - vector_column_name="vector", - wait_timeout=timedelta(seconds=10), - name="custom_vector_idx", - ) + # Test create_index with custom name (legacy form: vector_column_name kwarg) + with pytest.warns(DeprecationWarning, match="create_index"): + table.create_index( + vector_column_name="vector", + wait_timeout=timedelta(seconds=10), + name="custom_vector_idx", + ) # Validate that the name parameter was passed correctly in requests assert len(received_requests) == 3 @@ -630,6 +633,98 @@ def test_table_create_indices(): table.drop_index("custom_fts_idx") +def test_remote_create_index_new_api(): + received_requests = [] + + def handler(request): + if request.path == "/v1/table/test/create_index/": + content_len = int(request.headers.get("Content-Length", 0)) + body = request.rfile.read(content_len) if content_len > 0 else b"" + received_requests.append(json.loads(body) if body else {}) + request.send_response(200) + request.end_headers() + elif request.path == "/v1/table/test/create/?mode=create": + request.send_response(200) + request.send_header("Content-Type", "application/json") + request.end_headers() + request.wfile.write(b"{}") + elif request.path == "/v1/table/test/describe/": + request.send_response(200) + request.send_header("Content-Type", "application/json") + request.end_headers() + request.wfile.write( + json.dumps( + dict( + version=1, + schema=dict( + fields=[ + dict(name="id", type={"type": "int64"}, nullable=False), + dict( + name="category", + type={"type": "string"}, + nullable=False, + ), + dict( + name="text", type={"type": "string"}, nullable=False + ), + dict( + name="vector", + type={ + "type": "fixed_size_list", + "fields": [ + dict( + name="item", + type={"type": "float"}, + nullable=True, + ) + ], + "length": 2, + }, + nullable=False, + ), + ] + ), + ) + ).encode() + ) + else: + request.send_response(404) + request.end_headers() + + from lancedb.index import BTree, FTS, IvfPq, IvfRq + + with mock_lancedb_connection(handler) as db: + table = db.create_table("test", [{"id": 1}]) + + # New API: column-first, config= kwarg. Should NOT emit DeprecationWarning. + import warnings as _warnings + + with _warnings.catch_warnings(): + _warnings.simplefilter("error", DeprecationWarning) + table.create_index("vector", config=IvfPq(distance_type="l2")) + table.create_index("category", config=BTree()) + table.create_index("text", config=FTS()) + # IvfRq via new API + table.create_index("vector", config=IvfRq(distance_type="l2")) + + # Legacy index_type="IVF_RQ" routes to IvfRq config under the hood. + with pytest.warns(DeprecationWarning, match="create_index"): + table.create_index( + vector_column_name="vector", + index_type="IVF_RQ", + num_partitions=8, + ) + + assert len(received_requests) == 5 + assert [req["column"] for req in received_requests] == [ + "vector", + "category", + "text", + "vector", + "vector", + ] + + def test_table_wait_for_index_timeout(): def handler(request): index_stats = dict( diff --git a/python/python/tests/test_rerankers.py b/python/python/tests/test_rerankers.py index 3d028cb3a..c886772bb 100644 --- a/python/python/tests/test_rerankers.py +++ b/python/python/tests/test_rerankers.py @@ -603,3 +603,89 @@ def test_cross_encoder_reranker_return_all(tmp_path): assert "_relevance_score" in result.column_names assert "_score" in result.column_names assert "_distance" in result.column_names + + +# --------------------------------------------------------------------------- +# Regression tests for LinearCombinationReranker scoring bugs (issue #3154) +# --------------------------------------------------------------------------- + + +def test_linear_combination_best_match_ranks_first(): + """ + The document that is BOTH the closest vector match AND the only FTS match + must rank first. Previously _combine_score subtracted from 1, inverting + the ranking so the worst document ranked highest. + """ + reranker = LinearCombinationReranker(weight=0.7, return_score="all") + + # rowid 0: perfect vector match, sole FTS match → should rank 1st + # rowid 1: mediocre vector, no FTS match + # rowid 2: bad vector, no FTS match + vector_results = pa.Table.from_pydict( + { + "_rowid": [0, 1, 2], + "_distance": [0.0, 0.5, 0.9], + } + ) + fts_results = pa.Table.from_pydict( + { + "_rowid": [0], + "_score": [1.0], + } + ) + + combined = reranker.merge_results(vector_results, fts_results, fill=1.0) + scores = dict( + zip( + combined["_rowid"].to_pylist(), + combined["_relevance_score"].to_pylist(), + ) + ) + + # rowid 0 must have the highest relevance score + assert scores[0] > scores[1], ( + f"Best match (rowid 0, score={scores[0]:.4f}) should beat " + f"mid match (rowid 1, score={scores[1]:.4f})" + ) + assert scores[1] > scores[2], ( + f"Mid match (rowid 1, score={scores[1]:.4f}) should beat " + f"bad match (rowid 2, score={scores[2]:.4f})" + ) + + +def test_linear_combination_missing_fts_is_penalised(): + """ + A document with no FTS match must score *lower* than a document that + has a mediocre FTS match, everything else being equal. Previously + missing-FTS entries used fill=1.0 directly, which gave them a reward + (via the 1-(...) inversion) instead of a penalty. + """ + reranker = LinearCombinationReranker(weight=0.5, return_score="all") + + vector_results = pa.Table.from_pydict( + { + "_rowid": [0, 1], + "_distance": [0.2, 0.2], # identical vector scores + } + ) + fts_results = pa.Table.from_pydict( + { + "_rowid": [0], # rowid 1 has no FTS match + "_score": [0.3], # small FTS score + } + ) + + combined = reranker.merge_results(vector_results, fts_results, fill=1.0) + scores = dict( + zip( + combined["_rowid"].to_pylist(), + combined["_relevance_score"].to_pylist(), + ) + ) + + # rowid 0 has a small FTS score; rowid 1 has none. + # Even a small FTS contribution should beat having none at all. + assert scores[0] > scores[1], ( + f"Document with FTS score (rowid 0, {scores[0]:.4f}) should beat " + f"document with no FTS match (rowid 1, {scores[1]:.4f})" + ) diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index ed4656d81..2a07c2df6 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -4,6 +4,7 @@ import os import sys +import warnings from datetime import date, datetime, timedelta from time import sleep from typing import List @@ -11,7 +12,7 @@ from unittest.mock import patch import lancedb from lancedb.dependencies import _PANDAS_AVAILABLE -from lancedb.index import HnswFlat, HnswPq, HnswSq, IvfPq +from lancedb.index import BTree, FTS, HnswFlat, HnswPq, HnswSq, IvfPq import numpy as np import polars as pl import pyarrow as pa @@ -928,7 +929,12 @@ def test_create_index_method(mock_create_index, mem_db: DBConnection): num_bits=4, ) mock_create_index.assert_called_with( - "vector", replace=True, config=expected_config, name=None, train=True + "vector", + replace=True, + config=expected_config, + wait_timeout=None, + name=None, + train=True, ) # Test with target_partition_size @@ -948,7 +954,12 @@ def test_create_index_method(mock_create_index, mem_db: DBConnection): target_partition_size=8192, ) mock_create_index.assert_called_with( - "vector", replace=True, config=expected_config, name=None, train=True + "vector", + replace=True, + config=expected_config, + wait_timeout=None, + name=None, + train=True, ) # target_partition_size has a default value, @@ -967,7 +978,12 @@ def test_create_index_method(mock_create_index, mem_db: DBConnection): num_bits=4, ) mock_create_index.assert_called_with( - "vector", replace=True, config=expected_config, name=None, train=True + "vector", + replace=True, + config=expected_config, + wait_timeout=None, + name=None, + train=True, ) table.create_index( @@ -978,7 +994,12 @@ def test_create_index_method(mock_create_index, mem_db: DBConnection): ) expected_config = HnswPq(distance_type="dot") mock_create_index.assert_called_with( - "my_vector", replace=False, config=expected_config, name=None, train=True + "my_vector", + replace=False, + config=expected_config, + wait_timeout=None, + name=None, + train=True, ) table.create_index( @@ -993,7 +1014,12 @@ def test_create_index_method(mock_create_index, mem_db: DBConnection): distance_type="cosine", sample_rate=0.1, m=29, ef_construction=10 ) mock_create_index.assert_called_with( - "my_vector", replace=True, config=expected_config, name=None, train=True + "my_vector", + replace=True, + config=expected_config, + wait_timeout=None, + name=None, + train=True, ) table.create_index( @@ -1008,7 +1034,12 @@ def test_create_index_method(mock_create_index, mem_db: DBConnection): distance_type="cosine", sample_rate=0.1, m=29, ef_construction=10 ) mock_create_index.assert_called_with( - "my_vector", replace=True, config=expected_config, name=None, train=True + "my_vector", + replace=True, + config=expected_config, + wait_timeout=None, + name=None, + train=True, ) @@ -1032,6 +1063,7 @@ def test_create_index_name_and_train_parameters( "vector", replace=True, config=expected_config, + wait_timeout=None, name="my_custom_index", train=True, ) @@ -1039,13 +1071,82 @@ def test_create_index_name_and_train_parameters( # Test with train=False table.create_index(vector_column_name="vector", train=False) mock_create_index.assert_called_with( - "vector", replace=True, config=expected_config, name=None, train=False + "vector", + replace=True, + config=expected_config, + wait_timeout=None, + name=None, + train=False, ) # Test with both name and train table.create_index(vector_column_name="vector", name="my_index_name", train=True) mock_create_index.assert_called_with( - "vector", replace=True, config=expected_config, name="my_index_name", train=True + "vector", + replace=True, + config=expected_config, + wait_timeout=None, + name="my_index_name", + train=True, + ) + + +@patch("lancedb.table.AsyncTable.create_index") +def test_create_index_legacy_emits_deprecation_warning( + mock_create_index, mem_db: DBConnection +): + table = mem_db.create_table( + "test", + data=[{"vector": [3.1, 4.1]}, {"vector": [5.9, 26.5]}], + ) + + with pytest.warns(DeprecationWarning, match="create_index"): + table.create_index(metric="l2", num_partitions=8, vector_column_name="vector") + + +@patch("lancedb.table.AsyncTable.create_index") +def test_create_index_new_api(mock_create_index, mem_db: DBConnection): + table = mem_db.create_table( + "test", + data=[ + {"vector": [3.1, 4.1], "category": "a", "text": "hello world"}, + {"vector": [5.9, 26.5], "category": "b", "text": "goodbye"}, + ], + ) + + # Vector index via new API should not warn + with warnings.catch_warnings(): + warnings.simplefilter("error", DeprecationWarning) + table.create_index("vector", config=IvfPq(distance_type="l2")) + mock_create_index.assert_called_with( + "vector", + replace=True, + config=IvfPq(distance_type="l2"), + wait_timeout=None, + name=None, + train=True, + ) + + # Scalar index via new API + table.create_index("category", config=BTree()) + mock_create_index.assert_called_with( + "category", + replace=True, + config=BTree(), + wait_timeout=None, + name=None, + train=True, + ) + + # FTS index via new API + table.create_index("text", config=FTS(with_position=True)) + mock_create_index.assert_called_with( + "text", + replace=True, + config=FTS(with_position=True), + wait_timeout=None, + name=None, + train=True, ) @@ -1861,8 +1962,9 @@ def test_create_scalar_index(mem_db: DBConnection): "my_table", data=test_data, ) - # Test with default name - table.create_scalar_index("x") + # Test with default name; confirm DeprecationWarning fires + with pytest.warns(DeprecationWarning, match="create_scalar_index"): + table.create_scalar_index("x") indices = table.list_indices() assert len(indices) == 1 scalar_index = indices[0] diff --git a/python/src/table.rs b/python/src/table.rs index 546bec555..302c2bb46 100644 --- a/python/src/table.rs +++ b/python/src/table.rs @@ -143,18 +143,20 @@ pub struct MergeResult { pub num_inserted_rows: u64, pub num_deleted_rows: u64, pub num_attempts: u32, + pub num_rows: u64, } #[pymethods] impl MergeResult { pub fn __repr__(&self) -> String { format!( - "MergeResult(version={}, num_updated_rows={}, num_inserted_rows={}, num_deleted_rows={}, num_attempts={})", + "MergeResult(version={}, num_updated_rows={}, num_inserted_rows={}, num_deleted_rows={}, num_attempts={}, num_rows={})", self.version, self.num_updated_rows, self.num_inserted_rows, self.num_deleted_rows, - self.num_attempts + self.num_attempts, + self.num_rows ) } } @@ -167,6 +169,7 @@ impl From for MergeResult { num_inserted_rows: result.num_inserted_rows, num_deleted_rows: result.num_deleted_rows, num_attempts: result.num_attempts, + num_rows: result.num_rows, } } } @@ -194,6 +197,12 @@ impl LsmWriteSpec { } /// Identity sharding — shard by the raw value of `column`. + /// + /// `column` must be a deterministic function of the unenforced primary + /// key: every row with a given primary key must always produce the same + /// `column` value, or upserts of that key can land in different shards + /// and a stale version can win. Typically `column` is the primary key + /// itself or a stable attribute of it. #[staticmethod] pub fn identity(column: String) -> Self { Self { @@ -933,6 +942,12 @@ impl Table { if let Some(use_index) = parameters.use_index { builder.use_index(use_index); } + if let Some(use_lsm_write) = parameters.use_lsm_write { + builder.use_lsm_write(use_lsm_write); + } + if let Some(validate_single_shard) = parameters.validate_single_shard { + builder.validate_single_shard(validate_single_shard); + } future_into_py(self_.py(), async move { let res = builder.execute(Box::new(batches)).await.infer_error()?; @@ -971,6 +986,13 @@ impl Table { }) } + pub fn close_lsm_writers(self_: PyRef<'_, Self>) -> PyResult> { + let inner = self_.inner_ref()?.clone(); + future_into_py(self_.py(), async move { + inner.close_lsm_writers().await.infer_error() + }) + } + pub fn uses_v2_manifest_paths(self_: PyRef<'_, Self>) -> PyResult> { let inner = self_.inner_ref()?.clone(); future_into_py(self_.py(), async move { @@ -1124,6 +1146,8 @@ pub struct MergeInsertParams { when_not_matched_by_source_condition: Option, timeout: Option, use_index: Option, + use_lsm_write: Option, + validate_single_shard: Option, } #[pyclass] diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 76a06e6b8..f25b5b140 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,2 +1,2 @@ [toolchain] -channel = "1.94.0" +channel = "1.95.0" diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index cccecf461..b42d8d235 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lancedb" -version = "0.30.0-beta.0" +version = "0.30.0-beta.1" edition.workspace = true description = "LanceDB: A serverless, low-latency vector database for AI applications" license.workspace = true @@ -75,7 +75,7 @@ reqwest = { version = "0.12.0", default-features = false, features = [ "stream", ], optional = true } http = { version = "1", optional = true } # Matching what is in reqwest -uuid = { version = "1.7.0", features = ["v4"] } +uuid = { version = "1.7.0", features = ["v4", "v5"] } polars-arrow = { version = ">=0.37,<0.40.0", optional = true } polars = { version = ">=0.37,<0.40.0", optional = true } hf-hub = { version = "0.4.1", optional = true, default-features = false, features = [ @@ -104,6 +104,7 @@ datafusion.workspace = true http-body = "1" # Matching reqwest rstest = "0.23.0" test-log = "0.2" +serial_test = "3" [features] diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index 8034c2a53..c1d475c9c 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -812,8 +812,7 @@ impl ConnectBuilder { self } - /// The interval at which to check for updates from other processes. This - /// only affects LanceDB OSS. + /// The interval at which to check for updates from other processes. /// /// If left unset, consistency is not checked. For maximum read /// performance, this is the default. For strong consistency, set this to @@ -825,8 +824,11 @@ impl ConnectBuilder { /// This only affects read operations. Write operations are always /// consistent. /// - /// LanceDB Cloud uses eventual consistency under the hood, and is not - /// currently configurable. + /// # Cost + /// + /// Stronger consistency is not free. The smaller the interval, the more + /// often each read pays the cost of checking for updates against object + /// storage, raising per-read latency and cost. pub fn read_consistency_interval( mut self, read_consistency_interval: std::time::Duration, @@ -886,6 +888,7 @@ impl ConnectBuilder { options.host_override, self.request.client_config, storage_options.into(), + self.request.read_consistency_interval, )?); Ok(Connection { internal, diff --git a/rust/lancedb/src/dataloader/permutation/shuffle.rs b/rust/lancedb/src/dataloader/permutation/shuffle.rs index 7cd27e342..b26db7bea 100644 --- a/rust/lancedb/src/dataloader/permutation/shuffle.rs +++ b/rust/lancedb/src/dataloader/permutation/shuffle.rs @@ -464,11 +464,9 @@ mod tests { let mut iter = ids.into_iter().map(|o| o.unwrap()); while let Some(first) = iter.next() { let rows_left_in_clump = if first == 4470 { 19 } else { 29 }; - let mut expected_next = first + 1; - for _ in 0..rows_left_in_clump { + for expected_next in (first + 1)..=(first + rows_left_in_clump) { let next = iter.next().unwrap(); assert_eq!(next, expected_next); - expected_next += 1; } } } diff --git a/rust/lancedb/src/remote/client.rs b/rust/lancedb/src/remote/client.rs index 7fd5c6497..6a44f7f1c 100644 --- a/rust/lancedb/src/remote/client.rs +++ b/rust/lancedb/src/remote/client.rs @@ -245,6 +245,9 @@ pub struct RestfulLanceDbClient { pub(crate) sender: S, pub(crate) id_delimiter: String, pub(crate) header_provider: Option>, + /// Connection-level read consistency interval. Drives the + /// `x-lancedb-min-timestamp` freshness header sent on read requests. + pub(crate) read_consistency_interval: Option, } impl std::fmt::Debug for RestfulLanceDbClient { @@ -338,6 +341,7 @@ impl RestfulLanceDbClient { host_override: Option, default_headers: HeaderMap, client_config: ClientConfig, + read_consistency_interval: Option, ) -> Result { // Get the timeouts let timeout = @@ -435,6 +439,7 @@ impl RestfulLanceDbClient { .clone() .unwrap_or("$".to_string()), header_provider: client_config.header_provider, + read_consistency_interval, }) } } @@ -840,6 +845,16 @@ pub mod test_utils { pub fn client_with_handler( handler: impl Fn(reqwest::Request) -> http::response::Response + Send + Sync + 'static, ) -> RestfulLanceDbClient + where + T: Into, + { + client_with_handler_and_interval(handler, None) + } + + pub fn client_with_handler_and_interval( + handler: impl Fn(reqwest::Request) -> http::response::Response + Send + Sync + 'static, + read_consistency_interval: Option, + ) -> RestfulLanceDbClient where T: Into, { @@ -857,6 +872,7 @@ pub mod test_utils { }, id_delimiter: "$".to_string(), header_provider: None, + read_consistency_interval, } } @@ -881,6 +897,7 @@ pub mod test_utils { }, id_delimiter: config.id_delimiter.unwrap_or_else(|| "$".to_string()), header_provider: config.header_provider, + read_consistency_interval: None, } } } @@ -888,8 +905,18 @@ pub mod test_utils { #[cfg(test)] mod tests { use super::*; + use serial_test::serial; use std::time::Duration; + // Serializes the env-var-mutating tests below: cargo test runs tests in + // parallel, but several of these tests read and write the same process- + // global env vars (`LANCEDB_USER_ID*`), so they would race without this. + static ENV_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(()); + + fn lock_env() -> std::sync::MutexGuard<'static, ()> { + ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner()) + } + #[test] fn test_timeout_config_default() { let config = TimeoutConfig::default(); @@ -1046,6 +1073,7 @@ mod tests { sender: Sender, id_delimiter: "+".to_string(), header_provider: Some(Arc::new(provider) as Arc), + read_consistency_interval: None, }; // Apply dynamic headers @@ -1081,6 +1109,7 @@ mod tests { sender: Sender, id_delimiter: "+".to_string(), header_provider: Some(Arc::new(provider) as Arc), + read_consistency_interval: None, }; // Apply dynamic headers @@ -1118,6 +1147,7 @@ mod tests { sender: Sender, id_delimiter: "+".to_string(), header_provider: Some(Arc::new(provider) as Arc), + read_consistency_interval: None, }; // Header provider errors should fail the request @@ -1143,7 +1173,9 @@ mod tests { } #[test] + #[serial(user_id_env)] fn test_resolve_user_id_none() { + let _guard = lock_env(); let config = ClientConfig::default(); // Clear env vars that might be set from other tests // SAFETY: This is only called in tests @@ -1155,7 +1187,9 @@ mod tests { } #[test] + #[serial(user_id_env)] fn test_resolve_user_id_from_env() { + let _guard = lock_env(); // SAFETY: This is only called in tests unsafe { std::env::set_var("LANCEDB_USER_ID", "env-user-id"); @@ -1169,7 +1203,9 @@ mod tests { } #[test] + #[serial(user_id_env)] fn test_resolve_user_id_from_env_key() { + let _guard = lock_env(); // SAFETY: This is only called in tests unsafe { std::env::remove_var("LANCEDB_USER_ID"); @@ -1189,7 +1225,9 @@ mod tests { } #[test] + #[serial(user_id_env)] fn test_resolve_user_id_direct_takes_precedence() { + let _guard = lock_env(); // SAFETY: This is only called in tests unsafe { std::env::set_var("LANCEDB_USER_ID", "env-user-id"); @@ -1206,7 +1244,9 @@ mod tests { } #[test] + #[serial(user_id_env)] fn test_resolve_user_id_empty_env_ignored() { + let _guard = lock_env(); // SAFETY: This is only called in tests unsafe { std::env::set_var("LANCEDB_USER_ID", ""); diff --git a/rust/lancedb/src/remote/db.rs b/rust/lancedb/src/remote/db.rs index dfe5d0c99..c11584157 100644 --- a/rust/lancedb/src/remote/db.rs +++ b/rust/lancedb/src/remote/db.rs @@ -206,6 +206,7 @@ impl RemoteDatabase { host_override: Option, client_config: ClientConfig, options: RemoteOptions, + read_consistency_interval: Option, ) -> Result { let parsed = super::client::parse_db_url(uri)?; let header_map = RestfulLanceDbClient::::default_headers( @@ -233,6 +234,7 @@ impl RemoteDatabase { host_override, header_map, client_config.clone(), + read_consistency_interval, )?; let table_cache = Cache::builder() diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index b50ae4e7a..dc16b61c6 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -25,7 +25,7 @@ use crate::table::MergeResult; use crate::table::Tags; use crate::table::UpdateResult; use crate::table::query::create_multi_vector_plan; -use crate::table::{AnyQuery, Filter, PreprocessingOutput, TableStatistics}; +use crate::table::{AnyQuery, Filter, Predicate, PreprocessingOutput, TableStatistics}; use crate::utils::background_cache::BackgroundCache; use crate::utils::{ resolve_arrow_field_path, supported_btree_data_type, supported_vector_data_type, @@ -62,15 +62,76 @@ use std::collections::HashMap; use std::io::Cursor; use std::pin::Pin; use std::sync::{Arc, Mutex}; -use std::time::Duration; +use std::time::{Duration, SystemTime}; use tokio::sync::RwLock; const REQUEST_TIMEOUT_HEADER: HeaderName = HeaderName::from_static("x-request-timeout-ms"); +const MIN_VERSION_HEADER: HeaderName = HeaderName::from_static("x-lancedb-min-version"); +const MIN_TIMESTAMP_HEADER: HeaderName = HeaderName::from_static("x-lancedb-min-timestamp"); const METRIC_TYPE_KEY: &str = "metric_type"; const INDEX_TYPE_KEY: &str = "index_type"; const SCHEMA_CACHE_TTL: Duration = Duration::from_secs(30); const SCHEMA_CACHE_REFRESH_WINDOW: Duration = Duration::from_secs(5); +/// Per-table state driving the freshness headers (`x-lancedb-min-version` and +/// `x-lancedb-min-timestamp`) sent on read requests. +#[derive(Debug, Default, Clone, Copy)] +struct FreshnessState { + /// Provides read-your-write within a single handle: writes that return a + /// version update this, and reads send it as `x-lancedb-min-version`. + min_version: Option, + /// Wall-clock time captured at the last [`BaseTable::checkout_latest`] + /// call. Subsequent reads send + /// `max(baseline, now - read_consistency_interval)` as + /// `x-lancedb-min-timestamp`. + /// + /// Without this, `checkout_latest()` would have no effect on subsequent + /// reads when `read_consistency_interval` is unset (the default): a + /// server-side cache could still serve a snapshot older than the moment + /// the user explicitly asked for "latest". The baseline forces the + /// server to skip any cache entry older than the checkout time, so the + /// `checkout_latest()` signal is preserved across reads on the same + /// handle regardless of the configured consistency interval. + checkout_baseline: Option, +} + +/// Snapshot of the headers that should be attached to a single read request. +#[derive(Debug, Default, Clone, Copy)] +struct FreshnessHeaders { + min_version: Option, + min_timestamp: Option, +} + +impl FreshnessHeaders { + fn apply(self, mut request: RequestBuilder) -> RequestBuilder { + if let Some(v) = self.min_version { + request = request.header(MIN_VERSION_HEADER, v.to_string()); + } + if let Some(ts) = self.min_timestamp { + let dt: chrono::DateTime = ts.into(); + request = request.header(MIN_TIMESTAMP_HEADER, dt.to_rfc3339()); + } + request + } +} + +fn compute_min_timestamp( + state: &FreshnessState, + interval: Option, + now: SystemTime, +) -> Option { + let interval_based = match interval { + None => None, + Some(d) if d.is_zero() => Some(now), + Some(d) => Some(now.checked_sub(d).unwrap_or(now)), + }; + match (interval_based, state.checkout_baseline) { + (None, None) => None, + (Some(t), None) | (None, Some(t)) => Some(t), + (Some(a), Some(b)) => Some(a.max(b)), + } +} + pub struct RemoteTags<'a, S: HttpSend = Sender> { inner: &'a RemoteTable, } @@ -80,8 +141,7 @@ impl Tags for RemoteTags<'_, S> { async fn list(&self) -> Result> { let request = self .inner - .client - .post(&format!("/v1/table/{}/tags/list/", self.inner.identifier)); + .post_read(&format!("/v1/table/{}/tags/list/", self.inner.identifier)); let (request_id, response) = self.inner.send(request, true).await?; let response = self .inner @@ -112,48 +172,13 @@ impl Tags for RemoteTags<'_, S> { } async fn get_version(&self, tag: &str) -> Result { - let request = self - .inner - .client - .post(&format!( - "/v1/table/{}/tags/version/", - self.inner.identifier - )) - .json(&serde_json::json!({ "tag": tag })); - - let (request_id, response) = self.inner.send(request, true).await?; - let response = self - .inner - .check_table_response(&request_id, response) - .await?; - - match response.text().await { - Ok(body) => { - let value: serde_json::Value = - serde_json::from_str(&body).map_err(|e| Error::Http { - source: format!("Failed to parse tag version: {}", e).into(), - request_id: request_id.clone(), - status_code: None, - })?; - - value - .get("version") - .and_then(|v| v.as_u64()) - .ok_or_else(|| Error::Http { - source: format!("Invalid tag version response: {}", body).into(), - request_id, - status_code: None, - }) - } - Err(err) => { - let status_code = err.status(); - Err(Error::Http { - source: Box::new(err), - request_id, - status_code, - }) - } - } + let request = self.inner.post_read(&format!( + "/v1/table/{}/tags/version/", + self.inner.identifier + )); + self.inner + .resolve_tag_version_with_request(tag, request) + .await } async fn create(&mut self, tag: &str, version: u64) -> Result<()> { @@ -215,6 +240,7 @@ pub struct RemoteTable { version: RwLock>, location: RwLock>, schema_cache: BackgroundCache, + freshness: Mutex, } impl std::fmt::Debug for RemoteTable { @@ -243,6 +269,7 @@ impl RemoteTable { version: RwLock::new(None), location: RwLock::new(None), schema_cache: BackgroundCache::new(SCHEMA_CACHE_TTL, SCHEMA_CACHE_REFRESH_WINDOW), + freshness: Mutex::new(FreshnessState::default()), } } @@ -252,12 +279,56 @@ impl RemoteTable { } async fn describe_version(&self, version: Option) -> Result { - let mut request = self - .client - .post(&format!("/v1/table/{}/describe/", self.identifier)); + let request = self.post_read(&format!("/v1/table/{}/describe/", self.identifier)); + self.describe_with_request(request, version).await + } + async fn resolve_tag_version_with_request( + &self, + tag: &str, + request: RequestBuilder, + ) -> Result { + let request = request.json(&serde_json::json!({ "tag": tag })); + + let (request_id, response) = self.send(request, true).await?; + let response = self.check_table_response(&request_id, response).await?; + + match response.text().await { + Ok(body) => { + let value: serde_json::Value = + serde_json::from_str(&body).map_err(|e| Error::Http { + source: format!("Failed to parse tag version: {}", e).into(), + request_id: request_id.clone(), + status_code: None, + })?; + + value + .get("version") + .and_then(|v| v.as_u64()) + .ok_or_else(|| Error::Http { + source: format!("Invalid tag version response: {}", body).into(), + request_id, + status_code: None, + }) + } + Err(err) => { + let status_code = err.status(); + Err(Error::Http { + source: Box::new(err), + request_id, + status_code, + }) + } + } + } + + async fn describe_with_request( + &self, + request: RequestBuilder, + version: Option, + ) -> Result { let body = serde_json::json!({ "version": version }); - request = request.json(&body); + let request = request.json(&body); let (request_id, response) = self.send(request, true).await?; @@ -711,14 +782,44 @@ impl RemoteTable { *read_guard } + /// Snapshot the freshness headers to attach to a single read request. + /// Computed at call time so that retries reuse the same snapshot. + fn snapshot_freshness_headers(&self) -> FreshnessHeaders { + let state = *self.freshness.lock().unwrap(); + FreshnessHeaders { + min_version: state.min_version, + min_timestamp: compute_min_timestamp( + &state, + self.client.read_consistency_interval, + SystemTime::now(), + ), + } + } + + /// Build a POST request and attach the read-freshness headers + /// (`x-lancedb-min-version`, `x-lancedb-min-timestamp`). + fn post_read(&self, uri: &str) -> RequestBuilder { + self.snapshot_freshness_headers() + .apply(self.client.post(uri)) + } + + /// Record a version returned by a write so subsequent reads can request at + /// least that version via `x-lancedb-min-version`. A returned `0` from a + /// backward-compatible old server is ignored. + fn track_write_version(&self, version: u64) { + if version == 0 { + return; + } + let mut state = self.freshness.lock().unwrap(); + state.min_version = Some(state.min_version.map_or(version, |v| v.max(version))); + } + async fn execute_query( &self, query: &AnyQuery, options: &QueryExecutionOptions, ) -> Result>>> { - let mut request = self - .client - .post(&format!("/v1/table/{}/query/", self.identifier)); + let mut request = self.post_read(&format!("/v1/table/{}/query/", self.identifier)); if let Some(timeout) = options.timeout { // Also send to server, so it can abort the query if it takes too long. @@ -824,9 +925,10 @@ async fn fetch_schema( identifier: &str, table_name: &str, version: Option, + freshness_headers: FreshnessHeaders, ) -> Result { - let request = client - .post(&format!("/v1/table/{}/describe/", identifier)) + let request = freshness_headers + .apply(client.post(&format!("/v1/table/{}/describe/", identifier))) .json(&serde_json::json!({ "version": version })); let (request_id, response) = client.send_with_retry(request, None, true).await?; @@ -874,7 +976,9 @@ mod test_utils { use super::*; use crate::remote::ClientConfig; use crate::remote::client::test_utils::client_with_handler; - use crate::remote::client::test_utils::{MockSender, client_with_handler_and_config}; + use crate::remote::client::test_utils::{ + MockSender, client_with_handler_and_config, client_with_handler_and_interval, + }; impl RemoteTable { pub fn new_mock(name: String, handler: F, version: Option) -> Self @@ -892,6 +996,30 @@ mod test_utils { version: RwLock::new(None), location: RwLock::new(None), schema_cache: BackgroundCache::new(SCHEMA_CACHE_TTL, SCHEMA_CACHE_REFRESH_WINDOW), + freshness: Mutex::new(FreshnessState::default()), + } + } + + pub fn new_mock_with_consistency_interval( + name: String, + handler: F, + read_consistency_interval: Option, + ) -> Self + where + F: Fn(reqwest::Request) -> http::Response + Send + Sync + 'static, + T: Into, + { + let client = client_with_handler_and_interval(handler, read_consistency_interval); + Self { + client, + name: name.clone(), + namespace: vec![], + identifier: name, + server_version: ServerVersion::default(), + version: RwLock::new(None), + location: RwLock::new(None), + schema_cache: BackgroundCache::new(SCHEMA_CACHE_TTL, SCHEMA_CACHE_REFRESH_WINDOW), + freshness: Mutex::new(FreshnessState::default()), } } @@ -923,6 +1051,7 @@ mod test_utils { version: RwLock::new(None), location: RwLock::new(None), schema_cache: BackgroundCache::new(SCHEMA_CACHE_TTL, SCHEMA_CACHE_REFRESH_WINDOW), + freshness: Mutex::new(FreshnessState::default()), } } } @@ -986,6 +1115,7 @@ impl RemoteTable { if output.overwrite { self.invalidate_schema_cache(); } + self.track_write_version(add_result.version); return Ok(add_result); } @@ -1023,6 +1153,7 @@ impl RemoteTable { if output.overwrite { self.invalidate_schema_cache(); } + self.track_write_version(result.version); return Ok(result); } Err(e) => { @@ -1139,8 +1270,13 @@ impl BaseTable for RemoteTable { self.describe().await.map(|desc| desc.version) } async fn checkout(&self, version: u64) -> Result<()> { - // check that the version exists - self.describe_version(Some(version)) + // Validate the version exists. The describe is sent without freshness + // headers so a stale `min_version` from a previous write doesn't ride + // along on an explicit time-travel request. + let request = self + .client + .post(&format!("/v1/table/{}/describe/", self.identifier)); + self.describe_with_request(request, Some(version)) .await .map_err(|e| match e { // try to map the error to a more user-friendly error telling them @@ -1156,6 +1292,10 @@ impl BaseTable for RemoteTable { *write_guard = Some(version); drop(write_guard); + // Explicit time-travel: drop any read-your-write / freshness + // constraints so the user sees exactly the requested version. + *self.freshness.lock().unwrap() = FreshnessState::default(); + // Invalidate schema cache since we're switching versions self.invalidate_schema_cache(); @@ -1166,6 +1306,13 @@ impl BaseTable for RemoteTable { *write_guard = None; drop(write_guard); + // Drop any per-handle write tracking; subsequent reads use the + // baseline timestamp captured now to guarantee freshness. + *self.freshness.lock().unwrap() = FreshnessState { + min_version: None, + checkout_baseline: Some(SystemTime::now()), + }; + // Invalidate schema cache since we're switching versions self.invalidate_schema_cache(); @@ -1186,9 +1333,7 @@ impl BaseTable for RemoteTable { } async fn list_versions(&self) -> Result> { - let request = self - .client - .post(&format!("/v1/table/{}/version/list/", self.identifier)); + let request = self.post_read(&format!("/v1/table/{}/version/list/", self.identifier)); let (request_id, response) = self.send(request, true).await?; let response = self.check_table_response(&request_id, response).await?; @@ -1221,19 +1366,25 @@ impl BaseTable for RemoteTable { let client = self.client.clone(); let identifier = self.identifier.clone(); let table_name = self.name.clone(); + let freshness_headers = self.snapshot_freshness_headers(); self.schema_cache .get(move || async move { - fetch_schema(&client, &identifier, &table_name, version).await + fetch_schema( + &client, + &identifier, + &table_name, + version, + freshness_headers, + ) + .await }) .await .map_err(unwrap_shared_error) } async fn count_rows(&self, filter: Option) -> Result { - let mut request = self - .client - .post(&format!("/v1/table/{}/count_rows/", self.identifier)); + let mut request = self.post_read(&format!("/v1/table/{}/count_rows/", self.identifier)); let version = self.current_version().await; @@ -1359,9 +1510,7 @@ impl BaseTable for RemoteTable { } async fn explain_plan(&self, query: &AnyQuery, verbose: bool) -> Result { - let base_request = self - .client - .post(&format!("/v1/table/{}/explain_plan/", self.identifier)); + let base_request = self.post_read(&format!("/v1/table/{}/explain_plan/", self.identifier)); let query_bodies = self.prepare_query_bodies(query).await?; let requests: Vec = query_bodies @@ -1408,9 +1557,7 @@ impl BaseTable for RemoteTable { query: &AnyQuery, _options: QueryExecutionOptions, ) -> Result { - let request = self - .client - .post(&format!("/v1/table/{}/analyze_plan/", self.identifier)); + let request = self.post_read(&format!("/v1/table/{}/analyze_plan/", self.identifier)); let query_bodies = self.prepare_query_bodies(query).await?; let requests: Vec = query_bodies @@ -1480,12 +1627,17 @@ impl BaseTable for RemoteTable { status_code: None, })?; + self.track_write_version(update_response.version); Ok(update_response) } - async fn delete(&self, predicate: &str) -> Result { + async fn delete(&self, predicate: Predicate<'_>) -> Result { self.check_mutable().await?; - let body = serde_json::json!({ "predicate": predicate }); + let predicate_sql = match predicate { + Predicate::String(s) => s.to_string(), + Predicate::Expr(expr) => expr_to_sql_string(expr)?, + }; + let body = serde_json::json!({ "predicate": predicate_sql }); let request = self .client .post(&format!("/v1/table/{}/delete/", self.identifier)) @@ -1506,6 +1658,7 @@ impl BaseTable for RemoteTable { request_id, status_code: None, })?; + self.track_write_version(delete_response.version); Ok(delete_response) } @@ -1652,6 +1805,7 @@ impl BaseTable for RemoteTable { num_inserted_rows: 0, num_updated_rows: 0, num_attempts: 0, + num_rows: 0, }); } @@ -1662,6 +1816,7 @@ impl BaseTable for RemoteTable { status_code: None, })?; + self.track_write_version(merge_insert_response.version); Ok(merge_insert_response) } @@ -1687,12 +1842,22 @@ impl BaseTable for RemoteTable { Ok(Box::new(RemoteTags { inner: self })) } async fn checkout_tag(&self, tag: &str) -> Result<()> { - let tags = self.tags().await?; - let version = tags.get_version(tag).await?; + // Resolve the tag without attaching freshness headers; a stale + // `min_version` from a previous write should not ride along on an + // explicit time-travel request. + let request = self + .client + .post(&format!("/v1/table/{}/tags/version/", self.identifier)); + let version = self.resolve_tag_version_with_request(tag, request).await?; + let mut write_guard = self.version.write().await; *write_guard = Some(version); drop(write_guard); + // Explicit time-travel: drop any read-your-write / freshness + // constraints so the user sees exactly the tagged version. + *self.freshness.lock().unwrap() = FreshnessState::default(); + // Invalidate schema cache since we're switching versions self.invalidate_schema_cache(); @@ -1743,6 +1908,7 @@ impl BaseTable for RemoteTable { })?; self.invalidate_schema_cache(); + self.track_write_version(result.version); Ok(result) } @@ -1797,6 +1963,7 @@ impl BaseTable for RemoteTable { })?; self.invalidate_schema_cache(); + self.track_write_version(result.version); Ok(result) } @@ -1824,15 +1991,14 @@ impl BaseTable for RemoteTable { })?; self.invalidate_schema_cache(); + self.track_write_version(result.version); Ok(result) } async fn list_indices(&self) -> Result> { // Make request to list the indices - let mut request = self - .client - .post(&format!("/v1/table/{}/index/list/", self.identifier)); + let mut request = self.post_read(&format!("/v1/table/{}/index/list/", self.identifier)); let version = self.current_version().await; let body = serde_json::json!({ "version": version }); request = request.json(&body); @@ -1896,7 +2062,7 @@ impl BaseTable for RemoteTable { } async fn index_stats(&self, index_name: &str) -> Result> { - let mut request = self.client.post(&format!( + let mut request = self.post_read(&format!( "/v1/table/{}/index/{}/stats/", self.identifier, index_name )); @@ -2008,9 +2174,7 @@ impl BaseTable for RemoteTable { } async fn stats(&self) -> Result { - let request = self - .client - .post(&format!("/v1/table/{}/stats/", self.identifier)); + let request = self.post_read(&format!("/v1/table/{}/stats/", self.identifier)); let (request_id, response) = self.send(request, true).await?; let response = self.check_table_response(&request_id, response).await?; let body = response.text().await.err_to_http(request_id.clone())?; @@ -2851,6 +3015,33 @@ mod tests { assert_eq!(result.version, if old_server { 0 } else { 43 }); } + #[tokio::test] + async fn test_delete_expr() { + use datafusion_expr::{col, lit}; + + let table = Table::new_with_handler("my_table", move |request| { + if request.url().path() == "/v1/table/my_table/delete/" { + assert_eq!(request.method(), "POST"); + + let body = request.body().unwrap().as_bytes().unwrap(); + let body: serde_json::Value = serde_json::from_slice(body).unwrap(); + assert!(body.get("predicate").unwrap().is_string()); + + http::Response::builder() + .status(200) + .body(r#"{"num_deleted_rows": 4, "version": 2}"#) + .unwrap() + } else { + panic!("Unexpected request path: {}", request.url().path()); + } + }); + + let expr = col("id").gt(lit(5)); + let result = table.delete(&expr).await.unwrap(); + assert_eq!(result.num_deleted_rows, 4); + assert_eq!(result.version, 2); + } + #[rstest] #[case(true)] #[case(false)] @@ -5974,4 +6165,299 @@ mod tests { assert_eq!(create_count.load(Ordering::SeqCst), 2); assert_eq!(abort_count.load(Ordering::SeqCst), 1); } + + // ---- Read freshness header tests ------------------------------------ + + #[test] + fn test_compute_min_timestamp_combines_baseline_and_interval() { + let now = SystemTime::now(); + let baseline = now - Duration::from_secs(60); + + // No interval, no baseline -> no header. + assert_eq!( + compute_min_timestamp(&FreshnessState::default(), None, now), + None + ); + + // Baseline only -> baseline. + let state = FreshnessState { + min_version: None, + checkout_baseline: Some(baseline), + }; + assert_eq!(compute_min_timestamp(&state, None, now), Some(baseline)); + + // ZERO interval, no baseline -> now. + assert_eq!( + compute_min_timestamp(&FreshnessState::default(), Some(Duration::ZERO), now), + Some(now) + ); + + // Positive interval, no baseline -> now - interval. + assert_eq!( + compute_min_timestamp( + &FreshnessState::default(), + Some(Duration::from_secs(10)), + now + ), + Some(now - Duration::from_secs(10)) + ); + + // Both: pick the more-recent (i.e. tighter) constraint. + // baseline = now-60, now-interval = now-10. now-10 is newer. + let state = FreshnessState { + min_version: None, + checkout_baseline: Some(baseline), + }; + assert_eq!( + compute_min_timestamp(&state, Some(Duration::from_secs(10)), now), + Some(now - Duration::from_secs(10)) + ); + + // Both, baseline newer: pick baseline. + let recent_baseline = now - Duration::from_secs(5); + let state = FreshnessState { + min_version: None, + checkout_baseline: Some(recent_baseline), + }; + assert_eq!( + compute_min_timestamp(&state, Some(Duration::from_secs(60)), now), + Some(recent_baseline) + ); + } + + /// Allowed slop when comparing a header timestamp against a locally + /// captured wall-clock bound. Tests run fast enough that 1s is plenty. + const FRESHNESS_TOLERANCE: Duration = Duration::from_secs(1); + + fn capturing_handler( + body_for: F, + ) -> ( + impl Fn(reqwest::Request) -> http::Response + Clone + Send + Sync + 'static, + Arc>>, + ) + where + F: Fn(&str) -> String + Clone + Send + Sync + 'static, + { + let captured = Arc::new(std::sync::Mutex::new(None)); + let captured_c = captured.clone(); + let handler = move |request: reqwest::Request| { + *captured_c.lock().unwrap() = Some(request.headers().clone()); + let path = request.url().path().to_string(); + http::Response::builder() + .status(200) + .body(body_for(&path)) + .unwrap() + }; + (handler, captured) + } + + fn parse_min_timestamp(headers: &http::HeaderMap) -> SystemTime { + let value = headers + .get("x-lancedb-min-timestamp") + .expect("expected x-lancedb-min-timestamp header") + .to_str() + .unwrap(); + chrono::DateTime::parse_from_rfc3339(value) + .unwrap() + .with_timezone(&chrono::Utc) + .into() + } + + #[tokio::test] + async fn test_freshness_default_sends_no_headers() { + let (handler, captured) = capturing_handler(|_| "42".to_string()); + let table = Table::new_with_handler("my_table", handler); + + let _ = table.count_rows(None).await.unwrap(); + + let headers = captured.lock().unwrap().clone().unwrap(); + assert!(!headers.contains_key("x-lancedb-min-timestamp")); + assert!(!headers.contains_key("x-lancedb-min-version")); + } + + #[tokio::test] + async fn test_freshness_zero_interval_sends_now() { + let (handler, captured) = capturing_handler(|_| "42".to_string()); + let table = + Table::new_with_handler_and_interval("my_table", handler, Some(Duration::from_secs(0))); + + let before = SystemTime::now(); + table.count_rows(None).await.unwrap(); + let after = SystemTime::now(); + + let headers = captured.lock().unwrap().clone().unwrap(); + let sent = parse_min_timestamp(&headers); + assert!( + sent >= before - FRESHNESS_TOLERANCE && sent <= after + FRESHNESS_TOLERANCE, + "expected timestamp roughly equal to wall clock" + ); + assert!(!headers.contains_key("x-lancedb-min-version")); + } + + #[tokio::test] + async fn test_freshness_positive_interval_sends_now_minus_interval() { + let (handler, captured) = capturing_handler(|_| "42".to_string()); + let interval = Duration::from_secs(30); + let table = Table::new_with_handler_and_interval("my_table", handler, Some(interval)); + + let before = SystemTime::now(); + table.count_rows(None).await.unwrap(); + let after = SystemTime::now(); + + let headers = captured.lock().unwrap().clone().unwrap(); + let sent = parse_min_timestamp(&headers); + assert!( + sent >= before - interval - FRESHNESS_TOLERANCE + && sent <= after - interval + FRESHNESS_TOLERANCE, + "expected timestamp roughly equal to now - interval" + ); + } + + #[tokio::test] + async fn test_freshness_checkout_latest_sets_baseline() { + let (handler, captured) = capturing_handler(|path| match path { + "/v1/table/my_table/count_rows/" => "42".to_string(), + _ => panic!("unexpected path: {}", path), + }); + // No interval — only the baseline should drive the timestamp. + let table = Table::new_with_handler_and_interval("my_table", handler, None); + + let before_checkout = SystemTime::now(); + table.checkout_latest().await.unwrap(); + let after_checkout = SystemTime::now(); + + table.count_rows(None).await.unwrap(); + + let headers = captured.lock().unwrap().clone().unwrap(); + let sent = parse_min_timestamp(&headers); + assert!( + sent >= before_checkout - FRESHNESS_TOLERANCE + && sent <= after_checkout + FRESHNESS_TOLERANCE, + "expected timestamp captured at checkout_latest() time" + ); + assert!(!headers.contains_key("x-lancedb-min-version")); + } + + #[tokio::test] + async fn test_freshness_min_version_tracked_after_write() { + let (handler, captured) = capturing_handler(|path| match path { + "/v1/table/my_table/update/" => r#"{"rows_updated":1,"version":7}"#.to_string(), + "/v1/table/my_table/count_rows/" => "42".to_string(), + _ => panic!("unexpected path: {}", path), + }); + let table = Table::new_with_handler("my_table", handler); + + let _ = table.update().column("a", "a + 1").execute().await.unwrap(); + // Update headers also pass through captured; reset by reading after. + table.count_rows(None).await.unwrap(); + + let headers = captured.lock().unwrap().clone().unwrap(); + assert_eq!( + headers + .get("x-lancedb-min-version") + .unwrap() + .to_str() + .unwrap(), + "7" + ); + } + + /// Like `capturing_handler`, but keeps a per-path snapshot of the headers + /// from every request so tests can assert on a specific endpoint. + #[allow(clippy::type_complexity)] + fn path_capturing_handler( + body_for: F, + ) -> ( + impl Fn(reqwest::Request) -> http::Response + Clone + Send + Sync + 'static, + Arc>>, + ) + where + F: Fn(&str) -> String + Clone + Send + Sync + 'static, + { + let captured: Arc>> = + Arc::new(std::sync::Mutex::new(HashMap::new())); + let captured_c = captured.clone(); + let handler = move |request: reqwest::Request| { + let path = request.url().path().to_string(); + captured_c + .lock() + .unwrap() + .insert(path.clone(), request.headers().clone()); + http::Response::builder() + .status(200) + .body(body_for(&path)) + .unwrap() + }; + (handler, captured) + } + + #[tokio::test] + async fn test_freshness_checkout_validation_sends_no_min_version() { + // After a write bumps min_version, calling checkout(v) must not let + // that stale header ride along on the validating /describe/ request. + let (handler, captured) = path_capturing_handler(|path| match path { + "/v1/table/my_table/update/" => r#"{"rows_updated":1,"version":7}"#.to_string(), + "/v1/table/my_table/describe/" => r#"{"version":5,"schema":{"fields":[]}}"#.to_string(), + _ => panic!("unexpected path: {}", path), + }); + let table = Table::new_with_handler("my_table", handler); + + table.update().column("a", "a + 1").execute().await.unwrap(); + table.checkout(5).await.unwrap(); + + let captured = captured.lock().unwrap(); + let describe_headers = captured + .get("/v1/table/my_table/describe/") + .expect("describe should have been called by checkout(v)"); + assert!( + !describe_headers.contains_key("x-lancedb-min-version"), + "checkout(v) describe must not carry stale min_version", + ); + assert!(!describe_headers.contains_key("x-lancedb-min-timestamp")); + } + + #[tokio::test] + async fn test_freshness_checkout_tag_resolve_sends_no_min_version() { + // Same invariant for checkout_tag: the tag-resolve request must not + // pick up a stale min_version from a prior write. + let (handler, captured) = path_capturing_handler(|path| match path { + "/v1/table/my_table/update/" => r#"{"rows_updated":1,"version":7}"#.to_string(), + "/v1/table/my_table/tags/version/" => r#"{"version":5}"#.to_string(), + _ => panic!("unexpected path: {}", path), + }); + let table = Table::new_with_handler("my_table", handler); + + table.update().column("a", "a + 1").execute().await.unwrap(); + table.checkout_tag("v_initial").await.unwrap(); + + let captured = captured.lock().unwrap(); + let resolve_headers = captured + .get("/v1/table/my_table/tags/version/") + .expect("tags/version should have been called by checkout_tag"); + assert!( + !resolve_headers.contains_key("x-lancedb-min-version"), + "checkout_tag resolve must not carry stale min_version", + ); + assert!(!resolve_headers.contains_key("x-lancedb-min-timestamp")); + } + + #[tokio::test] + async fn test_freshness_checkout_clears_min_version() { + let (handler, captured) = capturing_handler(|path| match path { + "/v1/table/my_table/update/" => r#"{"rows_updated":1,"version":7}"#.to_string(), + // checkout(5) needs to describe version 5 first + "/v1/table/my_table/describe/" => r#"{"version":5,"schema":{"fields":[]}}"#.to_string(), + "/v1/table/my_table/count_rows/" => "42".to_string(), + _ => panic!("unexpected path: {}", path), + }); + let table = Table::new_with_handler("my_table", handler); + + table.update().column("a", "a + 1").execute().await.unwrap(); + table.checkout(5).await.unwrap(); + table.count_rows(None).await.unwrap(); + + let headers = captured.lock().unwrap().clone().unwrap(); + assert!(!headers.contains_key("x-lancedb-min-version")); + assert!(!headers.contains_key("x-lancedb-min-timestamp")); + } } diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 03f967e6e..ca34bbdf3 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -89,7 +89,6 @@ use futures::future::join_all; pub use lance::dataset::refs::{TagContents, Tags as LanceTags}; pub use lance::dataset::scanner::DatasetRecordBatchStream; use lance::dataset::statistics::DatasetStatisticsExt; -use lance_index::frag_reuse::FRAG_REUSE_INDEX_NAME; pub use lance_index::optimize::OptimizeOptions; pub use optimize::{CompactionOptions, OptimizeAction, OptimizeStats}; pub use schema_evolution::{AddColumnsResult, AlterColumnsResult, DropColumnsResult}; @@ -253,6 +252,36 @@ pub enum Filter { Datafusion(Expr), } +/// A predicate for filtering rows in delete operations. +/// +/// Accepts either a SQL string or a DataFusion [`Expr`]. Use the [`From`] +/// implementations to convert from `&str` or `&Expr` automatically. +/// See [`Table::delete`] for usage examples. +pub enum Predicate<'a> { + /// A SQL predicate string + String(&'a str), + /// A DataFusion logical expression + Expr(&'a Expr), +} + +impl<'a> From<&'a str> for Predicate<'a> { + fn from(s: &'a str) -> Self { + Predicate::String(s) + } +} + +impl<'a> From<&'a String> for Predicate<'a> { + fn from(s: &'a String) -> Self { + Predicate::String(s.as_str()) + } +} + +impl<'a> From<&'a Expr> for Predicate<'a> { + fn from(e: &'a Expr) -> Self { + Predicate::Expr(e) + } +} + #[async_trait] pub trait Tags: Send + Sync { /// List the tags of the table. @@ -282,17 +311,15 @@ pub use self::merge::MergeResult; /// date) and [`LsmWriteSpec::with_writer_config_defaults`] (default /// `ShardWriter` configuration recorded in the MemWAL index). /// -/// All variants require the table to have an unenforced primary key. -/// /// Install a spec with [`Table::set_lsm_write_spec`] and remove it with /// [`Table::unset_lsm_write_spec`]. The actual `merge_insert` dispatch /// onto the MemWAL writer is a follow-up. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum LsmWriteSpec { - /// Hash-bucket sharding by the unenforced primary key column. + /// Hash-bucket sharding by a scalar column. /// - /// `column` must equal the table's currently-set single-column - /// unenforced primary key. `num_buckets` must be in `[1, 1024]`. + /// `column` must be a non-nested column with a supported scalar type. + /// `num_buckets` must be in `[1, 1024]`. /// Iceberg-compatible Murmur3-x86-32 (seed 0) is used so each row's /// `bucket(column, num_buckets)` value is stable across processes. Bucket { @@ -339,6 +366,14 @@ impl LsmWriteSpec { /// Construct an identity-sharding spec (shard by the raw value of /// `column`) with no maintained indexes. + /// + /// `column` must be a deterministic function of the unenforced primary + /// key: every row with a given primary key must always produce the same + /// `column` value. MemWAL dedups upserts by primary key but tracks + /// generations per shard, so if the same key is written with two + /// different `column` values its versions land in different shards and a + /// stale value can win. Typically `column` is the primary key itself, or + /// a stable attribute of it (e.g. a tenant id). pub fn identity(column: impl Into) -> Self { Self::Identity { column: column.into(), @@ -491,8 +526,8 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync { /// Add new records to the table. async fn add(&self, add: AddDataBuilder) -> Result; - /// Delete rows from the table. - async fn delete(&self, predicate: &str) -> Result; + /// Delete rows from the table matching the given [`Predicate`]. + async fn delete(&self, predicate: Predicate<'_>) -> Result; /// Update rows in the table. async fn update(&self, update: UpdateBuilder) -> Result; /// Create an index on the provided column(s). @@ -553,6 +588,13 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync { message: "unset_lsm_write_spec is not supported on this table type".into(), }) } + /// Drain and close any cached MemWAL shard writers for this table. + /// + /// The default implementation is a no-op; table types that maintain + /// MemWAL shard writers override it. + async fn close_lsm_writers(&self) -> Result<()> { + Ok(()) + } /// Gets the table tag manager. async fn tags(&self) -> Result>; /// Optimize the dataset. @@ -656,6 +698,30 @@ mod test_utils { } } + pub fn new_with_handler_and_interval( + name: impl Into, + handler: impl Fn(reqwest::Request) -> http::Response + Clone + Send + Sync + 'static, + read_consistency_interval: Option, + ) -> Self + where + T: Into, + { + let inner = Arc::new( + crate::remote::table::RemoteTable::new_mock_with_consistency_interval( + name.into(), + handler.clone(), + read_consistency_interval, + ), + ); + let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler)); + Self { + inner, + database: Some(database), + // Registry is unused. + embedding_registry: Arc::new(MemoryRegistry::new()), + } + } + pub fn new_with_handler_version( name: impl Into, version: semver::Version, @@ -860,7 +926,8 @@ impl Table { /// Delete the rows from table that match the predicate. /// /// # Arguments - /// - `predicate` - The SQL predicate string to filter the rows to be deleted. + /// - `predicate` - A SQL string (`&str`) or DataFusion expression (`&Expr`) + /// that selects the rows to delete. /// /// # Example /// @@ -869,6 +936,7 @@ impl Table { /// # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch, /// # RecordBatchIterator, Int32Array}; /// # use arrow_schema::{Schema, Field, DataType}; + /// use datafusion_expr::{col, lit}; /// # tokio::runtime::Runtime::new().unwrap().block_on(async { /// let tmpdir = tempfile::tempdir().unwrap(); /// let db = lancedb::connect(tmpdir.path().to_str().unwrap()) @@ -898,11 +966,17 @@ impl Table { /// .execute() /// .await /// .unwrap(); + /// + /// // Using a SQL string: /// tbl.delete("id > 5").await.unwrap(); + /// + /// // Using a DataFusion expression: + /// let expr = col("id").lt(lit(4)); + /// tbl.delete(&expr).await.unwrap(); /// # }); /// ``` - pub async fn delete(&self, predicate: &str) -> Result { - self.inner.delete(predicate).await + pub async fn delete(&self, predicate: impl Into>) -> Result { + self.inner.delete(predicate.into()).await } /// Create an index on the provided column(s). @@ -1298,21 +1372,15 @@ impl Table { /// /// [`LsmWriteSpec`] chooses one of three sharding strategies: /// - /// - [`LsmWriteSpec::bucket`] — hash-bucket writes by the single-column - /// unenforced primary key. + /// - [`LsmWriteSpec::bucket`] — hash-bucket writes by a scalar column. /// - [`LsmWriteSpec::identity`] — shard by the raw value of a scalar column. /// - [`LsmWriteSpec::unsharded`] — route every write to a single shard. /// - /// All variants require the table to have an unenforced primary key - /// ([`Table::set_unenforced_primary_key`]); bucket sharding additionally - /// requires it to be the single column being bucketed. - /// /// # Example /// /// ``` /// # use lancedb::table::{LsmWriteSpec, Table}; /// # async fn example(table: &Table) -> Result<(), Box> { - /// table.set_unenforced_primary_key(["id"]).await?; /// table /// .set_lsm_write_spec( /// LsmWriteSpec::bucket("id", 16).with_maintained_indexes(["id_idx"]), @@ -1333,6 +1401,16 @@ impl Table { self.inner.unset_lsm_write_spec().await } + /// Drain and close any cached MemWAL shard writers held for this table. + /// + /// When an [`LsmWriteSpec`] is installed, `merge_insert` opens MemWAL shard + /// writers and caches them for reuse across calls. This closes them, + /// flushing pending data; writers reopen lazily on the next `merge_insert`. + /// It is a no-op when no writers are cached. + pub async fn close_lsm_writers(&self) -> Result<()> { + self.inner.close_lsm_writers().await + } + /// Retrieve the version of the table /// /// LanceDb supports versioning. Every operation that modifies the table increases @@ -2776,9 +2854,12 @@ impl BaseTable for NativeTable { merge::lsm::unset_lsm_write_spec(self).await } + async fn close_lsm_writers(&self) -> Result<()> { + merge::lsm::close_lsm_writers(self).await + } + /// Delete rows from the table - async fn delete(&self, predicate: &str) -> Result { - // Delegate to the submodule implementation + async fn delete(&self, predicate: Predicate<'_>) -> Result { delete::execute_delete(self, predicate).await } @@ -2811,71 +2892,32 @@ impl BaseTable for NativeTable { async fn list_indices(&self) -> Result> { let dataset = self.dataset.get().await?; - let indices = dataset.load_indices().await?; - let results = futures::stream::iter(indices.as_slice()) - .then(|idx| async { - // skip Lance internal indexes - if idx.name == FRAG_REUSE_INDEX_NAME { - return None; - } - - let stats = match dataset.index_statistics(idx.name.as_str()).await { - Ok(stats) => stats, - Err(e) => { - log::warn!( - "Failed to get statistics for index {} ({}): {}", - idx.name, - idx.uuid, - e - ); - return None; - } - }; - - let stats: serde_json::Value = match serde_json::from_str(&stats) { - Ok(stats) => stats, - Err(e) => { - log::warn!( - "Failed to deserialize index statistics for index {} ({}): {}", - idx.name, - idx.uuid, - e - ); - return None; - } - }; - - let Some(index_type) = stats.get("index_type").and_then(|v| v.as_str()) else { - log::warn!( - "Index statistics was missing 'index_type' field for index {} ({})", - idx.name, - idx.uuid - ); - return None; - }; - - let index_type: crate::index::IndexType = match index_type.parse() { + let indices = dataset + .describe_indices(None) + .await? + .into_iter() + .filter_map(|idx_desc| { + let index_type: crate::index::IndexType = match idx_desc.index_type().parse() { Ok(index_type) => index_type, Err(e) => { log::warn!( - "Failed to parse index type for index {} ({}): {}", - idx.name, - idx.uuid, + "Failed to parse index type for index {}: {}", + idx_desc.name(), e ); return None; } }; - let mut columns = Vec::with_capacity(idx.fields.len()); - for field_id in &idx.fields { - let field_path = match dataset.schema().field_path(*field_id) { + let field_ids = idx_desc.field_ids(); + let mut columns = Vec::with_capacity(field_ids.len()); + for field_id in field_ids { + let field_path = match dataset.schema().field_path(*field_id as i32) { Ok(field_path) => field_path, Err(e) => { log::warn!( - "Failed to resolve field path for index {} ({}) field id {}: {}", - idx.name, - idx.uuid, + "Failed to resolve field path for index {} field id {}: {}", + idx_desc.name(), field_id, e ); @@ -2885,17 +2927,14 @@ impl BaseTable for NativeTable { columns.push(field_path); } - let name = idx.name.clone(); Some(IndexConfig { + name: idx_desc.name().to_string(), index_type, columns, - name, }) }) - .collect::>() - .await; - - Ok(results.into_iter().flatten().collect()) + .collect(); + Ok(indices) } async fn uri(&self) -> Result { @@ -3005,11 +3044,12 @@ impl BaseTable for NativeTable { let p99 = *sorted_sizes.get(num_fragments * 99 / 100).unwrap_or(&0); let min = sorted_sizes.first().copied().unwrap_or(0); let max = sorted_sizes.last().copied().unwrap_or(0); - let mean = if num_fragments == 0 { - 0 - } else { - sorted_sizes.iter().copied().sum::() / num_fragments - }; + let mean = sorted_sizes + .iter() + .copied() + .sum::() + .checked_div(num_fragments) + .unwrap_or(0); let frag_stats = FragmentStatistics { num_fragments, @@ -4009,26 +4049,27 @@ mod tests { let index_configs = table.list_indices().await.unwrap(); assert_eq!(index_configs.len(), 5); + // list_indices returns indices in alphabetical order by name let mut configs_iter = index_configs.into_iter(); let index = configs_iter.next().unwrap(); assert_eq!(index.index_type, crate::index::IndexType::Bitmap); assert_eq!(index.columns, vec!["category".to_string()]); - let index = configs_iter.next().unwrap(); - assert_eq!(index.index_type, crate::index::IndexType::Bitmap); - assert_eq!(index.columns, vec!["is_active".to_string()]); - let index = configs_iter.next().unwrap(); assert_eq!(index.index_type, crate::index::IndexType::Bitmap); assert_eq!(index.columns, vec!["data".to_string()]); let index = configs_iter.next().unwrap(); assert_eq!(index.index_type, crate::index::IndexType::Bitmap); - assert_eq!(index.columns, vec!["large_data".to_string()]); + assert_eq!(index.columns, vec!["is_active".to_string()]); let index = configs_iter.next().unwrap(); assert_eq!(index.index_type, crate::index::IndexType::Bitmap); assert_eq!(index.columns, vec!["large_category".to_string()]); + + let index = configs_iter.next().unwrap(); + assert_eq!(index.index_type, crate::index::IndexType::Bitmap); + assert_eq!(index.columns, vec!["large_data".to_string()]); } #[tokio::test] @@ -4600,21 +4641,6 @@ mod tests { .unwrap(); let table = conn.create_table("t", reader).execute().await.unwrap(); - // Reject when no PK is set. - let err = table - .set_lsm_write_spec(LsmWriteSpec::bucket("id", 4)) - .await - .expect_err("should reject without PK"); - assert!(matches!(err, Error::Lance { .. }), "got {:?}", err); - - // Set PK, then a mismatched column on the spec must be rejected. - table.set_unenforced_primary_key(["id"]).await.unwrap(); - let err = table - .set_lsm_write_spec(LsmWriteSpec::bucket("name", 4)) - .await - .expect_err("should reject column != PK"); - assert!(matches!(err, Error::Lance { .. }), "got {:?}", err); - // Reject num_buckets out of range. for bad in [0u32, 1025] { let err = table @@ -4680,9 +4706,6 @@ mod tests { .unwrap(); let table = conn.create_table("t", reader).execute().await.unwrap(); - // Lance's MemWAL still requires *some* unenforced primary key on - // the dataset; Unsharded just skips the per-row hashing step. - table.set_unenforced_primary_key(["id"]).await.unwrap(); table .set_lsm_write_spec(LsmWriteSpec::unsharded()) .await @@ -4729,7 +4752,6 @@ mod tests { .unwrap(); let table = conn.create_table("t", reader).execute().await.unwrap(); - table.set_unenforced_primary_key(["id"]).await.unwrap(); table .set_lsm_write_spec( LsmWriteSpec::identity("region") @@ -4785,7 +4807,6 @@ mod tests { table.unset_lsm_write_spec().await.unwrap_err(); // Install a spec, then unset it. - table.set_unenforced_primary_key(["id"]).await.unwrap(); table .set_lsm_write_spec(LsmWriteSpec::bucket("id", 4)) .await diff --git a/rust/lancedb/src/table/add_data.rs b/rust/lancedb/src/table/add_data.rs index be8ec28ad..6c92e1b43 100644 --- a/rust/lancedb/src/table/add_data.rs +++ b/rust/lancedb/src/table/add_data.rs @@ -982,4 +982,105 @@ mod tests { table2.add(struct_batch).execute().await.unwrap(); assert_eq!(table2.count_rows(None).await.unwrap(), 2); } + + /// Regression test: appending `arrow.json` (PyArrow `pa.json_()`) data into a table + /// whose schema was created with `pa.json_()` (internally stored as `lance.json`, backed + /// by `LargeBinary`) must succeed without a schema-mismatch error. + /// + /// Previously `build_field_exprs` would attempt a `Utf8 → LargeBinary` DataFusion cast, + /// which produced a field whose Arrow extension metadata still read `arrow.json` instead + /// of `lance.json`. Lance-core then rejected the append with + /// `"json vs large_binary" schema mismatch`. + /// + /// PyArrow's `pa.json_()` may be backed by either `Utf8` or `LargeUtf8` depending on the + /// constructor used, so the test is parameterized over the input backing type. + #[rstest::rstest] + #[case::utf8(DataType::Utf8)] + #[case::large_utf8(DataType::LargeUtf8)] + #[tokio::test] + async fn test_add_arrow_json_into_lance_json_table(#[case] input_type: DataType) { + use arrow_array::{Array, cast::AsArray}; + use lance_arrow::ARROW_EXT_NAME_KEY; + use lance_arrow::json::{ARROW_JSON_EXT_NAME, JSON_EXT_NAME}; + + // Build a table whose "data" column is lance.json (LargeBinary + + // ARROW:extension:name = "lance.json"). + let lance_json_field = lance_arrow::json::json_field("data", true); + let table_schema = Arc::new(Schema::new(vec![lance_json_field])); + + let db = connect("memory://").execute().await.unwrap(); + let table = db + .create_empty_table("json_test", table_schema) + .execute() + .await + .unwrap(); + + // Sanity-check the stored schema. + let stored_field = table.schema().await.unwrap(); + let data_field = stored_field.field_with_name("data").unwrap(); + assert_eq!(data_field.data_type(), &DataType::LargeBinary); + assert_eq!( + data_field + .metadata() + .get(ARROW_EXT_NAME_KEY) + .map(|s| s.as_str()), + Some(JSON_EXT_NAME), + ); + + // Build an arrow.json input field (Utf8/LargeUtf8 + arrow.json extension). + // This is what PyArrow produces for pa.json_() arrays. + let arrow_json_metadata = std::collections::HashMap::from([( + ARROW_EXT_NAME_KEY.to_string(), + ARROW_JSON_EXT_NAME.to_string(), + )]); + let arrow_json_field = + Field::new("data", input_type.clone(), true).with_metadata(arrow_json_metadata); + let arrow_json_schema = Arc::new(Schema::new(vec![arrow_json_field])); + + let rows: Vec> = vec![None, Some(r#"{"a": 1}"#), Some(r#"{"b": 2}"#)]; + let string_array: Arc = match input_type { + DataType::Utf8 => Arc::new(arrow_array::StringArray::from(rows.clone())), + DataType::LargeUtf8 => Arc::new(arrow_array::LargeStringArray::from(rows.clone())), + other => panic!("unsupported arrow.json backing type for this test: {other:?}"), + }; + let batch = RecordBatch::try_new(arrow_json_schema, vec![string_array]).unwrap(); + + // This must not fail with a schema-mismatch error. + table.add(batch).execute().await.unwrap(); + + assert_eq!(table.count_rows(None).await.unwrap(), rows.len()); + + // A lance.json column is read back as Utf8 carrying arrow.json extension metadata. + let results: Vec = table + .query() + .select(Select::columns(&["data"])) + .execute() + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(results.len(), 1); + let batch = &results[0]; + assert_eq!(batch.num_rows(), rows.len()); + + let json_col = batch.column(0); + assert_eq!(json_col.data_type(), &DataType::Utf8); + let json_strs = json_col.as_string::(); + + for (i, expected) in rows.iter().enumerate() { + match expected { + None => assert!(json_strs.is_null(i), "row {i} expected null"), + Some(raw) => { + assert!(!json_strs.is_null(i), "row {i} expected non-null"); + let actual: serde_json::Value = serde_json::from_str(json_strs.value(i)) + .expect("read-back JSON should be valid"); + let expected: serde_json::Value = + serde_json::from_str(raw).expect("expected JSON should be valid"); + assert_eq!(actual, expected, "row {i} JSON mismatch"); + } + } + } + } } diff --git a/rust/lancedb/src/table/datafusion/cast.rs b/rust/lancedb/src/table/datafusion/cast.rs index b4abb16c5..ccf72ccb8 100644 --- a/rust/lancedb/src/table/datafusion/cast.rs +++ b/rust/lancedb/src/table/datafusion/cast.rs @@ -13,6 +13,7 @@ use datafusion_physical_expr::expressions::{CastExpr, Literal}; use datafusion_physical_plan::expressions::Column; use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::{ExecutionPlan, PhysicalExpr}; +use lance_arrow::json::{is_arrow_json_field, is_json_field}; use crate::{Error, Result}; @@ -64,6 +65,18 @@ fn build_field_exprs( let input_field = &input_fields[input_idx]; let input_expr = get_input_expr(input_idx); + // Special case: input is arrow.json (PyArrow pa.json_() extension type backed by + // Utf8/LargeUtf8) and the table field is lance.json (backed by LargeBinary). + // Lance-core's write path already handles the arrow.json → lance.json conversion + // (including JSONB encoding), so we pass the expression through unchanged and let + // lance-core deal with it. Attempting to cast Utf8 → LargeBinary here would + // produce a field whose metadata still identifies it as arrow.json, which then + // causes a schema-mismatch error inside lance-core. + if is_arrow_json_field(input_field) && is_json_field(table_field) { + result.push((input_expr, Arc::clone(input_field) as FieldRef)); + continue; + } + let expr = match (input_field.data_type(), table_field.data_type()) { // Both are structs: recurse into sub-fields to handle subschemas and casts. (DataType::Struct(in_children), DataType::Struct(tbl_children)) @@ -618,4 +631,75 @@ mod tests { .unwrap(); assert_eq!(a.values(), &[1, 3]); } + + /// `arrow.json` input (PyArrow `pa.json_()`, Utf8/LargeUtf8 + extension metadata) against a + /// `lance.json` table field (LargeBinary + extension metadata) must be passed through + /// without a cast so that lance-core can perform its own arrow.json → JSONB conversion. + /// + /// Before the fix, `cast_to_table_schema` attempted a `Utf8 → LargeBinary` DataFusion + /// cast that preserved the wrong extension metadata, causing lance-core to reject the + /// batch with a "json vs large_binary" schema-mismatch error. + #[rstest::rstest] + #[case::utf8(DataType::Utf8)] + #[case::large_utf8(DataType::LargeUtf8)] + #[tokio::test] + async fn test_arrow_json_passthrough_to_lance_json(#[case] input_type: DataType) { + use lance_arrow::ARROW_EXT_NAME_KEY; + use lance_arrow::json::{ARROW_JSON_EXT_NAME, json_field}; + + // Build a table schema with a lance.json field (LargeBinary + lance.json metadata). + let lance_field = json_field("data", true); + let table_schema = Schema::new(vec![lance_field]); + + // Build an input batch with an arrow.json field (Utf8/LargeUtf8 + arrow.json metadata). + let arrow_meta = std::collections::HashMap::from([( + ARROW_EXT_NAME_KEY.to_string(), + ARROW_JSON_EXT_NAME.to_string(), + )]); + let arrow_field = Field::new("data", input_type.clone(), true).with_metadata(arrow_meta); + let input_schema = Arc::new(Schema::new(vec![arrow_field])); + + let values = vec![Some(r#"{"x": 1}"#), None, Some(r#"{"y": 2}"#)]; + let input_array: Arc = match input_type { + DataType::Utf8 => Arc::new(StringArray::from(values)), + DataType::LargeUtf8 => Arc::new(arrow_array::LargeStringArray::from(values)), + other => panic!("unsupported arrow.json backing type for this test: {other:?}"), + }; + let input_batch = RecordBatch::try_new(input_schema, vec![input_array]).unwrap(); + + let plan = plan_from_batch(input_batch).await; + let projected = cast_to_table_schema(plan, &table_schema).unwrap(); + + // The projected schema's "data" field must carry arrow.json metadata + // (the input field), not be silently dropped or miscast. + let out_field = projected.schema().field_with_name("data").unwrap().clone(); + assert_eq!(out_field.data_type(), &input_type); + assert_eq!( + out_field + .metadata() + .get(ARROW_EXT_NAME_KEY) + .map(|s| s.as_str()), + Some(ARROW_JSON_EXT_NAME), + "output field must still carry arrow.json metadata so lance-core can handle it" + ); + + // The data must flow through correctly (3 rows, no panic). + let result = collect(projected).await; + assert_eq!(result.num_rows(), 3); + let (v0, v2) = match input_type { + DataType::Utf8 => { + let col: &StringArray = result.column(0).as_any().downcast_ref().unwrap(); + (col.value(0).to_string(), col.value(2).to_string()) + } + DataType::LargeUtf8 => { + let col: &arrow_array::LargeStringArray = + result.column(0).as_any().downcast_ref().unwrap(); + (col.value(0).to_string(), col.value(2).to_string()) + } + _ => unreachable!(), + }; + assert_eq!(v0, r#"{"x": 1}"#); + assert!(result.column(0).is_null(1)); + assert_eq!(v2, r#"{"y": 2}"#); + } } diff --git a/rust/lancedb/src/table/datafusion/udtf/fts.rs b/rust/lancedb/src/table/datafusion/udtf/fts.rs index 5b50ddfa3..8b79ca676 100644 --- a/rust/lancedb/src/table/datafusion/udtf/fts.rs +++ b/rust/lancedb/src/table/datafusion/udtf/fts.rs @@ -870,8 +870,10 @@ mod tests { .await .unwrap(); - // Should return empty or nearly empty result - assert!(result[0].num_rows() <= 1); + assert_eq!( + result.iter().map(|batch| batch.num_rows()).sum::(), + 0 + ); } #[tokio::test] diff --git a/rust/lancedb/src/table/dataset.rs b/rust/lancedb/src/table/dataset.rs index 584d45a2f..b4673d876 100644 --- a/rust/lancedb/src/table/dataset.rs +++ b/rust/lancedb/src/table/dataset.rs @@ -8,6 +8,7 @@ use std::{ use lance::{Dataset, dataset::refs}; +use crate::table::merge::lsm::ShardWriterCache; use crate::{Error, error::Result, utils::background_cache::BackgroundCache}; /// A wrapper around a [Dataset] that provides consistency checks. @@ -18,6 +19,10 @@ use crate::{Error, error::Result, utils::background_cache::BackgroundCache}; pub struct DatasetConsistencyWrapper { state: Arc>, consistency: ConsistencyMode, + /// The single MemWAL `ShardWriter` for this dataset, co-located so it is + /// cached for the session and shares the dataset's lifecycle. A dataset + /// writes to one shard at a time. Shared by `Arc` across clones. + shard_writer: Arc, } /// The current dataset and whether it is pinned to a specific version. @@ -67,9 +72,15 @@ impl DatasetConsistencyWrapper { pinned_version: None, })), consistency, + shard_writer: Arc::new(ShardWriterCache::default()), } } + /// The MemWAL `ShardWriter` cache co-located with this dataset. + pub(crate) fn shard_writer(&self) -> &Arc { + &self.shard_writer + } + /// Get the current dataset. /// /// Behavior depends on the consistency mode: diff --git a/rust/lancedb/src/table/delete.rs b/rust/lancedb/src/table/delete.rs index 3d469393c..8f11ee019 100644 --- a/rust/lancedb/src/table/delete.rs +++ b/rust/lancedb/src/table/delete.rs @@ -1,9 +1,12 @@ +use std::sync::Arc; + use futures::FutureExt; +use lance::dataset::DeleteBuilder; // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors use serde::{Deserialize, Serialize}; -use super::NativeTable; +use super::{NativeTable, Predicate}; use crate::Result; #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] @@ -21,17 +24,39 @@ pub struct DeleteResult { /// Internal implementation of the delete logic /// /// This logic was moved from NativeTable::delete to keep table.rs clean. -pub(crate) async fn execute_delete(table: &NativeTable, predicate: &str) -> Result { +pub(crate) async fn execute_delete( + table: &NativeTable, + predicate: Predicate<'_>, +) -> Result { table.dataset.ensure_mutable()?; - let mut dataset = (*table.dataset.get().await?).clone(); - let delete_result = dataset.delete(predicate).boxed().await?; - let num_deleted_rows = delete_result.num_deleted_rows; - let version = dataset.version().version; - table.dataset.update(dataset); - Ok(DeleteResult { - num_deleted_rows, - version, - }) + match predicate { + Predicate::String(s) => { + let mut dataset = (*table.dataset.get().await?).clone(); + let delete_result = dataset.delete(s).boxed().await?; + let num_deleted_rows = delete_result.num_deleted_rows; + let version = dataset.version().version; + table.dataset.update(dataset); + Ok(DeleteResult { + num_deleted_rows, + version, + }) + } + Predicate::Expr(expr) => { + let dataset = table.dataset.get().await?; + let delete_result = DeleteBuilder::from_expr(Arc::clone(&dataset), expr.clone()) + .execute() + .await?; + let num_deleted_rows = delete_result.num_deleted_rows; + let version = delete_result.new_dataset.version().version; + table.dataset.update( + Arc::try_unwrap(delete_result.new_dataset).unwrap_or_else(|arc| (*arc).clone()), + ); + Ok(DeleteResult { + num_deleted_rows, + version, + }) + } + } } #[cfg(test)] @@ -176,4 +201,100 @@ mod tests { "Table version must increment after delete operation" ); } + + #[tokio::test] + async fn test_delete_expr() { + use datafusion_expr::{col, lit}; + + let conn = connect("memory://").execute().await.unwrap(); + + // 1. Create a table with values 0 to 9 + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from_iter_values(0..10))], + ) + .unwrap(); + + let table = conn + .create_table("test_delete_expr", batch) + .execute() + .await + .unwrap(); + + // 2. Verify initial state + assert_eq!(table.count_rows(None).await.unwrap(), 10); + let initial_version = table.version().await.unwrap(); + + // 3. Execute Delete with Expr (removes values > 5) + let expr = col("i").gt(lit(5)); + table.delete(&expr).await.unwrap(); + + // 4. Verify results + assert_eq!(table.count_rows(None).await.unwrap(), 6); // 0, 1, 2, 3, 4, 5 remain + let current_version = table.version().await.unwrap(); + assert!( + current_version > initial_version, + "Table version must increment after delete_expr operation" + ); + + // 5. Verify specific data consistency + let batches = table + .query() + .execute() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + let batch = &batches[0]; + let array = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + // Ensure no value > 5 exists + for val in array.iter() { + assert!(val.unwrap() <= 5); + } + } + + #[tokio::test] + async fn test_delete_expr_increments_version() { + use datafusion_expr::lit; + + let conn = connect("memory://").execute().await.unwrap(); + + // Create a table with 5 rows + let batch = record_batch!(("id", Int32, [1, 2, 3, 4, 5])).unwrap(); + + let table = conn + .create_table("test_delete_expr_noop", batch) + .execute() + .await + .unwrap(); + + // Capture the initial state (Rows = 5, Version = 1) + let initial_rows = table.count_rows(None).await.unwrap(); + let initial_version = table.version().await.unwrap(); + + assert_eq!(initial_rows, 5); + let expr = lit(false); + table.delete(&expr).await.unwrap(); + + // Rows should still be 5 + let current_rows = table.count_rows(None).await.unwrap(); + assert_eq!( + current_rows, initial_rows, + "Data should not change when predicate is false" + ); + + // version check + let current_version = table.version().await.unwrap(); + assert!( + current_version > initial_version, + "Table version must increment after delete_expr operation" + ); + } } diff --git a/rust/lancedb/src/table/merge.rs b/rust/lancedb/src/table/merge.rs index def78aa4f..b3bda36af 100644 --- a/rust/lancedb/src/table/merge.rs +++ b/rust/lancedb/src/table/merge.rs @@ -41,6 +41,16 @@ pub struct MergeResult { /// A value of 1 means the operation succeeded on the first try. #[serde(default)] pub num_attempts: u32, + /// Total number of rows written. + /// + /// On the standard `merge_insert` path this equals + /// `num_inserted_rows + num_updated_rows`. On the MemWAL LSM write path the + /// insert/update breakdown is not known until compaction; in that mode + /// `num_inserted_rows`, `num_updated_rows`, `num_deleted_rows`, `version` + /// and `num_attempts` are all `0` and this field holds the total number of + /// rows written through the shard writer. + #[serde(default)] + pub num_rows: u64, } /// A builder used to create and run a merge insert operation @@ -57,6 +67,8 @@ pub struct MergeInsertBuilder { pub(crate) when_not_matched_by_source_delete_filt: Option, pub(crate) timeout: Option, pub(crate) use_index: bool, + pub(crate) use_lsm_write: Option, + pub(crate) validate_single_shard: bool, } impl MergeInsertBuilder { @@ -71,6 +83,8 @@ impl MergeInsertBuilder { when_not_matched_by_source_delete_filt: None, timeout: None, use_index: true, + use_lsm_write: None, + validate_single_shard: true, } } @@ -150,6 +164,34 @@ impl MergeInsertBuilder { self } + /// Controls whether `merge_insert` uses the MemWAL LSM write path. + /// + /// By default (unset), a `merge_insert` on a table with an + /// [`LsmWriteSpec`](super::LsmWriteSpec) installed is routed through + /// Lance's MemWAL shard writer, and a table without one uses the standard + /// path. Calling this with `false` forces the standard path even when a + /// spec is set. Calling it with `true` requires a spec — `merge_insert` + /// errors if none is installed. + pub fn use_lsm_write(&mut self, use_lsm_write: bool) -> &mut Self { + self.use_lsm_write = Some(use_lsm_write); + self + } + + /// Controls how an LSM `merge_insert` checks that its input targets a + /// single shard. + /// + /// When a table has an LSM write spec, every row in a `merge_insert` call + /// must route to the same shard. When `true` (the default), every row is + /// inspected to verify this. When `false`, only the first row is inspected + /// and the shard it routes to is used for the whole input — a faster path + /// for callers that have already pre-sharded their input. + /// + /// Has no effect on tables without an LSM write spec. + pub fn validate_single_shard(&mut self, validate_single_shard: bool) -> &mut Self { + self.validate_single_shard = validate_single_shard; + self + } + /// Executes the merge insert operation /// /// Returns version and statistics about the merge operation including the number of rows @@ -167,6 +209,23 @@ pub(crate) async fn execute_merge_insert( params: MergeInsertBuilder, new_data: Box, ) -> Result { + match lsm::lsm_dispatch_decision(table, ¶ms).await? { + lsm::LsmDispatch::Lsm(plan) => { + let future = + lsm::execute_lsm_merge_insert(table, plan, params.validate_single_shard, new_data); + return match params.timeout { + Some(timeout) => match tokio::time::timeout(timeout, future).await { + Ok(result) => result, + Err(_) => Err(Error::Runtime { + message: "merge insert timed out".to_string(), + }), + }, + None => future.await, + }; + } + lsm::LsmDispatch::Standard => {} + } + let dataset = table.dataset.get().await?; let mut builder = LanceMergeInsertBuilder::try_new(dataset.clone(), params.on)?; match ( @@ -219,6 +278,7 @@ pub(crate) async fn execute_merge_insert( num_inserted_rows: stats.num_inserted_rows, num_deleted_rows: stats.num_deleted_rows, num_attempts: stats.num_attempts, + num_rows: stats.num_inserted_rows + stats.num_updated_rows, }) } @@ -327,3 +387,366 @@ mod tests { assert_eq!(table.count_rows(None).await.unwrap(), 25); } } + +#[cfg(test)] +mod lsm_tests { + use std::sync::Arc; + + use arrow_array::{ + Int64Array, RecordBatch, RecordBatchIterator, RecordBatchReader, StringArray, + }; + use arrow_schema::{DataType, Field, Schema}; + use tempfile::{TempDir, tempdir}; + + use crate::connect; + use crate::error::Error; + use crate::table::{LsmWriteSpec, Table}; + + /// A reader of `[id: Int64, value: Int64]` rows; `value` is `0..n`. + fn id_value_reader(ids: Vec) -> Box { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("value", DataType::Int64, false), + ])); + let n = ids.len() as i64; + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(ids)), + Arc::new(Int64Array::from_iter_values(0..n)), + ], + ) + .unwrap(); + Box::new(RecordBatchIterator::new(vec![Ok(batch)], schema)) + } + + /// A reader of `[id: Int64, region: Utf8]` rows. + fn id_region_reader(rows: Vec<(i64, &str)>) -> Box { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("region", DataType::Utf8, false), + ])); + let ids: Vec = rows.iter().map(|(id, _)| *id).collect(); + let regions: Vec<&str> = rows.iter().map(|(_, region)| *region).collect(); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(ids)), + Arc::new(StringArray::from(regions)), + ], + ) + .unwrap(); + Box::new(RecordBatchIterator::new(vec![Ok(batch)], schema)) + } + + /// A multi-batch reader of `[id: Int64, region: Utf8]` rows. + fn id_region_multi_reader(batches: Vec>) -> Box { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("region", DataType::Utf8, false), + ])); + let records: Vec<_> = batches + .into_iter() + .map(|rows| { + let ids: Vec = rows.iter().map(|(id, _)| *id).collect(); + let regions: Vec<&str> = rows.iter().map(|(_, region)| *region).collect(); + Ok(RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(ids)), + Arc::new(StringArray::from(regions)), + ], + ) + .unwrap()) + }) + .collect(); + Box::new(RecordBatchIterator::new(records, schema)) + } + + /// Create an `[id, value]` table with `id` as the unenforced primary key. + async fn id_value_table(dir: &TempDir) -> Table { + let conn = connect(dir.path().to_str().unwrap()) + .execute() + .await + .unwrap(); + let table = conn + .create_table("t", id_value_reader(vec![1, 2, 3])) + .execute() + .await + .unwrap(); + table.set_unenforced_primary_key(["id"]).await.unwrap(); + table + } + + #[tokio::test] + async fn lsm_merge_insert_bucket() { + let dir = tempdir().unwrap(); + let table = id_value_table(&dir).await; + // num_buckets = 1: every row routes to the single bucket. + table + .set_lsm_write_spec(LsmWriteSpec::bucket("id", 1)) + .await + .unwrap(); + + // Empty `on` defaults to the primary key. + let mut builder = table.merge_insert(&[]); + builder + .when_matched_update_all(None) + .when_not_matched_insert_all(); + let result = builder + .execute(id_value_reader(vec![3, 4, 5])) + .await + .unwrap(); + + // LSM path: rows go to the MemWAL, the breakdown is unknown until + // compaction, so only `num_rows` is populated. + assert_eq!(result.num_rows, 3); + assert_eq!(result.version, 0); + assert_eq!(result.num_inserted_rows, 0); + assert_eq!(result.num_updated_rows, 0); + } + + #[tokio::test] + async fn lsm_merge_insert_unsharded() { + let dir = tempdir().unwrap(); + let table = id_value_table(&dir).await; + table + .set_lsm_write_spec(LsmWriteSpec::unsharded()) + .await + .unwrap(); + + let mut builder = table.merge_insert(&["id"]); + builder + .when_matched_update_all(None) + .when_not_matched_insert_all(); + let result = builder + .execute(id_value_reader(vec![10, 11, 12, 13])) + .await + .unwrap(); + assert_eq!(result.num_rows, 4); + } + + #[tokio::test] + async fn lsm_merge_insert_identity() { + let dir = tempdir().unwrap(); + let conn = connect(dir.path().to_str().unwrap()) + .execute() + .await + .unwrap(); + let table = conn + .create_table("t", id_region_reader(vec![(1, "us"), (2, "us")])) + .execute() + .await + .unwrap(); + table.set_unenforced_primary_key(["id"]).await.unwrap(); + table + .set_lsm_write_spec(LsmWriteSpec::identity("region")) + .await + .unwrap(); + + // All rows share one identity value, so they route to one shard. + let mut builder = table.merge_insert(&[]); + builder + .when_matched_update_all(None) + .when_not_matched_insert_all(); + let result = builder + .execute(id_region_reader(vec![(3, "us"), (4, "us")])) + .await + .unwrap(); + assert_eq!(result.num_rows, 2); + } + + #[tokio::test] + async fn lsm_merge_insert_use_lsm_write_false_falls_back() { + let dir = tempdir().unwrap(); + let table = id_value_table(&dir).await; + table + .set_lsm_write_spec(LsmWriteSpec::bucket("id", 1)) + .await + .unwrap(); + + // use_lsm_write(false) opts out: the standard path runs and commits. + let mut builder = table.merge_insert(&["id"]); + builder.when_not_matched_insert_all().use_lsm_write(false); + let result = builder + .execute(id_value_reader(vec![3, 4, 5])) + .await + .unwrap(); + + assert_eq!(result.num_inserted_rows, 2); + assert_eq!(table.count_rows(None).await.unwrap(), 5); + } + + #[tokio::test] + async fn lsm_merge_insert_rejects_on_not_primary_key() { + let dir = tempdir().unwrap(); + let table = id_value_table(&dir).await; + table + .set_lsm_write_spec(LsmWriteSpec::bucket("id", 1)) + .await + .unwrap(); + + let mut builder = table.merge_insert(&["value"]); + builder + .when_matched_update_all(None) + .when_not_matched_insert_all(); + let err = builder.execute(id_value_reader(vec![1])).await.unwrap_err(); + assert!(matches!(err, Error::InvalidInput { .. }), "got {err:?}"); + } + + #[tokio::test] + async fn lsm_merge_insert_rejects_non_upsert() { + let dir = tempdir().unwrap(); + let table = id_value_table(&dir).await; + table + .set_lsm_write_spec(LsmWriteSpec::bucket("id", 1)) + .await + .unwrap(); + + // Insert-only (no when_matched_update_all) is not the upsert shape. + let mut builder = table.merge_insert(&[]); + builder.when_not_matched_insert_all(); + let err = builder.execute(id_value_reader(vec![4])).await.unwrap_err(); + assert!(matches!(err, Error::InvalidInput { .. }), "got {err:?}"); + } + + #[tokio::test] + async fn lsm_close_writers_then_reopen() { + let dir = tempdir().unwrap(); + let table = id_value_table(&dir).await; + table + .set_lsm_write_spec(LsmWriteSpec::bucket("id", 1)) + .await + .unwrap(); + + let mut builder = table.merge_insert(&[]); + builder + .when_matched_update_all(None) + .when_not_matched_insert_all(); + builder.execute(id_value_reader(vec![7, 8])).await.unwrap(); + + table.close_lsm_writers().await.unwrap(); + + // The writer reopens lazily on the next merge_insert. + let mut builder = table.merge_insert(&[]); + builder + .when_matched_update_all(None) + .when_not_matched_insert_all(); + let result = builder.execute(id_value_reader(vec![9])).await.unwrap(); + assert_eq!(result.num_rows, 1); + } + + #[tokio::test] + async fn lsm_merge_insert_multi_batch() { + let dir = tempdir().unwrap(); + let conn = connect(dir.path().to_str().unwrap()) + .execute() + .await + .unwrap(); + let table = conn + .create_table("t", id_region_reader(vec![(1, "us")])) + .execute() + .await + .unwrap(); + table.set_unenforced_primary_key(["id"]).await.unwrap(); + table + .set_lsm_write_spec(LsmWriteSpec::identity("region")) + .await + .unwrap(); + + // Multiple batches that all route to one shard are written together. + let mut builder = table.merge_insert(&[]); + builder + .when_matched_update_all(None) + .when_not_matched_insert_all(); + let result = builder + .execute(id_region_multi_reader(vec![ + vec![(2, "us"), (3, "us")], + vec![(4, "us")], + ])) + .await + .unwrap(); + assert_eq!(result.num_rows, 3); + + // Batches that route to different shards are rejected; the validation + // runs before any write, so no partial write is left behind. + let mut builder = table.merge_insert(&[]); + builder + .when_matched_update_all(None) + .when_not_matched_insert_all(); + let err = builder + .execute(id_region_multi_reader(vec![ + vec![(5, "us")], + vec![(6, "eu")], + ])) + .await + .unwrap_err(); + assert!(matches!(err, Error::InvalidInput { .. }), "got {err:?}"); + } + + #[tokio::test] + async fn lsm_merge_insert_use_lsm_write_true_requires_spec() { + let dir = tempdir().unwrap(); + // id_value_table sets a primary key but no LSM write spec. + let table = id_value_table(&dir).await; + + let mut builder = table.merge_insert(&["id"]); + builder + .when_matched_update_all(None) + .when_not_matched_insert_all() + .use_lsm_write(true); + let err = builder.execute(id_value_reader(vec![4])).await.unwrap_err(); + assert!(matches!(err, Error::InvalidInput { .. }), "got {err:?}"); + } + + #[tokio::test] + async fn lsm_merge_insert_rejects_second_shard() { + let dir = tempdir().unwrap(); + let conn = connect(dir.path().to_str().unwrap()) + .execute() + .await + .unwrap(); + let table = conn + .create_table("t", id_region_reader(vec![(1, "us")])) + .execute() + .await + .unwrap(); + table.set_unenforced_primary_key(["id"]).await.unwrap(); + table + .set_lsm_write_spec(LsmWriteSpec::identity("region")) + .await + .unwrap(); + + // The first merge_insert opens the single writer for shard "us". + let mut builder = table.merge_insert(&[]); + builder + .when_matched_update_all(None) + .when_not_matched_insert_all(); + builder + .execute(id_region_reader(vec![(2, "us")])) + .await + .unwrap(); + + // A merge_insert routing to a different shard is rejected. + let mut builder = table.merge_insert(&[]); + builder + .when_matched_update_all(None) + .when_not_matched_insert_all(); + let err = builder + .execute(id_region_reader(vec![(3, "eu")])) + .await + .unwrap_err(); + assert!(matches!(err, Error::InvalidInput { .. }), "got {err:?}"); + + // After closing the writer, a different shard can be written. + table.close_lsm_writers().await.unwrap(); + let mut builder = table.merge_insert(&[]); + builder + .when_matched_update_all(None) + .when_not_matched_insert_all(); + builder + .execute(id_region_reader(vec![(4, "eu")])) + .await + .unwrap(); + } +} diff --git a/rust/lancedb/src/table/merge/lsm.rs b/rust/lancedb/src/table/merge/lsm.rs index 51d04f5e0..80246d59f 100644 --- a/rust/lancedb/src/table/merge/lsm.rs +++ b/rust/lancedb/src/table/merge/lsm.rs @@ -1,26 +1,71 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors -//! MemWAL LSM write-path spec management. +//! MemWAL LSM write path for `merge_insert`. //! -//! [`set_lsm_write_spec`] installs a [`super::super::LsmWriteSpec`] on a -//! table, which selects Lance's MemWAL LSM-style write path for future -//! `merge_insert` calls. [`unset_lsm_write_spec`] removes it. The actual -//! `merge_insert` dispatch and writer are a follow-up. +//! [`set_lsm_write_spec`] installs an [`LsmWriteSpec`] on a table by creating +//! Lance's MemWAL index; [`unset_lsm_write_spec`] removes it. Once a spec is +//! installed, `merge_insert` upsert calls are dispatched through Lance's MemWAL +//! `ShardWriter` (LSM-style append) instead of the standard merge path — see +//! [`lsm_dispatch_decision`] and [`execute_lsm_merge_insert`]. +//! +//! Each `merge_insert` call must target a single shard: every row must route +//! to the same shard under the installed sharding spec (bucket / identity / +//! unsharded). [`MergeInsertBuilder::validate_single_shard`] controls whether +//! every row is checked or only the first. A dataset writes to one shard at a +//! time; its writer is cached in the [`ShardWriterCache`] held alongside the +//! dataset, and [`close_lsm_writers`] closes it. -use lance::dataset::mem_wal::DatasetMemWalExt; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use arrow_array::cast::AsArray; +use arrow_array::types::{ + Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type, UInt16Type, UInt32Type, UInt64Type, +}; +use arrow_array::{Array, ArrayRef, Int32Array, RecordBatch, RecordBatchReader}; +use arrow_schema::{DataType, Schema as ArrowSchema, SchemaRef}; +use lance::Dataset; +use lance::dataset::mem_wal::{ + DatasetMemWalExt, ShardWriter, ShardWriterConfig, evaluate_sharding_spec, +}; use lance::index::DatasetIndexExt; +use lance_core::datatypes::Schema as LanceSchema; +use lance_index::mem_wal::{MemWalIndexDetails, ShardingSpec}; +use tokio::sync::RwLock; +use uuid::Uuid; use crate::error::{Error, Result}; +use crate::table::merge::{MergeInsertBuilder, MergeResult}; use crate::table::{LsmWriteSpec, NativeTable}; +/// Spec id of the sole sharding spec installed by [`set_lsm_write_spec`]. +/// Must match Lance's `InitializeMemWalBuilder` (`SHARDING_SPEC_ID`). +const SHARDING_SPEC_ID: u32 = 1; + +/// Transform name recorded by `bucket_sharding`. +const BUCKET_TRANSFORM: &str = "bucket"; +/// Transform name recorded by `identity_sharding`. +const IDENTITY_TRANSFORM: &str = "identity"; +/// Transform name recorded by `unsharded`. +const UNSHARDED_TRANSFORM: &str = "unsharded"; + +/// Parameter key holding the bucket count on the bucket transform. +const NUM_BUCKETS_PARAM: &str = "num_buckets"; + +/// Fixed namespace UUID for deriving deterministic shard ids. Hardcoded so +/// derivations stay stable across processes. +const SHARD_NAMESPACE: Uuid = Uuid::from_u128(0x4c53_4d57_5249_5445_5f53_4841_5244_3031); + // ============================================================================= // set_lsm_write_spec // ============================================================================= /// Install an [`LsmWriteSpec`] on the table. /// -/// The bucket / unsharded sharding spec is constructed and validated by Lance's +/// The bucket / identity / unsharded sharding spec is constructed and validated +/// by Lance's /// [`InitializeMemWalBuilder`](lance::dataset::mem_wal::InitializeMemWalBuilder). #[allow(clippy::redundant_pub_crate)] pub(crate) async fn set_lsm_write_spec(table: &NativeTable, spec: LsmWriteSpec) -> Result<()> { @@ -78,7 +123,8 @@ pub(crate) async fn set_lsm_write_spec(table: &NativeTable, spec: LsmWriteSpec) /// Remove the [`LsmWriteSpec`] from the table by dropping the MemWAL index. /// -/// Errors if no spec is currently set. +/// Any cached shard writers are drained and closed first. Errors if no spec is +/// currently set. #[allow(clippy::redundant_pub_crate)] pub(crate) async fn unset_lsm_write_spec(table: &NativeTable) -> Result<()> { table.dataset.ensure_mutable()?; @@ -92,6 +138,8 @@ pub(crate) async fn unset_lsm_write_spec(table: &NativeTable) -> Result<()> { } } + table.dataset.shard_writer().drain_and_close().await?; + let mut dataset = (*table.dataset.get().await?).clone(); dataset .drop_index(lance_index::mem_wal::MEM_WAL_INDEX_NAME) @@ -99,3 +147,937 @@ pub(crate) async fn unset_lsm_write_spec(table: &NativeTable) -> Result<()> { table.dataset.update(dataset); Ok(()) } + +// ============================================================================= +// close_lsm_writers +// ============================================================================= + +/// Drain and close every cached MemWAL shard writer for the table. +#[allow(clippy::redundant_pub_crate)] +pub(crate) async fn close_lsm_writers(table: &NativeTable) -> Result<()> { + table.dataset.shard_writer().drain_and_close().await +} + +// ============================================================================= +// ShardWriter cache +// ============================================================================= + +/// Per-table cache holding the single open MemWAL `ShardWriter`. +/// +/// Held by [`DatasetConsistencyWrapper`](crate::table::dataset::DatasetConsistencyWrapper) +/// so the writer lives where the dataset lives — cached for the session and +/// reused across `merge_insert` calls. A dataset writes to one shard at a +/// time; routing a `merge_insert` to a different shard requires closing the +/// current writer first via [`close_lsm_writers`]. `ShardWriter::put` takes +/// `&self`, so concurrent puts on the cached writer are safe; `close` consumes +/// the writer, so the entry wraps it in `RwLock>`. +#[derive(Default)] +#[allow(clippy::redundant_pub_crate)] +pub(crate) struct ShardWriterCache { + /// `Some((shard_id, entry))` once a writer has been opened for the session. + slot: RwLock)>>, +} + +impl std::fmt::Debug for ShardWriterCache { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ShardWriterCache").finish_non_exhaustive() + } +} + +struct ShardWriterEntry { + inner: RwLock>, +} + +impl ShardWriterEntry { + fn new(writer: ShardWriter) -> Self { + Self { + inner: RwLock::new(Some(writer)), + } + } + + async fn put(&self, batches: Vec) -> Result<()> { + let guard = self.inner.read().await; + let writer = guard.as_ref().ok_or_else(|| Error::Runtime { + message: "merge_insert: shard writer was closed before this write".to_string(), + })?; + writer.put(batches).await.map_err(|e| Error::Runtime { + message: format!("merge_insert: shard writer put failed: {}", e), + })?; + Ok(()) + } + + async fn close(&self) -> Result<()> { + let writer = { self.inner.write().await.take() }; + if let Some(writer) = writer { + writer.close().await.map_err(|e| Error::Runtime { + message: format!("merge_insert: shard writer close failed: {}", e), + })?; + } + Ok(()) + } +} + +impl ShardWriterCache { + /// Return the cached writer, opening one for `shard_id` with `config` if + /// the slot is empty. Errors if a writer is already open for a *different* + /// shard — the caller must close it first. + async fn writer_for_shard( + &self, + dataset: &Dataset, + shard_id: Uuid, + config: ShardWriterConfig, + ) -> Result> { + { + let guard = self.slot.read().await; + if let Some((cached, entry)) = guard.as_ref() { + check_shard_match(*cached, shard_id)?; + return Ok(entry.clone()); + } + } + let mut guard = self.slot.write().await; + // Re-check: another caller may have opened the writer meanwhile. + if let Some((cached, entry)) = guard.as_ref() { + check_shard_match(*cached, shard_id)?; + return Ok(entry.clone()); + } + let writer = dataset + .mem_wal_writer(shard_id, config) + .await + .map_err(|e| Error::Runtime { + message: format!( + "merge_insert: failed to open MemWAL shard writer for shard {}: {}", + shard_id, e + ), + })?; + let entry = Arc::new(ShardWriterEntry::new(writer)); + *guard = Some((shard_id, entry.clone())); + Ok(entry) + } + + /// Close the cached writer, if any, and clear the slot. + #[allow(clippy::redundant_pub_crate)] + pub(crate) async fn drain_and_close(&self) -> Result<()> { + let cached = { self.slot.write().await.take() }; + if let Some((_, entry)) = cached { + entry.close().await?; + } + Ok(()) + } +} + +/// Error if a cached writer is open for a shard other than the one needed. +fn check_shard_match(cached: Uuid, wanted: Uuid) -> Result<()> { + if cached == wanted { + return Ok(()); + } + Err(Error::InvalidInput { + message: format!( + "merge_insert: a shard writer is already open for shard {} but this input routes to shard {}; call close_lsm_writers before writing to a different shard", + cached, wanted + ), + }) +} + +// ============================================================================= +// merge_insert LSM dispatch +// ============================================================================= + +/// How the installed sharding spec routes rows to shards. +#[derive(Debug, Clone)] +enum LsmMode { + /// Hash-bucket the routing column into `num_buckets` shards. + Bucket { spec: ShardingSpec }, + /// Shard by the raw value of the routing column. + Identity { spec: ShardingSpec }, + /// Route every row to a single shard. + Unsharded, +} + +/// Resolved plan for routing a `merge_insert` through the MemWAL write path. +#[derive(Debug)] +#[allow(clippy::redundant_pub_crate)] +pub(crate) struct LsmPlan { + mode: LsmMode, + writer_config_defaults: HashMap, +} + +/// Outcome of [`lsm_dispatch_decision`]. +#[allow(clippy::redundant_pub_crate)] +pub(crate) enum LsmDispatch { + /// No LSM write spec applies; use the standard `merge_insert` path. + Standard, + /// Route the `merge_insert` through the MemWAL shard writer. + Lsm(LsmPlan), +} + +/// Decide whether a `merge_insert` should be routed through the MemWAL write +/// path, validating the builder against the installed spec. +#[allow(clippy::redundant_pub_crate)] +pub(crate) async fn lsm_dispatch_decision( + table: &NativeTable, + params: &MergeInsertBuilder, +) -> Result { + // `Some(false)` is an explicit opt-out: use the standard path. + if params.use_lsm_write == Some(false) { + return Ok(LsmDispatch::Standard); + } + + let dataset = table.dataset.get().await?; + let Some(details) = dataset.mem_wal_index_details().await? else { + // No LSM write spec installed. `Some(true)` explicitly asked for the + // LSM path, which is meaningless without a spec; `None` (the default) + // just falls back to the standard path. + if params.use_lsm_write == Some(true) { + return Err(Error::InvalidInput { + message: "merge_insert: use_lsm_write(true) requires an LSM write spec on the table; call set_lsm_write_spec first".to_string(), + }); + } + return Ok(LsmDispatch::Standard); + }; + + let pk_cols: Vec = dataset + .schema() + .unenforced_primary_key() + .iter() + .map(|f| f.name.clone()) + .collect(); + if pk_cols.is_empty() { + return Err(Error::Runtime { + message: "merge_insert: table has a MemWAL index but no unenforced primary key" + .to_string(), + }); + } + if !params.on.is_empty() && params.on != pk_cols { + return Err(Error::InvalidInput { + message: format!( + "merge_insert: `on` columns {:?} must match the table's unenforced primary key {:?} when an LSM write spec is set; pass an empty `on` to default to the primary key", + params.on, pk_cols + ), + }); + } + + if !is_upsert_only(params) { + return Err(Error::InvalidInput { + message: "merge_insert: when an LSM write spec is set, only the upsert form (when_matched_update_all without a filter + when_not_matched_insert_all, no by-source delete) is supported; call use_lsm_write(false) to use the standard merge_insert path".to_string(), + }); + } + + let mode = resolve_lsm_mode(&details)?; + Ok(LsmDispatch::Lsm(LsmPlan { + mode, + writer_config_defaults: details.writer_config_defaults, + })) +} + +/// Returns true if the builder requests the upsert-only shape the LSM write +/// path can honor. +fn is_upsert_only(params: &MergeInsertBuilder) -> bool { + params.when_matched_update_all + && params.when_matched_update_all_filt.is_none() + && params.when_not_matched_insert_all + && !params.when_not_matched_by_source_delete + && params.when_not_matched_by_source_delete_filt.is_none() +} + +/// Read the sharding mode from the MemWAL index details. +fn resolve_lsm_mode(details: &MemWalIndexDetails) -> Result { + let spec = details + .sharding_specs + .first() + .cloned() + .ok_or_else(|| Error::Runtime { + message: "merge_insert: MemWAL index has no sharding spec".to_string(), + })?; + let field = spec.fields.first().ok_or_else(|| Error::Runtime { + message: "merge_insert: MemWAL index has an empty sharding spec".to_string(), + })?; + match field.transform.as_deref() { + Some(BUCKET_TRANSFORM) => { + field + .parameters + .get(NUM_BUCKETS_PARAM) + .and_then(|s| s.parse::().ok()) + .filter(|n| *n > 0) + .ok_or_else(|| Error::Runtime { + message: "merge_insert: MemWAL bucket spec has a missing or invalid num_buckets parameter".to_string(), + })?; + Ok(LsmMode::Bucket { spec }) + } + Some(IDENTITY_TRANSFORM) => Ok(LsmMode::Identity { spec }), + Some(UNSHARDED_TRANSFORM) => Ok(LsmMode::Unsharded), + other => Err(Error::Runtime { + message: format!( + "merge_insert: MemWAL index has an unsupported sharding transform {:?}", + other + ), + }), + } +} + +// ============================================================================= +// LSM merge_insert execution +// ============================================================================= + +/// Execute a `merge_insert` through the MemWAL shard writer cache. +/// +/// The entire input is collected, schema-aligned, and shard-validated before +/// anything is written, then issued as a single atomic `ShardWriter::put` — so +/// a validation failure (e.g. input spanning shards) never leaves a partial +/// write behind. When `validate_single_shard` is set, every row is checked to +/// route to one shard; when disabled, only the first row of the whole input is. +#[allow(clippy::redundant_pub_crate)] +pub(crate) async fn execute_lsm_merge_insert( + table: &NativeTable, + plan: LsmPlan, + validate_single_shard: bool, + new_data: Box, +) -> Result { + let dataset = table.dataset.get().await?; + let target_schema: SchemaRef = Arc::new(ArrowSchema::from(dataset.schema())); + + // Collect, align and shard-validate the whole input before writing + // anything. `ShardWriter::put` is atomic over the batch vector, so any + // failure raised here leaves the MemWAL untouched. + let mut batches: Vec = Vec::new(); + let mut total_rows: u64 = 0; + + for batch in new_data { + let batch = batch.map_err(|e| Error::Arrow { source: e })?; + if batch.num_rows() == 0 { + continue; + } + let batch = align_batch_schema(batch, &target_schema)?; + total_rows += batch.num_rows() as u64; + batches.push(batch); + } + + // Empty input (or only empty batches): nothing to write. + let Some(shard_id) = resolve_input_shard( + &plan.mode, + dataset.schema(), + &batches, + validate_single_shard, + )? + else { + return Ok(lsm_merge_result(0)); + }; + + let config = shard_writer_config_from_defaults(&plan.writer_config_defaults); + let writer = table + .dataset + .shard_writer() + .writer_for_shard(dataset.as_ref(), shard_id, config) + .await?; + writer.put(batches).await?; + + Ok(lsm_merge_result(total_rows)) +} + +/// Resolve the target shard for a collected input. +fn resolve_input_shard( + mode: &LsmMode, + schema: &LanceSchema, + batches: &[RecordBatch], + validate_single_shard: bool, +) -> Result> { + let mut shard_id: Option = None; + for batch in batches { + if batch.num_rows() == 0 { + continue; + } + if !validate_single_shard && shard_id.is_some() { + continue; + } + let batch_shard = resolve_batch_shard(mode, schema, batch, validate_single_shard)?; + match shard_id { + Some(seen) if seen != batch_shard => { + return Err(Error::InvalidInput { + message: "merge_insert: input batches route to multiple shards; each merge_insert call must target a single shard".to_string(), + }); + } + _ => shard_id = Some(batch_shard), + } + } + Ok(shard_id) +} + +/// Compute the target shard id for a non-empty batch. When +/// `validate_single_shard` is set, every row is checked to route to the same +/// shard; otherwise only the first row is inspected. +fn resolve_batch_shard( + mode: &LsmMode, + schema: &LanceSchema, + batch: &RecordBatch, + validate_single_shard: bool, +) -> Result { + let routing_batch = if validate_single_shard { + batch.clone() + } else { + batch.slice(0, 1) + }; + match mode { + LsmMode::Unsharded => Ok(unsharded_shard_id()), + LsmMode::Bucket { spec } => { + let values = evaluate_lsm_shard_values(&routing_batch, spec, schema)?; + let buckets = values + .as_any() + .downcast_ref::() + .ok_or_else(|| Error::Runtime { + message: format!( + "merge_insert: MemWAL bucket evaluator returned {:?}; expected Int32", + values.data_type() + ), + })?; + let first = buckets.value(0); + if validate_single_shard { + for row in 1..routing_batch.num_rows() { + let bucket = buckets.value(row); + if bucket != first { + return Err(Error::InvalidInput { + message: format!( + "merge_insert: input row 0 hashes to bucket {} but row {} hashes to bucket {}; each merge_insert call must target a single bucket (pre-shard the input, or set validate_single_shard(false) to route by the first row only)", + first, row, bucket + ), + }); + } + } + } + Ok(bucket_shard_id(u32::try_from(first).map_err(|_| { + Error::Runtime { + message: format!( + "merge_insert: MemWAL bucket evaluator returned negative bucket {}", + first + ), + } + })?)) + } + LsmMode::Identity { spec } => { + let values = evaluate_lsm_shard_values(&routing_batch, spec, schema)?; + let first = encode_scalar(values.as_ref(), 0)?; + if validate_single_shard { + for row in 1..routing_batch.num_rows() { + if encode_scalar(values.as_ref(), row)? != first { + return Err(Error::InvalidInput { + message: "merge_insert: input rows have differing values for identity-sharding column; each merge_insert call must target a single shard (pre-shard the input, or set validate_single_shard(false) to route by the first row only)".to_string(), + }); + } + } + } + Ok(identity_shard_id(&first)) + } + } +} + +fn evaluate_lsm_shard_values( + batch: &RecordBatch, + spec: &ShardingSpec, + schema: &LanceSchema, +) -> Result { + let values = evaluate_sharding_spec(batch, spec, schema)?; + if values.num_columns() != 1 { + return Err(Error::Runtime { + message: format!( + "merge_insert: MemWAL sharding spec evaluated to {} fields; expected exactly one", + values.num_columns() + ), + }); + } + Ok(values.column(0).clone()) +} + +/// Encode one cell of an identity-sharding column to comparable bytes. +fn encode_scalar(array: &dyn Array, row: usize) -> Result> { + if array.is_null(row) { + return Err(Error::InvalidInput { + message: "merge_insert: identity sharding does not support null routing values" + .to_string(), + }); + } + Ok(match array.data_type() { + DataType::Int8 => array + .as_primitive::() + .value(row) + .to_le_bytes() + .to_vec(), + DataType::Int16 => array + .as_primitive::() + .value(row) + .to_le_bytes() + .to_vec(), + DataType::Int32 => array + .as_primitive::() + .value(row) + .to_le_bytes() + .to_vec(), + DataType::Int64 => array + .as_primitive::() + .value(row) + .to_le_bytes() + .to_vec(), + DataType::UInt8 => array + .as_primitive::() + .value(row) + .to_le_bytes() + .to_vec(), + DataType::UInt16 => array + .as_primitive::() + .value(row) + .to_le_bytes() + .to_vec(), + DataType::UInt32 => array + .as_primitive::() + .value(row) + .to_le_bytes() + .to_vec(), + DataType::UInt64 => array + .as_primitive::() + .value(row) + .to_le_bytes() + .to_vec(), + DataType::Utf8 => array.as_string::().value(row).as_bytes().to_vec(), + DataType::LargeUtf8 => array.as_string::().value(row).as_bytes().to_vec(), + DataType::Boolean => vec![u8::from(array.as_boolean().value(row))], + other => { + return Err(Error::InvalidInput { + message: format!( + "merge_insert: identity sharding does not support column dtype {:?}", + other + ), + }); + } + }) +} + +/// Deterministic shard id for a bucket index. +fn bucket_shard_id(bucket: u32) -> Uuid { + Uuid::new_v5(&SHARD_NAMESPACE, format!("bucket-{}", bucket).as_bytes()) +} + +/// Deterministic shard id for an identity value. +fn identity_shard_id(value: &[u8]) -> Uuid { + let mut name = b"identity-".to_vec(); + name.extend_from_slice(value); + Uuid::new_v5(&SHARD_NAMESPACE, &name) +} + +/// Deterministic shard id for the single unsharded shard. +fn unsharded_shard_id() -> Uuid { + Uuid::new_v5(&SHARD_NAMESPACE, b"unsharded") +} + +/// Build a [`ShardWriterConfig`] from the persisted `writer_config_defaults`. +/// +/// Unknown or unparseable keys are ignored; absent keys keep the +/// [`ShardWriterConfig`] default. The shard id is set by `mem_wal_writer`. +fn shard_writer_config_from_defaults(defaults: &HashMap) -> ShardWriterConfig { + let mut config = ShardWriterConfig::default().with_shard_spec_id(SHARDING_SPEC_ID); + let bool_of = |key: &str| defaults.get(key).and_then(|s| s.parse::().ok()); + let usize_of = |key: &str| defaults.get(key).and_then(|s| s.parse::().ok()); + let millis_of = |key: &str| { + defaults + .get(key) + .and_then(|s| s.parse::().ok()) + .map(Duration::from_millis) + }; + + if let Some(v) = bool_of("durable_write") { + config = config.with_durable_write(v); + } + if let Some(v) = bool_of("sync_indexed_write") { + config = config.with_sync_indexed_write(v); + } + if let Some(v) = usize_of("max_wal_buffer_size") { + config = config.with_max_wal_buffer_size(v); + } + if let Some(v) = usize_of("max_memtable_size") { + config = config.with_max_memtable_size(v); + } + if let Some(v) = usize_of("max_memtable_rows") { + config = config.with_max_memtable_rows(v); + } + if let Some(v) = usize_of("max_memtable_batches") { + config = config.with_max_memtable_batches(v); + } + if let Some(v) = usize_of("manifest_scan_batch_size") { + config = config.with_manifest_scan_batch_size(v); + } + if let Some(v) = usize_of("max_unflushed_memtable_bytes") { + config = config.with_max_unflushed_memtable_bytes(v); + } + if let Some(v) = millis_of("backpressure_log_interval_ms") { + config = config.with_backpressure_log_interval(v); + } + if let Some(v) = usize_of("async_index_buffer_rows") { + config = config.with_async_index_buffer_rows(v); + } + if let Some(v) = millis_of("async_index_interval_ms") { + config = config.with_async_index_interval(v); + } + if let Some(v) = bool_of("enable_memtable") { + config = config.with_enable_memtable(v); + } + if let Some(v) = millis_of("max_wal_flush_interval_ms") { + config = config.with_max_wal_flush_interval(v); + } + if let Some(v) = millis_of("stats_log_interval_ms") { + config = config.with_stats_log_interval(Some(v)); + } + config +} + +/// Re-attach the dataset's Arrow schema (including field metadata) to a +/// user-supplied input batch. The MemWAL `ShardWriter` checks batch schemas +/// against the dataset schema by exact equality, so input readers built +/// without the primary-key metadata must be rewrapped before being put. +/// +/// Columns are matched by name; column order in the input is irrelevant. +fn align_batch_schema(batch: RecordBatch, target: &SchemaRef) -> Result { + if batch.schema() == *target { + return Ok(batch); + } + let mut columns = Vec::with_capacity(target.fields().len()); + for field in target.fields() { + let column = batch + .column_by_name(field.name()) + .ok_or_else(|| Error::InvalidInput { + message: format!( + "merge_insert: input is missing column '{}' required by the table schema", + field.name() + ), + })?; + if column.data_type() != field.data_type() { + return Err(Error::InvalidInput { + message: format!( + "merge_insert: input column '{}' has dtype {:?}, expected {:?}", + field.name(), + column.data_type(), + field.data_type() + ), + }); + } + columns.push(column.clone()); + } + RecordBatch::try_new(target.clone(), columns).map_err(|e| Error::Arrow { source: e }) +} + +/// Build the [`MergeResult`] for an LSM-path `merge_insert`. +/// +/// The insert/update breakdown is not known until LSM compaction, so only the +/// total row count is reported. +fn lsm_merge_result(num_rows: u64) -> MergeResult { + MergeResult { + version: 0, + num_inserted_rows: 0, + num_updated_rows: 0, + num_deleted_rows: 0, + num_attempts: 0, + num_rows, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::{ArrayRef, BooleanArray, Int32Array, Int64Array, StringArray, UInt64Array}; + use arrow_schema::Field; + use lance_index::mem_wal::ShardingField; + + fn lance_schema(batch: &RecordBatch) -> LanceSchema { + LanceSchema::try_from(batch.schema().as_ref()).unwrap() + } + + fn single_field_spec(field: ShardingField) -> ShardingSpec { + ShardingSpec { + spec_id: SHARDING_SPEC_ID, + fields: vec![field], + } + } + + fn bucket_mode(source_id: i32, num_buckets: u32) -> LsmMode { + LsmMode::Bucket { + spec: single_field_spec(ShardingField { + field_id: "bucket".to_string(), + source_ids: vec![source_id], + transform: Some(BUCKET_TRANSFORM.to_string()), + expression: None, + result_type: "int32".to_string(), + parameters: HashMap::from([( + NUM_BUCKETS_PARAM.to_string(), + num_buckets.to_string(), + )]), + }), + } + } + + fn identity_mode(source_id: i32) -> LsmMode { + LsmMode::Identity { + spec: single_field_spec(ShardingField { + field_id: "identity".to_string(), + source_ids: vec![source_id], + transform: Some(IDENTITY_TRANSFORM.to_string()), + expression: None, + result_type: "utf8".to_string(), + parameters: HashMap::new(), + }), + } + } + + fn bucket_values(batch: &RecordBatch, num_buckets: u32) -> Vec { + let LsmMode::Bucket { spec } = bucket_mode(0, num_buckets) else { + unreachable!(); + }; + let values = evaluate_lsm_shard_values(batch, &spec, &lance_schema(batch)).unwrap(); + values.as_primitive::().values().to_vec() + } + + #[test] + fn bucket_assignments_are_pinned() { + let batch = RecordBatch::try_from_iter([( + "id", + Arc::new(StringArray::from(vec![Some("a"), Some("b"), None])) as ArrayRef, + )]) + .unwrap(); + assert_eq!(bucket_values(&batch, 8), vec![1, 5, 0]); + } + + #[test] + fn bucket_int32_uses_lance_evaluator() { + let batch = RecordBatch::try_from_iter([( + "id", + Arc::new(Int32Array::from(vec![Some(1), Some(2), None, Some(3)])) as ArrayRef, + )]) + .unwrap(); + assert_eq!(bucket_values(&batch, 8), vec![2, 7, 0, 1]); + } + + #[test] + fn bucket_accepts_lance_supported_scalar_types() { + let bool_batch = RecordBatch::try_from_iter([( + "id", + Arc::new(BooleanArray::from(vec![true])) as ArrayRef, + )]) + .unwrap(); + assert!( + resolve_batch_shard( + &bucket_mode(0, 8), + &lance_schema(&bool_batch), + &bool_batch, + true + ) + .is_ok() + ); + + let u64_batch = RecordBatch::try_from_iter([( + "id", + Arc::new(UInt64Array::from(vec![1_u64])) as ArrayRef, + )]) + .unwrap(); + assert!( + resolve_batch_shard( + &bucket_mode(0, 8), + &lance_schema(&u64_batch), + &u64_batch, + true + ) + .is_ok() + ); + } + + #[test] + fn shard_ids_are_deterministic_and_distinct() { + assert_eq!(bucket_shard_id(3), bucket_shard_id(3)); + assert_ne!(bucket_shard_id(3), bucket_shard_id(4)); + assert_ne!(bucket_shard_id(0), unsharded_shard_id()); + assert_eq!( + identity_shard_id(b"tenant-a"), + identity_shard_id(b"tenant-a") + ); + assert_ne!( + identity_shard_id(b"tenant-a"), + identity_shard_id(b"tenant-b") + ); + } + + #[test] + fn encode_scalar_distinguishes_values() { + let ints = Int64Array::from(vec![1, 2]); + assert_ne!( + encode_scalar(&ints, 0).unwrap(), + encode_scalar(&ints, 1).unwrap() + ); + let strs = StringArray::from(vec!["x", "y"]); + assert_ne!( + encode_scalar(&strs, 0).unwrap(), + encode_scalar(&strs, 1).unwrap() + ); + } + + #[test] + fn writer_config_from_defaults_parses_known_keys() { + let defaults = HashMap::from([ + ("durable_write".to_string(), "false".to_string()), + ("max_memtable_rows".to_string(), "4096".to_string()), + ("async_index_interval_ms".to_string(), "250".to_string()), + ("unknown_key".to_string(), "ignored".to_string()), + ]); + let config = shard_writer_config_from_defaults(&defaults); + assert!(!config.durable_write); + assert_eq!(config.max_memtable_rows, 4096); + assert_eq!(config.async_index_interval, Duration::from_millis(250)); + assert_eq!(config.shard_spec_id, SHARDING_SPEC_ID); + } + + #[test] + fn align_batch_schema_reorders_columns() { + let target: SchemaRef = Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("v", DataType::Int64, false), + ])); + let source = RecordBatch::try_new( + Arc::new(ArrowSchema::new(vec![ + Field::new("v", DataType::Int64, false), + Field::new("id", DataType::Int64, false), + ])), + vec![ + Arc::new(Int64Array::from(vec![10, 20])), + Arc::new(Int64Array::from(vec![1, 2])), + ], + ) + .unwrap(); + let aligned = align_batch_schema(source, &target).unwrap(); + assert_eq!(aligned.schema(), target); + assert_eq!( + aligned.column(0).as_primitive::().values(), + &[1, 2] + ); + } + + #[test] + fn align_batch_schema_rejects_missing_column() { + let target: SchemaRef = Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("v", DataType::Int64, false), + ])); + let source = RecordBatch::try_new( + Arc::new(ArrowSchema::new(vec![Field::new( + "id", + DataType::Int64, + false, + )])), + vec![Arc::new(Int64Array::from(vec![1, 2]))], + ) + .unwrap(); + assert!(matches!( + align_batch_schema(source, &target), + Err(Error::InvalidInput { .. }) + )); + } + + fn utf8_batch(col: &str, values: Vec<&str>) -> RecordBatch { + RecordBatch::try_new( + Arc::new(ArrowSchema::new(vec![Field::new( + col, + DataType::Utf8, + true, + )])), + vec![Arc::new(StringArray::from(values))], + ) + .unwrap() + } + + #[test] + fn resolve_batch_shard_bucket_same_bucket() { + let mode = bucket_mode(0, 8); + let batch = utf8_batch("id", vec!["a", "a"]); + assert_eq!( + resolve_batch_shard(&mode, &lance_schema(&batch), &batch, true).unwrap(), + bucket_shard_id(1) + ); + } + + #[test] + fn resolve_batch_shard_bucket_rejects_mixed() { + let mode = bucket_mode(0, 8); + let batch = utf8_batch("id", vec!["a", "b"]); + // validate_single_shard rejects a batch that spans buckets. + assert!(matches!( + resolve_batch_shard(&mode, &lance_schema(&batch), &batch, true), + Err(Error::InvalidInput { .. }) + )); + // With validation off, only row 0 is inspected, so it is accepted. + assert_eq!( + resolve_batch_shard(&mode, &lance_schema(&batch), &batch, false).unwrap(), + bucket_shard_id(1) + ); + } + + #[test] + fn resolve_batch_shard_bucket_routes_nulls_to_zero() { + let mode = bucket_mode(0, 8); + let batch = RecordBatch::try_new( + Arc::new(ArrowSchema::new(vec![Field::new( + "id", + DataType::Int64, + true, + )])), + vec![Arc::new(Int64Array::from(vec![None, None]))], + ) + .unwrap(); + assert_eq!( + resolve_batch_shard(&mode, &lance_schema(&batch), &batch, true).unwrap(), + bucket_shard_id(0) + ); + } + + #[test] + fn resolve_batch_shard_rejects_missing_routing_column() { + let mode = bucket_mode(0, 8); + let schema = LanceSchema::try_from(&ArrowSchema::new(vec![Field::new( + "id", + DataType::Utf8, + true, + )])) + .unwrap(); + let batch = utf8_batch("other", vec!["a"]); + assert!(resolve_batch_shard(&mode, &schema, &batch, true).is_err()); + } + + #[test] + fn resolve_batch_shard_identity_groups_by_value() { + let mode = identity_mode(0); + let same = utf8_batch("region", vec!["us", "us"]); + let mixed = utf8_batch("region", vec!["us", "eu"]); + assert!(resolve_batch_shard(&mode, &lance_schema(&same), &same, true).is_ok()); + assert!(matches!( + resolve_batch_shard(&mode, &lance_schema(&mixed), &mixed, true), + Err(Error::InvalidInput { .. }) + )); + // With validation off, the mixed batch is accepted (row 0 only). + assert!(resolve_batch_shard(&mode, &lance_schema(&mixed), &mixed, false).is_ok()); + } + + #[test] + fn resolve_input_shard_validation_off_only_uses_first_input_row() { + let mode = bucket_mode(0, 8); + let first = utf8_batch("id", vec!["a"]); + let second = utf8_batch("id", vec!["b"]); + let schema = lance_schema(&first); + assert_eq!( + resolve_input_shard(&mode, &schema, &[first.clone(), second.clone()], false).unwrap(), + Some(bucket_shard_id(1)) + ); + assert!(matches!( + resolve_input_shard(&mode, &schema, &[first, second], true), + Err(Error::InvalidInput { .. }) + )); + } + + #[test] + fn resolve_batch_shard_unsharded_is_constant() { + let batch = utf8_batch("anything", vec!["a", "b", "c"]); + assert_eq!( + resolve_batch_shard(&LsmMode::Unsharded, &lance_schema(&batch), &batch, true).unwrap(), + unsharded_shard_id() + ); + } +}