mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-24 13:59:58 +00:00
Compare commits
42 Commits
python-v0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a27c5cf12b | ||
|
|
f4dea72cc5 | ||
|
|
f76c4a5ce1 | ||
|
|
164ce397c2 | ||
|
|
445a312667 | ||
|
|
92d845fa72 | ||
|
|
397813f6a4 | ||
|
|
50c30c5d34 | ||
|
|
c9f248b058 | ||
|
|
0cb6da6b7e | ||
|
|
aec8332eb5 | ||
|
|
46061070e6 | ||
|
|
dae8334d0b | ||
|
|
8c81968b59 | ||
|
|
16cf2990f3 | ||
|
|
0a0f667bbd | ||
|
|
03753fd84b | ||
|
|
55cceaa309 | ||
|
|
c3797eb834 | ||
|
|
c0d0f38494 | ||
|
|
6a8ab78d0a | ||
|
|
27404c8623 | ||
|
|
f181c7e77f | ||
|
|
e70fd4fecc | ||
|
|
ac0068b80e | ||
|
|
ebac960571 | ||
|
|
59b57055e7 | ||
|
|
591c8de8fc | ||
|
|
f835ff310f | ||
|
|
cf8c2edaf4 | ||
|
|
61a714a459 | ||
|
|
5ddd84cec0 | ||
|
|
27ef0bb0a2 | ||
|
|
25402ba6ec | ||
|
|
37c359ed40 | ||
|
|
06cdf00987 | ||
|
|
144b7f5d54 | ||
|
|
edc9b9adec | ||
|
|
d11b2a6975 | ||
|
|
980aa70e2d | ||
|
|
d83e5a0208 | ||
|
|
16a6b9ce8f |
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.14.1-beta.3"
|
||||
current_version = "0.14.1"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
4
.github/workflows/make-release-commit.yml
vendored
4
.github/workflows/make-release-commit.yml
vendored
@@ -97,3 +97,7 @@ jobs:
|
||||
if: ${{ !inputs.dry_run && inputs.other }}
|
||||
with:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
- uses: ./.github/workflows/update_package_lock_nodejs
|
||||
if: ${{ !inputs.dry_run && inputs.other }}
|
||||
with:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
200
.github/workflows/npm-publish.yml
vendored
200
.github/workflows/npm-publish.yml
vendored
@@ -159,7 +159,7 @@ jobs:
|
||||
- name: Install common dependencies
|
||||
run: |
|
||||
apk add protobuf-dev curl clang mold grep npm bash
|
||||
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y --default-toolchain 1.80.0
|
||||
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y
|
||||
echo "source $HOME/.cargo/env" >> saved_env
|
||||
echo "export CC=clang" >> saved_env
|
||||
echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=-crt-static,+avx2,+fma,+f16c -Clinker=clang -Clink-arg=-fuse-ld=mold'" >> saved_env
|
||||
@@ -167,7 +167,7 @@ jobs:
|
||||
if: ${{ matrix.config.arch == 'aarch64' }}
|
||||
run: |
|
||||
source "$HOME/.cargo/env"
|
||||
rustup target add aarch64-unknown-linux-musl --toolchain 1.80.0
|
||||
rustup target add aarch64-unknown-linux-musl
|
||||
crt=$(realpath $(dirname $(rustup which rustc))/../lib/rustlib/aarch64-unknown-linux-musl/lib/self-contained)
|
||||
sysroot_lib=/usr/aarch64-unknown-linux-musl/usr/lib
|
||||
apk_url=https://dl-cdn.alpinelinux.org/alpine/latest-stable/main/aarch64/
|
||||
@@ -262,7 +262,7 @@ jobs:
|
||||
- name: Install common dependencies
|
||||
run: |
|
||||
apk add protobuf-dev curl clang mold grep npm bash openssl-dev openssl-libs-static
|
||||
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y --default-toolchain 1.80.0
|
||||
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y
|
||||
echo "source $HOME/.cargo/env" >> saved_env
|
||||
echo "export CC=clang" >> saved_env
|
||||
echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=-crt-static,+avx2,+fma,+f16c -Clinker=clang -Clink-arg=-fuse-ld=mold'" >> saved_env
|
||||
@@ -272,7 +272,7 @@ jobs:
|
||||
if: ${{ matrix.config.arch == 'aarch64' }}
|
||||
run: |
|
||||
source "$HOME/.cargo/env"
|
||||
rustup target add aarch64-unknown-linux-musl --toolchain 1.80.0
|
||||
rustup target add aarch64-unknown-linux-musl
|
||||
crt=$(realpath $(dirname $(rustup which rustc))/../lib/rustlib/aarch64-unknown-linux-musl/lib/self-contained)
|
||||
sysroot_lib=/usr/aarch64-unknown-linux-musl/usr/lib
|
||||
apk_url=https://dl-cdn.alpinelinux.org/alpine/latest-stable/main/aarch64/
|
||||
@@ -334,50 +334,51 @@ jobs:
|
||||
path: |
|
||||
node/dist/lancedb-vectordb-win32*.tgz
|
||||
|
||||
node-windows-arm64:
|
||||
name: vectordb ${{ matrix.config.arch }}-pc-windows-msvc
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
runs-on: ubuntu-latest
|
||||
container: alpine:edge
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
# - arch: x86_64
|
||||
- arch: aarch64
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apk add protobuf-dev curl clang lld llvm19 grep npm bash msitools sed
|
||||
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y --default-toolchain 1.80.0
|
||||
echo "source $HOME/.cargo/env" >> saved_env
|
||||
echo "export CC=clang" >> saved_env
|
||||
echo "export AR=llvm-ar" >> saved_env
|
||||
source "$HOME/.cargo/env"
|
||||
rustup target add ${{ matrix.config.arch }}-pc-windows-msvc --toolchain 1.80.0
|
||||
(mkdir -p sysroot && cd sysroot && sh ../ci/sysroot-${{ matrix.config.arch }}-pc-windows-msvc.sh)
|
||||
echo "export C_INCLUDE_PATH=/usr/${{ matrix.config.arch }}-pc-windows-msvc/usr/include" >> saved_env
|
||||
echo "export CARGO_BUILD_TARGET=${{ matrix.config.arch }}-pc-windows-msvc" >> saved_env
|
||||
- name: Configure x86_64 build
|
||||
if: ${{ matrix.config.arch == 'x86_64' }}
|
||||
run: |
|
||||
echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=+crt-static,+avx2,+fma,+f16c -Clinker=lld -Clink-arg=/LIBPATH:/usr/x86_64-pc-windows-msvc/usr/lib'" >> saved_env
|
||||
- name: Configure aarch64 build
|
||||
if: ${{ matrix.config.arch == 'aarch64' }}
|
||||
run: |
|
||||
echo "export RUSTFLAGS='-Ctarget-feature=+crt-static,+neon,+fp16,+fhm,+dotprod -Clinker=lld -Clink-arg=/LIBPATH:/usr/aarch64-pc-windows-msvc/usr/lib -Clink-arg=arm64rt.lib'" >> saved_env
|
||||
- name: Build Windows Artifacts
|
||||
run: |
|
||||
source ./saved_env
|
||||
bash ci/manylinux_node/build_vectordb.sh ${{ matrix.config.arch }} ${{ matrix.config.arch }}-pc-windows-msvc
|
||||
- name: Upload Windows Artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: node-native-windows-${{ matrix.config.arch }}
|
||||
path: |
|
||||
node/dist/lancedb-vectordb-win32*.tgz
|
||||
# TODO: https://github.com/lancedb/lancedb/issues/1975
|
||||
# node-windows-arm64:
|
||||
# name: vectordb ${{ matrix.config.arch }}-pc-windows-msvc
|
||||
# # if: startsWith(github.ref, 'refs/tags/v')
|
||||
# runs-on: ubuntu-latest
|
||||
# container: alpine:edge
|
||||
# strategy:
|
||||
# fail-fast: false
|
||||
# matrix:
|
||||
# config:
|
||||
# # - arch: x86_64
|
||||
# - arch: aarch64
|
||||
# steps:
|
||||
# - name: Checkout
|
||||
# uses: actions/checkout@v4
|
||||
# - name: Install dependencies
|
||||
# run: |
|
||||
# apk add protobuf-dev curl clang lld llvm19 grep npm bash msitools sed
|
||||
# curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y
|
||||
# echo "source $HOME/.cargo/env" >> saved_env
|
||||
# echo "export CC=clang" >> saved_env
|
||||
# echo "export AR=llvm-ar" >> saved_env
|
||||
# source "$HOME/.cargo/env"
|
||||
# rustup target add ${{ matrix.config.arch }}-pc-windows-msvc
|
||||
# (mkdir -p sysroot && cd sysroot && sh ../ci/sysroot-${{ matrix.config.arch }}-pc-windows-msvc.sh)
|
||||
# echo "export C_INCLUDE_PATH=/usr/${{ matrix.config.arch }}-pc-windows-msvc/usr/include" >> saved_env
|
||||
# echo "export CARGO_BUILD_TARGET=${{ matrix.config.arch }}-pc-windows-msvc" >> saved_env
|
||||
# - name: Configure x86_64 build
|
||||
# if: ${{ matrix.config.arch == 'x86_64' }}
|
||||
# run: |
|
||||
# echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=+crt-static,+avx2,+fma,+f16c -Clinker=lld -Clink-arg=/LIBPATH:/usr/x86_64-pc-windows-msvc/usr/lib'" >> saved_env
|
||||
# - name: Configure aarch64 build
|
||||
# if: ${{ matrix.config.arch == 'aarch64' }}
|
||||
# run: |
|
||||
# echo "export RUSTFLAGS='-Ctarget-feature=+crt-static,+neon,+fp16,+fhm,+dotprod -Clinker=lld -Clink-arg=/LIBPATH:/usr/aarch64-pc-windows-msvc/usr/lib -Clink-arg=arm64rt.lib'" >> saved_env
|
||||
# - name: Build Windows Artifacts
|
||||
# run: |
|
||||
# source ./saved_env
|
||||
# bash ci/manylinux_node/build_vectordb.sh ${{ matrix.config.arch }} ${{ matrix.config.arch }}-pc-windows-msvc
|
||||
# - name: Upload Windows Artifacts
|
||||
# uses: actions/upload-artifact@v4
|
||||
# with:
|
||||
# name: node-native-windows-${{ matrix.config.arch }}
|
||||
# path: |
|
||||
# node/dist/lancedb-vectordb-win32*.tgz
|
||||
|
||||
nodejs-windows:
|
||||
name: lancedb ${{ matrix.target }}
|
||||
@@ -413,57 +414,58 @@ jobs:
|
||||
path: |
|
||||
nodejs/dist/*.node
|
||||
|
||||
nodejs-windows-arm64:
|
||||
name: lancedb ${{ matrix.config.arch }}-pc-windows-msvc
|
||||
# Only runs on tags that matches the make-release action
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
runs-on: ubuntu-latest
|
||||
container: alpine:edge
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
# - arch: x86_64
|
||||
- arch: aarch64
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apk add protobuf-dev curl clang lld llvm19 grep npm bash msitools sed
|
||||
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y --default-toolchain 1.80.0
|
||||
echo "source $HOME/.cargo/env" >> saved_env
|
||||
echo "export CC=clang" >> saved_env
|
||||
echo "export AR=llvm-ar" >> saved_env
|
||||
source "$HOME/.cargo/env"
|
||||
rustup target add ${{ matrix.config.arch }}-pc-windows-msvc --toolchain 1.80.0
|
||||
(mkdir -p sysroot && cd sysroot && sh ../ci/sysroot-${{ matrix.config.arch }}-pc-windows-msvc.sh)
|
||||
echo "export C_INCLUDE_PATH=/usr/${{ matrix.config.arch }}-pc-windows-msvc/usr/include" >> saved_env
|
||||
echo "export CARGO_BUILD_TARGET=${{ matrix.config.arch }}-pc-windows-msvc" >> saved_env
|
||||
printf '#!/bin/sh\ncargo "$@"' > $HOME/.cargo/bin/cargo-xwin
|
||||
chmod u+x $HOME/.cargo/bin/cargo-xwin
|
||||
- name: Configure x86_64 build
|
||||
if: ${{ matrix.config.arch == 'x86_64' }}
|
||||
run: |
|
||||
echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=+crt-static,+avx2,+fma,+f16c -Clinker=lld -Clink-arg=/LIBPATH:/usr/x86_64-pc-windows-msvc/usr/lib'" >> saved_env
|
||||
- name: Configure aarch64 build
|
||||
if: ${{ matrix.config.arch == 'aarch64' }}
|
||||
run: |
|
||||
echo "export RUSTFLAGS='-Ctarget-feature=+crt-static,+neon,+fp16,+fhm,+dotprod -Clinker=lld -Clink-arg=/LIBPATH:/usr/aarch64-pc-windows-msvc/usr/lib -Clink-arg=arm64rt.lib'" >> saved_env
|
||||
- name: Build Windows Artifacts
|
||||
run: |
|
||||
source ./saved_env
|
||||
bash ci/manylinux_node/build_lancedb.sh ${{ matrix.config.arch }}
|
||||
- name: Upload Windows Artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: nodejs-native-windows-${{ matrix.config.arch }}
|
||||
path: |
|
||||
nodejs/dist/*.node
|
||||
# TODO: https://github.com/lancedb/lancedb/issues/1975
|
||||
# nodejs-windows-arm64:
|
||||
# name: lancedb ${{ matrix.config.arch }}-pc-windows-msvc
|
||||
# # Only runs on tags that matches the make-release action
|
||||
# # if: startsWith(github.ref, 'refs/tags/v')
|
||||
# runs-on: ubuntu-latest
|
||||
# container: alpine:edge
|
||||
# strategy:
|
||||
# fail-fast: false
|
||||
# matrix:
|
||||
# config:
|
||||
# # - arch: x86_64
|
||||
# - arch: aarch64
|
||||
# steps:
|
||||
# - name: Checkout
|
||||
# uses: actions/checkout@v4
|
||||
# - name: Install dependencies
|
||||
# run: |
|
||||
# apk add protobuf-dev curl clang lld llvm19 grep npm bash msitools sed
|
||||
# curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y
|
||||
# echo "source $HOME/.cargo/env" >> saved_env
|
||||
# echo "export CC=clang" >> saved_env
|
||||
# echo "export AR=llvm-ar" >> saved_env
|
||||
# source "$HOME/.cargo/env"
|
||||
# rustup target add ${{ matrix.config.arch }}-pc-windows-msvc
|
||||
# (mkdir -p sysroot && cd sysroot && sh ../ci/sysroot-${{ matrix.config.arch }}-pc-windows-msvc.sh)
|
||||
# echo "export C_INCLUDE_PATH=/usr/${{ matrix.config.arch }}-pc-windows-msvc/usr/include" >> saved_env
|
||||
# echo "export CARGO_BUILD_TARGET=${{ matrix.config.arch }}-pc-windows-msvc" >> saved_env
|
||||
# printf '#!/bin/sh\ncargo "$@"' > $HOME/.cargo/bin/cargo-xwin
|
||||
# chmod u+x $HOME/.cargo/bin/cargo-xwin
|
||||
# - name: Configure x86_64 build
|
||||
# if: ${{ matrix.config.arch == 'x86_64' }}
|
||||
# run: |
|
||||
# echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=+crt-static,+avx2,+fma,+f16c -Clinker=lld -Clink-arg=/LIBPATH:/usr/x86_64-pc-windows-msvc/usr/lib'" >> saved_env
|
||||
# - name: Configure aarch64 build
|
||||
# if: ${{ matrix.config.arch == 'aarch64' }}
|
||||
# run: |
|
||||
# echo "export RUSTFLAGS='-Ctarget-feature=+crt-static,+neon,+fp16,+fhm,+dotprod -Clinker=lld -Clink-arg=/LIBPATH:/usr/aarch64-pc-windows-msvc/usr/lib -Clink-arg=arm64rt.lib'" >> saved_env
|
||||
# - name: Build Windows Artifacts
|
||||
# run: |
|
||||
# source ./saved_env
|
||||
# bash ci/manylinux_node/build_lancedb.sh ${{ matrix.config.arch }}
|
||||
# - name: Upload Windows Artifacts
|
||||
# uses: actions/upload-artifact@v4
|
||||
# with:
|
||||
# name: nodejs-native-windows-${{ matrix.config.arch }}
|
||||
# path: |
|
||||
# nodejs/dist/*.node
|
||||
|
||||
release:
|
||||
name: vectordb NPM Publish
|
||||
needs: [node, node-macos, node-linux-gnu, node-linux-musl, node-windows, node-windows-arm64]
|
||||
needs: [node, node-macos, node-linux-gnu, node-linux-musl, node-windows]
|
||||
runs-on: ubuntu-latest
|
||||
# Only runs on tags that matches the make-release action
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
@@ -503,7 +505,7 @@ jobs:
|
||||
|
||||
release-nodejs:
|
||||
name: lancedb NPM Publish
|
||||
needs: [nodejs-macos, nodejs-linux-gnu, nodejs-linux-musl, nodejs-windows, nodejs-windows-arm64]
|
||||
needs: [nodejs-macos, nodejs-linux-gnu, nodejs-linux-musl, nodejs-windows]
|
||||
runs-on: ubuntu-latest
|
||||
# Only runs on tags that matches the make-release action
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
@@ -571,7 +573,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: main
|
||||
persist-credentials: false
|
||||
token: ${{ secrets.LANCEDB_RELEASE_TOKEN }}
|
||||
fetch-depth: 0
|
||||
lfs: true
|
||||
- uses: ./.github/workflows/update_package_lock
|
||||
@@ -589,7 +591,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: main
|
||||
persist-credentials: false
|
||||
token: ${{ secrets.LANCEDB_RELEASE_TOKEN }}
|
||||
fetch-depth: 0
|
||||
lfs: true
|
||||
- uses: ./.github/workflows/update_package_lock_nodejs
|
||||
|
||||
4
.github/workflows/python.yml
vendored
4
.github/workflows/python.yml
vendored
@@ -30,10 +30,10 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
python-version: "3.12"
|
||||
- name: Install ruff
|
||||
run: |
|
||||
pip install ruff==0.5.4
|
||||
pip install ruff==0.8.4
|
||||
- name: Format check
|
||||
run: ruff format --check .
|
||||
- name: Lint
|
||||
|
||||
40
.github/workflows/rust.yml
vendored
40
.github/workflows/rust.yml
vendored
@@ -185,7 +185,7 @@ jobs:
|
||||
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\Llvm\x64\bin"
|
||||
|
||||
# Add MSVC runtime libraries to LIB
|
||||
$env:LIB = "C:\BuildTools\VC\Tools\MSVC\$latestVersion\lib\arm64;" +
|
||||
$env:LIB = "C:\BuildTools\VC\Tools\MSVC\$latestVersion\lib\arm64;" +
|
||||
"C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\um\arm64;" +
|
||||
"C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\ucrt\arm64"
|
||||
Add-Content $env:GITHUB_ENV "LIB=$env:LIB"
|
||||
@@ -238,3 +238,41 @@ jobs:
|
||||
$env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT
|
||||
cargo build --target aarch64-pc-windows-msvc
|
||||
cargo test --target aarch64-pc-windows-msvc
|
||||
|
||||
msrv:
|
||||
# Check the minimum supported Rust version
|
||||
name: MSRV Check - Rust v${{ matrix.msrv }}
|
||||
runs-on: ubuntu-24.04
|
||||
strategy:
|
||||
matrix:
|
||||
msrv: ["1.78.0"] # This should match up with rust-version in Cargo.toml
|
||||
env:
|
||||
# Need up-to-date compilers for kernels
|
||||
CC: clang-18
|
||||
CXX: clang++-18
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: true
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt update
|
||||
sudo apt install -y protobuf-compiler libssl-dev
|
||||
- name: Install ${{ matrix.msrv }}
|
||||
uses: dtolnay/rust-toolchain@master
|
||||
with:
|
||||
toolchain: ${{ matrix.msrv }}
|
||||
- name: Downgrade dependencies
|
||||
# These packages have newer requirements for MSRV
|
||||
run: |
|
||||
cargo update -p aws-sdk-bedrockruntime --precise 1.64.0
|
||||
cargo update -p aws-sdk-dynamodb --precise 1.55.0
|
||||
cargo update -p aws-config --precise 1.5.10
|
||||
cargo update -p aws-sdk-kms --precise 1.51.0
|
||||
cargo update -p aws-sdk-s3 --precise 1.65.0
|
||||
cargo update -p aws-sdk-sso --precise 1.50.0
|
||||
cargo update -p aws-sdk-ssooidc --precise 1.51.0
|
||||
cargo update -p aws-sdk-sts --precise 1.51.0
|
||||
cargo update -p home --precise 0.5.9
|
||||
- name: cargo +${{ matrix.msrv }} check
|
||||
run: cargo check --workspace --tests --benches --all-features
|
||||
|
||||
4
.github/workflows/upload_wheel/action.yml
vendored
4
.github/workflows/upload_wheel/action.yml
vendored
@@ -22,7 +22,7 @@ runs:
|
||||
shell: bash
|
||||
id: choose_repo
|
||||
run: |
|
||||
if [ ${{ github.ref }} == "*beta*" ]; then
|
||||
if [[ ${{ github.ref }} == *beta* ]]; then
|
||||
echo "repo=fury" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "repo=pypi" >> $GITHUB_OUTPUT
|
||||
@@ -33,7 +33,7 @@ runs:
|
||||
FURY_TOKEN: ${{ inputs.fury_token }}
|
||||
PYPI_TOKEN: ${{ inputs.pypi_token }}
|
||||
run: |
|
||||
if [ ${{ steps.choose_repo.outputs.repo }} == "fury" ]; then
|
||||
if [[ ${{ steps.choose_repo.outputs.repo }} == fury ]]; then
|
||||
WHEEL=$(ls target/wheels/lancedb-*.whl 2> /dev/null | head -n 1)
|
||||
echo "Uploading $WHEEL to Fury"
|
||||
curl -f -F package=@$WHEEL https://$FURY_TOKEN@push.fury.io/lancedb/
|
||||
|
||||
20
Cargo.toml
20
Cargo.toml
@@ -18,19 +18,19 @@ repository = "https://github.com/lancedb/lancedb"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
keywords = ["lancedb", "lance", "database", "vector", "search"]
|
||||
categories = ["database-implementations"]
|
||||
rust-version = "1.80.0" # TODO: lower this once we upgrade Lance again.
|
||||
rust-version = "1.78.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.21.0", "features" = [
|
||||
lance = { "version" = "=0.21.1", "features" = [
|
||||
"dynamodb",
|
||||
], git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.3" }
|
||||
lance-io = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.3" }
|
||||
lance-index = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.3" }
|
||||
lance-linalg = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.3" }
|
||||
lance-table = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.3" }
|
||||
lance-testing = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.3" }
|
||||
lance-datafusion = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.3" }
|
||||
lance-encoding = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.3" }
|
||||
], git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||
lance-io = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||
lance-index = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||
lance-linalg = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||
lance-table = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||
lance-testing = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||
lance-datafusion = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||
lance-encoding = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "53.2", optional = false }
|
||||
arrow-array = "53.2"
|
||||
|
||||
@@ -62,6 +62,7 @@ plugins:
|
||||
# for cross references
|
||||
- https://arrow.apache.org/docs/objects.inv
|
||||
- https://pandas.pydata.org/docs/objects.inv
|
||||
- https://lancedb.github.io/lance/objects.inv
|
||||
- mkdocs-jupyter
|
||||
- render_swagger:
|
||||
allow_arbitrary_locations: true
|
||||
|
||||
@@ -141,14 +141,6 @@ recommend switching to stable releases.
|
||||
--8<-- "python/python/tests/docs/test_basic.py:connect_async"
|
||||
```
|
||||
|
||||
!!! note "Asynchronous Python API"
|
||||
|
||||
The asynchronous Python API is new and has some slight differences compared
|
||||
to the synchronous API. Feel free to start using the asynchronous version.
|
||||
Once all features have migrated we will start to move the synchronous API to
|
||||
use the same syntax as the asynchronous API. To help with this migration we
|
||||
have created a [migration guide](migration.md) detailing the differences.
|
||||
|
||||
=== "Typescript[^1]"
|
||||
|
||||
=== "@lancedb/lancedb"
|
||||
|
||||
@@ -50,7 +50,7 @@ Consider that we have a LanceDB table named `my_table`, whose string column `tex
|
||||
});
|
||||
|
||||
await tbl
|
||||
.search("puppy", queryType="fts")
|
||||
.search("puppy", "fts")
|
||||
.select(["text"])
|
||||
.limit(10)
|
||||
.toArray();
|
||||
|
||||
@@ -1,81 +1,14 @@
|
||||
# Rust-backed Client Migration Guide
|
||||
|
||||
In an effort to ensure all clients have the same set of capabilities we have begun migrating the
|
||||
python and node clients onto a common Rust base library. In python, this new client is part of
|
||||
the same lancedb package, exposed as an asynchronous client. Once the asynchronous client has
|
||||
reached full functionality we will begin migrating the synchronous library to be a thin wrapper
|
||||
around the asynchronous client.
|
||||
In an effort to ensure all clients have the same set of capabilities we have
|
||||
migrated the Python and Node clients onto a common Rust base library. In Python,
|
||||
both the synchronous and asynchronous clients are based on this implementation.
|
||||
In Node, the new client is available as `@lancedb/lancedb`, which replaces
|
||||
the existing `vectordb` package.
|
||||
|
||||
This guide describes the differences between the two APIs and will hopefully assist users
|
||||
This guide describes the differences between the two Node APIs and will hopefully assist users
|
||||
that would like to migrate to the new API.
|
||||
|
||||
## Python
|
||||
### Closeable Connections
|
||||
|
||||
The Connection now has a `close` method. You can call this when
|
||||
you are done with the connection to eagerly free resources. Currently
|
||||
this is limited to freeing/closing the HTTP connection for remote
|
||||
connections. In the future we may add caching or other resources to
|
||||
native connections so this is probably a good practice even if you
|
||||
aren't using remote connections.
|
||||
|
||||
In addition, the connection can be used as a context manager which may
|
||||
be a more convenient way to ensure the connection is closed.
|
||||
|
||||
```python
|
||||
import lancedb
|
||||
|
||||
async def my_async_fn():
|
||||
with await lancedb.connect_async("my_uri") as db:
|
||||
print(await db.table_names())
|
||||
```
|
||||
|
||||
It is not mandatory to call the `close` method. If you do not call it
|
||||
then the connection will be closed when the object is garbage collected.
|
||||
|
||||
### Closeable Table
|
||||
|
||||
The Table now also has a `close` method, similar to the connection. This
|
||||
can be used to eagerly free the cache used by a Table object. Similar to
|
||||
the connection, it can be used as a context manager and it is not mandatory
|
||||
to call the `close` method.
|
||||
|
||||
#### Changes to Table APIs
|
||||
|
||||
- Previously `Table.schema` was a property. Now it is an async method.
|
||||
- The method `Table.__len__` was removed and `len(table)` will no longer
|
||||
work. Use `Table.count_rows` instead.
|
||||
|
||||
#### Creating Indices
|
||||
|
||||
The `Table.create_index` method is now used for creating both vector indices
|
||||
and scalar indices. It currently requires a column name to be specified (the
|
||||
column to index). Vector index defaults are now smarter and scale better with
|
||||
the size of the data.
|
||||
|
||||
To specify index configuration details you will need to specify which kind of
|
||||
index you are using.
|
||||
|
||||
#### Querying
|
||||
|
||||
The `Table.search` method has been renamed to `AsyncTable.vector_search` for
|
||||
clarity.
|
||||
|
||||
### Features not yet supported
|
||||
|
||||
The following features are not yet supported by the asynchronous API. However,
|
||||
we plan to support them soon.
|
||||
|
||||
- You cannot specify an embedding function when creating or opening a table.
|
||||
You must calculate embeddings yourself if using the asynchronous API
|
||||
- The merge insert operation is not supported in the asynchronous API
|
||||
- Cleanup / compact / optimize indices are not supported in the asynchronous API
|
||||
- add / alter columns is not supported in the asynchronous API
|
||||
- The asynchronous API does not yet support any full text search or reranking
|
||||
search
|
||||
- Remote connections to LanceDb Cloud are not yet supported.
|
||||
- The method Table.head is not yet supported.
|
||||
|
||||
## TypeScript/JavaScript
|
||||
|
||||
For JS/TS users, we offer a brand new SDK [@lancedb/lancedb](https://www.npmjs.com/package/@lancedb/lancedb)
|
||||
|
||||
@@ -47,6 +47,8 @@ is also an [asynchronous API client](#connections-asynchronous).
|
||||
|
||||
::: lancedb.embeddings.registry.EmbeddingFunctionRegistry
|
||||
|
||||
::: lancedb.embeddings.base.EmbeddingFunctionConfig
|
||||
|
||||
::: lancedb.embeddings.base.EmbeddingFunction
|
||||
|
||||
::: lancedb.embeddings.base.TextEmbeddingFunction
|
||||
@@ -127,8 +129,16 @@ lists the indices that LanceDb supports.
|
||||
|
||||
::: lancedb.index.LabelList
|
||||
|
||||
::: lancedb.index.FTS
|
||||
|
||||
::: lancedb.index.IvfPq
|
||||
|
||||
::: lancedb.index.HnswPq
|
||||
|
||||
::: lancedb.index.HnswSq
|
||||
|
||||
::: lancedb.index.IvfFlat
|
||||
|
||||
## Querying (Asynchronous)
|
||||
|
||||
Queries allow you to return data from your database. Basic queries can be
|
||||
|
||||
@@ -17,4 +17,8 @@ pip install lancedb
|
||||
## Table
|
||||
|
||||
::: lancedb.remote.table.RemoteTable
|
||||
|
||||
options:
|
||||
filters:
|
||||
- "!cleanup_old_versions"
|
||||
- "!compact_files"
|
||||
- "!optimize"
|
||||
|
||||
@@ -13,11 +13,15 @@ A vector search finds the approximate or exact nearest neighbors to a given quer
|
||||
Distance metrics are a measure of the similarity between a pair of vectors.
|
||||
Currently, LanceDB supports the following metrics:
|
||||
|
||||
| Metric | Description |
|
||||
| -------- | --------------------------------------------------------------------------- |
|
||||
| `l2` | [Euclidean / L2 distance](https://en.wikipedia.org/wiki/Euclidean_distance) |
|
||||
| `cosine` | [Cosine Similarity](https://en.wikipedia.org/wiki/Cosine_similarity) |
|
||||
| `dot` | [Dot Production](https://en.wikipedia.org/wiki/Dot_product) |
|
||||
| Metric | Description |
|
||||
| --------- | --------------------------------------------------------------------------- |
|
||||
| `l2` | [Euclidean / L2 distance](https://en.wikipedia.org/wiki/Euclidean_distance) |
|
||||
| `cosine` | [Cosine Similarity](https://en.wikipedia.org/wiki/Cosine_similarity) |
|
||||
| `dot` | [Dot Production](https://en.wikipedia.org/wiki/Dot_product) |
|
||||
| `hamming` | [Hamming Distance](https://en.wikipedia.org/wiki/Hamming_distance) |
|
||||
|
||||
!!! note
|
||||
The `hamming` metric is only available for binary vectors.
|
||||
|
||||
## Exhaustive search (kNN)
|
||||
|
||||
@@ -107,6 +111,31 @@ an ANN search means that using an index often involves a trade-off between recal
|
||||
See the [IVF_PQ index](./concepts/index_ivfpq.md) for a deeper description of how `IVF_PQ`
|
||||
indexes work in LanceDB.
|
||||
|
||||
## Binary vector
|
||||
|
||||
LanceDB supports binary vectors as a data type, and has the ability to search binary vectors with hamming distance. The binary vectors are stored as uint8 arrays (every 8 bits are stored as a byte):
|
||||
|
||||
!!! note
|
||||
The dim of the binary vector must be a multiple of 8. A vector of dim 128 will be stored as a uint8 array of size 16.
|
||||
|
||||
=== "Python"
|
||||
|
||||
=== "sync API"
|
||||
|
||||
```python
|
||||
--8<-- "python/python/tests/docs/test_binary_vector.py:imports"
|
||||
|
||||
--8<-- "python/python/tests/docs/test_binary_vector.py:sync_binary_vector"
|
||||
```
|
||||
|
||||
=== "async API"
|
||||
|
||||
```python
|
||||
--8<-- "python/python/tests/docs/test_binary_vector.py:imports"
|
||||
|
||||
--8<-- "python/python/tests/docs/test_binary_vector.py:async_binary_vector"
|
||||
```
|
||||
|
||||
## Output search results
|
||||
|
||||
LanceDB returns vector search results via different formats commonly used in python.
|
||||
|
||||
@@ -16,6 +16,7 @@ excluded_globs = [
|
||||
"../src/concepts/*.md",
|
||||
"../src/ann_indexes.md",
|
||||
"../src/basic.md",
|
||||
"../src/search.md",
|
||||
"../src/hybrid_search/hybrid_search.md",
|
||||
"../src/reranking/*.md",
|
||||
"../src/guides/tuning_retrievers/*.md",
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.14.1-beta.3</version>
|
||||
<version>0.14.1-final.0</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.14.1-beta.3</version>
|
||||
<version>0.14.1-final.0</version>
|
||||
<packaging>pom</packaging>
|
||||
|
||||
<name>LanceDB Parent</name>
|
||||
|
||||
111
node/package-lock.json
generated
111
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.14.1-beta.3",
|
||||
"version": "0.14.1",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "vectordb",
|
||||
"version": "0.14.1-beta.3",
|
||||
"version": "0.14.1",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
@@ -52,14 +52,14 @@
|
||||
"uuid": "^9.0.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.14.1-beta.3",
|
||||
"@lancedb/vectordb-darwin-x64": "0.14.1-beta.3",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.14.1-beta.3",
|
||||
"@lancedb/vectordb-linux-arm64-musl": "0.14.1-beta.3",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.14.1-beta.3",
|
||||
"@lancedb/vectordb-linux-x64-musl": "0.14.1-beta.3",
|
||||
"@lancedb/vectordb-win32-arm64-msvc": "0.14.1-beta.3",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.14.1-beta.3"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.14.1",
|
||||
"@lancedb/vectordb-darwin-x64": "0.14.1",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.14.1",
|
||||
"@lancedb/vectordb-linux-arm64-musl": "0.14.1",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.14.1",
|
||||
"@lancedb/vectordb-linux-x64-musl": "0.14.1",
|
||||
"@lancedb/vectordb-win32-arm64-msvc": "0.14.1",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.14.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@apache-arrow/ts": "^14.0.2",
|
||||
@@ -329,6 +329,97 @@
|
||||
"@jridgewell/sourcemap-codec": "^1.4.10"
|
||||
}
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
||||
"version": "0.14.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.14.1.tgz",
|
||||
"integrity": "sha512-6t7XHR7dBjDmAS/kz5wbe7LPhKW+WkFA16ZPyh0lmuxfnss4VvN3LE6qQBHjzYzB9U6Nu/4ktQ50xZGEPTnc5A==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-x64": {
|
||||
"version": "0.14.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.14.1.tgz",
|
||||
"integrity": "sha512-8q6Kd6XnNPKN8wqj75pHVQ4KFl6z9BaI6lWDiEaCNcO3bjPZkcLFNosJq4raxZ9iUi50Yl0qFJ6qR0XFVTwnnw==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
||||
"version": "0.14.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.14.1.tgz",
|
||||
"integrity": "sha512-4djEMmeNb+p6nW/C4xb8wdMwnIbWfO8fYAwiplOxzxeOpPaUC9rhwUUDCbrJDCpMa8RP5ED4/jC6yT8epaDMDw==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-arm64-musl": {
|
||||
"version": "0.14.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-musl/-/vectordb-linux-arm64-musl-0.14.1.tgz",
|
||||
"integrity": "sha512-c33hSsp16pnC58plzx1OXuifp9Rachx/MshE/L/OReoutt74fFdrRJwUjE4UCAysyY5QdvTrNm9OhDjopQK2Bw==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
||||
"version": "0.14.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.14.1.tgz",
|
||||
"integrity": "sha512-psu6cH9iLiSbUEZD1EWbOA4THGYSwJvS2XICO9yN7A6D41AP/ynYMRZNKWo1fpdi2Fjb0xNQwiNhQyqwbi5gzA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-x64-musl": {
|
||||
"version": "0.14.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-musl/-/vectordb-linux-x64-musl-0.14.1.tgz",
|
||||
"integrity": "sha512-Rg4VWW80HaTFmR7EvNSu+nfRQQM8beO/otBn/Nus5mj5zFw/7cacGRmiEYhDnk5iAn8nauV+Jsi9j2U+C2hp5w==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
||||
"version": "0.14.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.14.1.tgz",
|
||||
"integrity": "sha512-XbifasmMbQIt3V9P0AtQND6M3XFiIAc1ZIgmjzBjOmxwqw4sQUwHMyJGIGOzKFZTK3fPJIGRHId7jAzXuBgfQg==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"win32"
|
||||
]
|
||||
},
|
||||
"node_modules/@neon-rs/cli": {
|
||||
"version": "0.0.160",
|
||||
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.14.1-beta.3",
|
||||
"version": "0.14.1",
|
||||
"description": " Serverless, low-latency vector database for AI applications",
|
||||
"private": false,
|
||||
"main": "dist/index.js",
|
||||
@@ -92,13 +92,13 @@
|
||||
}
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-x64": "0.14.1-beta.3",
|
||||
"@lancedb/vectordb-darwin-arm64": "0.14.1-beta.3",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.14.1-beta.3",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.14.1-beta.3",
|
||||
"@lancedb/vectordb-linux-x64-musl": "0.14.1-beta.3",
|
||||
"@lancedb/vectordb-linux-arm64-musl": "0.14.1-beta.3",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.14.1-beta.3",
|
||||
"@lancedb/vectordb-win32-arm64-msvc": "0.14.1-beta.3"
|
||||
"@lancedb/vectordb-darwin-x64": "0.14.1",
|
||||
"@lancedb/vectordb-darwin-arm64": "0.14.1",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.14.1",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.14.1",
|
||||
"@lancedb/vectordb-linux-x64-musl": "0.14.1",
|
||||
"@lancedb/vectordb-linux-arm64-musl": "0.14.1",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.14.1",
|
||||
"@lancedb/vectordb-win32-arm64-msvc": "0.14.1"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "lancedb-nodejs"
|
||||
edition.workspace = true
|
||||
version = "0.14.1-beta.3"
|
||||
version = "0.14.1"
|
||||
license.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
@@ -12,7 +12,10 @@ categories.workspace = true
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
async-trait.workspace = true
|
||||
arrow-ipc.workspace = true
|
||||
arrow-array.workspace = true
|
||||
arrow-schema.workspace = true
|
||||
env_logger.workspace = true
|
||||
futures.workspace = true
|
||||
lancedb = { path = "../rust/lancedb", features = ["remote"] }
|
||||
|
||||
@@ -20,6 +20,8 @@ import * as arrow18 from "apache-arrow-18";
|
||||
|
||||
import {
|
||||
convertToTable,
|
||||
fromBufferToRecordBatch,
|
||||
fromRecordBatchToBuffer,
|
||||
fromTableToBuffer,
|
||||
makeArrowTable,
|
||||
makeEmptyTable,
|
||||
@@ -553,5 +555,28 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("converting record batches to buffers", function () {
|
||||
it("can convert to buffered record batch and back again", async function () {
|
||||
const records = [
|
||||
{ text: "dog", vector: [0.1, 0.2] },
|
||||
{ text: "cat", vector: [0.3, 0.4] },
|
||||
];
|
||||
const table = await convertToTable(records);
|
||||
const batch = table.batches[0];
|
||||
|
||||
const buffer = await fromRecordBatchToBuffer(batch);
|
||||
const result = await fromBufferToRecordBatch(buffer);
|
||||
|
||||
expect(JSON.stringify(batch.toArray())).toEqual(
|
||||
JSON.stringify(result?.toArray()),
|
||||
);
|
||||
});
|
||||
|
||||
it("converting from buffer returns null if buffer has no record batches", async function () {
|
||||
const result = await fromBufferToRecordBatch(Buffer.from([0x01, 0x02])); // bad data
|
||||
expect(result).toEqual(null);
|
||||
});
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
79
nodejs/__test__/rerankers.test.ts
Normal file
79
nodejs/__test__/rerankers.test.ts
Normal file
@@ -0,0 +1,79 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import { RecordBatch } from "apache-arrow";
|
||||
import * as tmp from "tmp";
|
||||
import { Connection, Index, Table, connect, makeArrowTable } from "../lancedb";
|
||||
import { RRFReranker } from "../lancedb/rerankers";
|
||||
|
||||
describe("rerankers", function () {
|
||||
let tmpDir: tmp.DirResult;
|
||||
let conn: Connection;
|
||||
let table: Table;
|
||||
|
||||
beforeEach(async () => {
|
||||
tmpDir = tmp.dirSync({ unsafeCleanup: true });
|
||||
conn = await connect(tmpDir.name);
|
||||
table = await conn.createTable("mytable", [
|
||||
{ vector: [0.1, 0.1], text: "dog" },
|
||||
{ vector: [0.2, 0.2], text: "cat" },
|
||||
]);
|
||||
await table.createIndex("text", {
|
||||
config: Index.fts(),
|
||||
replace: true,
|
||||
});
|
||||
});
|
||||
|
||||
it("will query with the custom reranker", async function () {
|
||||
const expectedResult = [
|
||||
{
|
||||
text: "albert",
|
||||
// biome-ignore lint/style/useNamingConvention: this is the lance field name
|
||||
_relevance_score: 0.99,
|
||||
},
|
||||
];
|
||||
class MyCustomReranker {
|
||||
async rerankHybrid(
|
||||
_query: string,
|
||||
_vecResults: RecordBatch,
|
||||
_ftsResults: RecordBatch,
|
||||
): Promise<RecordBatch> {
|
||||
// no reranker logic, just return some static data
|
||||
const table = makeArrowTable(expectedResult);
|
||||
return table.batches[0];
|
||||
}
|
||||
}
|
||||
|
||||
let result = await table
|
||||
.query()
|
||||
.nearestTo([0.1, 0.1])
|
||||
.fullTextSearch("dog")
|
||||
.rerank(new MyCustomReranker())
|
||||
.select(["text"])
|
||||
.limit(5)
|
||||
.toArray();
|
||||
|
||||
result = JSON.parse(JSON.stringify(result)); // convert StructRow to Object
|
||||
expect(result).toEqual([
|
||||
{
|
||||
text: "albert",
|
||||
// biome-ignore lint/style/useNamingConvention: this is the lance field name
|
||||
_relevance_score: 0.99,
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it("will query with RRFReranker", async function () {
|
||||
// smoke test to see if the Rust wrapping Typescript is wired up correctly
|
||||
const result = await table
|
||||
.query()
|
||||
.nearestTo([0.1, 0.1])
|
||||
.fullTextSearch("dog")
|
||||
.rerank(await RRFReranker.create())
|
||||
.select(["text"])
|
||||
.limit(5)
|
||||
.toArray();
|
||||
|
||||
expect(result).toHaveLength(2);
|
||||
});
|
||||
});
|
||||
@@ -27,7 +27,9 @@ import {
|
||||
List,
|
||||
Null,
|
||||
RecordBatch,
|
||||
RecordBatchFileReader,
|
||||
RecordBatchFileWriter,
|
||||
RecordBatchReader,
|
||||
RecordBatchStreamWriter,
|
||||
Schema,
|
||||
Struct,
|
||||
@@ -810,6 +812,30 @@ export async function fromDataToBuffer(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Read a single record batch from a buffer.
|
||||
*
|
||||
* Returns null if the buffer does not contain a record batch
|
||||
*/
|
||||
export async function fromBufferToRecordBatch(
|
||||
data: Buffer,
|
||||
): Promise<RecordBatch | null> {
|
||||
const iter = await RecordBatchFileReader.readAll(Buffer.from(data)).next()
|
||||
.value;
|
||||
const recordBatch = iter?.next().value;
|
||||
return recordBatch || null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a buffer containing a single record batch
|
||||
*/
|
||||
export async function fromRecordBatchToBuffer(
|
||||
batch: RecordBatch,
|
||||
): Promise<Buffer> {
|
||||
const writer = new RecordBatchFileWriter().writeAll([batch]);
|
||||
return Buffer.from(await writer.toUint8Array());
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialize an Arrow Table into a buffer using the Arrow IPC Stream serialization
|
||||
*
|
||||
|
||||
@@ -62,6 +62,7 @@ export { Index, IndexOptions, IvfPqOptions } from "./indices";
|
||||
export { Table, AddDataOptions, UpdateOptions, OptimizeOptions } from "./table";
|
||||
|
||||
export * as embedding from "./embedding";
|
||||
export * as rerankers from "./rerankers";
|
||||
|
||||
/**
|
||||
* Connect to a LanceDB instance at the given URI.
|
||||
|
||||
@@ -16,6 +16,8 @@ import {
|
||||
Table as ArrowTable,
|
||||
type IntoVector,
|
||||
RecordBatch,
|
||||
fromBufferToRecordBatch,
|
||||
fromRecordBatchToBuffer,
|
||||
tableFromIPC,
|
||||
} from "./arrow";
|
||||
import { type IvfPqOptions } from "./indices";
|
||||
@@ -25,6 +27,7 @@ import {
|
||||
Table as NativeTable,
|
||||
VectorQuery as NativeVectorQuery,
|
||||
} from "./native";
|
||||
import { Reranker } from "./rerankers";
|
||||
export class RecordBatchIterator implements AsyncIterator<RecordBatch> {
|
||||
private promisedInner?: Promise<NativeBatchIterator>;
|
||||
private inner?: NativeBatchIterator;
|
||||
@@ -542,6 +545,27 @@ export class VectorQuery extends QueryBase<NativeVectorQuery> {
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
||||
rerank(reranker: Reranker): VectorQuery {
|
||||
super.doCall((inner) =>
|
||||
inner.rerank({
|
||||
rerankHybrid: async (_, args) => {
|
||||
const vecResults = await fromBufferToRecordBatch(args.vecResults);
|
||||
const ftsResults = await fromBufferToRecordBatch(args.ftsResults);
|
||||
const result = await reranker.rerankHybrid(
|
||||
args.query,
|
||||
vecResults as RecordBatch,
|
||||
ftsResults as RecordBatch,
|
||||
);
|
||||
|
||||
const buffer = fromRecordBatchToBuffer(result);
|
||||
return buffer;
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
||||
/** A builder for LanceDB queries. */
|
||||
|
||||
17
nodejs/lancedb/rerankers/index.ts
Normal file
17
nodejs/lancedb/rerankers/index.ts
Normal file
@@ -0,0 +1,17 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import { RecordBatch } from "apache-arrow";
|
||||
|
||||
export * from "./rrf";
|
||||
|
||||
// Interface for a reranker. A reranker is used to rerank the results from a
|
||||
// vector and FTS search. This is useful for combining the results from both
|
||||
// search methods.
|
||||
export interface Reranker {
|
||||
rerankHybrid(
|
||||
query: string,
|
||||
vecResults: RecordBatch,
|
||||
ftsResults: RecordBatch,
|
||||
): Promise<RecordBatch>;
|
||||
}
|
||||
40
nodejs/lancedb/rerankers/rrf.ts
Normal file
40
nodejs/lancedb/rerankers/rrf.ts
Normal file
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import { RecordBatch } from "apache-arrow";
|
||||
import { fromBufferToRecordBatch, fromRecordBatchToBuffer } from "../arrow";
|
||||
import { RrfReranker as NativeRRFReranker } from "../native";
|
||||
|
||||
/**
|
||||
* Reranks the results using the Reciprocal Rank Fusion (RRF) algorithm.
|
||||
*
|
||||
* Internally this uses the Rust implementation
|
||||
*/
|
||||
export class RRFReranker {
|
||||
private inner: NativeRRFReranker;
|
||||
|
||||
constructor(inner: NativeRRFReranker) {
|
||||
this.inner = inner;
|
||||
}
|
||||
|
||||
public static async create(k: number = 60) {
|
||||
return new RRFReranker(
|
||||
await NativeRRFReranker.tryNew(new Float32Array([k])),
|
||||
);
|
||||
}
|
||||
|
||||
async rerankHybrid(
|
||||
query: string,
|
||||
vecResults: RecordBatch,
|
||||
ftsResults: RecordBatch,
|
||||
): Promise<RecordBatch> {
|
||||
const buffer = await this.inner.rerankHybrid(
|
||||
query,
|
||||
await fromRecordBatchToBuffer(vecResults),
|
||||
await fromRecordBatchToBuffer(ftsResults),
|
||||
);
|
||||
const recordBatch = await fromBufferToRecordBatch(buffer);
|
||||
|
||||
return recordBatch as RecordBatch;
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-arm64",
|
||||
"version": "0.14.1-beta.3",
|
||||
"version": "0.14.1",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.darwin-arm64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-x64",
|
||||
"version": "0.14.1-beta.3",
|
||||
"version": "0.14.1",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.darwin-x64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||
"version": "0.14.1-beta.3",
|
||||
"version": "0.14.1",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||
"version": "0.14.1-beta.3",
|
||||
"version": "0.14.1",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||
"version": "0.14.1-beta.3",
|
||||
"version": "0.14.1",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||
"version": "0.14.1-beta.3",
|
||||
"version": "0.14.1",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||
"version": "0.14.1-beta.3",
|
||||
"version": "0.14.1",
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||
"version": "0.14.1-beta.3",
|
||||
"version": "0.14.1",
|
||||
"os": ["win32"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.win32-x64-msvc.node",
|
||||
|
||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.14.0",
|
||||
"version": "0.14.1",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.14.0",
|
||||
"version": "0.14.1",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
"ann"
|
||||
],
|
||||
"private": false,
|
||||
"version": "0.14.1-beta.3",
|
||||
"version": "0.14.1",
|
||||
"main": "dist/index.js",
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
|
||||
@@ -24,6 +24,7 @@ mod iterator;
|
||||
pub mod merge;
|
||||
mod query;
|
||||
pub mod remote;
|
||||
mod rerankers;
|
||||
mod table;
|
||||
mod util;
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use lancedb::index::scalar::FullTextSearchQuery;
|
||||
use lancedb::query::ExecutableQuery;
|
||||
use lancedb::query::Query as LanceDbQuery;
|
||||
@@ -25,6 +27,8 @@ use napi_derive::napi;
|
||||
use crate::error::convert_error;
|
||||
use crate::error::NapiErrorExt;
|
||||
use crate::iterator::RecordBatchIterator;
|
||||
use crate::rerankers::Reranker;
|
||||
use crate::rerankers::RerankerCallbacks;
|
||||
use crate::util::parse_distance_type;
|
||||
|
||||
#[napi]
|
||||
@@ -218,6 +222,14 @@ impl VectorQuery {
|
||||
self.inner = self.inner.clone().with_row_id();
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn rerank(&mut self, callbacks: RerankerCallbacks) {
|
||||
self.inner = self
|
||||
.inner
|
||||
.clone()
|
||||
.rerank(Arc::new(Reranker::new(callbacks)));
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn execute(
|
||||
&self,
|
||||
|
||||
147
nodejs/src/rerankers.rs
Normal file
147
nodejs/src/rerankers.rs
Normal file
@@ -0,0 +1,147 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use arrow_array::RecordBatch;
|
||||
use async_trait::async_trait;
|
||||
use napi::{
|
||||
bindgen_prelude::*,
|
||||
threadsafe_function::{ErrorStrategy, ThreadsafeFunction},
|
||||
};
|
||||
use napi_derive::napi;
|
||||
|
||||
use lancedb::ipc::batches_to_ipc_file;
|
||||
use lancedb::rerankers::Reranker as LanceDBReranker;
|
||||
use lancedb::{error::Error, ipc::ipc_file_to_batches};
|
||||
|
||||
use crate::error::NapiErrorExt;
|
||||
|
||||
/// Reranker implementation that "wraps" a NodeJS Reranker implementation.
|
||||
/// This contains references to the callbacks that can be used to invoke the
|
||||
/// reranking methods on the NodeJS implementation and handles serializing the
|
||||
/// record batches to Arrow IPC buffers.
|
||||
#[napi]
|
||||
pub struct Reranker {
|
||||
/// callback to the Javascript which will call the rerankHybrid method of
|
||||
/// some Reranker implementation
|
||||
rerank_hybrid: ThreadsafeFunction<RerankHybridCallbackArgs, ErrorStrategy::CalleeHandled>,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl Reranker {
|
||||
#[napi]
|
||||
pub fn new(callbacks: RerankerCallbacks) -> Self {
|
||||
let rerank_hybrid = callbacks
|
||||
.rerank_hybrid
|
||||
.create_threadsafe_function(0, move |ctx| Ok(vec![ctx.value]))
|
||||
.unwrap();
|
||||
|
||||
Self { rerank_hybrid }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl lancedb::rerankers::Reranker for Reranker {
|
||||
async fn rerank_hybrid(
|
||||
&self,
|
||||
query: &str,
|
||||
vector_results: RecordBatch,
|
||||
fts_results: RecordBatch,
|
||||
) -> lancedb::error::Result<RecordBatch> {
|
||||
let callback_args = RerankHybridCallbackArgs {
|
||||
query: query.to_string(),
|
||||
vec_results: batches_to_ipc_file(&[vector_results])?,
|
||||
fts_results: batches_to_ipc_file(&[fts_results])?,
|
||||
};
|
||||
let promised_buffer: Promise<Buffer> = self
|
||||
.rerank_hybrid
|
||||
.call_async(Ok(callback_args))
|
||||
.await
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("napi error status={}, reason={}", e.status, e.reason),
|
||||
})?;
|
||||
let buffer = promised_buffer.await.map_err(|e| Error::Runtime {
|
||||
message: format!("napi error status={}, reason={}", e.status, e.reason),
|
||||
})?;
|
||||
let mut reader = ipc_file_to_batches(buffer.to_vec())?;
|
||||
let result = reader.next().ok_or(Error::Runtime {
|
||||
message: "reranker result deserialization failed".to_string(),
|
||||
})??;
|
||||
|
||||
return Ok(result);
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Reranker {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str("NodeJSRerankerWrapper")
|
||||
}
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
pub struct RerankerCallbacks {
|
||||
pub rerank_hybrid: JsFunction,
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
pub struct RerankHybridCallbackArgs {
|
||||
pub query: String,
|
||||
pub vec_results: Vec<u8>,
|
||||
pub fts_results: Vec<u8>,
|
||||
}
|
||||
|
||||
fn buffer_to_record_batch(buffer: Buffer) -> Result<RecordBatch> {
|
||||
let mut reader = ipc_file_to_batches(buffer.to_vec()).default_error()?;
|
||||
reader
|
||||
.next()
|
||||
.ok_or(Error::InvalidInput {
|
||||
message: "expected buffer containing record batch".to_string(),
|
||||
})
|
||||
.default_error()?
|
||||
.map_err(Error::from)
|
||||
.default_error()
|
||||
}
|
||||
|
||||
/// Wrapper around rust RRFReranker
|
||||
#[napi]
|
||||
pub struct RRFReranker {
|
||||
inner: lancedb::rerankers::rrf::RRFReranker,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl RRFReranker {
|
||||
#[napi]
|
||||
pub async fn try_new(k: &[f32]) -> Result<Self> {
|
||||
let k = k
|
||||
.first()
|
||||
.copied()
|
||||
.ok_or(Error::InvalidInput {
|
||||
message: "must supply RRF Reranker constructor arg 'k'".to_string(),
|
||||
})
|
||||
.default_error()?;
|
||||
|
||||
Ok(Self {
|
||||
inner: lancedb::rerankers::rrf::RRFReranker::new(k),
|
||||
})
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub async fn rerank_hybrid(
|
||||
&self,
|
||||
query: String,
|
||||
vec_results: Buffer,
|
||||
fts_results: Buffer,
|
||||
) -> Result<Buffer> {
|
||||
let vec_results = buffer_to_record_batch(vec_results)?;
|
||||
let fts_results = buffer_to_record_batch(fts_results)?;
|
||||
|
||||
let result = self
|
||||
.inner
|
||||
.rerank_hybrid(&query, vec_results, fts_results)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let result_buff = batches_to_ipc_file(&[result]).default_error()?;
|
||||
|
||||
Ok(Buffer::from(result_buff.as_ref()))
|
||||
}
|
||||
}
|
||||
@@ -5,8 +5,9 @@ pub fn parse_distance_type(distance_type: impl AsRef<str>) -> napi::Result<Dista
|
||||
"l2" => Ok(DistanceType::L2),
|
||||
"cosine" => Ok(DistanceType::Cosine),
|
||||
"dot" => Ok(DistanceType::Dot),
|
||||
"hamming" => Ok(DistanceType::Hamming),
|
||||
_ => Err(napi::Error::from_reason(format!(
|
||||
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
|
||||
"Invalid distance type '{}'. Must be one of l2, cosine, dot, or hamming",
|
||||
distance_type.as_ref()
|
||||
))),
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.17.1-beta.4"
|
||||
current_version = "0.17.2-beta.2"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.17.1-beta.4"
|
||||
version = "0.17.2-beta.2"
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
license.workspace = true
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
[project]
|
||||
name = "lancedb"
|
||||
# version in Cargo.toml
|
||||
dynamic = ["version"]
|
||||
dependencies = [
|
||||
"deprecation",
|
||||
"pylance==0.21.0b3",
|
||||
"pylance==0.21.1b1",
|
||||
"tqdm>=4.27.0",
|
||||
"pydantic>=1.10",
|
||||
"packaging",
|
||||
@@ -52,8 +53,9 @@ tests = [
|
||||
"pytz",
|
||||
"polars>=0.19, <=1.3.0",
|
||||
"tantivy",
|
||||
"pyarrow-stubs"
|
||||
]
|
||||
dev = ["ruff", "pre-commit"]
|
||||
dev = ["ruff", "pre-commit", "pyright"]
|
||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||
clip = ["torch", "pillow", "open-clip"]
|
||||
embeddings = [
|
||||
@@ -93,3 +95,7 @@ markers = [
|
||||
"asyncio",
|
||||
"s3_test",
|
||||
]
|
||||
|
||||
[tool.pyright]
|
||||
include = ["python/lancedb/table.py"]
|
||||
pythonVersion = "3.12"
|
||||
|
||||
@@ -70,7 +70,7 @@ def connect(
|
||||
default configuration is used.
|
||||
storage_options: dict, optional
|
||||
Additional options for the storage backend. See available options at
|
||||
https://lancedb.github.io/lancedb/guides/storage/
|
||||
<https://lancedb.github.io/lancedb/guides/storage/>
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -82,11 +82,13 @@ def connect(
|
||||
|
||||
For object storage, use a URI prefix:
|
||||
|
||||
>>> db = lancedb.connect("s3://my-bucket/lancedb")
|
||||
>>> db = lancedb.connect("s3://my-bucket/lancedb",
|
||||
... storage_options={"aws_access_key_id": "***"})
|
||||
|
||||
Connect to LanceDB cloud:
|
||||
|
||||
>>> db = lancedb.connect("db://my_database", api_key="ldb_...")
|
||||
>>> db = lancedb.connect("db://my_database", api_key="ldb_...",
|
||||
... client_config={"retry_config": {"retries": 5}})
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -164,7 +166,7 @@ async def connect_async(
|
||||
default configuration is used.
|
||||
storage_options: dict, optional
|
||||
Additional options for the storage backend. See available options at
|
||||
https://lancedb.github.io/lancedb/guides/storage/
|
||||
<https://lancedb.github.io/lancedb/guides/storage/>
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
@@ -1,20 +1,11 @@
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple, Any, Union, Literal
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
class Index:
|
||||
@staticmethod
|
||||
def ivf_pq(
|
||||
distance_type: Optional[str],
|
||||
num_partitions: Optional[int],
|
||||
num_sub_vectors: Optional[int],
|
||||
max_iterations: Optional[int],
|
||||
sample_rate: Optional[int],
|
||||
) -> Index: ...
|
||||
@staticmethod
|
||||
def btree() -> Index: ...
|
||||
from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
|
||||
|
||||
class Connection(object):
|
||||
uri: str
|
||||
async def table_names(
|
||||
self, start_after: Optional[str], limit: Optional[int]
|
||||
) -> list[str]: ...
|
||||
@@ -42,18 +33,35 @@ class Connection(object):
|
||||
class Table:
|
||||
def name(self) -> str: ...
|
||||
def __repr__(self) -> str: ...
|
||||
def is_open(self) -> bool: ...
|
||||
def close(self) -> None: ...
|
||||
async def schema(self) -> pa.Schema: ...
|
||||
async def add(self, data: pa.RecordBatchReader, mode: str) -> None: ...
|
||||
async def add(
|
||||
self, data: pa.RecordBatchReader, mode: Literal["append", "overwrite"]
|
||||
) -> None: ...
|
||||
async def update(self, updates: Dict[str, str], where: Optional[str]) -> None: ...
|
||||
async def count_rows(self, filter: Optional[str]) -> int: ...
|
||||
async def create_index(
|
||||
self, column: str, config: Optional[Index], replace: Optional[bool]
|
||||
self,
|
||||
column: str,
|
||||
index: Union[IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS],
|
||||
replace: Optional[bool],
|
||||
): ...
|
||||
async def list_versions(self) -> List[Dict[str, Any]]: ...
|
||||
async def version(self) -> int: ...
|
||||
async def checkout(self, version): ...
|
||||
async def checkout(self, version: int): ...
|
||||
async def checkout_latest(self): ...
|
||||
async def restore(self): ...
|
||||
async def list_indices(self) -> List[IndexConfig]: ...
|
||||
async def list_indices(self) -> list[IndexConfig]: ...
|
||||
async def delete(self, filter: str): ...
|
||||
async def add_columns(self, columns: list[tuple[str, str]]) -> None: ...
|
||||
async def alter_columns(self, columns: list[dict[str, Any]]) -> None: ...
|
||||
async def optimize(
|
||||
self,
|
||||
*,
|
||||
cleanup_since_ms: Optional[int] = None,
|
||||
delete_unverified: Optional[bool] = None,
|
||||
) -> OptimizeStats: ...
|
||||
def query(self) -> Query: ...
|
||||
def vector_search(self) -> VectorQuery: ...
|
||||
|
||||
|
||||
@@ -23,3 +23,6 @@ class BackgroundEventLoop:
|
||||
|
||||
def run(self, future):
|
||||
return asyncio.run_coroutine_threadsafe(future, self.loop).result()
|
||||
|
||||
|
||||
LOOP = BackgroundEventLoop()
|
||||
|
||||
@@ -17,12 +17,13 @@ from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Union
|
||||
|
||||
from overrides import EnforceOverrides, override
|
||||
from lancedb.embeddings.registry import EmbeddingFunctionRegistry
|
||||
from overrides import EnforceOverrides, override # type: ignore
|
||||
|
||||
from lancedb.common import data_to_reader, sanitize_uri, validate_schema
|
||||
from lancedb.background_loop import BackgroundEventLoop
|
||||
from lancedb.background_loop import LOOP
|
||||
|
||||
from ._lancedb import connect as lancedb_connect
|
||||
from ._lancedb import connect as lancedb_connect # type: ignore
|
||||
from .table import (
|
||||
AsyncTable,
|
||||
LanceTable,
|
||||
@@ -43,8 +44,6 @@ if TYPE_CHECKING:
|
||||
from .common import DATA, URI
|
||||
from .embeddings import EmbeddingFunctionConfig
|
||||
|
||||
LOOP = BackgroundEventLoop()
|
||||
|
||||
|
||||
class DBConnection(EnforceOverrides):
|
||||
"""An active LanceDB connection interface."""
|
||||
@@ -82,6 +81,10 @@ class DBConnection(EnforceOverrides):
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
*,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
data_storage_version: Optional[str] = None,
|
||||
enable_v2_manifest_paths: Optional[bool] = None,
|
||||
) -> Table:
|
||||
"""Create a [Table][lancedb.table.Table] in the database.
|
||||
|
||||
@@ -119,6 +122,24 @@ class DBConnection(EnforceOverrides):
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
storage_options: dict, optional
|
||||
Additional options for the storage backend. Options already set on the
|
||||
connection will be inherited by the table, but can be overridden here.
|
||||
See available options at
|
||||
<https://lancedb.github.io/lancedb/guides/storage/>
|
||||
data_storage_version: optional, str, default "stable"
|
||||
The version of the data storage format to use. Newer versions are more
|
||||
efficient but require newer versions of lance to read. The default is
|
||||
"stable" which will use the legacy v2 version. See the user guide
|
||||
for more details.
|
||||
enable_v2_manifest_paths: bool, optional, default False
|
||||
Use the new V2 manifest paths. These paths provide more efficient
|
||||
opening of datasets with many versions on object stores. WARNING:
|
||||
turning this on will make the dataset unreadable for older versions
|
||||
of LanceDB (prior to 0.13.0). To migrate an existing dataset, instead
|
||||
use the
|
||||
[Table.migrate_manifest_paths_v2][lancedb.table.Table.migrate_v2_manifest_paths]
|
||||
method.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -140,7 +161,7 @@ class DBConnection(EnforceOverrides):
|
||||
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
|
||||
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
|
||||
>>> db.create_table("my_table", data)
|
||||
LanceTable(connection=..., name="my_table")
|
||||
LanceTable(name='my_table', version=1, ...)
|
||||
>>> db["my_table"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
@@ -161,7 +182,7 @@ class DBConnection(EnforceOverrides):
|
||||
... "long": [-122.7, -74.1]
|
||||
... })
|
||||
>>> db.create_table("table2", data)
|
||||
LanceTable(connection=..., name="table2")
|
||||
LanceTable(name='table2', version=1, ...)
|
||||
>>> db["table2"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
@@ -184,7 +205,7 @@ class DBConnection(EnforceOverrides):
|
||||
... pa.field("long", pa.float32())
|
||||
... ])
|
||||
>>> db.create_table("table3", data, schema = custom_schema)
|
||||
LanceTable(connection=..., name="table3")
|
||||
LanceTable(name='table3', version=1, ...)
|
||||
>>> db["table3"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
@@ -218,7 +239,7 @@ class DBConnection(EnforceOverrides):
|
||||
... pa.field("price", pa.float32()),
|
||||
... ])
|
||||
>>> db.create_table("table4", make_batches(), schema=schema)
|
||||
LanceTable(connection=..., name="table4")
|
||||
LanceTable(name='table4', version=1, ...)
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -226,7 +247,13 @@ class DBConnection(EnforceOverrides):
|
||||
def __getitem__(self, name: str) -> LanceTable:
|
||||
return self.open_table(name)
|
||||
|
||||
def open_table(self, name: str, *, index_cache_size: Optional[int] = None) -> Table:
|
||||
def open_table(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
) -> Table:
|
||||
"""Open a Lance Table in the database.
|
||||
|
||||
Parameters
|
||||
@@ -243,6 +270,11 @@ class DBConnection(EnforceOverrides):
|
||||
This cache applies to the entire opened table, across all indices.
|
||||
Setting this value higher will increase performance on larger datasets
|
||||
at the expense of more RAM
|
||||
storage_options: dict, optional
|
||||
Additional options for the storage backend. Options already set on the
|
||||
connection will be inherited by the table, but can be overridden here.
|
||||
See available options at
|
||||
<https://lancedb.github.io/lancedb/guides/storage/>
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -309,15 +341,15 @@ class LanceDBConnection(DBConnection):
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2},
|
||||
... {"vector": [0.5, 1.3], "b": 4}])
|
||||
LanceTable(connection=..., name="my_table")
|
||||
LanceTable(name='my_table', version=1, ...)
|
||||
>>> db.create_table("another_table", data=[{"vector": [0.4, 0.4], "b": 6}])
|
||||
LanceTable(connection=..., name="another_table")
|
||||
LanceTable(name='another_table', version=1, ...)
|
||||
>>> sorted(db.table_names())
|
||||
['another_table', 'my_table']
|
||||
>>> len(db)
|
||||
2
|
||||
>>> db["my_table"]
|
||||
LanceTable(connection=..., name="my_table")
|
||||
LanceTable(name='my_table', version=1, ...)
|
||||
>>> "my_table" in db
|
||||
True
|
||||
>>> db.drop_table("my_table")
|
||||
@@ -363,7 +395,7 @@ class LanceDBConnection(DBConnection):
|
||||
self._conn = AsyncConnection(LOOP.run(do_connect()))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
val = f"{self.__class__.__name__}({self._uri}"
|
||||
val = f"{self.__class__.__name__}(uri={self._uri!r}"
|
||||
if self.read_consistency_interval is not None:
|
||||
val += f", read_consistency_interval={repr(self.read_consistency_interval)}"
|
||||
val += ")"
|
||||
@@ -403,6 +435,10 @@ class LanceDBConnection(DBConnection):
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
*,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
data_storage_version: Optional[str] = None,
|
||||
enable_v2_manifest_paths: Optional[bool] = None,
|
||||
) -> LanceTable:
|
||||
"""Create a table in the database.
|
||||
|
||||
@@ -424,12 +460,19 @@ class LanceDBConnection(DBConnection):
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
embedding_functions=embedding_functions,
|
||||
storage_options=storage_options,
|
||||
data_storage_version=data_storage_version,
|
||||
enable_v2_manifest_paths=enable_v2_manifest_paths,
|
||||
)
|
||||
return tbl
|
||||
|
||||
@override
|
||||
def open_table(
|
||||
self, name: str, *, index_cache_size: Optional[int] = None
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
) -> LanceTable:
|
||||
"""Open a table in the database.
|
||||
|
||||
@@ -442,7 +485,12 @@ class LanceDBConnection(DBConnection):
|
||||
-------
|
||||
A LanceTable object representing the table.
|
||||
"""
|
||||
return LanceTable.open(self, name, index_cache_size=index_cache_size)
|
||||
return LanceTable.open(
|
||||
self,
|
||||
name,
|
||||
storage_options=storage_options,
|
||||
index_cache_size=index_cache_size,
|
||||
)
|
||||
|
||||
@override
|
||||
def drop_table(self, name: str, ignore_missing: bool = False):
|
||||
@@ -455,13 +503,7 @@ class LanceDBConnection(DBConnection):
|
||||
ignore_missing: bool, default False
|
||||
If True, ignore if the table does not exist.
|
||||
"""
|
||||
try:
|
||||
LOOP.run(self._conn.drop_table(name))
|
||||
except ValueError as e:
|
||||
if not ignore_missing:
|
||||
raise e
|
||||
if f"Table '{name}' was not found" not in str(e):
|
||||
raise e
|
||||
LOOP.run(self._conn.drop_table(name, ignore_missing=ignore_missing))
|
||||
|
||||
@override
|
||||
def drop_database(self):
|
||||
@@ -524,6 +566,10 @@ class AsyncConnection(object):
|
||||
Any attempt to use the connection after it is closed will result in an error."""
|
||||
self._inner.close()
|
||||
|
||||
@property
|
||||
def uri(self) -> str:
|
||||
return self._inner.uri
|
||||
|
||||
async def table_names(
|
||||
self, *, start_after: Optional[str] = None, limit: Optional[int] = None
|
||||
) -> Iterable[str]:
|
||||
@@ -557,6 +603,7 @@ class AsyncConnection(object):
|
||||
fill_value: Optional[float] = None,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
*,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
data_storage_version: Optional[str] = None,
|
||||
use_legacy_format: Optional[bool] = None,
|
||||
enable_v2_manifest_paths: Optional[bool] = None,
|
||||
@@ -601,7 +648,7 @@ class AsyncConnection(object):
|
||||
Additional options for the storage backend. Options already set on the
|
||||
connection will be inherited by the table, but can be overridden here.
|
||||
See available options at
|
||||
https://lancedb.github.io/lancedb/guides/storage/
|
||||
<https://lancedb.github.io/lancedb/guides/storage/>
|
||||
data_storage_version: optional, str, default "stable"
|
||||
The version of the data storage format to use. Newer versions are more
|
||||
efficient but require newer versions of lance to read. The default is
|
||||
@@ -730,6 +777,17 @@ class AsyncConnection(object):
|
||||
"""
|
||||
metadata = None
|
||||
|
||||
if embedding_functions is not None:
|
||||
# If we passed in embedding functions explicitly
|
||||
# then we'll override any schema metadata that
|
||||
# may was implicitly specified by the LanceModel schema
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
metadata = registry.get_table_metadata(embedding_functions)
|
||||
|
||||
data, schema = sanitize_create_table(
|
||||
data, schema, metadata, on_bad_vectors, fill_value
|
||||
)
|
||||
|
||||
# Defining defaults here and not in function prototype. In the future
|
||||
# these defaults will move into rust so better to keep them as None.
|
||||
if on_bad_vectors is None:
|
||||
@@ -791,7 +849,7 @@ class AsyncConnection(object):
|
||||
Additional options for the storage backend. Options already set on the
|
||||
connection will be inherited by the table, but can be overridden here.
|
||||
See available options at
|
||||
https://lancedb.github.io/lancedb/guides/storage/
|
||||
<https://lancedb.github.io/lancedb/guides/storage/>
|
||||
index_cache_size: int, default 256
|
||||
Set the size of the index cache, specified as a number of entries
|
||||
|
||||
@@ -822,15 +880,23 @@ class AsyncConnection(object):
|
||||
"""
|
||||
await self._inner.rename_table(old_name, new_name)
|
||||
|
||||
async def drop_table(self, name: str):
|
||||
async def drop_table(self, name: str, *, ignore_missing: bool = False):
|
||||
"""Drop a table from the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
ignore_missing: bool, default False
|
||||
If True, ignore if the table does not exist.
|
||||
"""
|
||||
await self._inner.drop_table(name)
|
||||
try:
|
||||
await self._inner.drop_table(name)
|
||||
except ValueError as e:
|
||||
if not ignore_missing:
|
||||
raise e
|
||||
if f"Table '{name}' was not found" not in str(e):
|
||||
raise e
|
||||
|
||||
async def drop_database(self):
|
||||
"""
|
||||
|
||||
@@ -1,20 +1,10 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
"""Full text search index using tantivy-py"""
|
||||
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
@@ -31,7 +21,7 @@ from .table import LanceTable
|
||||
def create_index(
|
||||
index_path: str,
|
||||
text_fields: List[str],
|
||||
ordering_fields: List[str] = None,
|
||||
ordering_fields: Optional[List[str]] = None,
|
||||
tokenizer_name: str = "default",
|
||||
) -> tantivy.Index:
|
||||
"""
|
||||
@@ -75,8 +65,8 @@ def populate_index(
|
||||
index: tantivy.Index,
|
||||
table: LanceTable,
|
||||
fields: List[str],
|
||||
writer_heap_size: int = 1024 * 1024 * 1024,
|
||||
ordering_fields: List[str] = None,
|
||||
writer_heap_size: Optional[int] = None,
|
||||
ordering_fields: Optional[List[str]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Populate an index with data from a LanceTable
|
||||
@@ -99,6 +89,7 @@ def populate_index(
|
||||
"""
|
||||
if ordering_fields is None:
|
||||
ordering_fields = []
|
||||
writer_heap_size = writer_heap_size or 1024 * 1024 * 1024
|
||||
# first check the fields exist and are string or large string type
|
||||
nested = []
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Optional
|
||||
|
||||
from ._lancedb import (
|
||||
Index as LanceDbIndex,
|
||||
)
|
||||
from ._lancedb import (
|
||||
IndexConfig,
|
||||
)
|
||||
@@ -29,6 +27,7 @@ lang_mapping = {
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BTree:
|
||||
"""Describes a btree index configuration
|
||||
|
||||
@@ -50,10 +49,10 @@ class BTree:
|
||||
the block size may be added in the future.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._inner = LanceDbIndex.btree()
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Bitmap:
|
||||
"""Describe a Bitmap index configuration.
|
||||
|
||||
@@ -73,10 +72,10 @@ class Bitmap:
|
||||
requires 128 / 8 * 1Bi bytes on disk.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._inner = LanceDbIndex.bitmap()
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class LabelList:
|
||||
"""Describe a LabelList index configuration.
|
||||
|
||||
@@ -87,41 +86,57 @@ class LabelList:
|
||||
For example, it works with `tags`, `categories`, `keywords`, etc.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._inner = LanceDbIndex.label_list()
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class FTS:
|
||||
"""Describe a FTS index configuration.
|
||||
|
||||
`FTS` is a full-text search index that can be used on `String` columns
|
||||
|
||||
For example, it works with `title`, `description`, `content`, etc.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
with_position : bool, default True
|
||||
Whether to store the position of the token in the document. Setting this
|
||||
to False can reduce the size of the index and improve indexing speed,
|
||||
but it will disable support for phrase queries.
|
||||
base_tokenizer : str, default "simple"
|
||||
The base tokenizer to use for tokenization. Options are:
|
||||
- "simple": Splits text by whitespace and punctuation.
|
||||
- "whitespace": Split text by whitespace, but not punctuation.
|
||||
- "raw": No tokenization. The entire text is treated as a single token.
|
||||
language : str, default "English"
|
||||
The language to use for tokenization.
|
||||
max_token_length : int, default 40
|
||||
The maximum token length to index. Tokens longer than this length will be
|
||||
ignored.
|
||||
lower_case : bool, default True
|
||||
Whether to convert the token to lower case. This makes queries case-insensitive.
|
||||
stem : bool, default False
|
||||
Whether to stem the token. Stemming reduces words to their root form.
|
||||
For example, in English "running" and "runs" would both be reduced to "run".
|
||||
remove_stop_words : bool, default False
|
||||
Whether to remove stop words. Stop words are common words that are often
|
||||
removed from text before indexing. For example, in English "the" and "and".
|
||||
ascii_folding : bool, default False
|
||||
Whether to fold ASCII characters. This converts accented characters to
|
||||
their ASCII equivalent. For example, "café" would be converted to "cafe".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
with_position: bool = True,
|
||||
base_tokenizer: str = "simple",
|
||||
language: str = "English",
|
||||
max_token_length: Optional[int] = 40,
|
||||
lower_case: bool = True,
|
||||
stem: bool = False,
|
||||
remove_stop_words: bool = False,
|
||||
ascii_folding: bool = False,
|
||||
):
|
||||
self._inner = LanceDbIndex.fts(
|
||||
with_position=with_position,
|
||||
base_tokenizer=base_tokenizer,
|
||||
language=language,
|
||||
max_token_length=max_token_length,
|
||||
lower_case=lower_case,
|
||||
stem=stem,
|
||||
remove_stop_words=remove_stop_words,
|
||||
ascii_folding=ascii_folding,
|
||||
)
|
||||
with_position: bool = True
|
||||
base_tokenizer: Literal["simple", "raw", "whitespace"] = "simple"
|
||||
language: str = "English"
|
||||
max_token_length: Optional[int] = 40
|
||||
lower_case: bool = True
|
||||
stem: bool = False
|
||||
remove_stop_words: bool = False
|
||||
ascii_folding: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class HnswPq:
|
||||
"""Describe a HNSW-PQ index configuration.
|
||||
|
||||
@@ -232,30 +247,17 @@ class HnswPq:
|
||||
search phase.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
distance_type: Optional[str] = None,
|
||||
num_partitions: Optional[int] = None,
|
||||
num_sub_vectors: Optional[int] = None,
|
||||
num_bits: Optional[int] = None,
|
||||
max_iterations: Optional[int] = None,
|
||||
sample_rate: Optional[int] = None,
|
||||
m: Optional[int] = None,
|
||||
ef_construction: Optional[int] = None,
|
||||
):
|
||||
self._inner = LanceDbIndex.hnsw_pq(
|
||||
distance_type=distance_type,
|
||||
num_partitions=num_partitions,
|
||||
num_sub_vectors=num_sub_vectors,
|
||||
num_bits=num_bits,
|
||||
max_iterations=max_iterations,
|
||||
sample_rate=sample_rate,
|
||||
m=m,
|
||||
ef_construction=ef_construction,
|
||||
)
|
||||
distance_type: Literal["l2", "cosine", "dot"] = "l2"
|
||||
num_partitions: Optional[int] = None
|
||||
num_sub_vectors: Optional[int] = None
|
||||
num_bits: int = 8
|
||||
max_iterations: int = 50
|
||||
sample_rate: int = 256
|
||||
m: int = 20
|
||||
ef_construction: int = 300
|
||||
|
||||
|
||||
@dataclass
|
||||
class HnswSq:
|
||||
"""Describe a HNSW-SQ index configuration.
|
||||
|
||||
@@ -345,26 +347,106 @@ class HnswSq:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
distance_type: Optional[str] = None,
|
||||
num_partitions: Optional[int] = None,
|
||||
max_iterations: Optional[int] = None,
|
||||
sample_rate: Optional[int] = None,
|
||||
m: Optional[int] = None,
|
||||
ef_construction: Optional[int] = None,
|
||||
):
|
||||
self._inner = LanceDbIndex.hnsw_sq(
|
||||
distance_type=distance_type,
|
||||
num_partitions=num_partitions,
|
||||
max_iterations=max_iterations,
|
||||
sample_rate=sample_rate,
|
||||
m=m,
|
||||
ef_construction=ef_construction,
|
||||
)
|
||||
distance_type: Literal["l2", "cosine", "dot"] = "l2"
|
||||
num_partitions: Optional[int] = None
|
||||
max_iterations: int = 50
|
||||
sample_rate: int = 256
|
||||
m: int = 20
|
||||
ef_construction: int = 300
|
||||
|
||||
|
||||
@dataclass
|
||||
class IvfFlat:
|
||||
"""Describes an IVF Flat Index
|
||||
|
||||
This index stores raw vectors.
|
||||
These vectors are grouped into partitions of similar vectors.
|
||||
Each partition keeps track of a centroid which is
|
||||
the average value of all vectors in the group.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
distance_type: str, default "L2"
|
||||
The distance metric used to train the index
|
||||
|
||||
This is used when training the index to calculate the IVF partitions
|
||||
(vectors are grouped in partitions with similar vectors according to this
|
||||
distance type) and to calculate a subvector's code during quantization.
|
||||
|
||||
The distance type used to train an index MUST match the distance type used
|
||||
to search the index. Failure to do so will yield inaccurate results.
|
||||
|
||||
The following distance types are available:
|
||||
|
||||
"l2" - Euclidean distance. This is a very common distance metric that
|
||||
accounts for both magnitude and direction when determining the distance
|
||||
between vectors. L2 distance has a range of [0, ∞).
|
||||
|
||||
"cosine" - Cosine distance. Cosine distance is a distance metric
|
||||
calculated from the cosine similarity between two vectors. Cosine
|
||||
similarity is a measure of similarity between two non-zero vectors of an
|
||||
inner product space. It is defined to equal the cosine of the angle
|
||||
between them. Unlike L2, the cosine distance is not affected by the
|
||||
magnitude of the vectors. Cosine distance has a range of [0, 2].
|
||||
|
||||
Note: the cosine distance is undefined when one (or both) of the vectors
|
||||
are all zeros (there is no direction). These vectors are invalid and may
|
||||
never be returned from a vector search.
|
||||
|
||||
"dot" - Dot product. Dot distance is the dot product of two vectors. Dot
|
||||
distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their
|
||||
L2 norm is 1), then dot distance is equivalent to the cosine distance.
|
||||
|
||||
"hamming" - Hamming distance. Hamming distance is a distance metric
|
||||
calculated as the number of positions at which the corresponding bits are
|
||||
different. Hamming distance has a range of [0, vector dimension].
|
||||
|
||||
num_partitions: int, default sqrt(num_rows)
|
||||
The number of IVF partitions to create.
|
||||
|
||||
This value should generally scale with the number of rows in the dataset.
|
||||
By default the number of partitions is the square root of the number of
|
||||
rows.
|
||||
|
||||
If this value is too large then the first part of the search (picking the
|
||||
right partition) will be slow. If this value is too small then the second
|
||||
part of the search (searching within a partition) will be slow.
|
||||
|
||||
max_iterations: int, default 50
|
||||
Max iteration to train kmeans.
|
||||
|
||||
When training an IVF PQ index we use kmeans to calculate the partitions.
|
||||
This parameter controls how many iterations of kmeans to run.
|
||||
|
||||
Increasing this might improve the quality of the index but in most cases
|
||||
these extra iterations have diminishing returns.
|
||||
|
||||
The default value is 50.
|
||||
sample_rate: int, default 256
|
||||
The rate used to calculate the number of training vectors for kmeans.
|
||||
|
||||
When an IVF PQ index is trained, we need to calculate partitions. These
|
||||
are groups of vectors that are similar to each other. To do this we use an
|
||||
algorithm called kmeans.
|
||||
|
||||
Running kmeans on a large dataset can be slow. To speed this up we run
|
||||
kmeans on a random sample of the data. This parameter controls the size of
|
||||
the sample. The total number of vectors used to train the index is
|
||||
`sample_rate * num_partitions`.
|
||||
|
||||
Increasing this value might improve the quality of the index but in most
|
||||
cases the default should be sufficient.
|
||||
|
||||
The default value is 256.
|
||||
"""
|
||||
|
||||
distance_type: Literal["l2", "cosine", "dot", "hamming"] = "l2"
|
||||
num_partitions: Optional[int] = None
|
||||
max_iterations: int = 50
|
||||
sample_rate: int = 256
|
||||
|
||||
|
||||
@dataclass
|
||||
class IvfPq:
|
||||
"""Describes an IVF PQ Index
|
||||
|
||||
@@ -387,120 +469,113 @@ class IvfPq:
|
||||
|
||||
Note that training an IVF PQ index on a large dataset is a slow operation and
|
||||
currently is also a memory intensive operation.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
distance_type: str, default "L2"
|
||||
The distance metric used to train the index
|
||||
|
||||
This is used when training the index to calculate the IVF partitions
|
||||
(vectors are grouped in partitions with similar vectors according to this
|
||||
distance type) and to calculate a subvector's code during quantization.
|
||||
|
||||
The distance type used to train an index MUST match the distance type used
|
||||
to search the index. Failure to do so will yield inaccurate results.
|
||||
|
||||
The following distance types are available:
|
||||
|
||||
"l2" - Euclidean distance. This is a very common distance metric that
|
||||
accounts for both magnitude and direction when determining the distance
|
||||
between vectors. L2 distance has a range of [0, ∞).
|
||||
|
||||
"cosine" - Cosine distance. Cosine distance is a distance metric
|
||||
calculated from the cosine similarity between two vectors. Cosine
|
||||
similarity is a measure of similarity between two non-zero vectors of an
|
||||
inner product space. It is defined to equal the cosine of the angle
|
||||
between them. Unlike L2, the cosine distance is not affected by the
|
||||
magnitude of the vectors. Cosine distance has a range of [0, 2].
|
||||
|
||||
Note: the cosine distance is undefined when one (or both) of the vectors
|
||||
are all zeros (there is no direction). These vectors are invalid and may
|
||||
never be returned from a vector search.
|
||||
|
||||
"dot" - Dot product. Dot distance is the dot product of two vectors. Dot
|
||||
distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their
|
||||
L2 norm is 1), then dot distance is equivalent to the cosine distance.
|
||||
num_partitions: int, default sqrt(num_rows)
|
||||
The number of IVF partitions to create.
|
||||
|
||||
This value should generally scale with the number of rows in the dataset.
|
||||
By default the number of partitions is the square root of the number of
|
||||
rows.
|
||||
|
||||
If this value is too large then the first part of the search (picking the
|
||||
right partition) will be slow. If this value is too small then the second
|
||||
part of the search (searching within a partition) will be slow.
|
||||
num_sub_vectors: int, default is vector dimension / 16
|
||||
Number of sub-vectors of PQ.
|
||||
|
||||
This value controls how much the vector is compressed during the
|
||||
quantization step. The more sub vectors there are the less the vector is
|
||||
compressed. The default is the dimension of the vector divided by 16. If
|
||||
the dimension is not evenly divisible by 16 we use the dimension divded by
|
||||
8.
|
||||
|
||||
The above two cases are highly preferred. Having 8 or 16 values per
|
||||
subvector allows us to use efficient SIMD instructions.
|
||||
|
||||
If the dimension is not visible by 8 then we use 1 subvector. This is not
|
||||
ideal and will likely result in poor performance.
|
||||
num_bits: int, default 8
|
||||
Number of bits to encode each sub-vector.
|
||||
|
||||
This value controls how much the sub-vectors are compressed. The more bits
|
||||
the more accurate the index but the slower search. The default is 8
|
||||
bits. Only 4 and 8 are supported.
|
||||
max_iterations: int, default 50
|
||||
Max iteration to train kmeans.
|
||||
|
||||
When training an IVF PQ index we use kmeans to calculate the partitions.
|
||||
This parameter controls how many iterations of kmeans to run.
|
||||
|
||||
Increasing this might improve the quality of the index but in most cases
|
||||
these extra iterations have diminishing returns.
|
||||
|
||||
The default value is 50.
|
||||
sample_rate: int, default 256
|
||||
The rate used to calculate the number of training vectors for kmeans.
|
||||
|
||||
When an IVF PQ index is trained, we need to calculate partitions. These
|
||||
are groups of vectors that are similar to each other. To do this we use an
|
||||
algorithm called kmeans.
|
||||
|
||||
Running kmeans on a large dataset can be slow. To speed this up we run
|
||||
kmeans on a random sample of the data. This parameter controls the size of
|
||||
the sample. The total number of vectors used to train the index is
|
||||
`sample_rate * num_partitions`.
|
||||
|
||||
Increasing this value might improve the quality of the index but in most
|
||||
cases the default should be sufficient.
|
||||
|
||||
The default value is 256.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
distance_type: Optional[str] = None,
|
||||
num_partitions: Optional[int] = None,
|
||||
num_sub_vectors: Optional[int] = None,
|
||||
num_bits: Optional[int] = None,
|
||||
max_iterations: Optional[int] = None,
|
||||
sample_rate: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Create an IVF PQ index config
|
||||
|
||||
Parameters
|
||||
----------
|
||||
distance_type: str, default "L2"
|
||||
The distance metric used to train the index
|
||||
|
||||
This is used when training the index to calculate the IVF partitions
|
||||
(vectors are grouped in partitions with similar vectors according to this
|
||||
distance type) and to calculate a subvector's code during quantization.
|
||||
|
||||
The distance type used to train an index MUST match the distance type used
|
||||
to search the index. Failure to do so will yield inaccurate results.
|
||||
|
||||
The following distance types are available:
|
||||
|
||||
"l2" - Euclidean distance. This is a very common distance metric that
|
||||
accounts for both magnitude and direction when determining the distance
|
||||
between vectors. L2 distance has a range of [0, ∞).
|
||||
|
||||
"cosine" - Cosine distance. Cosine distance is a distance metric
|
||||
calculated from the cosine similarity between two vectors. Cosine
|
||||
similarity is a measure of similarity between two non-zero vectors of an
|
||||
inner product space. It is defined to equal the cosine of the angle
|
||||
between them. Unlike L2, the cosine distance is not affected by the
|
||||
magnitude of the vectors. Cosine distance has a range of [0, 2].
|
||||
|
||||
Note: the cosine distance is undefined when one (or both) of the vectors
|
||||
are all zeros (there is no direction). These vectors are invalid and may
|
||||
never be returned from a vector search.
|
||||
|
||||
"dot" - Dot product. Dot distance is the dot product of two vectors. Dot
|
||||
distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their
|
||||
L2 norm is 1), then dot distance is equivalent to the cosine distance.
|
||||
num_partitions: int, default sqrt(num_rows)
|
||||
The number of IVF partitions to create.
|
||||
|
||||
This value should generally scale with the number of rows in the dataset.
|
||||
By default the number of partitions is the square root of the number of
|
||||
rows.
|
||||
|
||||
If this value is too large then the first part of the search (picking the
|
||||
right partition) will be slow. If this value is too small then the second
|
||||
part of the search (searching within a partition) will be slow.
|
||||
num_sub_vectors: int, default is vector dimension / 16
|
||||
Number of sub-vectors of PQ.
|
||||
|
||||
This value controls how much the vector is compressed during the
|
||||
quantization step. The more sub vectors there are the less the vector is
|
||||
compressed. The default is the dimension of the vector divided by 16. If
|
||||
the dimension is not evenly divisible by 16 we use the dimension divded by
|
||||
8.
|
||||
|
||||
The above two cases are highly preferred. Having 8 or 16 values per
|
||||
subvector allows us to use efficient SIMD instructions.
|
||||
|
||||
If the dimension is not visible by 8 then we use 1 subvector. This is not
|
||||
ideal and will likely result in poor performance.
|
||||
num_bits: int, default 8
|
||||
Number of bits to encode each sub-vector.
|
||||
|
||||
This value controls how much the sub-vectors are compressed. The more bits
|
||||
the more accurate the index but the slower search. The default is 8
|
||||
bits. Only 4 and 8 are supported.
|
||||
max_iterations: int, default 50
|
||||
Max iteration to train kmeans.
|
||||
|
||||
When training an IVF PQ index we use kmeans to calculate the partitions.
|
||||
This parameter controls how many iterations of kmeans to run.
|
||||
|
||||
Increasing this might improve the quality of the index but in most cases
|
||||
these extra iterations have diminishing returns.
|
||||
|
||||
The default value is 50.
|
||||
sample_rate: int, default 256
|
||||
The rate used to calculate the number of training vectors for kmeans.
|
||||
|
||||
When an IVF PQ index is trained, we need to calculate partitions. These
|
||||
are groups of vectors that are similar to each other. To do this we use an
|
||||
algorithm called kmeans.
|
||||
|
||||
Running kmeans on a large dataset can be slow. To speed this up we run
|
||||
kmeans on a random sample of the data. This parameter controls the size of
|
||||
the sample. The total number of vectors used to train the index is
|
||||
`sample_rate * num_partitions`.
|
||||
|
||||
Increasing this value might improve the quality of the index but in most
|
||||
cases the default should be sufficient.
|
||||
|
||||
The default value is 256.
|
||||
"""
|
||||
if distance_type is not None:
|
||||
distance_type = distance_type.lower()
|
||||
self._inner = LanceDbIndex.ivf_pq(
|
||||
distance_type=distance_type,
|
||||
num_partitions=num_partitions,
|
||||
num_sub_vectors=num_sub_vectors,
|
||||
num_bits=num_bits,
|
||||
max_iterations=max_iterations,
|
||||
sample_rate=sample_rate,
|
||||
)
|
||||
distance_type: Literal["l2", "cosine", "dot"] = "l2"
|
||||
num_partitions: Optional[int] = None
|
||||
num_sub_vectors: Optional[int] = None
|
||||
num_bits: int = 8
|
||||
max_iterations: int = 50
|
||||
sample_rate: int = 256
|
||||
|
||||
|
||||
__all__ = ["BTree", "IvfPq", "IndexConfig"]
|
||||
__all__ = [
|
||||
"BTree",
|
||||
"IvfPq",
|
||||
"IvfFlat",
|
||||
"HnswPq",
|
||||
"HnswSq",
|
||||
"IndexConfig",
|
||||
"FTS",
|
||||
"Bitmap",
|
||||
"LabelList",
|
||||
]
|
||||
|
||||
@@ -115,6 +115,9 @@ class Query(pydantic.BaseModel):
|
||||
# e.g. `{"nprobes": "10", "refine_factor": "10"}`
|
||||
nprobes: int = 10
|
||||
|
||||
lower_bound: Optional[float] = None
|
||||
upper_bound: Optional[float] = None
|
||||
|
||||
# Refine factor.
|
||||
refine_factor: Optional[int] = None
|
||||
|
||||
@@ -126,6 +129,9 @@ class Query(pydantic.BaseModel):
|
||||
|
||||
ef: Optional[int] = None
|
||||
|
||||
# Default is true. Set to false to enforce a brute force search.
|
||||
use_index: bool = True
|
||||
|
||||
|
||||
class LanceQueryBuilder(ABC):
|
||||
"""An abstract query builder. Subclasses are defined for vector search,
|
||||
@@ -253,6 +259,7 @@ class LanceQueryBuilder(ABC):
|
||||
self._vector = None
|
||||
self._text = None
|
||||
self._ef = None
|
||||
self._use_index = True
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.3.1",
|
||||
@@ -511,6 +518,7 @@ class LanceQueryBuilder(ABC):
|
||||
"metric": self._metric,
|
||||
"nprobes": self._nprobes,
|
||||
"refine_factor": self._refine_factor,
|
||||
"use_index": self._use_index,
|
||||
},
|
||||
prefilter=self._prefilter,
|
||||
filter=self._str_query,
|
||||
@@ -599,6 +607,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
self._query = query
|
||||
self._metric = "L2"
|
||||
self._nprobes = 20
|
||||
self._lower_bound = None
|
||||
self._upper_bound = None
|
||||
self._refine_factor = None
|
||||
self._vector_column = vector_column
|
||||
self._prefilter = False
|
||||
@@ -644,6 +654,30 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
self._nprobes = nprobes
|
||||
return self
|
||||
|
||||
def distance_range(
|
||||
self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None
|
||||
) -> LanceVectorQueryBuilder:
|
||||
"""Set the distance range to use.
|
||||
|
||||
Only rows with distances within range [lower_bound, upper_bound)
|
||||
will be returned.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lower: Optional[float]
|
||||
The lower bound of the distance range.
|
||||
upper_bound: Optional[float]
|
||||
The upper bound of the distance range.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._lower_bound = lower_bound
|
||||
self._upper_bound = upper_bound
|
||||
return self
|
||||
|
||||
def ef(self, ef: int) -> LanceVectorQueryBuilder:
|
||||
"""Set the number of candidates to consider during search.
|
||||
|
||||
@@ -723,12 +757,15 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
metric=self._metric,
|
||||
columns=self._columns,
|
||||
nprobes=self._nprobes,
|
||||
lower_bound=self._lower_bound,
|
||||
upper_bound=self._upper_bound,
|
||||
refine_factor=self._refine_factor,
|
||||
vector_column=self._vector_column,
|
||||
with_row_id=self._with_row_id,
|
||||
offset=self._offset,
|
||||
fast_search=self._fast_search,
|
||||
ef=self._ef,
|
||||
use_index=self._use_index,
|
||||
)
|
||||
result_set = self._table._execute_query(query, batch_size)
|
||||
if self._reranker is not None:
|
||||
@@ -802,6 +839,24 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
self._str_query = query_string if query_string is not None else self._str_query
|
||||
return self
|
||||
|
||||
def bypass_vector_index(self) -> LanceVectorQueryBuilder:
|
||||
"""
|
||||
If this is called then any vector index is skipped
|
||||
|
||||
An exhaustive (flat) search will be performed. The query vector will
|
||||
be compared to every vector in the table. At high scales this can be
|
||||
expensive. However, this is often still useful. For example, skipping
|
||||
the vector index can give you ground truth results which you can use to
|
||||
calculate your recall to select an appropriate value for nprobes.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceVectorQueryBuilder
|
||||
The LanceVectorQueryBuilder object.
|
||||
"""
|
||||
self._use_index = False
|
||||
return self
|
||||
|
||||
|
||||
class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
"""A builder for full text search for LanceDB."""
|
||||
@@ -1108,6 +1163,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._vector_query.refine_factor(self._refine_factor)
|
||||
if self._ef:
|
||||
self._vector_query.ef(self._ef)
|
||||
if not self._use_index:
|
||||
self._vector_query.bypass_vector_index()
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow)
|
||||
@@ -1258,6 +1315,31 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._nprobes = nprobes
|
||||
return self
|
||||
|
||||
def distance_range(
|
||||
self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None
|
||||
) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the distance range to use.
|
||||
|
||||
Only rows with distances within range [lower_bound, upper_bound)
|
||||
will be returned.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lower: Optional[float]
|
||||
The lower bound of the distance range.
|
||||
upper_bound: Optional[float]
|
||||
The upper bound of the distance range.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._lower_bound = lower_bound
|
||||
self._upper_bound = upper_bound
|
||||
return self
|
||||
|
||||
def ef(self, ef: int) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the number of candidates to consider during search.
|
||||
@@ -1323,6 +1405,24 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._text = text
|
||||
return self
|
||||
|
||||
def bypass_vector_index(self) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
If this is called then any vector index is skipped
|
||||
|
||||
An exhaustive (flat) search will be performed. The query vector will
|
||||
be compared to every vector in the table. At high scales this can be
|
||||
expensive. However, this is often still useful. For example, skipping
|
||||
the vector index can give you ground truth results which you can use to
|
||||
calculate your recall to select an appropriate value for nprobes.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._use_index = False
|
||||
return self
|
||||
|
||||
|
||||
class AsyncQueryBase(object):
|
||||
def __init__(self, inner: Union[LanceQuery | LanceVectorQuery]):
|
||||
@@ -1811,6 +1911,29 @@ class AsyncVectorQuery(AsyncQueryBase):
|
||||
self._inner.nprobes(nprobes)
|
||||
return self
|
||||
|
||||
def distance_range(
|
||||
self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None
|
||||
) -> AsyncVectorQuery:
|
||||
"""Set the distance range to use.
|
||||
|
||||
Only rows with distances within range [lower_bound, upper_bound)
|
||||
will be returned.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lower: Optional[float]
|
||||
The lower bound of the distance range.
|
||||
upper_bound: Optional[float]
|
||||
The upper bound of the distance range.
|
||||
|
||||
Returns
|
||||
-------
|
||||
AsyncVectorQuery
|
||||
The AsyncVectorQuery object.
|
||||
"""
|
||||
self._inner.distance_range(lower_bound, upper_bound)
|
||||
return self
|
||||
|
||||
def ef(self, ef: int) -> AsyncVectorQuery:
|
||||
"""
|
||||
Set the number of candidates to consider during search
|
||||
|
||||
@@ -121,7 +121,13 @@ class RemoteDBConnection(DBConnection):
|
||||
return LOOP.run(self._conn.table_names(start_after=page_token, limit=limit))
|
||||
|
||||
@override
|
||||
def open_table(self, name: str, *, index_cache_size: Optional[int] = None) -> Table:
|
||||
def open_table(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
) -> Table:
|
||||
"""Open a Lance Table in the database.
|
||||
|
||||
Parameters
|
||||
|
||||
@@ -1,22 +1,15 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
from functools import cached_property
|
||||
from typing import Dict, Iterable, List, Optional, Union, Literal
|
||||
import warnings
|
||||
|
||||
from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfPq, LabelList
|
||||
from lancedb._lancedb import IndexConfig
|
||||
from lancedb.embeddings.base import EmbeddingFunctionConfig
|
||||
from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfFlat, IvfPq, LabelList
|
||||
from lancedb.remote.db import LOOP
|
||||
import pyarrow as pa
|
||||
|
||||
@@ -25,7 +18,7 @@ from lancedb.merge import LanceMergeInsertBuilder
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
|
||||
from ..query import LanceVectorQueryBuilder, LanceQueryBuilder
|
||||
from ..table import AsyncTable, Query, Table
|
||||
from ..table import AsyncTable, IndexStatistics, Query, Table
|
||||
|
||||
|
||||
class RemoteTable(Table):
|
||||
@@ -62,7 +55,7 @@ class RemoteTable(Table):
|
||||
return LOOP.run(self._table.version())
|
||||
|
||||
@cached_property
|
||||
def embedding_functions(self) -> dict:
|
||||
def embedding_functions(self) -> Dict[str, EmbeddingFunctionConfig]:
|
||||
"""
|
||||
Get the embedding functions for the table
|
||||
|
||||
@@ -88,17 +81,17 @@ class RemoteTable(Table):
|
||||
"""to_pandas() is not yet supported on LanceDB cloud."""
|
||||
return NotImplementedError("to_pandas() is not yet supported on LanceDB cloud.")
|
||||
|
||||
def checkout(self, version):
|
||||
def checkout(self, version: int):
|
||||
return LOOP.run(self._table.checkout(version))
|
||||
|
||||
def checkout_latest(self):
|
||||
return LOOP.run(self._table.checkout_latest())
|
||||
|
||||
def list_indices(self):
|
||||
def list_indices(self) -> Iterable[IndexConfig]:
|
||||
"""List all the indices on the table"""
|
||||
return LOOP.run(self._table.list_indices())
|
||||
|
||||
def index_stats(self, index_uuid: str):
|
||||
def index_stats(self, index_uuid: str) -> Optional[IndexStatistics]:
|
||||
"""List all the stats of a specified index"""
|
||||
return LOOP.run(self._table.index_stats(index_uuid))
|
||||
|
||||
@@ -232,10 +225,12 @@ class RemoteTable(Table):
|
||||
config = HnswPq(distance_type=metric)
|
||||
elif index_type == "IVF_HNSW_SQ":
|
||||
config = HnswSq(distance_type=metric)
|
||||
elif index_type == "IVF_FLAT":
|
||||
config = IvfFlat(distance_type=metric)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown vector index type: {index_type}. Valid options are"
|
||||
" 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
|
||||
" 'IVF_FLAT', 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
|
||||
)
|
||||
|
||||
LOOP.run(self._table.create_index(vector_column_name, config=config))
|
||||
@@ -479,16 +474,28 @@ class RemoteTable(Table):
|
||||
)
|
||||
|
||||
def cleanup_old_versions(self, *_):
|
||||
"""cleanup_old_versions() is not supported on the LanceDB cloud"""
|
||||
raise NotImplementedError(
|
||||
"cleanup_old_versions() is not supported on the LanceDB cloud"
|
||||
"""
|
||||
cleanup_old_versions() is a no-op on LanceDB Cloud.
|
||||
|
||||
Tables are automatically cleaned up and optimized.
|
||||
"""
|
||||
warnings.warn(
|
||||
"cleanup_old_versions() is a no-op on LanceDB Cloud. "
|
||||
"Tables are automatically cleaned up and optimized."
|
||||
)
|
||||
pass
|
||||
|
||||
def compact_files(self, *_):
|
||||
"""compact_files() is not supported on the LanceDB cloud"""
|
||||
raise NotImplementedError(
|
||||
"compact_files() is not supported on the LanceDB cloud"
|
||||
"""
|
||||
compact_files() is a no-op on LanceDB Cloud.
|
||||
|
||||
Tables are automatically compacted and optimized.
|
||||
"""
|
||||
warnings.warn(
|
||||
"compact_files() is a no-op on LanceDB Cloud. "
|
||||
"Tables are automatically compacted and optimized."
|
||||
)
|
||||
pass
|
||||
|
||||
def optimize(
|
||||
self,
|
||||
@@ -496,12 +503,16 @@ class RemoteTable(Table):
|
||||
cleanup_older_than: Optional[timedelta] = None,
|
||||
delete_unverified: bool = False,
|
||||
):
|
||||
"""optimize() is not supported on the LanceDB cloud.
|
||||
Indices are optimized automatically."""
|
||||
raise NotImplementedError(
|
||||
"optimize() is not supported on the LanceDB cloud. "
|
||||
"""
|
||||
optimize() is a no-op on LanceDB Cloud.
|
||||
|
||||
Indices are optimized automatically.
|
||||
"""
|
||||
warnings.warn(
|
||||
"optimize() is a no-op on LanceDB Cloud. "
|
||||
"Indices are optimized automatically."
|
||||
)
|
||||
pass
|
||||
|
||||
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
return LOOP.run(self._table.count_rows(filter))
|
||||
@@ -515,6 +526,16 @@ class RemoteTable(Table):
|
||||
def drop_columns(self, columns: Iterable[str]):
|
||||
return LOOP.run(self._table.drop_columns(columns))
|
||||
|
||||
def uses_v2_manifest_paths(self) -> bool:
|
||||
raise NotImplementedError(
|
||||
"uses_v2_manifest_paths() is not supported on the LanceDB Cloud"
|
||||
)
|
||||
|
||||
def migrate_v2_manifest_paths(self):
|
||||
raise NotImplementedError(
|
||||
"migrate_v2_manifest_paths() is not supported on the LanceDB Cloud"
|
||||
)
|
||||
|
||||
|
||||
def add_index(tbl: pa.Table, i: int) -> pa.Table:
|
||||
return tbl.add_column(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -314,3 +314,15 @@ def deprecated(func):
|
||||
def validate_table_name(name: str):
|
||||
"""Verify the table name is valid."""
|
||||
native_validate_table_name(name)
|
||||
|
||||
|
||||
def add_note(base_exception: BaseException, note: str):
|
||||
if hasattr(base_exception, "add_note"):
|
||||
base_exception.add_note(note)
|
||||
elif isinstance(base_exception.args[0], str):
|
||||
base_exception.args = (
|
||||
base_exception.args[0] + "\n" + note,
|
||||
*base_exception.args[1:],
|
||||
)
|
||||
else:
|
||||
raise ValueError("Cannot add note to exception")
|
||||
|
||||
32
python/python/tests/conftest.py
Normal file
32
python/python/tests/conftest.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from datetime import timedelta
|
||||
from lancedb.db import AsyncConnection, DBConnection
|
||||
import lancedb
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
|
||||
# Use an in-memory database for most tests.
|
||||
@pytest.fixture
|
||||
def mem_db() -> DBConnection:
|
||||
return lancedb.connect("memory://")
|
||||
|
||||
|
||||
# Use a temporary directory when we need to inspect the database files.
|
||||
@pytest.fixture
|
||||
def tmp_db(tmp_path) -> DBConnection:
|
||||
return lancedb.connect(tmp_path)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def mem_db_async() -> AsyncConnection:
|
||||
return await lancedb.connect_async("memory://")
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def tmp_db_async(tmp_path) -> AsyncConnection:
|
||||
return await lancedb.connect_async(
|
||||
tmp_path, read_consistency_interval=timedelta(seconds=0)
|
||||
)
|
||||
44
python/python/tests/docs/test_binary_vector.py
Normal file
44
python/python/tests/docs/test_binary_vector.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import shutil
|
||||
|
||||
# --8<-- [start:imports]
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pytest
|
||||
# --8<-- [end:imports]
|
||||
|
||||
shutil.rmtree("data/binary_lancedb", ignore_errors=True)
|
||||
|
||||
|
||||
def test_binary_vector():
|
||||
# --8<-- [start:sync_binary_vector]
|
||||
db = lancedb.connect("data/binary_lancedb")
|
||||
data = [
|
||||
{
|
||||
"id": i,
|
||||
"vector": np.random.randint(0, 256, size=16),
|
||||
}
|
||||
for i in range(1024)
|
||||
]
|
||||
tbl = db.create_table("my_binary_vectors", data=data)
|
||||
query = np.random.randint(0, 256, size=16)
|
||||
tbl.search(query).to_arrow()
|
||||
# --8<-- [end:sync_binary_vector]
|
||||
db.drop_table("my_binary_vectors")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_binary_vector_async():
|
||||
# --8<-- [start:async_binary_vector]
|
||||
db = await lancedb.connect_async("data/binary_lancedb")
|
||||
data = [
|
||||
{
|
||||
"id": i,
|
||||
"vector": np.random.randint(0, 256, size=16),
|
||||
}
|
||||
for i in range(1024)
|
||||
]
|
||||
tbl = await db.create_table("my_binary_vectors", data=data)
|
||||
query = np.random.randint(0, 256, size=16)
|
||||
await tbl.query().nearest_to(query).to_arrow()
|
||||
# --8<-- [end:async_binary_vector]
|
||||
await db.drop_table("my_binary_vectors")
|
||||
@@ -98,7 +98,7 @@ def test_ingest_pd(tmp_path):
|
||||
assert db.open_table("test").name == db["test"].name
|
||||
|
||||
|
||||
def test_ingest_iterator(tmp_path):
|
||||
def test_ingest_iterator(mem_db: lancedb.DBConnection):
|
||||
class PydanticSchema(LanceModel):
|
||||
vector: Vector(2)
|
||||
item: str
|
||||
@@ -156,8 +156,7 @@ def test_ingest_iterator(tmp_path):
|
||||
]
|
||||
|
||||
def run_tests(schema):
|
||||
db = lancedb.connect(tmp_path)
|
||||
tbl = db.create_table("table2", make_batches(), schema=schema, mode="overwrite")
|
||||
tbl = mem_db.create_table("table2", make_batches(), schema=schema)
|
||||
tbl.to_pandas()
|
||||
assert tbl.search([3.1, 4.1]).limit(1).to_pandas()["_distance"][0] == 0.0
|
||||
assert tbl.search([5.9, 26.5]).limit(1).to_pandas()["_distance"][0] == 0.0
|
||||
@@ -165,15 +164,14 @@ def test_ingest_iterator(tmp_path):
|
||||
tbl.add(make_batches())
|
||||
assert tbl_len == 50
|
||||
assert len(tbl) == tbl_len * 2
|
||||
assert len(tbl.list_versions()) == 3
|
||||
db.drop_database()
|
||||
assert len(tbl.list_versions()) == 2
|
||||
mem_db.drop_database()
|
||||
|
||||
run_tests(arrow_schema)
|
||||
run_tests(PydanticSchema)
|
||||
|
||||
|
||||
def test_table_names(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
def test_table_names(tmp_db: lancedb.DBConnection):
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
@@ -181,10 +179,10 @@ def test_table_names(tmp_path):
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
db.create_table("test2", data=data)
|
||||
db.create_table("test1", data=data)
|
||||
db.create_table("test3", data=data)
|
||||
assert db.table_names() == ["test1", "test2", "test3"]
|
||||
tmp_db.create_table("test2", data=data)
|
||||
tmp_db.create_table("test1", data=data)
|
||||
tmp_db.create_table("test3", data=data)
|
||||
assert tmp_db.table_names() == ["test1", "test2", "test3"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -209,8 +207,7 @@ async def test_table_names_async(tmp_path):
|
||||
assert await db.table_names(start_after="test1") == ["test2", "test3"]
|
||||
|
||||
|
||||
def test_create_mode(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
def test_create_mode(tmp_db: lancedb.DBConnection):
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
@@ -218,10 +215,10 @@ def test_create_mode(tmp_path):
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
db.create_table("test", data=data)
|
||||
tmp_db.create_table("test", data=data)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
db.create_table("test", data=data)
|
||||
tmp_db.create_table("test", data=data)
|
||||
|
||||
new_data = pd.DataFrame(
|
||||
{
|
||||
@@ -230,13 +227,11 @@ def test_create_mode(tmp_path):
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
tbl = db.create_table("test", data=new_data, mode="overwrite")
|
||||
tbl = tmp_db.create_table("test", data=new_data, mode="overwrite")
|
||||
assert tbl.to_pandas().item.tolist() == ["fizz", "buzz"]
|
||||
|
||||
|
||||
def test_create_table_from_iterator(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
|
||||
def test_create_table_from_iterator(mem_db: lancedb.DBConnection):
|
||||
def gen_data():
|
||||
for _ in range(10):
|
||||
yield pa.RecordBatch.from_arrays(
|
||||
@@ -248,14 +243,12 @@ def test_create_table_from_iterator(tmp_path):
|
||||
["vector", "item", "price"],
|
||||
)
|
||||
|
||||
table = db.create_table("test", data=gen_data())
|
||||
table = mem_db.create_table("test", data=gen_data())
|
||||
assert table.count_rows() == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_table_from_iterator_async(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
|
||||
async def test_create_table_from_iterator_async(mem_db_async: lancedb.AsyncConnection):
|
||||
def gen_data():
|
||||
for _ in range(10):
|
||||
yield pa.RecordBatch.from_arrays(
|
||||
@@ -267,12 +260,11 @@ async def test_create_table_from_iterator_async(tmp_path):
|
||||
["vector", "item", "price"],
|
||||
)
|
||||
|
||||
table = await db.create_table("test", data=gen_data())
|
||||
table = await mem_db_async.create_table("test", data=gen_data())
|
||||
assert await table.count_rows() == 10
|
||||
|
||||
|
||||
def test_create_exist_ok(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
def test_create_exist_ok(tmp_db: lancedb.DBConnection):
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
@@ -280,13 +272,13 @@ def test_create_exist_ok(tmp_path):
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
tbl = db.create_table("test", data=data)
|
||||
tbl = tmp_db.create_table("test", data=data)
|
||||
|
||||
with pytest.raises(OSError):
|
||||
db.create_table("test", data=data)
|
||||
with pytest.raises(ValueError):
|
||||
tmp_db.create_table("test", data=data)
|
||||
|
||||
# open the table but don't add more rows
|
||||
tbl2 = db.create_table("test", data=data, exist_ok=True)
|
||||
tbl2 = tmp_db.create_table("test", data=data, exist_ok=True)
|
||||
assert tbl.name == tbl2.name
|
||||
assert tbl.schema == tbl2.schema
|
||||
assert len(tbl) == len(tbl2)
|
||||
@@ -298,7 +290,7 @@ def test_create_exist_ok(tmp_path):
|
||||
pa.field("price", pa.float64()),
|
||||
]
|
||||
)
|
||||
tbl3 = db.create_table("test", schema=schema, exist_ok=True)
|
||||
tbl3 = tmp_db.create_table("test", schema=schema, exist_ok=True)
|
||||
assert tbl3.schema == schema
|
||||
|
||||
bad_schema = pa.schema(
|
||||
@@ -310,7 +302,7 @@ def test_create_exist_ok(tmp_path):
|
||||
]
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
db.create_table("test", schema=bad_schema, exist_ok=True)
|
||||
tmp_db.create_table("test", schema=bad_schema, exist_ok=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -325,26 +317,24 @@ async def test_connect(tmp_path):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
assert db.is_open()
|
||||
db.close()
|
||||
assert not db.is_open()
|
||||
async def test_close(mem_db_async: lancedb.AsyncConnection):
|
||||
assert mem_db_async.is_open()
|
||||
mem_db_async.close()
|
||||
assert not mem_db_async.is_open()
|
||||
|
||||
with pytest.raises(RuntimeError, match="is closed"):
|
||||
await db.table_names()
|
||||
await mem_db_async.table_names()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_manager(tmp_path):
|
||||
with await lancedb.connect_async(tmp_path) as db:
|
||||
async def test_context_manager():
|
||||
with await lancedb.connect_async("memory://") as db:
|
||||
assert db.is_open()
|
||||
assert not db.is_open()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_mode_async(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
async def test_create_mode_async(tmp_db_async: lancedb.AsyncConnection):
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
@@ -352,10 +342,10 @@ async def test_create_mode_async(tmp_path):
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
await db.create_table("test", data=data)
|
||||
await tmp_db_async.create_table("test", data=data)
|
||||
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
await db.create_table("test", data=data)
|
||||
await tmp_db_async.create_table("test", data=data)
|
||||
|
||||
new_data = pd.DataFrame(
|
||||
{
|
||||
@@ -364,15 +354,14 @@ async def test_create_mode_async(tmp_path):
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
_tbl = await db.create_table("test", data=new_data, mode="overwrite")
|
||||
_tbl = await tmp_db_async.create_table("test", data=new_data, mode="overwrite")
|
||||
|
||||
# MIGRATION: to_pandas() is not available in async
|
||||
# assert tbl.to_pandas().item.tolist() == ["fizz", "buzz"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_exist_ok_async(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
async def test_create_exist_ok_async(tmp_db_async: lancedb.AsyncConnection):
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
@@ -380,13 +369,13 @@ async def test_create_exist_ok_async(tmp_path):
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
tbl = await db.create_table("test", data=data)
|
||||
tbl = await tmp_db_async.create_table("test", data=data)
|
||||
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
await db.create_table("test", data=data)
|
||||
await tmp_db_async.create_table("test", data=data)
|
||||
|
||||
# open the table but don't add more rows
|
||||
tbl2 = await db.create_table("test", data=data, exist_ok=True)
|
||||
tbl2 = await tmp_db_async.create_table("test", data=data, exist_ok=True)
|
||||
assert tbl.name == tbl2.name
|
||||
assert await tbl.schema() == await tbl2.schema()
|
||||
|
||||
@@ -397,7 +386,7 @@ async def test_create_exist_ok_async(tmp_path):
|
||||
pa.field("price", pa.float64()),
|
||||
]
|
||||
)
|
||||
tbl3 = await db.create_table("test", schema=schema, exist_ok=True)
|
||||
tbl3 = await tmp_db_async.create_table("test", schema=schema, exist_ok=True)
|
||||
assert await tbl3.schema() == schema
|
||||
|
||||
# Migration: When creating a table, but the table already exists, but
|
||||
@@ -448,13 +437,12 @@ async def test_create_table_v2_manifest_paths_async(tmp_path):
|
||||
assert re.match(r"\d{20}\.manifest", manifest)
|
||||
|
||||
|
||||
def test_open_table_sync(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
db.create_table("test", data=[{"id": 0}])
|
||||
assert db.open_table("test").count_rows() == 1
|
||||
assert db.open_table("test", index_cache_size=0).count_rows() == 1
|
||||
with pytest.raises(FileNotFoundError, match="does not exist"):
|
||||
db.open_table("does_not_exist")
|
||||
def test_open_table_sync(tmp_db: lancedb.DBConnection):
|
||||
tmp_db.create_table("test", data=[{"id": 0}])
|
||||
assert tmp_db.open_table("test").count_rows() == 1
|
||||
assert tmp_db.open_table("test", index_cache_size=0).count_rows() == 1
|
||||
with pytest.raises(ValueError, match="Table 'does_not_exist' was not found"):
|
||||
tmp_db.open_table("does_not_exist")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -494,8 +482,7 @@ async def test_open_table(tmp_path):
|
||||
await db.open_table("does_not_exist")
|
||||
|
||||
|
||||
def test_delete_table(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
def test_delete_table(tmp_db: lancedb.DBConnection):
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
@@ -503,26 +490,51 @@ def test_delete_table(tmp_path):
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
db.create_table("test", data=data)
|
||||
tmp_db.create_table("test", data=data)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
db.create_table("test", data=data)
|
||||
tmp_db.create_table("test", data=data)
|
||||
|
||||
assert db.table_names() == ["test"]
|
||||
assert tmp_db.table_names() == ["test"]
|
||||
|
||||
db.drop_table("test")
|
||||
assert db.table_names() == []
|
||||
tmp_db.drop_table("test")
|
||||
assert tmp_db.table_names() == []
|
||||
|
||||
db.create_table("test", data=data)
|
||||
assert db.table_names() == ["test"]
|
||||
tmp_db.create_table("test", data=data)
|
||||
assert tmp_db.table_names() == ["test"]
|
||||
|
||||
# dropping a table that does not exist should pass
|
||||
# if ignore_missing=True
|
||||
db.drop_table("does_not_exist", ignore_missing=True)
|
||||
tmp_db.drop_table("does_not_exist", ignore_missing=True)
|
||||
|
||||
|
||||
def test_drop_database(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_table_async(tmp_db: lancedb.DBConnection):
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
"item": ["foo", "bar"],
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
|
||||
tmp_db.create_table("test", data=data)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
tmp_db.create_table("test", data=data)
|
||||
|
||||
assert tmp_db.table_names() == ["test"]
|
||||
|
||||
tmp_db.drop_table("test")
|
||||
assert tmp_db.table_names() == []
|
||||
|
||||
tmp_db.create_table("test", data=data)
|
||||
assert tmp_db.table_names() == ["test"]
|
||||
|
||||
tmp_db.drop_table("does_not_exist", ignore_missing=True)
|
||||
|
||||
|
||||
def test_drop_database(tmp_db: lancedb.DBConnection):
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
@@ -537,51 +549,50 @@ def test_drop_database(tmp_path):
|
||||
"price": [12.0, 17.0],
|
||||
}
|
||||
)
|
||||
db.create_table("test", data=data)
|
||||
tmp_db.create_table("test", data=data)
|
||||
with pytest.raises(Exception):
|
||||
db.create_table("test", data=data)
|
||||
tmp_db.create_table("test", data=data)
|
||||
|
||||
assert db.table_names() == ["test"]
|
||||
assert tmp_db.table_names() == ["test"]
|
||||
|
||||
db.create_table("new_test", data=new_data)
|
||||
db.drop_database()
|
||||
assert db.table_names() == []
|
||||
tmp_db.create_table("new_test", data=new_data)
|
||||
tmp_db.drop_database()
|
||||
assert tmp_db.table_names() == []
|
||||
|
||||
# it should pass when no tables are present
|
||||
db.create_table("test", data=new_data)
|
||||
db.drop_table("test")
|
||||
assert db.table_names() == []
|
||||
db.drop_database()
|
||||
assert db.table_names() == []
|
||||
tmp_db.create_table("test", data=new_data)
|
||||
tmp_db.drop_table("test")
|
||||
assert tmp_db.table_names() == []
|
||||
tmp_db.drop_database()
|
||||
assert tmp_db.table_names() == []
|
||||
|
||||
# creating an empty database with schema
|
||||
schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), list_size=2))])
|
||||
db.create_table("empty_table", schema=schema)
|
||||
tmp_db.create_table("empty_table", schema=schema)
|
||||
# dropping a empty database should pass
|
||||
db.drop_database()
|
||||
assert db.table_names() == []
|
||||
tmp_db.drop_database()
|
||||
assert tmp_db.table_names() == []
|
||||
|
||||
|
||||
def test_empty_or_nonexistent_table(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
def test_empty_or_nonexistent_table(mem_db: lancedb.DBConnection):
|
||||
with pytest.raises(Exception):
|
||||
db.create_table("test_with_no_data")
|
||||
mem_db.create_table("test_with_no_data")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
db.open_table("does_not_exist")
|
||||
mem_db.open_table("does_not_exist")
|
||||
|
||||
schema = pa.schema([pa.field("a", pa.int64(), nullable=False)])
|
||||
test = db.create_table("test", schema=schema)
|
||||
test = mem_db.create_table("test", schema=schema)
|
||||
|
||||
class TestModel(LanceModel):
|
||||
a: int
|
||||
|
||||
test2 = db.create_table("test2", schema=TestModel)
|
||||
test2 = mem_db.create_table("test2", schema=TestModel)
|
||||
assert test.schema == test2.schema
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_in_v2_mode(tmp_path):
|
||||
async def test_create_in_v2_mode(mem_db_async: lancedb.AsyncConnection):
|
||||
def make_data():
|
||||
for i in range(10):
|
||||
yield pa.record_batch([pa.array([x for x in range(1024)])], names=["x"])
|
||||
@@ -591,10 +602,8 @@ async def test_create_in_v2_mode(tmp_path):
|
||||
|
||||
schema = pa.schema([pa.field("x", pa.int64())])
|
||||
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
|
||||
# Create table in v1 mode
|
||||
tbl = await db.create_table(
|
||||
tbl = await mem_db_async.create_table(
|
||||
"test", data=make_data(), schema=schema, data_storage_version="legacy"
|
||||
)
|
||||
|
||||
@@ -610,7 +619,7 @@ async def test_create_in_v2_mode(tmp_path):
|
||||
assert not await is_in_v2_mode(tbl)
|
||||
|
||||
# Create table in v2 mode
|
||||
tbl = await db.create_table(
|
||||
tbl = await mem_db_async.create_table(
|
||||
"test_v2", data=make_data(), schema=schema, use_legacy_format=False
|
||||
)
|
||||
|
||||
@@ -622,7 +631,7 @@ async def test_create_in_v2_mode(tmp_path):
|
||||
assert await is_in_v2_mode(tbl)
|
||||
|
||||
# Create empty table in v2 mode and add data
|
||||
tbl = await db.create_table(
|
||||
tbl = await mem_db_async.create_table(
|
||||
"test_empty_v2", data=None, schema=schema, use_legacy_format=False
|
||||
)
|
||||
await tbl.add(make_table())
|
||||
@@ -630,7 +639,7 @@ async def test_create_in_v2_mode(tmp_path):
|
||||
assert await is_in_v2_mode(tbl)
|
||||
|
||||
# Create empty table uses v1 mode by default
|
||||
tbl = await db.create_table(
|
||||
tbl = await mem_db_async.create_table(
|
||||
"test_empty_v2_default", data=None, schema=schema, data_storage_version="legacy"
|
||||
)
|
||||
await tbl.add(make_table())
|
||||
@@ -638,18 +647,17 @@ async def test_create_in_v2_mode(tmp_path):
|
||||
assert not await is_in_v2_mode(tbl)
|
||||
|
||||
|
||||
def test_replace_index(tmp_path):
|
||||
db = lancedb.connect(uri=tmp_path)
|
||||
table = db.create_table(
|
||||
def test_replace_index(mem_db: lancedb.DBConnection):
|
||||
table = mem_db.create_table(
|
||||
"test",
|
||||
[
|
||||
{"vector": np.random.rand(128), "item": "foo", "price": float(i)}
|
||||
for i in range(1000)
|
||||
{"vector": np.random.rand(32), "item": "foo", "price": float(i)}
|
||||
for i in range(512)
|
||||
],
|
||||
)
|
||||
table.create_index(
|
||||
num_partitions=2,
|
||||
num_sub_vectors=4,
|
||||
num_sub_vectors=2,
|
||||
)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
@@ -660,27 +668,26 @@ def test_replace_index(tmp_path):
|
||||
)
|
||||
|
||||
table.create_index(
|
||||
num_partitions=2,
|
||||
num_sub_vectors=4,
|
||||
num_partitions=1,
|
||||
num_sub_vectors=2,
|
||||
replace=True,
|
||||
index_cache_size=10,
|
||||
)
|
||||
|
||||
|
||||
def test_prefilter_with_index(tmp_path):
|
||||
db = lancedb.connect(uri=tmp_path)
|
||||
def test_prefilter_with_index(mem_db: lancedb.DBConnection):
|
||||
data = [
|
||||
{"vector": np.random.rand(128), "item": "foo", "price": float(i)}
|
||||
for i in range(1000)
|
||||
{"vector": np.random.rand(32), "item": "foo", "price": float(i)}
|
||||
for i in range(512)
|
||||
]
|
||||
sample_key = data[100]["vector"]
|
||||
table = db.create_table(
|
||||
table = mem_db.create_table(
|
||||
"test",
|
||||
data,
|
||||
)
|
||||
table.create_index(
|
||||
num_partitions=2,
|
||||
num_sub_vectors=4,
|
||||
num_sub_vectors=2,
|
||||
)
|
||||
table = (
|
||||
table.search(sample_key)
|
||||
@@ -691,13 +698,34 @@ def test_prefilter_with_index(tmp_path):
|
||||
assert table.num_rows == 1
|
||||
|
||||
|
||||
def test_create_table_with_invalid_names(tmp_path):
|
||||
db = lancedb.connect(uri=tmp_path)
|
||||
def test_create_table_with_invalid_names(tmp_db: lancedb.DBConnection):
|
||||
data = [{"vector": np.random.rand(128), "item": "foo"} for i in range(10)]
|
||||
with pytest.raises(ValueError):
|
||||
db.create_table("foo/bar", data)
|
||||
tmp_db.create_table("foo/bar", data)
|
||||
with pytest.raises(ValueError):
|
||||
db.create_table("foo bar", data)
|
||||
tmp_db.create_table("foo bar", data)
|
||||
with pytest.raises(ValueError):
|
||||
db.create_table("foo$$bar", data)
|
||||
db.create_table("foo.bar", data)
|
||||
tmp_db.create_table("foo$$bar", data)
|
||||
tmp_db.create_table("foo.bar", data)
|
||||
|
||||
|
||||
def test_bypass_vector_index_sync(tmp_db: lancedb.DBConnection):
|
||||
data = [{"vector": np.random.rand(32)} for _ in range(512)]
|
||||
sample_key = data[100]["vector"]
|
||||
table = tmp_db.create_table(
|
||||
"test",
|
||||
data,
|
||||
)
|
||||
|
||||
table.create_index(
|
||||
num_partitions=2,
|
||||
num_sub_vectors=2,
|
||||
)
|
||||
|
||||
plan_with_index = table.search(sample_key).explain_plan(verbose=True)
|
||||
assert "ANN" in plan_with_index
|
||||
|
||||
plan_without_index = (
|
||||
table.search(sample_key).bypass_vector_index().explain_plan(verbose=True)
|
||||
)
|
||||
assert "KNN" in plan_without_index
|
||||
|
||||
@@ -15,10 +15,12 @@ import random
|
||||
from unittest import mock
|
||||
|
||||
import lancedb as ldb
|
||||
from lancedb.db import DBConnection
|
||||
from lancedb.index import FTS
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from utils import exception_output
|
||||
|
||||
pytest.importorskip("lancedb.fts")
|
||||
tantivy = pytest.importorskip("tantivy")
|
||||
@@ -165,8 +167,24 @@ def test_search_index(tmp_path, table):
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_search_fts(table, use_tantivy):
|
||||
table.create_fts_index("text", use_tantivy=use_tantivy)
|
||||
results = table.search("puppy").limit(5).to_list()
|
||||
results = table.search("puppy").select(["id", "text"]).limit(5).to_list()
|
||||
assert len(results) == 5
|
||||
assert len(results[0]) == 3 # id, text, _score
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fts_select_async(async_table):
|
||||
tbl = await async_table
|
||||
await tbl.create_index("text", config=FTS())
|
||||
results = (
|
||||
await tbl.query()
|
||||
.nearest_to_text("puppy")
|
||||
.select(["id", "text"])
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(results) == 5
|
||||
assert len(results[0]) == 3 # id, text, _score
|
||||
|
||||
|
||||
def test_search_fts_phrase_query(table):
|
||||
@@ -458,3 +476,44 @@ def test_syntax(table):
|
||||
table.search('the cats OR dogs were not really "pets" at all').phrase_query().limit(
|
||||
10
|
||||
).to_list()
|
||||
|
||||
|
||||
def test_language(mem_db: DBConnection):
|
||||
sentences = [
|
||||
"Il n'y a que trois routes qui traversent la ville.",
|
||||
"Je veux prendre la route vers l'est.",
|
||||
"Je te retrouve au café au bout de la route.",
|
||||
]
|
||||
data = [{"text": s} for s in sentences]
|
||||
table = mem_db.create_table("test", data=data)
|
||||
|
||||
with pytest.raises(ValueError) as e:
|
||||
table.create_fts_index("text", use_tantivy=False, language="klingon")
|
||||
|
||||
assert exception_output(e) == (
|
||||
"ValueError: LanceDB does not support the requested language: 'klingon'\n"
|
||||
"Supported languages: Arabic, Danish, Dutch, English, Finnish, French, "
|
||||
"German, Greek, Hungarian, Italian, Norwegian, Portuguese, Romanian, "
|
||||
"Russian, Spanish, Swedish, Tamil, Turkish"
|
||||
)
|
||||
|
||||
table.create_fts_index(
|
||||
"text",
|
||||
use_tantivy=False,
|
||||
language="French",
|
||||
stem=True,
|
||||
ascii_folding=True,
|
||||
remove_stop_words=True,
|
||||
)
|
||||
|
||||
# Can get "routes" and "route" from the same root
|
||||
results = table.search("route", query_type="fts").limit(5).to_list()
|
||||
assert len(results) == 3
|
||||
|
||||
# Can find "café", without needing to provide accent
|
||||
results = table.search("cafe", query_type="fts").limit(5).to_list()
|
||||
assert len(results) == 1
|
||||
|
||||
# Stop words -> no results
|
||||
results = table.search("la", query_type="fts").limit(5).to_list()
|
||||
assert len(results) == 0
|
||||
|
||||
@@ -8,7 +8,7 @@ import pyarrow as pa
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from lancedb import AsyncConnection, AsyncTable, connect_async
|
||||
from lancedb.index import BTree, IvfPq, Bitmap, LabelList, HnswPq, HnswSq
|
||||
from lancedb.index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@@ -42,6 +42,27 @@ async def some_table(db_async):
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def binary_table(db_async):
|
||||
data = [
|
||||
{
|
||||
"id": i,
|
||||
"vector": [i] * 128,
|
||||
}
|
||||
for i in range(NROWS)
|
||||
]
|
||||
return await db_async.create_table(
|
||||
"binary_table",
|
||||
data,
|
||||
schema=pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field("vector", pa.list_(pa.uint8(), 128)),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_scalar_index(some_table: AsyncTable):
|
||||
# Can create
|
||||
@@ -143,3 +164,27 @@ async def test_create_hnswsq_index(some_table: AsyncTable):
|
||||
await some_table.create_index("vector", config=HnswSq(num_partitions=10))
|
||||
indices = await some_table.list_indices()
|
||||
assert len(indices) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_index_with_binary_vectors(binary_table: AsyncTable):
|
||||
await binary_table.create_index(
|
||||
"vector", config=IvfFlat(distance_type="hamming", num_partitions=10)
|
||||
)
|
||||
indices = await binary_table.list_indices()
|
||||
assert len(indices) == 1
|
||||
assert indices[0].index_type == "IvfFlat"
|
||||
assert indices[0].columns == ["vector"]
|
||||
assert indices[0].name == "vector_idx"
|
||||
|
||||
stats = await binary_table.index_stats("vector_idx")
|
||||
assert stats.index_type == "IVF_FLAT"
|
||||
assert stats.distance_type == "hamming"
|
||||
assert stats.num_indexed_rows == await binary_table.count_rows()
|
||||
assert stats.num_unindexed_rows == 0
|
||||
assert stats.num_indices == 1
|
||||
|
||||
# the dataset contains vectors with all values from 0 to 255
|
||||
for v in range(256):
|
||||
res = await binary_table.query().nearest_to([v] * 128).to_arrow()
|
||||
assert res["id"][0].as_py() == v
|
||||
|
||||
@@ -94,6 +94,73 @@ def test_with_row_id(table: lancedb.table.Table):
|
||||
assert rs["_rowid"].to_pylist() == [0, 1]
|
||||
|
||||
|
||||
def test_distance_range(table: lancedb.table.Table):
|
||||
q = [0, 0]
|
||||
rs = table.search(q).to_arrow()
|
||||
dists = rs["_distance"].to_pylist()
|
||||
min_dist = dists[0]
|
||||
max_dist = dists[-1]
|
||||
|
||||
res = table.search(q).distance_range(upper_bound=min_dist).to_arrow()
|
||||
assert len(res) == 0
|
||||
|
||||
res = table.search(q).distance_range(lower_bound=max_dist).to_arrow()
|
||||
assert len(res) == 1
|
||||
assert res["_distance"].to_pylist() == [max_dist]
|
||||
|
||||
res = table.search(q).distance_range(upper_bound=max_dist).to_arrow()
|
||||
assert len(res) == 1
|
||||
assert res["_distance"].to_pylist() == [min_dist]
|
||||
|
||||
res = table.search(q).distance_range(lower_bound=min_dist).to_arrow()
|
||||
assert len(res) == 2
|
||||
assert res["_distance"].to_pylist() == [min_dist, max_dist]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_distance_range_async(table_async: AsyncTable):
|
||||
q = [0, 0]
|
||||
rs = await table_async.query().nearest_to(q).to_arrow()
|
||||
dists = rs["_distance"].to_pylist()
|
||||
min_dist = dists[0]
|
||||
max_dist = dists[-1]
|
||||
|
||||
res = (
|
||||
await table_async.query()
|
||||
.nearest_to(q)
|
||||
.distance_range(upper_bound=min_dist)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(res) == 0
|
||||
|
||||
res = (
|
||||
await table_async.query()
|
||||
.nearest_to(q)
|
||||
.distance_range(lower_bound=max_dist)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(res) == 1
|
||||
assert res["_distance"].to_pylist() == [max_dist]
|
||||
|
||||
res = (
|
||||
await table_async.query()
|
||||
.nearest_to(q)
|
||||
.distance_range(upper_bound=max_dist)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(res) == 1
|
||||
assert res["_distance"].to_pylist() == [min_dist]
|
||||
|
||||
res = (
|
||||
await table_async.query()
|
||||
.nearest_to(q)
|
||||
.distance_range(lower_bound=min_dist)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(res) == 2
|
||||
assert res["_distance"].to_pylist() == [min_dist, max_dist]
|
||||
|
||||
|
||||
def test_vector_query_with_no_limit(table):
|
||||
with pytest.raises(ValueError):
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(0).select(
|
||||
|
||||
@@ -306,6 +306,8 @@ def test_query_sync_minimal():
|
||||
"k": 10,
|
||||
"prefilter": False,
|
||||
"refine_factor": None,
|
||||
"lower_bound": None,
|
||||
"upper_bound": None,
|
||||
"ef": None,
|
||||
"vector": [1.0, 2.0, 3.0],
|
||||
"nprobes": 20,
|
||||
@@ -348,6 +350,8 @@ def test_query_sync_maximal():
|
||||
"refine_factor": 10,
|
||||
"vector": [1.0, 2.0, 3.0],
|
||||
"nprobes": 5,
|
||||
"lower_bound": None,
|
||||
"upper_bound": None,
|
||||
"ef": None,
|
||||
"filter": "id > 0",
|
||||
"columns": ["id", "name"],
|
||||
@@ -449,6 +453,8 @@ def test_query_sync_hybrid():
|
||||
"refine_factor": None,
|
||||
"vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
"nprobes": 20,
|
||||
"lower_bound": None,
|
||||
"upper_bound": None,
|
||||
"ef": None,
|
||||
"with_row_id": True,
|
||||
"version": None,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
11
python/python/tests/utils.py
Normal file
11
python/python/tests/utils.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
import pytest
|
||||
|
||||
|
||||
def exception_output(e_info: pytest.ExceptionInfo):
|
||||
import traceback
|
||||
|
||||
# skip traceback part, since it's not worth checking in tests
|
||||
lines = traceback.format_exception_only(e_info.type, e_info.value)
|
||||
return "".join(lines).strip()
|
||||
@@ -58,6 +58,11 @@ impl Connection {
|
||||
self.inner.take();
|
||||
}
|
||||
|
||||
#[getter]
|
||||
pub fn uri(&self) -> PyResult<String> {
|
||||
self.get_inner().map(|inner| inner.uri().to_string())
|
||||
}
|
||||
|
||||
#[pyo3(signature = (start_after=None, limit=None))]
|
||||
pub fn table_names(
|
||||
self_: PyRef<'_, Self>,
|
||||
|
||||
@@ -12,224 +12,174 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Mutex;
|
||||
|
||||
use lancedb::index::scalar::FtsIndexBuilder;
|
||||
use lancedb::{
|
||||
index::{
|
||||
scalar::BTreeIndexBuilder,
|
||||
vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
|
||||
Index as LanceDbIndex,
|
||||
},
|
||||
DistanceType,
|
||||
use lancedb::index::vector::IvfFlatIndexBuilder;
|
||||
use lancedb::index::{
|
||||
scalar::{BTreeIndexBuilder, FtsIndexBuilder, TokenizerConfig},
|
||||
vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
|
||||
Index as LanceDbIndex,
|
||||
};
|
||||
use pyo3::{
|
||||
exceptions::{PyKeyError, PyRuntimeError, PyValueError},
|
||||
pyclass, pymethods, IntoPy, PyObject, PyResult, Python,
|
||||
exceptions::{PyKeyError, PyValueError},
|
||||
intern, pyclass, pymethods,
|
||||
types::PyAnyMethods,
|
||||
Bound, FromPyObject, IntoPy, PyAny, PyObject, PyResult, Python,
|
||||
};
|
||||
|
||||
use crate::util::parse_distance_type;
|
||||
|
||||
#[pyclass]
|
||||
pub struct Index {
|
||||
inner: Mutex<Option<LanceDbIndex>>,
|
||||
}
|
||||
|
||||
impl Index {
|
||||
pub fn consume(&self) -> PyResult<LanceDbIndex> {
|
||||
self.inner
|
||||
.lock()
|
||||
.unwrap()
|
||||
.take()
|
||||
.ok_or_else(|| PyRuntimeError::new_err("cannot use an Index more than once"))
|
||||
pub fn class_name<'a>(ob: &'a Bound<'_, PyAny>) -> PyResult<&'a str> {
|
||||
let full_name: &str = ob
|
||||
.getattr(intern!(ob.py(), "__class__"))?
|
||||
.getattr(intern!(ob.py(), "__name__"))?
|
||||
.extract()?;
|
||||
match full_name.rsplit_once('.') {
|
||||
Some((_, name)) => Ok(name),
|
||||
None => Ok(full_name),
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl Index {
|
||||
#[pyo3(signature = (distance_type=None, num_partitions=None, num_sub_vectors=None,num_bits=None, max_iterations=None, sample_rate=None))]
|
||||
#[staticmethod]
|
||||
pub fn ivf_pq(
|
||||
distance_type: Option<String>,
|
||||
num_partitions: Option<u32>,
|
||||
num_sub_vectors: Option<u32>,
|
||||
num_bits: Option<u32>,
|
||||
max_iterations: Option<u32>,
|
||||
sample_rate: Option<u32>,
|
||||
) -> PyResult<Self> {
|
||||
let mut ivf_pq_builder = IvfPqIndexBuilder::default();
|
||||
if let Some(distance_type) = distance_type {
|
||||
let distance_type = match distance_type.as_str() {
|
||||
"l2" => Ok(DistanceType::L2),
|
||||
"cosine" => Ok(DistanceType::Cosine),
|
||||
"dot" => Ok(DistanceType::Dot),
|
||||
_ => Err(PyValueError::new_err(format!(
|
||||
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
|
||||
distance_type
|
||||
))),
|
||||
}?;
|
||||
ivf_pq_builder = ivf_pq_builder.distance_type(distance_type);
|
||||
pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<LanceDbIndex> {
|
||||
if let Some(source) = source {
|
||||
match class_name(source)? {
|
||||
"BTree" => Ok(LanceDbIndex::BTree(BTreeIndexBuilder::default())),
|
||||
"Bitmap" => Ok(LanceDbIndex::Bitmap(Default::default())),
|
||||
"LabelList" => Ok(LanceDbIndex::LabelList(Default::default())),
|
||||
"FTS" => {
|
||||
let params = source.extract::<FtsParams>()?;
|
||||
let inner_opts = TokenizerConfig::default()
|
||||
.base_tokenizer(params.base_tokenizer)
|
||||
.language(¶ms.language)
|
||||
.map_err(|_| PyValueError::new_err(format!("LanceDB does not support the requested language: '{}'", params.language)))?
|
||||
.lower_case(params.lower_case)
|
||||
.max_token_length(params.max_token_length)
|
||||
.remove_stop_words(params.remove_stop_words)
|
||||
.stem(params.stem)
|
||||
.ascii_folding(params.ascii_folding);
|
||||
let mut opts = FtsIndexBuilder::default()
|
||||
.with_position(params.with_position);
|
||||
opts.tokenizer_configs = inner_opts;
|
||||
Ok(LanceDbIndex::FTS(opts))
|
||||
},
|
||||
"IvfFlat" => {
|
||||
let params = source.extract::<IvfFlatParams>()?;
|
||||
let distance_type = parse_distance_type(params.distance_type)?;
|
||||
let mut ivf_flat_builder = IvfFlatIndexBuilder::default()
|
||||
.distance_type(distance_type)
|
||||
.max_iterations(params.max_iterations)
|
||||
.sample_rate(params.sample_rate);
|
||||
if let Some(num_partitions) = params.num_partitions {
|
||||
ivf_flat_builder = ivf_flat_builder.num_partitions(num_partitions);
|
||||
}
|
||||
Ok(LanceDbIndex::IvfFlat(ivf_flat_builder))
|
||||
},
|
||||
"IvfPq" => {
|
||||
let params = source.extract::<IvfPqParams>()?;
|
||||
let distance_type = parse_distance_type(params.distance_type)?;
|
||||
let mut ivf_pq_builder = IvfPqIndexBuilder::default()
|
||||
.distance_type(distance_type)
|
||||
.max_iterations(params.max_iterations)
|
||||
.sample_rate(params.sample_rate)
|
||||
.num_bits(params.num_bits);
|
||||
if let Some(num_partitions) = params.num_partitions {
|
||||
ivf_pq_builder = ivf_pq_builder.num_partitions(num_partitions);
|
||||
}
|
||||
if let Some(num_sub_vectors) = params.num_sub_vectors {
|
||||
ivf_pq_builder = ivf_pq_builder.num_sub_vectors(num_sub_vectors);
|
||||
}
|
||||
Ok(LanceDbIndex::IvfPq(ivf_pq_builder))
|
||||
},
|
||||
"HnswPq" => {
|
||||
let params = source.extract::<IvfHnswPqParams>()?;
|
||||
let distance_type = parse_distance_type(params.distance_type)?;
|
||||
let mut hnsw_pq_builder = IvfHnswPqIndexBuilder::default()
|
||||
.distance_type(distance_type)
|
||||
.max_iterations(params.max_iterations)
|
||||
.sample_rate(params.sample_rate)
|
||||
.num_edges(params.m)
|
||||
.ef_construction(params.ef_construction)
|
||||
.num_bits(params.num_bits);
|
||||
if let Some(num_partitions) = params.num_partitions {
|
||||
hnsw_pq_builder = hnsw_pq_builder.num_partitions(num_partitions);
|
||||
}
|
||||
if let Some(num_sub_vectors) = params.num_sub_vectors {
|
||||
hnsw_pq_builder = hnsw_pq_builder.num_sub_vectors(num_sub_vectors);
|
||||
}
|
||||
Ok(LanceDbIndex::IvfHnswPq(hnsw_pq_builder))
|
||||
},
|
||||
"HnswSq" => {
|
||||
let params = source.extract::<IvfHnswSqParams>()?;
|
||||
let distance_type = parse_distance_type(params.distance_type)?;
|
||||
let mut hnsw_sq_builder = IvfHnswSqIndexBuilder::default()
|
||||
.distance_type(distance_type)
|
||||
.max_iterations(params.max_iterations)
|
||||
.sample_rate(params.sample_rate)
|
||||
.num_edges(params.m)
|
||||
.ef_construction(params.ef_construction);
|
||||
if let Some(num_partitions) = params.num_partitions {
|
||||
hnsw_sq_builder = hnsw_sq_builder.num_partitions(num_partitions);
|
||||
}
|
||||
Ok(LanceDbIndex::IvfHnswSq(hnsw_sq_builder))
|
||||
},
|
||||
not_supported => Err(PyValueError::new_err(format!(
|
||||
"Invalid index type '{}'. Must be one of BTree, Bitmap, LabelList, FTS, IvfPq, IvfHnswPq, or IvfHnswSq",
|
||||
not_supported
|
||||
))),
|
||||
}
|
||||
if let Some(num_partitions) = num_partitions {
|
||||
ivf_pq_builder = ivf_pq_builder.num_partitions(num_partitions);
|
||||
}
|
||||
if let Some(num_sub_vectors) = num_sub_vectors {
|
||||
ivf_pq_builder = ivf_pq_builder.num_sub_vectors(num_sub_vectors);
|
||||
}
|
||||
if let Some(num_bits) = num_bits {
|
||||
ivf_pq_builder = ivf_pq_builder.num_bits(num_bits);
|
||||
}
|
||||
if let Some(max_iterations) = max_iterations {
|
||||
ivf_pq_builder = ivf_pq_builder.max_iterations(max_iterations);
|
||||
}
|
||||
if let Some(sample_rate) = sample_rate {
|
||||
ivf_pq_builder = ivf_pq_builder.sample_rate(sample_rate);
|
||||
}
|
||||
Ok(Self {
|
||||
inner: Mutex::new(Some(LanceDbIndex::IvfPq(ivf_pq_builder))),
|
||||
})
|
||||
} else {
|
||||
Ok(LanceDbIndex::Auto)
|
||||
}
|
||||
}
|
||||
|
||||
#[staticmethod]
|
||||
pub fn btree() -> PyResult<Self> {
|
||||
Ok(Self {
|
||||
inner: Mutex::new(Some(LanceDbIndex::BTree(BTreeIndexBuilder::default()))),
|
||||
})
|
||||
}
|
||||
#[derive(FromPyObject)]
|
||||
struct FtsParams {
|
||||
with_position: bool,
|
||||
base_tokenizer: String,
|
||||
language: String,
|
||||
max_token_length: Option<usize>,
|
||||
lower_case: bool,
|
||||
stem: bool,
|
||||
remove_stop_words: bool,
|
||||
ascii_folding: bool,
|
||||
}
|
||||
|
||||
#[staticmethod]
|
||||
pub fn bitmap() -> PyResult<Self> {
|
||||
Ok(Self {
|
||||
inner: Mutex::new(Some(LanceDbIndex::Bitmap(Default::default()))),
|
||||
})
|
||||
}
|
||||
#[derive(FromPyObject)]
|
||||
struct IvfFlatParams {
|
||||
distance_type: String,
|
||||
num_partitions: Option<u32>,
|
||||
max_iterations: u32,
|
||||
sample_rate: u32,
|
||||
}
|
||||
|
||||
#[staticmethod]
|
||||
pub fn label_list() -> PyResult<Self> {
|
||||
Ok(Self {
|
||||
inner: Mutex::new(Some(LanceDbIndex::LabelList(Default::default()))),
|
||||
})
|
||||
}
|
||||
#[derive(FromPyObject)]
|
||||
struct IvfPqParams {
|
||||
distance_type: String,
|
||||
num_partitions: Option<u32>,
|
||||
num_sub_vectors: Option<u32>,
|
||||
num_bits: u32,
|
||||
max_iterations: u32,
|
||||
sample_rate: u32,
|
||||
}
|
||||
|
||||
#[pyo3(signature = (with_position=None, base_tokenizer=None, language=None, max_token_length=None, lower_case=None, stem=None, remove_stop_words=None, ascii_folding=None))]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[staticmethod]
|
||||
pub fn fts(
|
||||
with_position: Option<bool>,
|
||||
base_tokenizer: Option<String>,
|
||||
language: Option<String>,
|
||||
max_token_length: Option<usize>,
|
||||
lower_case: Option<bool>,
|
||||
stem: Option<bool>,
|
||||
remove_stop_words: Option<bool>,
|
||||
ascii_folding: Option<bool>,
|
||||
) -> Self {
|
||||
let mut opts = FtsIndexBuilder::default();
|
||||
if let Some(with_position) = with_position {
|
||||
opts = opts.with_position(with_position);
|
||||
}
|
||||
if let Some(base_tokenizer) = base_tokenizer {
|
||||
opts.tokenizer_configs = opts.tokenizer_configs.base_tokenizer(base_tokenizer);
|
||||
}
|
||||
if let Some(language) = language {
|
||||
opts.tokenizer_configs = opts.tokenizer_configs.language(&language).unwrap();
|
||||
}
|
||||
opts.tokenizer_configs = opts.tokenizer_configs.max_token_length(max_token_length);
|
||||
if let Some(lower_case) = lower_case {
|
||||
opts.tokenizer_configs = opts.tokenizer_configs.lower_case(lower_case);
|
||||
}
|
||||
if let Some(stem) = stem {
|
||||
opts.tokenizer_configs = opts.tokenizer_configs.stem(stem);
|
||||
}
|
||||
if let Some(remove_stop_words) = remove_stop_words {
|
||||
opts.tokenizer_configs = opts.tokenizer_configs.remove_stop_words(remove_stop_words);
|
||||
}
|
||||
if let Some(ascii_folding) = ascii_folding {
|
||||
opts.tokenizer_configs = opts.tokenizer_configs.ascii_folding(ascii_folding);
|
||||
}
|
||||
Self {
|
||||
inner: Mutex::new(Some(LanceDbIndex::FTS(opts))),
|
||||
}
|
||||
}
|
||||
#[derive(FromPyObject)]
|
||||
struct IvfHnswPqParams {
|
||||
distance_type: String,
|
||||
num_partitions: Option<u32>,
|
||||
num_sub_vectors: Option<u32>,
|
||||
num_bits: u32,
|
||||
max_iterations: u32,
|
||||
sample_rate: u32,
|
||||
m: u32,
|
||||
ef_construction: u32,
|
||||
}
|
||||
|
||||
#[pyo3(signature = (distance_type=None, num_partitions=None, num_sub_vectors=None,num_bits=None, max_iterations=None, sample_rate=None, m=None, ef_construction=None))]
|
||||
#[staticmethod]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn hnsw_pq(
|
||||
distance_type: Option<String>,
|
||||
num_partitions: Option<u32>,
|
||||
num_sub_vectors: Option<u32>,
|
||||
num_bits: Option<u32>,
|
||||
max_iterations: Option<u32>,
|
||||
sample_rate: Option<u32>,
|
||||
m: Option<u32>,
|
||||
ef_construction: Option<u32>,
|
||||
) -> PyResult<Self> {
|
||||
let mut hnsw_pq_builder = IvfHnswPqIndexBuilder::default();
|
||||
if let Some(distance_type) = distance_type {
|
||||
let distance_type = parse_distance_type(distance_type)?;
|
||||
hnsw_pq_builder = hnsw_pq_builder.distance_type(distance_type);
|
||||
}
|
||||
if let Some(num_partitions) = num_partitions {
|
||||
hnsw_pq_builder = hnsw_pq_builder.num_partitions(num_partitions);
|
||||
}
|
||||
if let Some(num_sub_vectors) = num_sub_vectors {
|
||||
hnsw_pq_builder = hnsw_pq_builder.num_sub_vectors(num_sub_vectors);
|
||||
}
|
||||
if let Some(num_bits) = num_bits {
|
||||
hnsw_pq_builder = hnsw_pq_builder.num_bits(num_bits);
|
||||
}
|
||||
if let Some(max_iterations) = max_iterations {
|
||||
hnsw_pq_builder = hnsw_pq_builder.max_iterations(max_iterations);
|
||||
}
|
||||
if let Some(sample_rate) = sample_rate {
|
||||
hnsw_pq_builder = hnsw_pq_builder.sample_rate(sample_rate);
|
||||
}
|
||||
if let Some(m) = m {
|
||||
hnsw_pq_builder = hnsw_pq_builder.num_edges(m);
|
||||
}
|
||||
if let Some(ef_construction) = ef_construction {
|
||||
hnsw_pq_builder = hnsw_pq_builder.ef_construction(ef_construction);
|
||||
}
|
||||
Ok(Self {
|
||||
inner: Mutex::new(Some(LanceDbIndex::IvfHnswPq(hnsw_pq_builder))),
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (distance_type=None, num_partitions=None, max_iterations=None, sample_rate=None, m=None, ef_construction=None))]
|
||||
#[staticmethod]
|
||||
pub fn hnsw_sq(
|
||||
distance_type: Option<String>,
|
||||
num_partitions: Option<u32>,
|
||||
max_iterations: Option<u32>,
|
||||
sample_rate: Option<u32>,
|
||||
m: Option<u32>,
|
||||
ef_construction: Option<u32>,
|
||||
) -> PyResult<Self> {
|
||||
let mut hnsw_sq_builder = IvfHnswSqIndexBuilder::default();
|
||||
if let Some(distance_type) = distance_type {
|
||||
let distance_type = parse_distance_type(distance_type)?;
|
||||
hnsw_sq_builder = hnsw_sq_builder.distance_type(distance_type);
|
||||
}
|
||||
if let Some(num_partitions) = num_partitions {
|
||||
hnsw_sq_builder = hnsw_sq_builder.num_partitions(num_partitions);
|
||||
}
|
||||
if let Some(max_iterations) = max_iterations {
|
||||
hnsw_sq_builder = hnsw_sq_builder.max_iterations(max_iterations);
|
||||
}
|
||||
if let Some(sample_rate) = sample_rate {
|
||||
hnsw_sq_builder = hnsw_sq_builder.sample_rate(sample_rate);
|
||||
}
|
||||
if let Some(m) = m {
|
||||
hnsw_sq_builder = hnsw_sq_builder.num_edges(m);
|
||||
}
|
||||
if let Some(ef_construction) = ef_construction {
|
||||
hnsw_sq_builder = hnsw_sq_builder.ef_construction(ef_construction);
|
||||
}
|
||||
Ok(Self {
|
||||
inner: Mutex::new(Some(LanceDbIndex::IvfHnswSq(hnsw_sq_builder))),
|
||||
})
|
||||
}
|
||||
#[derive(FromPyObject)]
|
||||
struct IvfHnswSqParams {
|
||||
distance_type: String,
|
||||
num_partitions: Option<u32>,
|
||||
max_iterations: u32,
|
||||
sample_rate: u32,
|
||||
m: u32,
|
||||
ef_construction: u32,
|
||||
}
|
||||
|
||||
#[pyclass(get_all)]
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
use arrow::RecordBatchStream;
|
||||
use connection::{connect, Connection};
|
||||
use env_logger::Env;
|
||||
use index::{Index, IndexConfig};
|
||||
use index::IndexConfig;
|
||||
use pyo3::{
|
||||
pymodule,
|
||||
types::{PyModule, PyModuleMethods},
|
||||
@@ -40,7 +40,6 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
env_logger::init_from_env(env);
|
||||
m.add_class::<Connection>()?;
|
||||
m.add_class::<Table>()?;
|
||||
m.add_class::<Index>()?;
|
||||
m.add_class::<IndexConfig>()?;
|
||||
m.add_class::<Query>()?;
|
||||
m.add_class::<VectorQuery>()?;
|
||||
|
||||
@@ -152,6 +152,10 @@ impl FTSQuery {
|
||||
self.inner = self.inner.clone().select(Select::dynamic(&columns));
|
||||
}
|
||||
|
||||
pub fn select_columns(&mut self, columns: Vec<String>) {
|
||||
self.inner = self.inner.clone().select(Select::columns(&columns));
|
||||
}
|
||||
|
||||
pub fn limit(&mut self, limit: u32) {
|
||||
self.inner = self.inner.clone().limit(limit as usize);
|
||||
}
|
||||
@@ -280,6 +284,11 @@ impl VectorQuery {
|
||||
self.inner = self.inner.clone().nprobes(nprobe as usize);
|
||||
}
|
||||
|
||||
#[pyo3(signature = (lower_bound=None, upper_bound=None))]
|
||||
pub fn distance_range(&mut self, lower_bound: Option<f32>, upper_bound: Option<f32>) {
|
||||
self.inner = self.inner.clone().distance_range(lower_bound, upper_bound);
|
||||
}
|
||||
|
||||
pub fn ef(&mut self, ef: u32) {
|
||||
self.inner = self.inner.clone().ef(ef as usize);
|
||||
}
|
||||
@@ -341,6 +350,11 @@ impl HybridQuery {
|
||||
self.inner_fts.select(columns);
|
||||
}
|
||||
|
||||
pub fn select_columns(&mut self, columns: Vec<String>) {
|
||||
self.inner_vec.select_columns(columns.clone());
|
||||
self.inner_fts.select_columns(columns);
|
||||
}
|
||||
|
||||
pub fn limit(&mut self, limit: u32) {
|
||||
self.inner_vec.limit(limit);
|
||||
self.inner_fts.limit(limit);
|
||||
|
||||
@@ -19,7 +19,7 @@ use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
use crate::{
|
||||
error::PythonErrorExt,
|
||||
index::{Index, IndexConfig},
|
||||
index::{extract_index_params, IndexConfig},
|
||||
query::Query,
|
||||
};
|
||||
|
||||
@@ -97,10 +97,12 @@ impl Table {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
/// Returns True if the table is open, False if it is closed.
|
||||
pub fn is_open(&self) -> bool {
|
||||
self.inner.is_some()
|
||||
}
|
||||
|
||||
/// Closes the table, releasing any resources associated with it.
|
||||
pub fn close(&mut self) {
|
||||
self.inner.take();
|
||||
}
|
||||
@@ -177,14 +179,10 @@ impl Table {
|
||||
pub fn create_index<'a>(
|
||||
self_: PyRef<'a, Self>,
|
||||
column: String,
|
||||
index: Option<&Index>,
|
||||
index: Option<Bound<'_, PyAny>>,
|
||||
replace: Option<bool>,
|
||||
) -> PyResult<Bound<'a, PyAny>> {
|
||||
let index = if let Some(index) = index {
|
||||
index.consume()?
|
||||
} else {
|
||||
lancedb::index::Index::Auto
|
||||
};
|
||||
let index = extract_index_params(&index)?;
|
||||
let mut op = self_.inner_ref()?.create_index(&[column], index);
|
||||
if let Some(replace) = replace {
|
||||
op = op.replace(replace);
|
||||
@@ -305,6 +303,7 @@ impl Table {
|
||||
Query::new(self.inner_ref().unwrap().query())
|
||||
}
|
||||
|
||||
/// Optimize the on-disk data by compacting and pruning old data, for better performance.
|
||||
#[pyo3(signature = (cleanup_since_ms=None, delete_unverified=None))]
|
||||
pub fn optimize(
|
||||
self_: PyRef<'_, Self>,
|
||||
|
||||
@@ -43,8 +43,9 @@ pub fn parse_distance_type(distance_type: impl AsRef<str>) -> PyResult<DistanceT
|
||||
"l2" => Ok(DistanceType::L2),
|
||||
"cosine" => Ok(DistanceType::Cosine),
|
||||
"dot" => Ok(DistanceType::Dot),
|
||||
"hamming" => Ok(DistanceType::Hamming),
|
||||
_ => Err(PyValueError::new_err(format!(
|
||||
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
|
||||
"Invalid distance type '{}'. Must be one of l2, cosine, dot, or hamming",
|
||||
distance_type.as_ref()
|
||||
))),
|
||||
}
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
[toolchain]
|
||||
channel = "1.80.0"
|
||||
channel = "1.83.0"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-node"
|
||||
version = "0.14.1-beta.3"
|
||||
version = "0.14.1"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
[package]
|
||||
name = "lancedb"
|
||||
version = "0.14.1-beta.3"
|
||||
version = "0.14.1"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
rust-version = "1.75"
|
||||
rust-version.workspace = true
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
[dependencies]
|
||||
|
||||
@@ -1050,6 +1050,8 @@ impl ConnectionInternal for Database {
|
||||
write_params.enable_v2_manifest_paths =
|
||||
options.enable_v2_manifest_paths.unwrap_or_default();
|
||||
|
||||
let data_schema = data.schema();
|
||||
|
||||
match NativeTable::create(
|
||||
&table_uri,
|
||||
&options.name,
|
||||
@@ -1069,7 +1071,18 @@ impl ConnectionInternal for Database {
|
||||
CreateTableMode::ExistOk(callback) => {
|
||||
let builder = OpenTableBuilder::new(options.parent, options.name);
|
||||
let builder = (callback)(builder);
|
||||
builder.execute().await
|
||||
let table = builder.execute().await?;
|
||||
|
||||
let table_schema = table.schema().await?;
|
||||
|
||||
if table_schema != data_schema {
|
||||
return Err(Error::Schema {
|
||||
message: "Provided schema does not match existing table schema"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(table)
|
||||
}
|
||||
CreateTableMode::Overwrite => unreachable!(),
|
||||
},
|
||||
|
||||
@@ -17,6 +17,7 @@ use std::sync::Arc;
|
||||
use scalar::FtsIndexBuilder;
|
||||
use serde::Deserialize;
|
||||
use serde_with::skip_serializing_none;
|
||||
use vector::IvfFlatIndexBuilder;
|
||||
|
||||
use crate::{table::TableInternal, DistanceType, Error, Result};
|
||||
|
||||
@@ -56,6 +57,9 @@ pub enum Index {
|
||||
/// Full text search index using bm25.
|
||||
FTS(FtsIndexBuilder),
|
||||
|
||||
/// IVF index
|
||||
IvfFlat(IvfFlatIndexBuilder),
|
||||
|
||||
/// IVF index with Product Quantization
|
||||
IvfPq(IvfPqIndexBuilder),
|
||||
|
||||
@@ -106,6 +110,8 @@ impl IndexBuilder {
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub enum IndexType {
|
||||
// Vector
|
||||
#[serde(alias = "IVF_FLAT")]
|
||||
IvfFlat,
|
||||
#[serde(alias = "IVF_PQ")]
|
||||
IvfPq,
|
||||
#[serde(alias = "IVF_HNSW_PQ")]
|
||||
@@ -127,6 +133,7 @@ pub enum IndexType {
|
||||
impl std::fmt::Display for IndexType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::IvfFlat => write!(f, "IVF_FLAT"),
|
||||
Self::IvfPq => write!(f, "IVF_PQ"),
|
||||
Self::IvfHnswPq => write!(f, "IVF_HNSW_PQ"),
|
||||
Self::IvfHnswSq => write!(f, "IVF_HNSW_SQ"),
|
||||
@@ -147,6 +154,7 @@ impl std::str::FromStr for IndexType {
|
||||
"BITMAP" => Ok(Self::Bitmap),
|
||||
"LABEL_LIST" | "LABELLIST" => Ok(Self::LabelList),
|
||||
"FTS" | "INVERTED" => Ok(Self::FTS),
|
||||
"IVF_FLAT" => Ok(Self::IvfFlat),
|
||||
"IVF_PQ" => Ok(Self::IvfPq),
|
||||
"IVF_HNSW_PQ" => Ok(Self::IvfHnswPq),
|
||||
"IVF_HNSW_SQ" => Ok(Self::IvfHnswSq),
|
||||
|
||||
@@ -77,5 +77,5 @@ impl FtsIndexBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
use lance_index::scalar::inverted::TokenizerConfig;
|
||||
pub use lance_index::scalar::inverted::TokenizerConfig;
|
||||
pub use lance_index::scalar::FullTextSearchQuery;
|
||||
|
||||
@@ -162,6 +162,43 @@ macro_rules! impl_hnsw_params_setter {
|
||||
};
|
||||
}
|
||||
|
||||
/// Builder for an IVF Flat index.
|
||||
///
|
||||
/// This index stores raw vectors. These vectors are grouped into partitions of similar vectors.
|
||||
/// Each partition keeps track of a centroid which is the average value of all vectors in the group.
|
||||
///
|
||||
/// During a query the centroids are compared with the query vector to find the closest partitions.
|
||||
/// The raw vectors in these partitions are then searched to find the closest vectors.
|
||||
///
|
||||
/// The partitioning process is called IVF and the `num_partitions` parameter controls how many groups to create.
|
||||
///
|
||||
/// Note that training an IVF Flat index on a large dataset is a slow operation and currently is also a memory intensive operation.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IvfFlatIndexBuilder {
|
||||
pub(crate) distance_type: DistanceType,
|
||||
|
||||
// IVF
|
||||
pub(crate) num_partitions: Option<u32>,
|
||||
pub(crate) sample_rate: u32,
|
||||
pub(crate) max_iterations: u32,
|
||||
}
|
||||
|
||||
impl Default for IvfFlatIndexBuilder {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
distance_type: DistanceType::L2,
|
||||
num_partitions: None,
|
||||
sample_rate: 256,
|
||||
max_iterations: 50,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IvfFlatIndexBuilder {
|
||||
impl_distance_type_setter!();
|
||||
impl_ivf_params_setter!();
|
||||
}
|
||||
|
||||
/// Builder for an IVF PQ index.
|
||||
///
|
||||
/// This index stores a compressed (quantized) copy of every vector. These vectors
|
||||
|
||||
@@ -214,6 +214,7 @@ mod polars_arrow_convertors;
|
||||
pub mod query;
|
||||
#[cfg(feature = "remote")]
|
||||
pub mod remote;
|
||||
pub mod rerankers;
|
||||
pub mod table;
|
||||
pub mod utils;
|
||||
|
||||
|
||||
@@ -15,19 +15,31 @@
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow::compute::concat_batches;
|
||||
use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array};
|
||||
use arrow_schema::DataType;
|
||||
use datafusion_physical_plan::ExecutionPlan;
|
||||
use futures::{stream, try_join, FutureExt, TryStreamExt};
|
||||
use half::f16;
|
||||
use lance::dataset::scanner::DatasetRecordBatchStream;
|
||||
use lance::{
|
||||
arrow::RecordBatchExt,
|
||||
dataset::{scanner::DatasetRecordBatchStream, ROW_ID},
|
||||
};
|
||||
use lance_datafusion::exec::execute_plan;
|
||||
use lance_index::scalar::inverted::SCORE_COL;
|
||||
use lance_index::scalar::FullTextSearchQuery;
|
||||
use lance_index::vector::DIST_COL;
|
||||
use lance_io::stream::RecordBatchStreamAdapter;
|
||||
|
||||
use crate::arrow::SendableRecordBatchStream;
|
||||
use crate::error::{Error, Result};
|
||||
use crate::rerankers::rrf::RRFReranker;
|
||||
use crate::rerankers::{check_reranker_result, NormalizeMethod, Reranker};
|
||||
use crate::table::TableInternal;
|
||||
use crate::DistanceType;
|
||||
|
||||
mod hybrid;
|
||||
|
||||
pub(crate) const DEFAULT_TOP_K: usize = 10;
|
||||
|
||||
/// Which columns should be retrieved from the database
|
||||
@@ -339,7 +351,7 @@ pub trait QueryBase {
|
||||
fn limit(self, limit: usize) -> Self;
|
||||
|
||||
/// Set the offset of the query.
|
||||
|
||||
///
|
||||
/// By default, it fetches starting with the first row.
|
||||
/// This method can be used to skip the first `offset` rows.
|
||||
fn offset(self, offset: usize) -> Self;
|
||||
@@ -435,6 +447,16 @@ pub trait QueryBase {
|
||||
|
||||
/// Return the `_rowid` meta column from the Table.
|
||||
fn with_row_id(self) -> Self;
|
||||
|
||||
/// Rerank the results using the specified reranker.
|
||||
///
|
||||
/// This is currently only supported for Hybrid Search.
|
||||
fn rerank(self, reranker: Arc<dyn Reranker>) -> Self;
|
||||
|
||||
/// The method to normalize the scores. Can be "rank" or "Score". If "Rank",
|
||||
/// the scores are converted to ranks and then normalized. If "Score", the
|
||||
/// scores are normalized directly.
|
||||
fn norm(self, norm: NormalizeMethod) -> Self;
|
||||
}
|
||||
|
||||
pub trait HasQuery {
|
||||
@@ -481,6 +503,16 @@ impl<T: HasQuery> QueryBase for T {
|
||||
self.mut_query().with_row_id = true;
|
||||
self
|
||||
}
|
||||
|
||||
fn rerank(mut self, reranker: Arc<dyn Reranker>) -> Self {
|
||||
self.mut_query().reranker = Some(reranker);
|
||||
self
|
||||
}
|
||||
|
||||
fn norm(mut self, norm: NormalizeMethod) -> Self {
|
||||
self.mut_query().norm = Some(norm);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Options for controlling the execution of a query
|
||||
@@ -600,6 +632,13 @@ pub struct Query {
|
||||
|
||||
/// If set to false, the filter will be applied after the vector search.
|
||||
pub(crate) prefilter: bool,
|
||||
|
||||
/// Implementation of reranker that can be used to reorder or combine query
|
||||
/// results, especially if using hybrid search
|
||||
pub(crate) reranker: Option<Arc<dyn Reranker>>,
|
||||
|
||||
/// Configure how query results are normalized when doing hybrid search
|
||||
pub(crate) norm: Option<NormalizeMethod>,
|
||||
}
|
||||
|
||||
impl Query {
|
||||
@@ -614,6 +653,8 @@ impl Query {
|
||||
fast_search: false,
|
||||
with_row_id: false,
|
||||
prefilter: true,
|
||||
reranker: None,
|
||||
norm: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -714,6 +755,10 @@ pub struct VectorQuery {
|
||||
// IVF PQ - ANN search.
|
||||
pub(crate) query_vector: Vec<Arc<dyn Array>>,
|
||||
pub(crate) nprobes: usize,
|
||||
// The lower bound (inclusive) of the distance to search for.
|
||||
pub(crate) lower_bound: Option<f32>,
|
||||
// The upper bound (exclusive) of the distance to search for.
|
||||
pub(crate) upper_bound: Option<f32>,
|
||||
// The number of candidates to return during the refine step for HNSW,
|
||||
// defaults to 1.5 * limit.
|
||||
pub(crate) ef: Option<usize>,
|
||||
@@ -730,6 +775,8 @@ impl VectorQuery {
|
||||
column: None,
|
||||
query_vector: Vec::new(),
|
||||
nprobes: 20,
|
||||
lower_bound: None,
|
||||
upper_bound: None,
|
||||
ef: None,
|
||||
refine_factor: None,
|
||||
distance_type: None,
|
||||
@@ -790,6 +837,14 @@ impl VectorQuery {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the distance range for vector search,
|
||||
/// only rows with distances in the range [lower_bound, upper_bound) will be returned
|
||||
pub fn distance_range(mut self, lower_bound: Option<f32>, upper_bound: Option<f32>) -> Self {
|
||||
self.lower_bound = lower_bound;
|
||||
self.upper_bound = upper_bound;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the number of candidates to return during the refine step for HNSW
|
||||
///
|
||||
/// This argument is only used when the vector column has an HNSW index.
|
||||
@@ -862,6 +917,65 @@ impl VectorQuery {
|
||||
self.use_index = false;
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn execute_hybrid(&self) -> Result<SendableRecordBatchStream> {
|
||||
// clone query and specify we want to include row IDs, which can be needed for reranking
|
||||
let fts_query = self.base.clone().with_row_id();
|
||||
let mut vector_query = self.clone().with_row_id();
|
||||
|
||||
vector_query.base.full_text_search = None;
|
||||
let (fts_results, vec_results) = try_join!(fts_query.execute(), vector_query.execute())?;
|
||||
|
||||
let (fts_results, vec_results) = try_join!(
|
||||
fts_results.try_collect::<Vec<_>>(),
|
||||
vec_results.try_collect::<Vec<_>>()
|
||||
)?;
|
||||
|
||||
// try to get the schema to use when combining batches.
|
||||
// if either
|
||||
let (fts_schema, vec_schema) = hybrid::query_schemas(&fts_results, &vec_results);
|
||||
|
||||
// concatenate all the batches together
|
||||
let mut fts_results = concat_batches(&fts_schema, fts_results.iter())?;
|
||||
let mut vec_results = concat_batches(&vec_schema, vec_results.iter())?;
|
||||
|
||||
if matches!(self.base.norm, Some(NormalizeMethod::Rank)) {
|
||||
vec_results = hybrid::rank(vec_results, DIST_COL, None)?;
|
||||
fts_results = hybrid::rank(fts_results, SCORE_COL, None)?;
|
||||
}
|
||||
|
||||
vec_results = hybrid::normalize_scores(vec_results, DIST_COL, None)?;
|
||||
fts_results = hybrid::normalize_scores(fts_results, SCORE_COL, None)?;
|
||||
|
||||
let reranker = self
|
||||
.base
|
||||
.reranker
|
||||
.clone()
|
||||
.unwrap_or(Arc::new(RRFReranker::default()));
|
||||
|
||||
let fts_query = self.base.full_text_search.as_ref().ok_or(Error::Runtime {
|
||||
message: "there should be an FTS search".to_string(),
|
||||
})?;
|
||||
|
||||
let mut results = reranker
|
||||
.rerank_hybrid(&fts_query.query, vec_results, fts_results)
|
||||
.await?;
|
||||
|
||||
check_reranker_result(&results)?;
|
||||
|
||||
let limit = self.base.limit.unwrap_or(DEFAULT_TOP_K);
|
||||
if results.num_rows() > limit {
|
||||
results = results.slice(0, limit);
|
||||
}
|
||||
|
||||
if !self.base.with_row_id {
|
||||
results = results.drop_column(ROW_ID)?;
|
||||
}
|
||||
|
||||
Ok(SendableRecordBatchStream::from(
|
||||
RecordBatchStreamAdapter::new(results.schema(), stream::iter([Ok(results)])),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutableQuery for VectorQuery {
|
||||
@@ -873,6 +987,11 @@ impl ExecutableQuery for VectorQuery {
|
||||
&self,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
if self.base.full_text_search.is_some() {
|
||||
let hybrid_result = async move { self.execute_hybrid().await }.boxed().await?;
|
||||
return Ok(hybrid_result);
|
||||
}
|
||||
|
||||
Ok(SendableRecordBatchStream::from(
|
||||
DatasetRecordBatchStream::new(execute_plan(
|
||||
self.create_plan(options).await?,
|
||||
@@ -894,20 +1013,20 @@ impl HasQuery for VectorQuery {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
use std::{collections::HashSet, sync::Arc};
|
||||
|
||||
use super::*;
|
||||
use arrow::{compute::concat_batches, datatypes::Int32Type};
|
||||
use arrow::{array::downcast_array, compute::concat_batches, datatypes::Int32Type};
|
||||
use arrow_array::{
|
||||
cast::AsArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator,
|
||||
RecordBatchReader,
|
||||
cast::AsArray, types::Float32Type, FixedSizeListArray, Float32Array, Int32Array,
|
||||
RecordBatch, RecordBatchIterator, RecordBatchReader, StringArray,
|
||||
};
|
||||
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
|
||||
use futures::{StreamExt, TryStreamExt};
|
||||
use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector};
|
||||
use tempfile::tempdir;
|
||||
|
||||
use crate::{connect, Table};
|
||||
use crate::{connect, connection::CreateTableMode, Table};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_setters_getters() {
|
||||
@@ -1245,6 +1364,30 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_distance_range() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let table = make_test_table(&tmp_dir).await;
|
||||
let results = table
|
||||
.vector_search(&[0.1, 0.2, 0.3, 0.4])
|
||||
.unwrap()
|
||||
.distance_range(Some(0.0), Some(1.0))
|
||||
.limit(10)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
for batch in results {
|
||||
let distances = batch["_distance"].as_primitive::<Float32Type>();
|
||||
assert!(distances.iter().all(|d| {
|
||||
let d = d.unwrap();
|
||||
(0.0..1.0).contains(&d)
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_query_vectors() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
@@ -1274,4 +1417,156 @@ mod tests {
|
||||
assert!(query_index.values().contains(&0));
|
||||
assert!(query_index.values().contains(&1));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hybrid_search() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let dataset_path = tmp_dir.path();
|
||||
let conn = connect(dataset_path.to_str().unwrap())
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let dims = 2;
|
||||
let schema = Arc::new(ArrowSchema::new(vec![
|
||||
ArrowField::new("text", DataType::Utf8, false),
|
||||
ArrowField::new(
|
||||
"vector",
|
||||
DataType::FixedSizeList(
|
||||
Arc::new(ArrowField::new("item", DataType::Float32, true)),
|
||||
dims,
|
||||
),
|
||||
false,
|
||||
),
|
||||
]));
|
||||
|
||||
let text = StringArray::from(vec!["dog", "cat", "a", "b"]);
|
||||
let vectors = vec![
|
||||
Some(vec![Some(0.0), Some(0.0)]),
|
||||
Some(vec![Some(-2.0), Some(-2.0)]),
|
||||
Some(vec![Some(50.0), Some(50.0)]),
|
||||
Some(vec![Some(-30.0), Some(-30.0)]),
|
||||
];
|
||||
let vector = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(vectors, dims);
|
||||
|
||||
let record_batch =
|
||||
RecordBatch::try_new(schema.clone(), vec![Arc::new(text), Arc::new(vector)]).unwrap();
|
||||
let record_batch_iter =
|
||||
RecordBatchIterator::new(vec![record_batch].into_iter().map(Ok), schema.clone());
|
||||
let table = conn
|
||||
.create_table("my_table", record_batch_iter)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
table
|
||||
.create_index(&["text"], crate::index::Index::FTS(Default::default()))
|
||||
.replace(true)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let fts_query = FullTextSearchQuery::new("b".to_string());
|
||||
let results = table
|
||||
.query()
|
||||
.full_text_search(fts_query)
|
||||
.limit(2)
|
||||
.nearest_to(&[-10.0, -10.0])
|
||||
.unwrap()
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let batch = &results[0];
|
||||
|
||||
let texts: StringArray = downcast_array(batch.column_by_name("text").unwrap());
|
||||
let texts = texts.iter().map(|e| e.unwrap()).collect::<HashSet<_>>();
|
||||
assert!(texts.contains("cat")); // should be close by vector search
|
||||
assert!(texts.contains("b")); // should be close by fts search
|
||||
|
||||
// ensure that this works correctly if there are no matching FTS results
|
||||
let fts_query = FullTextSearchQuery::new("z".to_string());
|
||||
table
|
||||
.query()
|
||||
.full_text_search(fts_query)
|
||||
.limit(2)
|
||||
.nearest_to(&[-10.0, -10.0])
|
||||
.unwrap()
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hybrid_search_empty_table() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let dataset_path = tmp_dir.path();
|
||||
let conn = connect(dataset_path.to_str().unwrap())
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let dims = 2;
|
||||
|
||||
let schema = Arc::new(ArrowSchema::new(vec![
|
||||
ArrowField::new("text", DataType::Utf8, false),
|
||||
ArrowField::new(
|
||||
"vector",
|
||||
DataType::FixedSizeList(
|
||||
Arc::new(ArrowField::new("item", DataType::Float32, true)),
|
||||
dims,
|
||||
),
|
||||
false,
|
||||
),
|
||||
]));
|
||||
|
||||
// ensure hybrid search is also supported on a fully empty table
|
||||
let vectors: Vec<Option<Vec<Option<f32>>>> = Vec::new();
|
||||
let record_batch = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(StringArray::from(Vec::<&str>::new())),
|
||||
Arc::new(
|
||||
FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(vectors, dims),
|
||||
),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let record_batch_iter =
|
||||
RecordBatchIterator::new(vec![record_batch].into_iter().map(Ok), schema.clone());
|
||||
let table = conn
|
||||
.create_table("my_table", record_batch_iter)
|
||||
.mode(CreateTableMode::Overwrite)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
table
|
||||
.create_index(&["text"], crate::index::Index::FTS(Default::default()))
|
||||
.replace(true)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
let fts_query = FullTextSearchQuery::new("b".to_string());
|
||||
let results = table
|
||||
.query()
|
||||
.full_text_search(fts_query)
|
||||
.limit(2)
|
||||
.nearest_to(&[-10.0, -10.0])
|
||||
.unwrap()
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
let batch = &results[0];
|
||||
assert_eq!(0, batch.num_rows());
|
||||
assert_eq!(2, batch.num_columns());
|
||||
}
|
||||
}
|
||||
|
||||
346
rust/lancedb/src/query/hybrid.rs
Normal file
346
rust/lancedb/src/query/hybrid.rs
Normal file
@@ -0,0 +1,346 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use arrow::compute::{
|
||||
kernels::numeric::{div, sub},
|
||||
max, min,
|
||||
};
|
||||
use arrow_array::{cast::downcast_array, Float32Array, RecordBatch};
|
||||
use arrow_schema::{DataType, Field, Schema, SortOptions};
|
||||
use lance::dataset::ROW_ID;
|
||||
use lance_index::{scalar::inverted::SCORE_COL, vector::DIST_COL};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
|
||||
/// Converts results's score column to a rank.
|
||||
///
|
||||
/// Expects the `column` argument to be type Float32 and will panic if it's not
|
||||
pub fn rank(results: RecordBatch, column: &str, ascending: Option<bool>) -> Result<RecordBatch> {
|
||||
let scores = results.column_by_name(column).ok_or(Error::InvalidInput {
|
||||
message: format!(
|
||||
"expected column {} not found in rank. found columns {:?}",
|
||||
column,
|
||||
results
|
||||
.schema()
|
||||
.fields()
|
||||
.iter()
|
||||
.map(|f| f.name())
|
||||
.collect::<Vec<_>>(),
|
||||
),
|
||||
})?;
|
||||
|
||||
if results.num_rows() == 0 {
|
||||
return Ok(results);
|
||||
}
|
||||
|
||||
let scores: Float32Array = downcast_array(scores);
|
||||
let ranks = Float32Array::from_iter_values(
|
||||
arrow::compute::kernels::rank::rank(
|
||||
&scores,
|
||||
Some(SortOptions {
|
||||
descending: !ascending.unwrap_or(true),
|
||||
..Default::default()
|
||||
}),
|
||||
)?
|
||||
.iter()
|
||||
.map(|i| *i as f32),
|
||||
);
|
||||
|
||||
let schema = results.schema();
|
||||
let (column_idx, _) = schema.column_with_name(column).unwrap();
|
||||
let mut columns = results.columns().to_vec();
|
||||
columns[column_idx] = Arc::new(ranks);
|
||||
|
||||
let results = RecordBatch::try_new(results.schema(), columns)?;
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Get the query schemas needed when combining the search results.
|
||||
///
|
||||
/// If either of the record batches are empty, then we create a schema from the
|
||||
/// other record batch, and replace the score/distance column. If both record
|
||||
/// batches are empty, create empty schemas.
|
||||
pub fn query_schemas(
|
||||
fts_results: &[RecordBatch],
|
||||
vec_results: &[RecordBatch],
|
||||
) -> (Arc<Schema>, Arc<Schema>) {
|
||||
let (fts_schema, vec_schema) = match (
|
||||
fts_results.first().map(|r| r.schema()),
|
||||
vec_results.first().map(|r| r.schema()),
|
||||
) {
|
||||
(Some(fts_schema), Some(vec_schema)) => (fts_schema, vec_schema),
|
||||
(None, Some(vec_schema)) => {
|
||||
let fts_schema = with_field_name_replaced(&vec_schema, DIST_COL, SCORE_COL);
|
||||
(Arc::new(fts_schema), vec_schema)
|
||||
}
|
||||
(Some(fts_schema), None) => {
|
||||
let vec_schema = with_field_name_replaced(&fts_schema, DIST_COL, SCORE_COL);
|
||||
(fts_schema, Arc::new(vec_schema))
|
||||
}
|
||||
(None, None) => (Arc::new(empty_fts_schema()), Arc::new(empty_vec_schema())),
|
||||
};
|
||||
|
||||
(fts_schema, vec_schema)
|
||||
}
|
||||
|
||||
pub fn empty_fts_schema() -> Schema {
|
||||
Schema::new(vec![
|
||||
Arc::new(Field::new(SCORE_COL, DataType::Float32, false)),
|
||||
Arc::new(Field::new(ROW_ID, DataType::UInt64, false)),
|
||||
])
|
||||
}
|
||||
|
||||
pub fn empty_vec_schema() -> Schema {
|
||||
Schema::new(vec![
|
||||
Arc::new(Field::new(DIST_COL, DataType::Float32, false)),
|
||||
Arc::new(Field::new(ROW_ID, DataType::UInt64, false)),
|
||||
])
|
||||
}
|
||||
|
||||
pub fn with_field_name_replaced(schema: &Schema, target: &str, replacement: &str) -> Schema {
|
||||
let field_idx = schema.fields().iter().enumerate().find_map(|(i, field)| {
|
||||
if field.name() == target {
|
||||
Some(i)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
|
||||
let mut fields = schema.fields().to_vec();
|
||||
if let Some(idx) = field_idx {
|
||||
let new_field = (*fields[idx]).clone().with_name(replacement);
|
||||
fields[idx] = Arc::new(new_field);
|
||||
}
|
||||
|
||||
Schema::new(fields)
|
||||
}
|
||||
|
||||
/// Normalize the scores column to have values between 0 and 1.
|
||||
///
|
||||
/// Expects the `column` argument to be type Float32 and will panic if it's not
|
||||
pub fn normalize_scores(
|
||||
results: RecordBatch,
|
||||
column: &str,
|
||||
invert: Option<bool>,
|
||||
) -> Result<RecordBatch> {
|
||||
let scores = results.column_by_name(column).ok_or(Error::InvalidInput {
|
||||
message: format!(
|
||||
"expected column {} not found in rank. found columns {:?}",
|
||||
column,
|
||||
results
|
||||
.schema()
|
||||
.fields()
|
||||
.iter()
|
||||
.map(|f| f.name())
|
||||
.collect::<Vec<_>>(),
|
||||
),
|
||||
})?;
|
||||
|
||||
if results.num_rows() == 0 {
|
||||
return Ok(results);
|
||||
}
|
||||
let mut scores: Float32Array = downcast_array(scores);
|
||||
|
||||
let max = max(&scores).unwrap_or(0.0);
|
||||
let min = min(&scores).unwrap_or(0.0);
|
||||
|
||||
// this is equivalent to np.isclose which is used in python
|
||||
let rng = if max - min < 10e-5 { max } else { max - min };
|
||||
|
||||
// if rng is 0, then min and max are both 0 so we just leave the scores as is
|
||||
if rng != 0.0 {
|
||||
let tmp = div(
|
||||
&sub(&scores, &Float32Array::new_scalar(min))?,
|
||||
&Float32Array::new_scalar(rng),
|
||||
)?;
|
||||
scores = downcast_array(&tmp);
|
||||
}
|
||||
|
||||
if invert.unwrap_or(false) {
|
||||
let tmp = sub(&Float32Array::new_scalar(1.0), &scores)?;
|
||||
scores = downcast_array(&tmp);
|
||||
}
|
||||
|
||||
let schema = results.schema();
|
||||
let (column_idx, _) = schema.column_with_name(column).unwrap();
|
||||
let mut columns = results.columns().to_vec();
|
||||
columns[column_idx] = Arc::new(scores);
|
||||
|
||||
let results = RecordBatch::try_new(results.schema(), columns).unwrap();
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use arrow_array::StringArray;
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
|
||||
#[test]
|
||||
fn test_rank() {
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Arc::new(Field::new("name", DataType::Utf8, false)),
|
||||
Arc::new(Field::new("score", DataType::Float32, false)),
|
||||
]));
|
||||
|
||||
let names = StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"]);
|
||||
let scores = Float32Array::from(vec![0.2, 0.4, 0.1, 0.6, 0.45]);
|
||||
|
||||
let batch =
|
||||
RecordBatch::try_new(schema.clone(), vec![Arc::new(names), Arc::new(scores)]).unwrap();
|
||||
|
||||
let result = rank(batch.clone(), "score", Some(false)).unwrap();
|
||||
assert_eq!(2, result.schema().fields().len());
|
||||
assert_eq!("name", result.schema().field(0).name());
|
||||
assert_eq!("score", result.schema().field(1).name());
|
||||
|
||||
let names: StringArray = downcast_array(result.column(0));
|
||||
assert_eq!(
|
||||
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec!["foo", "bar", "baz", "bean", "dog"]
|
||||
);
|
||||
let scores: Float32Array = downcast_array(result.column(1));
|
||||
assert_eq!(
|
||||
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec![4.0, 3.0, 5.0, 1.0, 2.0]
|
||||
);
|
||||
|
||||
// check sort ascending
|
||||
let result = rank(batch.clone(), "score", Some(true)).unwrap();
|
||||
let names: StringArray = downcast_array(result.column(0));
|
||||
assert_eq!(
|
||||
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec!["foo", "bar", "baz", "bean", "dog"]
|
||||
);
|
||||
let scores: Float32Array = downcast_array(result.column(1));
|
||||
assert_eq!(
|
||||
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec![2.0, 3.0, 1.0, 5.0, 4.0]
|
||||
);
|
||||
|
||||
// ensure default sort is ascending
|
||||
let result = rank(batch.clone(), "score", None).unwrap();
|
||||
let names: StringArray = downcast_array(result.column(0));
|
||||
assert_eq!(
|
||||
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec!["foo", "bar", "baz", "bean", "dog"]
|
||||
);
|
||||
let scores: Float32Array = downcast_array(result.column(1));
|
||||
assert_eq!(
|
||||
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec![2.0, 3.0, 1.0, 5.0, 4.0]
|
||||
);
|
||||
|
||||
// check it can handle an empty batch
|
||||
let batch = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(StringArray::from(Vec::<&str>::new())),
|
||||
Arc::new(Float32Array::from(Vec::<f32>::new())),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let result = rank(batch.clone(), "score", None).unwrap();
|
||||
assert_eq!(0, result.num_rows());
|
||||
assert_eq!(2, result.schema().fields().len());
|
||||
assert_eq!("name", result.schema().field(0).name());
|
||||
assert_eq!("score", result.schema().field(1).name());
|
||||
|
||||
// check it returns the expected error when there's no column
|
||||
let result = rank(batch.clone(), "bad_col", None);
|
||||
match result {
|
||||
Err(Error::InvalidInput { message }) => {
|
||||
assert_eq!("expected column bad_col not found in rank. found columns [\"name\", \"score\"]", message);
|
||||
}
|
||||
_ => {
|
||||
panic!("expected invalid input error, received {:?}", result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_scores() {
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Arc::new(Field::new("name", DataType::Utf8, false)),
|
||||
Arc::new(Field::new("score", DataType::Float32, false)),
|
||||
]));
|
||||
|
||||
let names = Arc::new(StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"]));
|
||||
let scores = Arc::new(Float32Array::from(vec![-4.0, 2.0, 0.0, 3.0, 6.0]));
|
||||
|
||||
let batch =
|
||||
RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap();
|
||||
|
||||
let result = normalize_scores(batch.clone(), "score", Some(false)).unwrap();
|
||||
let names: StringArray = downcast_array(result.column(0));
|
||||
assert_eq!(
|
||||
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec!["foo", "bar", "baz", "bean", "dog"]
|
||||
);
|
||||
let scores: Float32Array = downcast_array(result.column(1));
|
||||
assert_eq!(
|
||||
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec![0.0, 0.6, 0.4, 0.7, 1.0]
|
||||
);
|
||||
|
||||
// check it can invert the normalization
|
||||
let result = normalize_scores(batch.clone(), "score", Some(true)).unwrap();
|
||||
let scores: Float32Array = downcast_array(result.column(1));
|
||||
assert_eq!(
|
||||
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec![1.0, 1.0 - 0.6, 0.6, 0.3, 0.0]
|
||||
);
|
||||
|
||||
// check that the default is not inverted
|
||||
let result = normalize_scores(batch.clone(), "score", None).unwrap();
|
||||
let scores: Float32Array = downcast_array(result.column(1));
|
||||
assert_eq!(
|
||||
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec![0.0, 0.6, 0.4, 0.7, 1.0]
|
||||
);
|
||||
|
||||
// check that it will function correctly if all the values are the same
|
||||
let names = Arc::new(StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"]));
|
||||
let scores = Arc::new(Float32Array::from(vec![2.1, 2.1, 2.1, 2.1, 2.1]));
|
||||
let batch =
|
||||
RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap();
|
||||
let result = normalize_scores(batch.clone(), "score", None).unwrap();
|
||||
let scores: Float32Array = downcast_array(result.column(1));
|
||||
assert_eq!(
|
||||
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec![0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
);
|
||||
|
||||
// check it keeps floating point rounding errors for same score normalized the same
|
||||
// e.g., the behaviour is consistent with python
|
||||
let scores = Arc::new(Float32Array::from(vec![1.0, 1.0, 1.0, 1.0, 0.9999999]));
|
||||
let batch =
|
||||
RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap();
|
||||
let result = normalize_scores(batch.clone(), "score", None).unwrap();
|
||||
let scores: Float32Array = downcast_array(result.column(1));
|
||||
assert_eq!(
|
||||
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec![
|
||||
1.0 - 0.9999999,
|
||||
1.0 - 0.9999999,
|
||||
1.0 - 0.9999999,
|
||||
1.0 - 0.9999999,
|
||||
0.0
|
||||
]
|
||||
);
|
||||
|
||||
// check that it can handle if all the scores are 0
|
||||
let scores = Arc::new(Float32Array::from(vec![0.0, 0.0, 0.0, 0.0, 0.0]));
|
||||
let batch =
|
||||
RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap();
|
||||
let result = normalize_scores(batch.clone(), "score", None).unwrap();
|
||||
let scores: Float32Array = downcast_array(result.column(1));
|
||||
assert_eq!(
|
||||
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec![0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -210,6 +210,8 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
body["prefilter"] = query.base.prefilter.into();
|
||||
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
|
||||
body["nprobes"] = query.nprobes.into();
|
||||
body["lower_bound"] = query.lower_bound.into();
|
||||
body["upper_bound"] = query.upper_bound.into();
|
||||
body["ef"] = query.ef.into();
|
||||
body["refine_factor"] = query.refine_factor.into();
|
||||
if let Some(vector_column) = query.column.as_ref() {
|
||||
@@ -563,6 +565,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
let (index_type, distance_type) = match index.index {
|
||||
// TODO: Should we pass the actual index parameters? SaaS does not
|
||||
// yet support them.
|
||||
Index::IvfFlat(index) => ("IVF_FLAT", Some(index.distance_type)),
|
||||
Index::IvfPq(index) => ("IVF_PQ", Some(index.distance_type)),
|
||||
Index::IvfHnswSq(index) => ("IVF_HNSW_SQ", Some(index.distance_type)),
|
||||
Index::BTree(_) => ("BTREE", None),
|
||||
@@ -873,6 +876,7 @@ mod tests {
|
||||
use lance_index::scalar::FullTextSearchQuery;
|
||||
use reqwest::Body;
|
||||
|
||||
use crate::index::vector::IvfFlatIndexBuilder;
|
||||
use crate::{
|
||||
index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType},
|
||||
query::{ExecutableQuery, QueryBase},
|
||||
@@ -1302,6 +1306,8 @@ mod tests {
|
||||
"prefilter": true,
|
||||
"distance_type": "l2",
|
||||
"nprobes": 20,
|
||||
"lower_bound": Option::<f32>::None,
|
||||
"upper_bound": Option::<f32>::None,
|
||||
"k": 10,
|
||||
"ef": Option::<usize>::None,
|
||||
"refine_factor": null,
|
||||
@@ -1351,6 +1357,8 @@ mod tests {
|
||||
"bypass_vector_index": true,
|
||||
"columns": ["a", "b"],
|
||||
"nprobes": 12,
|
||||
"lower_bound": Option::<f32>::None,
|
||||
"upper_bound": Option::<f32>::None,
|
||||
"ef": Option::<usize>::None,
|
||||
"refine_factor": 2,
|
||||
"version": null,
|
||||
@@ -1489,6 +1497,11 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_create_index() {
|
||||
let cases = [
|
||||
(
|
||||
"IVF_FLAT",
|
||||
Some("hamming"),
|
||||
Index::IvfFlat(IvfFlatIndexBuilder::default().distance_type(DistanceType::Hamming)),
|
||||
),
|
||||
("IVF_PQ", Some("l2"), Index::IvfPq(Default::default())),
|
||||
(
|
||||
"IVF_PQ",
|
||||
|
||||
87
rust/lancedb/src/rerankers.rs
Normal file
87
rust/lancedb/src/rerankers.rs
Normal file
@@ -0,0 +1,87 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::collections::BTreeSet;
|
||||
|
||||
use arrow::{
|
||||
array::downcast_array,
|
||||
compute::{concat_batches, filter_record_batch},
|
||||
};
|
||||
use arrow_array::{BooleanArray, RecordBatch, UInt64Array};
|
||||
use async_trait::async_trait;
|
||||
use lance::dataset::ROW_ID;
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
|
||||
pub mod rrf;
|
||||
|
||||
/// column name for reranker relevance score
|
||||
const RELEVANCE_SCORE: &str = "_relevance_score";
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum NormalizeMethod {
|
||||
Score,
|
||||
Rank,
|
||||
}
|
||||
|
||||
/// Interface for a reranker. A reranker is used to rerank the results from a
|
||||
/// vector and FTS search. This is useful for combining the results from both
|
||||
/// search methods.
|
||||
#[async_trait]
|
||||
pub trait Reranker: std::fmt::Debug + Sync + Send {
|
||||
// TODO support vector reranking and FTS reranking. Currently only hybrid reranking is supported.
|
||||
|
||||
/// Rerank function receives the individual results from the vector and FTS search
|
||||
/// results. You can choose to use any of the results to generate the final results,
|
||||
/// allowing maximum flexibility.
|
||||
async fn rerank_hybrid(
|
||||
&self,
|
||||
query: &str,
|
||||
vector_results: RecordBatch,
|
||||
fts_results: RecordBatch,
|
||||
) -> Result<RecordBatch>;
|
||||
|
||||
fn merge_results(
|
||||
&self,
|
||||
vector_results: RecordBatch,
|
||||
fts_results: RecordBatch,
|
||||
) -> Result<RecordBatch> {
|
||||
let combined = concat_batches(&fts_results.schema(), [vector_results, fts_results].iter())?;
|
||||
|
||||
let mut mask = BooleanArray::builder(combined.num_rows());
|
||||
let mut unique_ids = BTreeSet::new();
|
||||
let row_ids = combined.column_by_name(ROW_ID).ok_or(Error::InvalidInput {
|
||||
message: format!(
|
||||
"could not find expected column {} while merging results. found columns {:?}",
|
||||
ROW_ID,
|
||||
combined
|
||||
.schema()
|
||||
.fields()
|
||||
.iter()
|
||||
.map(|f| f.name())
|
||||
.collect::<Vec<_>>()
|
||||
),
|
||||
})?;
|
||||
let row_ids: UInt64Array = downcast_array(row_ids);
|
||||
row_ids.values().iter().for_each(|id| {
|
||||
mask.append_value(unique_ids.insert(id));
|
||||
});
|
||||
|
||||
let combined = filter_record_batch(&combined, &mask.finish())?;
|
||||
|
||||
Ok(combined)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_reranker_result(result: &RecordBatch) -> Result<()> {
|
||||
if result.schema().column_with_name(RELEVANCE_SCORE).is_none() {
|
||||
return Err(Error::Schema {
|
||||
message: format!(
|
||||
"rerank_hybrid must return a RecordBatch with a column named {}",
|
||||
RELEVANCE_SCORE
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
223
rust/lancedb/src/rerankers/rrf.rs
Normal file
223
rust/lancedb/src/rerankers/rrf.rs
Normal file
@@ -0,0 +1,223 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow::{
|
||||
array::downcast_array,
|
||||
compute::{sort_to_indices, take},
|
||||
};
|
||||
use arrow_array::{Float32Array, RecordBatch, UInt64Array};
|
||||
use arrow_schema::{DataType, Field, Schema, SortOptions};
|
||||
use async_trait::async_trait;
|
||||
use lance::dataset::ROW_ID;
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::rerankers::{Reranker, RELEVANCE_SCORE};
|
||||
|
||||
/// Reranks the results using Reciprocal Rank Fusion(RRF) algorithm based
|
||||
/// on the scores of vector and FTS search.
|
||||
///
|
||||
#[derive(Debug)]
|
||||
pub struct RRFReranker {
|
||||
k: f32,
|
||||
}
|
||||
|
||||
impl RRFReranker {
|
||||
/// Create a new RRFReranker
|
||||
///
|
||||
/// The parameter k is a constant used in the RRF formula (default is 60).
|
||||
/// Experiments indicate that k = 60 was near-optimal, but that the choice
|
||||
/// is not critical. See paper:
|
||||
/// https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf
|
||||
pub fn new(k: f32) -> Self {
|
||||
Self { k }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RRFReranker {
|
||||
fn default() -> Self {
|
||||
Self { k: 60.0 }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Reranker for RRFReranker {
|
||||
async fn rerank_hybrid(
|
||||
&self,
|
||||
_query: &str,
|
||||
vector_results: RecordBatch,
|
||||
fts_results: RecordBatch,
|
||||
) -> Result<RecordBatch> {
|
||||
let vector_ids = vector_results
|
||||
.column_by_name(ROW_ID)
|
||||
.ok_or(Error::InvalidInput {
|
||||
message: format!(
|
||||
"expected column {} not found in vector_results. found columns {:?}",
|
||||
ROW_ID,
|
||||
vector_results
|
||||
.schema()
|
||||
.fields()
|
||||
.iter()
|
||||
.map(|f| f.name())
|
||||
.collect::<Vec<_>>()
|
||||
),
|
||||
})?;
|
||||
let fts_ids = fts_results
|
||||
.column_by_name(ROW_ID)
|
||||
.ok_or(Error::InvalidInput {
|
||||
message: format!(
|
||||
"expected column {} not found in fts_results. found columns {:?}",
|
||||
ROW_ID,
|
||||
fts_results
|
||||
.schema()
|
||||
.fields()
|
||||
.iter()
|
||||
.map(|f| f.name())
|
||||
.collect::<Vec<_>>()
|
||||
),
|
||||
})?;
|
||||
|
||||
let vector_ids: UInt64Array = downcast_array(&vector_ids);
|
||||
let fts_ids: UInt64Array = downcast_array(&fts_ids);
|
||||
|
||||
let mut rrf_score_map = BTreeMap::new();
|
||||
let mut update_score_map = |(i, result_id)| {
|
||||
let score = 1.0 / (i as f32 + self.k);
|
||||
rrf_score_map
|
||||
.entry(result_id)
|
||||
.and_modify(|e| *e += score)
|
||||
.or_insert(score);
|
||||
};
|
||||
vector_ids
|
||||
.values()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.for_each(&mut update_score_map);
|
||||
fts_ids
|
||||
.values()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.for_each(&mut update_score_map);
|
||||
|
||||
let combined_results = self.merge_results(vector_results, fts_results)?;
|
||||
|
||||
let combined_row_ids: UInt64Array =
|
||||
downcast_array(combined_results.column_by_name(ROW_ID).unwrap());
|
||||
let relevance_scores = Float32Array::from_iter_values(
|
||||
combined_row_ids
|
||||
.values()
|
||||
.iter()
|
||||
.map(|row_id| rrf_score_map.get(row_id).unwrap())
|
||||
.copied(),
|
||||
);
|
||||
|
||||
// keep track of indices sorted by the relevance column
|
||||
let sort_indices = sort_to_indices(
|
||||
&relevance_scores,
|
||||
Some(SortOptions {
|
||||
descending: true,
|
||||
..Default::default()
|
||||
}),
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// add relevance scores to columns
|
||||
let mut columns = combined_results.columns().to_vec();
|
||||
columns.push(Arc::new(relevance_scores));
|
||||
|
||||
// sort by the relevance scores
|
||||
let columns = columns
|
||||
.iter()
|
||||
.map(|c| take(c, &sort_indices, None).unwrap())
|
||||
.collect();
|
||||
|
||||
// add relevance score to schema
|
||||
let mut fields = combined_results.schema().fields().to_vec();
|
||||
fields.push(Arc::new(Field::new(
|
||||
RELEVANCE_SCORE,
|
||||
DataType::Float32,
|
||||
false,
|
||||
)));
|
||||
let schema = Schema::new(fields);
|
||||
|
||||
let combined_results = RecordBatch::try_new(Arc::new(schema), columns)?;
|
||||
|
||||
Ok(combined_results)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test {
|
||||
use super::*;
|
||||
use arrow_array::StringArray;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rrf_reranker() {
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("name", DataType::Utf8, false),
|
||||
Field::new(ROW_ID, DataType::UInt64, false),
|
||||
]));
|
||||
|
||||
let vec_results = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"])),
|
||||
Arc::new(UInt64Array::from(vec![1, 4, 2, 5, 3])),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let fts_results = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(StringArray::from(vec!["bar", "bean", "dog"])),
|
||||
Arc::new(UInt64Array::from(vec![4, 5, 3])),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// scores should be calculated as:
|
||||
// - foo = 1/1 = 1.0
|
||||
// - bar = 1/2 + 1/1 = 1.5
|
||||
// - baz = 1/3 = 0.333
|
||||
// - bean = 1/4 + 1/2 = 0.75
|
||||
// - dog = 1/5 + 1/3 = 0.533
|
||||
// then we should get the result ranked in descending order
|
||||
|
||||
let reranker = RRFReranker::new(1.0);
|
||||
|
||||
let result = reranker
|
||||
.rerank_hybrid("", vec_results, fts_results)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(3, result.schema().fields().len());
|
||||
assert_eq!("name", result.schema().fields().first().unwrap().name());
|
||||
assert_eq!(ROW_ID, result.schema().fields().get(1).unwrap().name());
|
||||
assert_eq!(
|
||||
RELEVANCE_SCORE,
|
||||
result.schema().fields().get(2).unwrap().name()
|
||||
);
|
||||
|
||||
let names: StringArray = downcast_array(result.column(0));
|
||||
assert_eq!(
|
||||
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec!["bar", "foo", "bean", "dog", "baz"]
|
||||
);
|
||||
|
||||
let ids: UInt64Array = downcast_array(result.column(1));
|
||||
assert_eq!(
|
||||
ids.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec![4, 1, 5, 3, 2]
|
||||
);
|
||||
|
||||
let scores: Float32Array = downcast_array(result.column(2));
|
||||
assert_eq!(
|
||||
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec![1.5, 1.0, 0.75, 1.0 / 5.0 + 1.0 / 3.0, 1.0 / 3.0]
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -18,9 +18,9 @@ use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow::array::AsArray;
|
||||
use arrow::datatypes::Float32Type;
|
||||
use arrow::datatypes::{Float32Type, UInt8Type};
|
||||
use arrow_array::{RecordBatchIterator, RecordBatchReader};
|
||||
use arrow_schema::{Field, Schema, SchemaRef};
|
||||
use arrow_schema::{DataType, Field, Schema, SchemaRef};
|
||||
use async_trait::async_trait;
|
||||
use datafusion_physical_plan::display::DisplayableExecutionPlan;
|
||||
use datafusion_physical_plan::projection::ProjectionExec;
|
||||
@@ -58,8 +58,8 @@ use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, M
|
||||
use crate::error::{Error, Result};
|
||||
use crate::index::scalar::FtsIndexBuilder;
|
||||
use crate::index::vector::{
|
||||
suggested_num_partitions_for_hnsw, IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder,
|
||||
IvfPqIndexBuilder, VectorIndex,
|
||||
suggested_num_partitions_for_hnsw, IvfFlatIndexBuilder, IvfHnswPqIndexBuilder,
|
||||
IvfHnswSqIndexBuilder, IvfPqIndexBuilder, VectorIndex,
|
||||
};
|
||||
use crate::index::IndexStatistics;
|
||||
use crate::index::{
|
||||
@@ -1306,6 +1306,44 @@ impl NativeTable {
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn create_ivf_flat_index(
|
||||
&self,
|
||||
index: IvfFlatIndexBuilder,
|
||||
field: &Field,
|
||||
replace: bool,
|
||||
) -> Result<()> {
|
||||
if !supported_vector_data_type(field.data_type()) {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"An IVF Flat index cannot be created on the column `{}` which has data type {}",
|
||||
field.name(),
|
||||
field.data_type()
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
let num_partitions = if let Some(n) = index.num_partitions {
|
||||
n
|
||||
} else {
|
||||
suggested_num_partitions(self.count_rows(None).await?)
|
||||
};
|
||||
let mut dataset = self.dataset.get_mut().await?;
|
||||
let lance_idx_params = lance::index::vector::VectorIndexParams::ivf_flat(
|
||||
num_partitions as usize,
|
||||
index.distance_type.into(),
|
||||
);
|
||||
dataset
|
||||
.create_index(
|
||||
&[field.name()],
|
||||
IndexType::Vector,
|
||||
None,
|
||||
&lance_idx_params,
|
||||
replace,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_ivf_pq_index(
|
||||
&self,
|
||||
index: IvfPqIndexBuilder,
|
||||
@@ -1778,6 +1816,10 @@ impl TableInternal for NativeTable {
|
||||
Index::Bitmap(_) => self.create_bitmap_index(field, opts).await,
|
||||
Index::LabelList(_) => self.create_label_list_index(field, opts).await,
|
||||
Index::FTS(fts_opts) => self.create_fts_index(field, fts_opts, opts.replace).await,
|
||||
Index::IvfFlat(ivf_flat) => {
|
||||
self.create_ivf_flat_index(ivf_flat, field, opts.replace)
|
||||
.await
|
||||
}
|
||||
Index::IvfPq(ivf_pq) => self.create_ivf_pq_index(ivf_pq, field, opts.replace).await,
|
||||
Index::IvfHnswPq(ivf_hnsw_pq) => {
|
||||
self.create_ivf_hnsw_pq_index(ivf_hnsw_pq, field, opts.replace)
|
||||
@@ -1848,14 +1890,21 @@ impl TableInternal for NativeTable {
|
||||
message: format!("Column {} not found in dataset schema", column),
|
||||
})?;
|
||||
|
||||
if let arrow_schema::DataType::FixedSizeList(f, dim) = field.data_type() {
|
||||
if !f.data_type().is_floating() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"The data type of the vector column '{}' is not a floating point type",
|
||||
column
|
||||
),
|
||||
});
|
||||
let mut is_binary = false;
|
||||
if let arrow_schema::DataType::FixedSizeList(element, dim) = field.data_type() {
|
||||
match element.data_type() {
|
||||
e_type if e_type.is_floating() => {}
|
||||
e_type if *e_type == DataType::UInt8 => {
|
||||
is_binary = true;
|
||||
}
|
||||
_ => {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"The data type of the vector column '{}' is not a floating point type",
|
||||
column
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
if dim != query_vector.len() as i32 {
|
||||
return Err(Error::InvalidInput {
|
||||
@@ -1870,12 +1919,22 @@ impl TableInternal for NativeTable {
|
||||
}
|
||||
}
|
||||
|
||||
let query_vector = query_vector.as_primitive::<Float32Type>();
|
||||
scanner.nearest(
|
||||
&column,
|
||||
query_vector,
|
||||
query.base.limit.unwrap_or(DEFAULT_TOP_K),
|
||||
)?;
|
||||
if is_binary {
|
||||
let query_vector = arrow::compute::cast(&query_vector, &DataType::UInt8)?;
|
||||
let query_vector = query_vector.as_primitive::<UInt8Type>();
|
||||
scanner.nearest(
|
||||
&column,
|
||||
query_vector,
|
||||
query.base.limit.unwrap_or(DEFAULT_TOP_K),
|
||||
)?;
|
||||
} else {
|
||||
let query_vector = query_vector.as_primitive::<Float32Type>();
|
||||
scanner.nearest(
|
||||
&column,
|
||||
query_vector,
|
||||
query.base.limit.unwrap_or(DEFAULT_TOP_K),
|
||||
)?;
|
||||
}
|
||||
}
|
||||
scanner.limit(
|
||||
query.base.limit.map(|limit| limit as i64),
|
||||
@@ -1885,6 +1944,7 @@ impl TableInternal for NativeTable {
|
||||
if let Some(ef) = query.ef {
|
||||
scanner.ef(ef);
|
||||
}
|
||||
scanner.distance_range(query.lower_bound, query.upper_bound);
|
||||
scanner.use_index(query.use_index);
|
||||
scanner.prefilter(query.base.prefilter);
|
||||
match query.base.select {
|
||||
|
||||
@@ -110,7 +110,7 @@ pub(crate) fn default_vector_column(schema: &Schema, dim: Option<i32>) -> Result
|
||||
.iter()
|
||||
.filter_map(|field| match field.data_type() {
|
||||
arrow_schema::DataType::FixedSizeList(f, d)
|
||||
if f.data_type().is_floating()
|
||||
if (f.data_type().is_floating() || f.data_type() == &DataType::UInt8)
|
||||
&& dim.map(|expect| *d == expect).unwrap_or(true) =>
|
||||
{
|
||||
Some(field.name())
|
||||
@@ -171,7 +171,9 @@ pub fn supported_fts_data_type(dtype: &DataType) -> bool {
|
||||
|
||||
pub fn supported_vector_data_type(dtype: &DataType) -> bool {
|
||||
match dtype {
|
||||
DataType::FixedSizeList(inner, _) => DataType::is_floating(inner.data_type()),
|
||||
DataType::FixedSizeList(inner, _) => {
|
||||
DataType::is_floating(inner.data_type()) || *inner.data_type() == DataType::UInt8
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user