mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 05:19:58 +00:00
Compare commits
48 Commits
python-v0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
92d845fa72 | ||
|
|
397813f6a4 | ||
|
|
50c30c5d34 | ||
|
|
c9f248b058 | ||
|
|
0cb6da6b7e | ||
|
|
aec8332eb5 | ||
|
|
46061070e6 | ||
|
|
dae8334d0b | ||
|
|
8c81968b59 | ||
|
|
16cf2990f3 | ||
|
|
0a0f667bbd | ||
|
|
03753fd84b | ||
|
|
55cceaa309 | ||
|
|
c3797eb834 | ||
|
|
c0d0f38494 | ||
|
|
6a8ab78d0a | ||
|
|
27404c8623 | ||
|
|
f181c7e77f | ||
|
|
e70fd4fecc | ||
|
|
ac0068b80e | ||
|
|
ebac960571 | ||
|
|
59b57055e7 | ||
|
|
591c8de8fc | ||
|
|
f835ff310f | ||
|
|
cf8c2edaf4 | ||
|
|
61a714a459 | ||
|
|
5ddd84cec0 | ||
|
|
27ef0bb0a2 | ||
|
|
25402ba6ec | ||
|
|
37c359ed40 | ||
|
|
06cdf00987 | ||
|
|
144b7f5d54 | ||
|
|
edc9b9adec | ||
|
|
d11b2a6975 | ||
|
|
980aa70e2d | ||
|
|
d83e5a0208 | ||
|
|
16a6b9ce8f | ||
|
|
e3c6213333 | ||
|
|
00552439d9 | ||
|
|
c0ee370f83 | ||
|
|
17e4022045 | ||
|
|
c3ebac1a92 | ||
|
|
10f919a0a9 | ||
|
|
8af5476395 | ||
|
|
bcbbeb7a00 | ||
|
|
d6c0f75078 | ||
|
|
e820e356a0 | ||
|
|
509286492f |
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.14.1-beta.1"
|
||||
current_version = "0.14.1"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
4
.github/workflows/make-release-commit.yml
vendored
4
.github/workflows/make-release-commit.yml
vendored
@@ -97,3 +97,7 @@ jobs:
|
||||
if: ${{ !inputs.dry_run && inputs.other }}
|
||||
with:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
- uses: ./.github/workflows/update_package_lock_nodejs
|
||||
if: ${{ !inputs.dry_run && inputs.other }}
|
||||
with:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
200
.github/workflows/npm-publish.yml
vendored
200
.github/workflows/npm-publish.yml
vendored
@@ -159,7 +159,7 @@ jobs:
|
||||
- name: Install common dependencies
|
||||
run: |
|
||||
apk add protobuf-dev curl clang mold grep npm bash
|
||||
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y --default-toolchain 1.80.0
|
||||
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y
|
||||
echo "source $HOME/.cargo/env" >> saved_env
|
||||
echo "export CC=clang" >> saved_env
|
||||
echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=-crt-static,+avx2,+fma,+f16c -Clinker=clang -Clink-arg=-fuse-ld=mold'" >> saved_env
|
||||
@@ -167,7 +167,7 @@ jobs:
|
||||
if: ${{ matrix.config.arch == 'aarch64' }}
|
||||
run: |
|
||||
source "$HOME/.cargo/env"
|
||||
rustup target add aarch64-unknown-linux-musl --toolchain 1.80.0
|
||||
rustup target add aarch64-unknown-linux-musl
|
||||
crt=$(realpath $(dirname $(rustup which rustc))/../lib/rustlib/aarch64-unknown-linux-musl/lib/self-contained)
|
||||
sysroot_lib=/usr/aarch64-unknown-linux-musl/usr/lib
|
||||
apk_url=https://dl-cdn.alpinelinux.org/alpine/latest-stable/main/aarch64/
|
||||
@@ -262,7 +262,7 @@ jobs:
|
||||
- name: Install common dependencies
|
||||
run: |
|
||||
apk add protobuf-dev curl clang mold grep npm bash openssl-dev openssl-libs-static
|
||||
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y --default-toolchain 1.80.0
|
||||
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y
|
||||
echo "source $HOME/.cargo/env" >> saved_env
|
||||
echo "export CC=clang" >> saved_env
|
||||
echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=-crt-static,+avx2,+fma,+f16c -Clinker=clang -Clink-arg=-fuse-ld=mold'" >> saved_env
|
||||
@@ -272,7 +272,7 @@ jobs:
|
||||
if: ${{ matrix.config.arch == 'aarch64' }}
|
||||
run: |
|
||||
source "$HOME/.cargo/env"
|
||||
rustup target add aarch64-unknown-linux-musl --toolchain 1.80.0
|
||||
rustup target add aarch64-unknown-linux-musl
|
||||
crt=$(realpath $(dirname $(rustup which rustc))/../lib/rustlib/aarch64-unknown-linux-musl/lib/self-contained)
|
||||
sysroot_lib=/usr/aarch64-unknown-linux-musl/usr/lib
|
||||
apk_url=https://dl-cdn.alpinelinux.org/alpine/latest-stable/main/aarch64/
|
||||
@@ -334,50 +334,51 @@ jobs:
|
||||
path: |
|
||||
node/dist/lancedb-vectordb-win32*.tgz
|
||||
|
||||
node-windows-arm64:
|
||||
name: vectordb ${{ matrix.config.arch }}-pc-windows-msvc
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
runs-on: ubuntu-latest
|
||||
container: alpine:edge
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
# - arch: x86_64
|
||||
- arch: aarch64
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apk add protobuf-dev curl clang lld llvm19 grep npm bash msitools sed
|
||||
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y --default-toolchain 1.80.0
|
||||
echo "source $HOME/.cargo/env" >> saved_env
|
||||
echo "export CC=clang" >> saved_env
|
||||
echo "export AR=llvm-ar" >> saved_env
|
||||
source "$HOME/.cargo/env"
|
||||
rustup target add ${{ matrix.config.arch }}-pc-windows-msvc --toolchain 1.80.0
|
||||
(mkdir -p sysroot && cd sysroot && sh ../ci/sysroot-${{ matrix.config.arch }}-pc-windows-msvc.sh)
|
||||
echo "export C_INCLUDE_PATH=/usr/${{ matrix.config.arch }}-pc-windows-msvc/usr/include" >> saved_env
|
||||
echo "export CARGO_BUILD_TARGET=${{ matrix.config.arch }}-pc-windows-msvc" >> saved_env
|
||||
- name: Configure x86_64 build
|
||||
if: ${{ matrix.config.arch == 'x86_64' }}
|
||||
run: |
|
||||
echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=+crt-static,+avx2,+fma,+f16c -Clinker=lld -Clink-arg=/LIBPATH:/usr/x86_64-pc-windows-msvc/usr/lib'" >> saved_env
|
||||
- name: Configure aarch64 build
|
||||
if: ${{ matrix.config.arch == 'aarch64' }}
|
||||
run: |
|
||||
echo "export RUSTFLAGS='-Ctarget-feature=+crt-static,+neon,+fp16,+fhm,+dotprod -Clinker=lld -Clink-arg=/LIBPATH:/usr/aarch64-pc-windows-msvc/usr/lib -Clink-arg=arm64rt.lib'" >> saved_env
|
||||
- name: Build Windows Artifacts
|
||||
run: |
|
||||
source ./saved_env
|
||||
bash ci/manylinux_node/build_vectordb.sh ${{ matrix.config.arch }} ${{ matrix.config.arch }}-pc-windows-msvc
|
||||
- name: Upload Windows Artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: node-native-windows-${{ matrix.config.arch }}
|
||||
path: |
|
||||
node/dist/lancedb-vectordb-win32*.tgz
|
||||
# TODO: https://github.com/lancedb/lancedb/issues/1975
|
||||
# node-windows-arm64:
|
||||
# name: vectordb ${{ matrix.config.arch }}-pc-windows-msvc
|
||||
# # if: startsWith(github.ref, 'refs/tags/v')
|
||||
# runs-on: ubuntu-latest
|
||||
# container: alpine:edge
|
||||
# strategy:
|
||||
# fail-fast: false
|
||||
# matrix:
|
||||
# config:
|
||||
# # - arch: x86_64
|
||||
# - arch: aarch64
|
||||
# steps:
|
||||
# - name: Checkout
|
||||
# uses: actions/checkout@v4
|
||||
# - name: Install dependencies
|
||||
# run: |
|
||||
# apk add protobuf-dev curl clang lld llvm19 grep npm bash msitools sed
|
||||
# curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y
|
||||
# echo "source $HOME/.cargo/env" >> saved_env
|
||||
# echo "export CC=clang" >> saved_env
|
||||
# echo "export AR=llvm-ar" >> saved_env
|
||||
# source "$HOME/.cargo/env"
|
||||
# rustup target add ${{ matrix.config.arch }}-pc-windows-msvc
|
||||
# (mkdir -p sysroot && cd sysroot && sh ../ci/sysroot-${{ matrix.config.arch }}-pc-windows-msvc.sh)
|
||||
# echo "export C_INCLUDE_PATH=/usr/${{ matrix.config.arch }}-pc-windows-msvc/usr/include" >> saved_env
|
||||
# echo "export CARGO_BUILD_TARGET=${{ matrix.config.arch }}-pc-windows-msvc" >> saved_env
|
||||
# - name: Configure x86_64 build
|
||||
# if: ${{ matrix.config.arch == 'x86_64' }}
|
||||
# run: |
|
||||
# echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=+crt-static,+avx2,+fma,+f16c -Clinker=lld -Clink-arg=/LIBPATH:/usr/x86_64-pc-windows-msvc/usr/lib'" >> saved_env
|
||||
# - name: Configure aarch64 build
|
||||
# if: ${{ matrix.config.arch == 'aarch64' }}
|
||||
# run: |
|
||||
# echo "export RUSTFLAGS='-Ctarget-feature=+crt-static,+neon,+fp16,+fhm,+dotprod -Clinker=lld -Clink-arg=/LIBPATH:/usr/aarch64-pc-windows-msvc/usr/lib -Clink-arg=arm64rt.lib'" >> saved_env
|
||||
# - name: Build Windows Artifacts
|
||||
# run: |
|
||||
# source ./saved_env
|
||||
# bash ci/manylinux_node/build_vectordb.sh ${{ matrix.config.arch }} ${{ matrix.config.arch }}-pc-windows-msvc
|
||||
# - name: Upload Windows Artifacts
|
||||
# uses: actions/upload-artifact@v4
|
||||
# with:
|
||||
# name: node-native-windows-${{ matrix.config.arch }}
|
||||
# path: |
|
||||
# node/dist/lancedb-vectordb-win32*.tgz
|
||||
|
||||
nodejs-windows:
|
||||
name: lancedb ${{ matrix.target }}
|
||||
@@ -413,57 +414,58 @@ jobs:
|
||||
path: |
|
||||
nodejs/dist/*.node
|
||||
|
||||
nodejs-windows-arm64:
|
||||
name: lancedb ${{ matrix.config.arch }}-pc-windows-msvc
|
||||
# Only runs on tags that matches the make-release action
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
runs-on: ubuntu-latest
|
||||
container: alpine:edge
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
# - arch: x86_64
|
||||
- arch: aarch64
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apk add protobuf-dev curl clang lld llvm19 grep npm bash msitools sed
|
||||
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y --default-toolchain 1.80.0
|
||||
echo "source $HOME/.cargo/env" >> saved_env
|
||||
echo "export CC=clang" >> saved_env
|
||||
echo "export AR=llvm-ar" >> saved_env
|
||||
source "$HOME/.cargo/env"
|
||||
rustup target add ${{ matrix.config.arch }}-pc-windows-msvc --toolchain 1.80.0
|
||||
(mkdir -p sysroot && cd sysroot && sh ../ci/sysroot-${{ matrix.config.arch }}-pc-windows-msvc.sh)
|
||||
echo "export C_INCLUDE_PATH=/usr/${{ matrix.config.arch }}-pc-windows-msvc/usr/include" >> saved_env
|
||||
echo "export CARGO_BUILD_TARGET=${{ matrix.config.arch }}-pc-windows-msvc" >> saved_env
|
||||
printf '#!/bin/sh\ncargo "$@"' > $HOME/.cargo/bin/cargo-xwin
|
||||
chmod u+x $HOME/.cargo/bin/cargo-xwin
|
||||
- name: Configure x86_64 build
|
||||
if: ${{ matrix.config.arch == 'x86_64' }}
|
||||
run: |
|
||||
echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=+crt-static,+avx2,+fma,+f16c -Clinker=lld -Clink-arg=/LIBPATH:/usr/x86_64-pc-windows-msvc/usr/lib'" >> saved_env
|
||||
- name: Configure aarch64 build
|
||||
if: ${{ matrix.config.arch == 'aarch64' }}
|
||||
run: |
|
||||
echo "export RUSTFLAGS='-Ctarget-feature=+crt-static,+neon,+fp16,+fhm,+dotprod -Clinker=lld -Clink-arg=/LIBPATH:/usr/aarch64-pc-windows-msvc/usr/lib -Clink-arg=arm64rt.lib'" >> saved_env
|
||||
- name: Build Windows Artifacts
|
||||
run: |
|
||||
source ./saved_env
|
||||
bash ci/manylinux_node/build_lancedb.sh ${{ matrix.config.arch }}
|
||||
- name: Upload Windows Artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: nodejs-native-windows-${{ matrix.config.arch }}
|
||||
path: |
|
||||
nodejs/dist/*.node
|
||||
# TODO: https://github.com/lancedb/lancedb/issues/1975
|
||||
# nodejs-windows-arm64:
|
||||
# name: lancedb ${{ matrix.config.arch }}-pc-windows-msvc
|
||||
# # Only runs on tags that matches the make-release action
|
||||
# # if: startsWith(github.ref, 'refs/tags/v')
|
||||
# runs-on: ubuntu-latest
|
||||
# container: alpine:edge
|
||||
# strategy:
|
||||
# fail-fast: false
|
||||
# matrix:
|
||||
# config:
|
||||
# # - arch: x86_64
|
||||
# - arch: aarch64
|
||||
# steps:
|
||||
# - name: Checkout
|
||||
# uses: actions/checkout@v4
|
||||
# - name: Install dependencies
|
||||
# run: |
|
||||
# apk add protobuf-dev curl clang lld llvm19 grep npm bash msitools sed
|
||||
# curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y
|
||||
# echo "source $HOME/.cargo/env" >> saved_env
|
||||
# echo "export CC=clang" >> saved_env
|
||||
# echo "export AR=llvm-ar" >> saved_env
|
||||
# source "$HOME/.cargo/env"
|
||||
# rustup target add ${{ matrix.config.arch }}-pc-windows-msvc
|
||||
# (mkdir -p sysroot && cd sysroot && sh ../ci/sysroot-${{ matrix.config.arch }}-pc-windows-msvc.sh)
|
||||
# echo "export C_INCLUDE_PATH=/usr/${{ matrix.config.arch }}-pc-windows-msvc/usr/include" >> saved_env
|
||||
# echo "export CARGO_BUILD_TARGET=${{ matrix.config.arch }}-pc-windows-msvc" >> saved_env
|
||||
# printf '#!/bin/sh\ncargo "$@"' > $HOME/.cargo/bin/cargo-xwin
|
||||
# chmod u+x $HOME/.cargo/bin/cargo-xwin
|
||||
# - name: Configure x86_64 build
|
||||
# if: ${{ matrix.config.arch == 'x86_64' }}
|
||||
# run: |
|
||||
# echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=+crt-static,+avx2,+fma,+f16c -Clinker=lld -Clink-arg=/LIBPATH:/usr/x86_64-pc-windows-msvc/usr/lib'" >> saved_env
|
||||
# - name: Configure aarch64 build
|
||||
# if: ${{ matrix.config.arch == 'aarch64' }}
|
||||
# run: |
|
||||
# echo "export RUSTFLAGS='-Ctarget-feature=+crt-static,+neon,+fp16,+fhm,+dotprod -Clinker=lld -Clink-arg=/LIBPATH:/usr/aarch64-pc-windows-msvc/usr/lib -Clink-arg=arm64rt.lib'" >> saved_env
|
||||
# - name: Build Windows Artifacts
|
||||
# run: |
|
||||
# source ./saved_env
|
||||
# bash ci/manylinux_node/build_lancedb.sh ${{ matrix.config.arch }}
|
||||
# - name: Upload Windows Artifacts
|
||||
# uses: actions/upload-artifact@v4
|
||||
# with:
|
||||
# name: nodejs-native-windows-${{ matrix.config.arch }}
|
||||
# path: |
|
||||
# nodejs/dist/*.node
|
||||
|
||||
release:
|
||||
name: vectordb NPM Publish
|
||||
needs: [node, node-macos, node-linux-gnu, node-linux-musl, node-windows, node-windows-arm64]
|
||||
needs: [node, node-macos, node-linux-gnu, node-linux-musl, node-windows]
|
||||
runs-on: ubuntu-latest
|
||||
# Only runs on tags that matches the make-release action
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
@@ -503,7 +505,7 @@ jobs:
|
||||
|
||||
release-nodejs:
|
||||
name: lancedb NPM Publish
|
||||
needs: [nodejs-macos, nodejs-linux-gnu, nodejs-linux-musl, nodejs-windows, nodejs-windows-arm64]
|
||||
needs: [nodejs-macos, nodejs-linux-gnu, nodejs-linux-musl, nodejs-windows]
|
||||
runs-on: ubuntu-latest
|
||||
# Only runs on tags that matches the make-release action
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
@@ -571,7 +573,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: main
|
||||
persist-credentials: false
|
||||
token: ${{ secrets.LANCEDB_RELEASE_TOKEN }}
|
||||
fetch-depth: 0
|
||||
lfs: true
|
||||
- uses: ./.github/workflows/update_package_lock
|
||||
@@ -589,7 +591,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: main
|
||||
persist-credentials: false
|
||||
token: ${{ secrets.LANCEDB_RELEASE_TOKEN }}
|
||||
fetch-depth: 0
|
||||
lfs: true
|
||||
- uses: ./.github/workflows/update_package_lock_nodejs
|
||||
|
||||
40
.github/workflows/rust.yml
vendored
40
.github/workflows/rust.yml
vendored
@@ -185,7 +185,7 @@ jobs:
|
||||
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\Llvm\x64\bin"
|
||||
|
||||
# Add MSVC runtime libraries to LIB
|
||||
$env:LIB = "C:\BuildTools\VC\Tools\MSVC\$latestVersion\lib\arm64;" +
|
||||
$env:LIB = "C:\BuildTools\VC\Tools\MSVC\$latestVersion\lib\arm64;" +
|
||||
"C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\um\arm64;" +
|
||||
"C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\ucrt\arm64"
|
||||
Add-Content $env:GITHUB_ENV "LIB=$env:LIB"
|
||||
@@ -238,3 +238,41 @@ jobs:
|
||||
$env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT
|
||||
cargo build --target aarch64-pc-windows-msvc
|
||||
cargo test --target aarch64-pc-windows-msvc
|
||||
|
||||
msrv:
|
||||
# Check the minimum supported Rust version
|
||||
name: MSRV Check - Rust v${{ matrix.msrv }}
|
||||
runs-on: ubuntu-24.04
|
||||
strategy:
|
||||
matrix:
|
||||
msrv: ["1.78.0"] # This should match up with rust-version in Cargo.toml
|
||||
env:
|
||||
# Need up-to-date compilers for kernels
|
||||
CC: clang-18
|
||||
CXX: clang++-18
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: true
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt update
|
||||
sudo apt install -y protobuf-compiler libssl-dev
|
||||
- name: Install ${{ matrix.msrv }}
|
||||
uses: dtolnay/rust-toolchain@master
|
||||
with:
|
||||
toolchain: ${{ matrix.msrv }}
|
||||
- name: Downgrade dependencies
|
||||
# These packages have newer requirements for MSRV
|
||||
run: |
|
||||
cargo update -p aws-sdk-bedrockruntime --precise 1.64.0
|
||||
cargo update -p aws-sdk-dynamodb --precise 1.55.0
|
||||
cargo update -p aws-config --precise 1.5.10
|
||||
cargo update -p aws-sdk-kms --precise 1.51.0
|
||||
cargo update -p aws-sdk-s3 --precise 1.65.0
|
||||
cargo update -p aws-sdk-sso --precise 1.50.0
|
||||
cargo update -p aws-sdk-ssooidc --precise 1.51.0
|
||||
cargo update -p aws-sdk-sts --precise 1.51.0
|
||||
cargo update -p home --precise 0.5.9
|
||||
- name: cargo +${{ matrix.msrv }} check
|
||||
run: cargo check --workspace --tests --benches --all-features
|
||||
|
||||
4
.github/workflows/upload_wheel/action.yml
vendored
4
.github/workflows/upload_wheel/action.yml
vendored
@@ -22,7 +22,7 @@ runs:
|
||||
shell: bash
|
||||
id: choose_repo
|
||||
run: |
|
||||
if [ ${{ github.ref }} == "*beta*" ]; then
|
||||
if [[ ${{ github.ref }} == *beta* ]]; then
|
||||
echo "repo=fury" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "repo=pypi" >> $GITHUB_OUTPUT
|
||||
@@ -33,7 +33,7 @@ runs:
|
||||
FURY_TOKEN: ${{ inputs.fury_token }}
|
||||
PYPI_TOKEN: ${{ inputs.pypi_token }}
|
||||
run: |
|
||||
if [ ${{ steps.choose_repo.outputs.repo }} == "fury" ]; then
|
||||
if [[ ${{ steps.choose_repo.outputs.repo }} == fury ]]; then
|
||||
WHEEL=$(ls target/wheels/lancedb-*.whl 2> /dev/null | head -n 1)
|
||||
echo "Uploading $WHEEL to Fury"
|
||||
curl -f -F package=@$WHEEL https://$FURY_TOKEN@push.fury.io/lancedb/
|
||||
|
||||
20
Cargo.toml
20
Cargo.toml
@@ -18,19 +18,19 @@ repository = "https://github.com/lancedb/lancedb"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
keywords = ["lancedb", "lance", "database", "vector", "search"]
|
||||
categories = ["database-implementations"]
|
||||
rust-version = "1.80.0" # TODO: lower this once we upgrade Lance again.
|
||||
rust-version = "1.78.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.20.0", "features" = [
|
||||
lance = { "version" = "=0.21.1", "features" = [
|
||||
"dynamodb",
|
||||
] }
|
||||
lance-io = "0.20.0"
|
||||
lance-index = "0.20.0"
|
||||
lance-linalg = "0.20.0"
|
||||
lance-table = "0.20.0"
|
||||
lance-testing = "0.20.0"
|
||||
lance-datafusion = "0.20.0"
|
||||
lance-encoding = "0.20.0"
|
||||
], git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||
lance-io = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||
lance-index = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||
lance-linalg = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||
lance-table = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||
lance-testing = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||
lance-datafusion = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||
lance-encoding = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "53.2", optional = false }
|
||||
arrow-array = "53.2"
|
||||
|
||||
@@ -62,6 +62,7 @@ plugins:
|
||||
# for cross references
|
||||
- https://arrow.apache.org/docs/objects.inv
|
||||
- https://pandas.pydata.org/docs/objects.inv
|
||||
- https://lancedb.github.io/lance/objects.inv
|
||||
- mkdocs-jupyter
|
||||
- render_swagger:
|
||||
allow_arbitrary_locations: true
|
||||
@@ -231,6 +232,7 @@ nav:
|
||||
- 🐍 Python: python/saas-python.md
|
||||
- 👾 JavaScript: javascript/modules.md
|
||||
- REST API: cloud/rest.md
|
||||
- FAQs: cloud/cloud_faq.md
|
||||
|
||||
- Quick start: basic.md
|
||||
- Concepts:
|
||||
@@ -357,6 +359,7 @@ nav:
|
||||
- 🐍 Python: python/saas-python.md
|
||||
- 👾 JavaScript: javascript/modules.md
|
||||
- REST API: cloud/rest.md
|
||||
- FAQs: cloud/cloud_faq.md
|
||||
|
||||
extra_css:
|
||||
- styles/global.css
|
||||
|
||||
@@ -141,14 +141,6 @@ recommend switching to stable releases.
|
||||
--8<-- "python/python/tests/docs/test_basic.py:connect_async"
|
||||
```
|
||||
|
||||
!!! note "Asynchronous Python API"
|
||||
|
||||
The asynchronous Python API is new and has some slight differences compared
|
||||
to the synchronous API. Feel free to start using the asynchronous version.
|
||||
Once all features have migrated we will start to move the synchronous API to
|
||||
use the same syntax as the asynchronous API. To help with this migration we
|
||||
have created a [migration guide](migration.md) detailing the differences.
|
||||
|
||||
=== "Typescript[^1]"
|
||||
|
||||
=== "@lancedb/lancedb"
|
||||
|
||||
34
docs/src/cloud/cloud_faq.md
Normal file
34
docs/src/cloud/cloud_faq.md
Normal file
@@ -0,0 +1,34 @@
|
||||
This section provides answers to the most common questions asked about LanceDB Cloud. By following these guidelines, you can ensure a smooth, performant experience with LanceDB Cloud.
|
||||
|
||||
### Should I reuse the database connection?
|
||||
Yes! It is recommended to establish a single database connection and maintain it throughout your interaction with the tables within.
|
||||
|
||||
LanceDB uses HTTP connections to communicate with the servers. By re-using the Connection object, you avoid the overhead of repeatedly establishing HTTP connections, significantly improving efficiency.
|
||||
|
||||
### Should I re-use the `Table` object?
|
||||
`table = db.open_table()` should be called once and used for all subsequent table operations. If there are changes to the opened table, `table` always reflect the **latest version** of the data.
|
||||
|
||||
### What should I do if I need to search for rows by `id`?
|
||||
LanceDB Cloud currently does not support an ID or primary key column. You are recommended to add a
|
||||
user-defined ID column. To significantly improve the query performance with SQL causes, a scalar BITMAP/BTREE index should be created on this column.
|
||||
|
||||
### What are the vector indexing types supported by LanceDB Cloud?
|
||||
We support `IVF_PQ` and `IVF_HNSW_SQ` as the `index_type` which is passed to `create_index`. LanceDB Cloud tunes the indexing parameters automatically to achieve the best tradeoff between query latency and query quality.
|
||||
|
||||
### When I add new rows to a table, do I need to manually update the index?
|
||||
No! LanceDB Cloud triggers an asynchronous background job to index the new vectors.
|
||||
|
||||
Even though indexing is asynchronous, your vectors will still be immediately searchable. LanceDB uses brute-force search to search over unindexed rows. This makes you new data is immediately available, but does increase latency temporarily. To disable the brute-force part of search, set the `fast_search` flag in your query to `true`.
|
||||
|
||||
### Do I need to reindex the whole dataset if only a small portion of the data is deleted or updated?
|
||||
No! Similar to adding data to the table, LanceDB Cloud triggers an asynchronous background job to update the existing indices. Therefore, no action is needed from users and there is absolutely no
|
||||
downtime expected.
|
||||
|
||||
### How do I know whether an index has been created?
|
||||
While index creation in LanceDB Cloud is generally fast, querying immediately after a `create_index` call may result in errors. It's recommended to use `list_indices` to verify index creation before querying.
|
||||
|
||||
### Why is my query latency higher than expected?
|
||||
Multiple factors can impact query latency. To reduce query latency, consider the following:
|
||||
- Send pre-warm queries: send a few queries to warm up the cache before an actual user query.
|
||||
- Check network latency: LanceDB Cloud is hosted in AWS `us-east-1` region. It is recommended to run queries from an EC2 instance that is in the same region.
|
||||
- Create scalar indices: If you are filtering on metadata, it is recommended to create scalar indices on those columns. This will speedup searches with metadata filtering. See [here](../guides/scalar_index.md) for more details on creating a scalar index.
|
||||
@@ -804,12 +804,13 @@ a table:
|
||||
|
||||
You can add new columns to the table with the `add_columns` method. New columns
|
||||
are filled with values based on a SQL expression. For example, you can add a new
|
||||
column `y` to the table and fill it with the value of `x + 1`.
|
||||
column `y` to the table, fill it with the value of `x * 2` and set the expected
|
||||
data type for it.
|
||||
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
table.add_columns({"double_price": "price * 2"})
|
||||
--8<-- "python/python/tests/docs/test_basic.py:add_columns"
|
||||
```
|
||||
**API Reference:** [lancedb.table.Table.add_columns][]
|
||||
|
||||
@@ -849,8 +850,7 @@ rewriting the column, which can be a heavy operation.
|
||||
|
||||
```python
|
||||
import pyarrow as pa
|
||||
table.alter_column({"path": "double_price", "rename": "dbl_price",
|
||||
"data_type": pa.float32(), "nullable": False})
|
||||
--8<-- "python/python/tests/docs/test_basic.py:alter_columns"
|
||||
```
|
||||
**API Reference:** [lancedb.table.Table.alter_columns][]
|
||||
|
||||
@@ -873,7 +873,7 @@ will remove the column from the schema.
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
table.drop_columns(["dbl_price"])
|
||||
--8<-- "python/python/tests/docs/test_basic.py:drop_columns"
|
||||
```
|
||||
**API Reference:** [lancedb.table.Table.drop_columns][]
|
||||
|
||||
|
||||
@@ -1,81 +1,14 @@
|
||||
# Rust-backed Client Migration Guide
|
||||
|
||||
In an effort to ensure all clients have the same set of capabilities we have begun migrating the
|
||||
python and node clients onto a common Rust base library. In python, this new client is part of
|
||||
the same lancedb package, exposed as an asynchronous client. Once the asynchronous client has
|
||||
reached full functionality we will begin migrating the synchronous library to be a thin wrapper
|
||||
around the asynchronous client.
|
||||
In an effort to ensure all clients have the same set of capabilities we have
|
||||
migrated the Python and Node clients onto a common Rust base library. In Python,
|
||||
both the synchronous and asynchronous clients are based on this implementation.
|
||||
In Node, the new client is available as `@lancedb/lancedb`, which replaces
|
||||
the existing `vectordb` package.
|
||||
|
||||
This guide describes the differences between the two APIs and will hopefully assist users
|
||||
This guide describes the differences between the two Node APIs and will hopefully assist users
|
||||
that would like to migrate to the new API.
|
||||
|
||||
## Python
|
||||
### Closeable Connections
|
||||
|
||||
The Connection now has a `close` method. You can call this when
|
||||
you are done with the connection to eagerly free resources. Currently
|
||||
this is limited to freeing/closing the HTTP connection for remote
|
||||
connections. In the future we may add caching or other resources to
|
||||
native connections so this is probably a good practice even if you
|
||||
aren't using remote connections.
|
||||
|
||||
In addition, the connection can be used as a context manager which may
|
||||
be a more convenient way to ensure the connection is closed.
|
||||
|
||||
```python
|
||||
import lancedb
|
||||
|
||||
async def my_async_fn():
|
||||
with await lancedb.connect_async("my_uri") as db:
|
||||
print(await db.table_names())
|
||||
```
|
||||
|
||||
It is not mandatory to call the `close` method. If you do not call it
|
||||
then the connection will be closed when the object is garbage collected.
|
||||
|
||||
### Closeable Table
|
||||
|
||||
The Table now also has a `close` method, similar to the connection. This
|
||||
can be used to eagerly free the cache used by a Table object. Similar to
|
||||
the connection, it can be used as a context manager and it is not mandatory
|
||||
to call the `close` method.
|
||||
|
||||
#### Changes to Table APIs
|
||||
|
||||
- Previously `Table.schema` was a property. Now it is an async method.
|
||||
- The method `Table.__len__` was removed and `len(table)` will no longer
|
||||
work. Use `Table.count_rows` instead.
|
||||
|
||||
#### Creating Indices
|
||||
|
||||
The `Table.create_index` method is now used for creating both vector indices
|
||||
and scalar indices. It currently requires a column name to be specified (the
|
||||
column to index). Vector index defaults are now smarter and scale better with
|
||||
the size of the data.
|
||||
|
||||
To specify index configuration details you will need to specify which kind of
|
||||
index you are using.
|
||||
|
||||
#### Querying
|
||||
|
||||
The `Table.search` method has been renamed to `AsyncTable.vector_search` for
|
||||
clarity.
|
||||
|
||||
### Features not yet supported
|
||||
|
||||
The following features are not yet supported by the asynchronous API. However,
|
||||
we plan to support them soon.
|
||||
|
||||
- You cannot specify an embedding function when creating or opening a table.
|
||||
You must calculate embeddings yourself if using the asynchronous API
|
||||
- The merge insert operation is not supported in the asynchronous API
|
||||
- Cleanup / compact / optimize indices are not supported in the asynchronous API
|
||||
- add / alter columns is not supported in the asynchronous API
|
||||
- The asynchronous API does not yet support any full text search or reranking
|
||||
search
|
||||
- Remote connections to LanceDb Cloud are not yet supported.
|
||||
- The method Table.head is not yet supported.
|
||||
|
||||
## TypeScript/JavaScript
|
||||
|
||||
For JS/TS users, we offer a brand new SDK [@lancedb/lancedb](https://www.npmjs.com/package/@lancedb/lancedb)
|
||||
|
||||
@@ -47,6 +47,8 @@ is also an [asynchronous API client](#connections-asynchronous).
|
||||
|
||||
::: lancedb.embeddings.registry.EmbeddingFunctionRegistry
|
||||
|
||||
::: lancedb.embeddings.base.EmbeddingFunctionConfig
|
||||
|
||||
::: lancedb.embeddings.base.EmbeddingFunction
|
||||
|
||||
::: lancedb.embeddings.base.TextEmbeddingFunction
|
||||
@@ -127,8 +129,16 @@ lists the indices that LanceDb supports.
|
||||
|
||||
::: lancedb.index.LabelList
|
||||
|
||||
::: lancedb.index.FTS
|
||||
|
||||
::: lancedb.index.IvfPq
|
||||
|
||||
::: lancedb.index.HnswPq
|
||||
|
||||
::: lancedb.index.HnswSq
|
||||
|
||||
::: lancedb.index.IvfFlat
|
||||
|
||||
## Querying (Asynchronous)
|
||||
|
||||
Queries allow you to return data from your database. Basic queries can be
|
||||
|
||||
@@ -17,4 +17,8 @@ pip install lancedb
|
||||
## Table
|
||||
|
||||
::: lancedb.remote.table.RemoteTable
|
||||
|
||||
options:
|
||||
filters:
|
||||
- "!cleanup_old_versions"
|
||||
- "!compact_files"
|
||||
- "!optimize"
|
||||
|
||||
@@ -13,11 +13,15 @@ A vector search finds the approximate or exact nearest neighbors to a given quer
|
||||
Distance metrics are a measure of the similarity between a pair of vectors.
|
||||
Currently, LanceDB supports the following metrics:
|
||||
|
||||
| Metric | Description |
|
||||
| -------- | --------------------------------------------------------------------------- |
|
||||
| `l2` | [Euclidean / L2 distance](https://en.wikipedia.org/wiki/Euclidean_distance) |
|
||||
| `cosine` | [Cosine Similarity](https://en.wikipedia.org/wiki/Cosine_similarity) |
|
||||
| `dot` | [Dot Production](https://en.wikipedia.org/wiki/Dot_product) |
|
||||
| Metric | Description |
|
||||
| --------- | --------------------------------------------------------------------------- |
|
||||
| `l2` | [Euclidean / L2 distance](https://en.wikipedia.org/wiki/Euclidean_distance) |
|
||||
| `cosine` | [Cosine Similarity](https://en.wikipedia.org/wiki/Cosine_similarity) |
|
||||
| `dot` | [Dot Production](https://en.wikipedia.org/wiki/Dot_product) |
|
||||
| `hamming` | [Hamming Distance](https://en.wikipedia.org/wiki/Hamming_distance) |
|
||||
|
||||
!!! note
|
||||
The `hamming` metric is only available for binary vectors.
|
||||
|
||||
## Exhaustive search (kNN)
|
||||
|
||||
@@ -107,6 +111,31 @@ an ANN search means that using an index often involves a trade-off between recal
|
||||
See the [IVF_PQ index](./concepts/index_ivfpq.md) for a deeper description of how `IVF_PQ`
|
||||
indexes work in LanceDB.
|
||||
|
||||
## Binary vector
|
||||
|
||||
LanceDB supports binary vectors as a data type, and has the ability to search binary vectors with hamming distance. The binary vectors are stored as uint8 arrays (every 8 bits are stored as a byte):
|
||||
|
||||
!!! note
|
||||
The dim of the binary vector must be a multiple of 8. A vector of dim 128 will be stored as a uint8 array of size 16.
|
||||
|
||||
=== "Python"
|
||||
|
||||
=== "sync API"
|
||||
|
||||
```python
|
||||
--8<-- "python/python/tests/docs/test_binary_vector.py:imports"
|
||||
|
||||
--8<-- "python/python/tests/docs/test_binary_vector.py:sync_binary_vector"
|
||||
```
|
||||
|
||||
=== "async API"
|
||||
|
||||
```python
|
||||
--8<-- "python/python/tests/docs/test_binary_vector.py:imports"
|
||||
|
||||
--8<-- "python/python/tests/docs/test_binary_vector.py:async_binary_vector"
|
||||
```
|
||||
|
||||
## Output search results
|
||||
|
||||
LanceDB returns vector search results via different formats commonly used in python.
|
||||
|
||||
@@ -16,6 +16,7 @@ excluded_globs = [
|
||||
"../src/concepts/*.md",
|
||||
"../src/ann_indexes.md",
|
||||
"../src/basic.md",
|
||||
"../src/search.md",
|
||||
"../src/hybrid_search/hybrid_search.md",
|
||||
"../src/reranking/*.md",
|
||||
"../src/guides/tuning_retrievers/*.md",
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.14.1-beta.1</version>
|
||||
<version>0.14.1-final.0</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.14.1-beta.1</version>
|
||||
<version>0.14.1-final.0</version>
|
||||
<packaging>pom</packaging>
|
||||
|
||||
<name>LanceDB Parent</name>
|
||||
|
||||
111
node/package-lock.json
generated
111
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.14.1-beta.1",
|
||||
"version": "0.14.1",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "vectordb",
|
||||
"version": "0.14.1-beta.1",
|
||||
"version": "0.14.1",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
@@ -52,14 +52,14 @@
|
||||
"uuid": "^9.0.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.14.1-beta.1",
|
||||
"@lancedb/vectordb-darwin-x64": "0.14.1-beta.1",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.14.1-beta.1",
|
||||
"@lancedb/vectordb-linux-arm64-musl": "0.14.1-beta.1",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.14.1-beta.1",
|
||||
"@lancedb/vectordb-linux-x64-musl": "0.14.1-beta.1",
|
||||
"@lancedb/vectordb-win32-arm64-msvc": "0.14.1-beta.1",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.14.1-beta.1"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.14.1",
|
||||
"@lancedb/vectordb-darwin-x64": "0.14.1",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.14.1",
|
||||
"@lancedb/vectordb-linux-arm64-musl": "0.14.1",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.14.1",
|
||||
"@lancedb/vectordb-linux-x64-musl": "0.14.1",
|
||||
"@lancedb/vectordb-win32-arm64-msvc": "0.14.1",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.14.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@apache-arrow/ts": "^14.0.2",
|
||||
@@ -329,6 +329,97 @@
|
||||
"@jridgewell/sourcemap-codec": "^1.4.10"
|
||||
}
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
||||
"version": "0.14.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.14.1.tgz",
|
||||
"integrity": "sha512-6t7XHR7dBjDmAS/kz5wbe7LPhKW+WkFA16ZPyh0lmuxfnss4VvN3LE6qQBHjzYzB9U6Nu/4ktQ50xZGEPTnc5A==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-x64": {
|
||||
"version": "0.14.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.14.1.tgz",
|
||||
"integrity": "sha512-8q6Kd6XnNPKN8wqj75pHVQ4KFl6z9BaI6lWDiEaCNcO3bjPZkcLFNosJq4raxZ9iUi50Yl0qFJ6qR0XFVTwnnw==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
||||
"version": "0.14.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.14.1.tgz",
|
||||
"integrity": "sha512-4djEMmeNb+p6nW/C4xb8wdMwnIbWfO8fYAwiplOxzxeOpPaUC9rhwUUDCbrJDCpMa8RP5ED4/jC6yT8epaDMDw==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-arm64-musl": {
|
||||
"version": "0.14.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-musl/-/vectordb-linux-arm64-musl-0.14.1.tgz",
|
||||
"integrity": "sha512-c33hSsp16pnC58plzx1OXuifp9Rachx/MshE/L/OReoutt74fFdrRJwUjE4UCAysyY5QdvTrNm9OhDjopQK2Bw==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
||||
"version": "0.14.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.14.1.tgz",
|
||||
"integrity": "sha512-psu6cH9iLiSbUEZD1EWbOA4THGYSwJvS2XICO9yN7A6D41AP/ynYMRZNKWo1fpdi2Fjb0xNQwiNhQyqwbi5gzA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-x64-musl": {
|
||||
"version": "0.14.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-musl/-/vectordb-linux-x64-musl-0.14.1.tgz",
|
||||
"integrity": "sha512-Rg4VWW80HaTFmR7EvNSu+nfRQQM8beO/otBn/Nus5mj5zFw/7cacGRmiEYhDnk5iAn8nauV+Jsi9j2U+C2hp5w==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
||||
"version": "0.14.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.14.1.tgz",
|
||||
"integrity": "sha512-XbifasmMbQIt3V9P0AtQND6M3XFiIAc1ZIgmjzBjOmxwqw4sQUwHMyJGIGOzKFZTK3fPJIGRHId7jAzXuBgfQg==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"win32"
|
||||
]
|
||||
},
|
||||
"node_modules/@neon-rs/cli": {
|
||||
"version": "0.0.160",
|
||||
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.14.1-beta.1",
|
||||
"version": "0.14.1",
|
||||
"description": " Serverless, low-latency vector database for AI applications",
|
||||
"private": false,
|
||||
"main": "dist/index.js",
|
||||
@@ -92,13 +92,13 @@
|
||||
}
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-x64": "0.14.1-beta.1",
|
||||
"@lancedb/vectordb-darwin-arm64": "0.14.1-beta.1",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.14.1-beta.1",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.14.1-beta.1",
|
||||
"@lancedb/vectordb-linux-x64-musl": "0.14.1-beta.1",
|
||||
"@lancedb/vectordb-linux-arm64-musl": "0.14.1-beta.1",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.14.1-beta.1",
|
||||
"@lancedb/vectordb-win32-arm64-msvc": "0.14.1-beta.1"
|
||||
"@lancedb/vectordb-darwin-x64": "0.14.1",
|
||||
"@lancedb/vectordb-darwin-arm64": "0.14.1",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.14.1",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.14.1",
|
||||
"@lancedb/vectordb-linux-x64-musl": "0.14.1",
|
||||
"@lancedb/vectordb-linux-arm64-musl": "0.14.1",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.14.1",
|
||||
"@lancedb/vectordb-win32-arm64-msvc": "0.14.1"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "lancedb-nodejs"
|
||||
edition.workspace = true
|
||||
version = "0.14.1-beta.1"
|
||||
version = "0.14.1"
|
||||
license.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
@@ -12,7 +12,10 @@ categories.workspace = true
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
async-trait.workspace = true
|
||||
arrow-ipc.workspace = true
|
||||
arrow-array.workspace = true
|
||||
arrow-schema.workspace = true
|
||||
env_logger.workspace = true
|
||||
futures.workspace = true
|
||||
lancedb = { path = "../rust/lancedb", features = ["remote"] }
|
||||
|
||||
@@ -20,6 +20,8 @@ import * as arrow18 from "apache-arrow-18";
|
||||
|
||||
import {
|
||||
convertToTable,
|
||||
fromBufferToRecordBatch,
|
||||
fromRecordBatchToBuffer,
|
||||
fromTableToBuffer,
|
||||
makeArrowTable,
|
||||
makeEmptyTable,
|
||||
@@ -553,5 +555,28 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("converting record batches to buffers", function () {
|
||||
it("can convert to buffered record batch and back again", async function () {
|
||||
const records = [
|
||||
{ text: "dog", vector: [0.1, 0.2] },
|
||||
{ text: "cat", vector: [0.3, 0.4] },
|
||||
];
|
||||
const table = await convertToTable(records);
|
||||
const batch = table.batches[0];
|
||||
|
||||
const buffer = await fromRecordBatchToBuffer(batch);
|
||||
const result = await fromBufferToRecordBatch(buffer);
|
||||
|
||||
expect(JSON.stringify(batch.toArray())).toEqual(
|
||||
JSON.stringify(result?.toArray()),
|
||||
);
|
||||
});
|
||||
|
||||
it("converting from buffer returns null if buffer has no record batches", async function () {
|
||||
const result = await fromBufferToRecordBatch(Buffer.from([0x01, 0x02])); // bad data
|
||||
expect(result).toEqual(null);
|
||||
});
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
79
nodejs/__test__/rerankers.test.ts
Normal file
79
nodejs/__test__/rerankers.test.ts
Normal file
@@ -0,0 +1,79 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import { RecordBatch } from "apache-arrow";
|
||||
import * as tmp from "tmp";
|
||||
import { Connection, Index, Table, connect, makeArrowTable } from "../lancedb";
|
||||
import { RRFReranker } from "../lancedb/rerankers";
|
||||
|
||||
describe("rerankers", function () {
|
||||
let tmpDir: tmp.DirResult;
|
||||
let conn: Connection;
|
||||
let table: Table;
|
||||
|
||||
beforeEach(async () => {
|
||||
tmpDir = tmp.dirSync({ unsafeCleanup: true });
|
||||
conn = await connect(tmpDir.name);
|
||||
table = await conn.createTable("mytable", [
|
||||
{ vector: [0.1, 0.1], text: "dog" },
|
||||
{ vector: [0.2, 0.2], text: "cat" },
|
||||
]);
|
||||
await table.createIndex("text", {
|
||||
config: Index.fts(),
|
||||
replace: true,
|
||||
});
|
||||
});
|
||||
|
||||
it("will query with the custom reranker", async function () {
|
||||
const expectedResult = [
|
||||
{
|
||||
text: "albert",
|
||||
// biome-ignore lint/style/useNamingConvention: this is the lance field name
|
||||
_relevance_score: 0.99,
|
||||
},
|
||||
];
|
||||
class MyCustomReranker {
|
||||
async rerankHybrid(
|
||||
_query: string,
|
||||
_vecResults: RecordBatch,
|
||||
_ftsResults: RecordBatch,
|
||||
): Promise<RecordBatch> {
|
||||
// no reranker logic, just return some static data
|
||||
const table = makeArrowTable(expectedResult);
|
||||
return table.batches[0];
|
||||
}
|
||||
}
|
||||
|
||||
let result = await table
|
||||
.query()
|
||||
.nearestTo([0.1, 0.1])
|
||||
.fullTextSearch("dog")
|
||||
.rerank(new MyCustomReranker())
|
||||
.select(["text"])
|
||||
.limit(5)
|
||||
.toArray();
|
||||
|
||||
result = JSON.parse(JSON.stringify(result)); // convert StructRow to Object
|
||||
expect(result).toEqual([
|
||||
{
|
||||
text: "albert",
|
||||
// biome-ignore lint/style/useNamingConvention: this is the lance field name
|
||||
_relevance_score: 0.99,
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it("will query with RRFReranker", async function () {
|
||||
// smoke test to see if the Rust wrapping Typescript is wired up correctly
|
||||
const result = await table
|
||||
.query()
|
||||
.nearestTo([0.1, 0.1])
|
||||
.fullTextSearch("dog")
|
||||
.rerank(await RRFReranker.create())
|
||||
.select(["text"])
|
||||
.limit(5)
|
||||
.toArray();
|
||||
|
||||
expect(result).toHaveLength(2);
|
||||
});
|
||||
});
|
||||
@@ -1058,6 +1058,26 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
||||
expect(results[0].text).toBe(data[0].text);
|
||||
});
|
||||
|
||||
test("full text search without lowercase", async () => {
|
||||
const db = await connect(tmpDir.name);
|
||||
const data = [
|
||||
{ text: "hello world", vector: [0.1, 0.2, 0.3] },
|
||||
{ text: "Hello World", vector: [0.4, 0.5, 0.6] },
|
||||
];
|
||||
const table = await db.createTable("test", data);
|
||||
await table.createIndex("text", {
|
||||
config: Index.fts({ withPosition: false }),
|
||||
});
|
||||
const results = await table.search("hello").toArray();
|
||||
expect(results.length).toBe(2);
|
||||
|
||||
await table.createIndex("text", {
|
||||
config: Index.fts({ withPosition: false, lowercase: false }),
|
||||
});
|
||||
const results2 = await table.search("hello").toArray();
|
||||
expect(results2.length).toBe(1);
|
||||
});
|
||||
|
||||
test("full text search phrase query", async () => {
|
||||
const db = await connect(tmpDir.name);
|
||||
const data = [
|
||||
|
||||
@@ -119,7 +119,9 @@ test("basic table examples", async () => {
|
||||
|
||||
{
|
||||
// --8<-- [start:add_columns]
|
||||
await tbl.addColumns([{ name: "double_price", valueSql: "price * 2" }]);
|
||||
await tbl.addColumns([
|
||||
{ name: "double_price", valueSql: "cast((price * 2) as Float)" },
|
||||
]);
|
||||
// --8<-- [end:add_columns]
|
||||
// --8<-- [start:alter_columns]
|
||||
await tbl.alterColumns([
|
||||
|
||||
@@ -27,7 +27,9 @@ import {
|
||||
List,
|
||||
Null,
|
||||
RecordBatch,
|
||||
RecordBatchFileReader,
|
||||
RecordBatchFileWriter,
|
||||
RecordBatchReader,
|
||||
RecordBatchStreamWriter,
|
||||
Schema,
|
||||
Struct,
|
||||
@@ -810,6 +812,30 @@ export async function fromDataToBuffer(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Read a single record batch from a buffer.
|
||||
*
|
||||
* Returns null if the buffer does not contain a record batch
|
||||
*/
|
||||
export async function fromBufferToRecordBatch(
|
||||
data: Buffer,
|
||||
): Promise<RecordBatch | null> {
|
||||
const iter = await RecordBatchFileReader.readAll(Buffer.from(data)).next()
|
||||
.value;
|
||||
const recordBatch = iter?.next().value;
|
||||
return recordBatch || null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a buffer containing a single record batch
|
||||
*/
|
||||
export async function fromRecordBatchToBuffer(
|
||||
batch: RecordBatch,
|
||||
): Promise<Buffer> {
|
||||
const writer = new RecordBatchFileWriter().writeAll([batch]);
|
||||
return Buffer.from(await writer.toUint8Array());
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialize an Arrow Table into a buffer using the Arrow IPC Stream serialization
|
||||
*
|
||||
|
||||
@@ -62,6 +62,7 @@ export { Index, IndexOptions, IvfPqOptions } from "./indices";
|
||||
export { Table, AddDataOptions, UpdateOptions, OptimizeOptions } from "./table";
|
||||
|
||||
export * as embedding from "./embedding";
|
||||
export * as rerankers from "./rerankers";
|
||||
|
||||
/**
|
||||
* Connect to a LanceDB instance at the given URI.
|
||||
|
||||
@@ -349,6 +349,52 @@ export interface FtsOptions {
|
||||
* which will make the index smaller and faster to build, but will not support phrase queries.
|
||||
*/
|
||||
withPosition?: boolean;
|
||||
|
||||
/**
|
||||
* The tokenizer to use when building the index.
|
||||
* The default is "simple".
|
||||
*
|
||||
* The following tokenizers are available:
|
||||
*
|
||||
* "simple" - Simple tokenizer. This tokenizer splits the text into tokens using whitespace and punctuation as a delimiter.
|
||||
*
|
||||
* "whitespace" - Whitespace tokenizer. This tokenizer splits the text into tokens using whitespace as a delimiter.
|
||||
*
|
||||
* "raw" - Raw tokenizer. This tokenizer does not split the text into tokens and indexes the entire text as a single token.
|
||||
*/
|
||||
baseTokenizer?: "simple" | "whitespace" | "raw";
|
||||
|
||||
/**
|
||||
* language for stemming and stop words
|
||||
* this is only used when `stem` or `remove_stop_words` is true
|
||||
*/
|
||||
language?: string;
|
||||
|
||||
/**
|
||||
* maximum token length
|
||||
* tokens longer than this length will be ignored
|
||||
*/
|
||||
maxTokenLength?: number;
|
||||
|
||||
/**
|
||||
* whether to lowercase tokens
|
||||
*/
|
||||
lowercase?: boolean;
|
||||
|
||||
/**
|
||||
* whether to stem tokens
|
||||
*/
|
||||
stem?: boolean;
|
||||
|
||||
/**
|
||||
* whether to remove stop words
|
||||
*/
|
||||
removeStopWords?: boolean;
|
||||
|
||||
/**
|
||||
* whether to remove punctuation
|
||||
*/
|
||||
asciiFolding?: boolean;
|
||||
}
|
||||
|
||||
export class Index {
|
||||
@@ -450,7 +496,18 @@ export class Index {
|
||||
* For now, the full text search index only supports English, and doesn't support phrase search.
|
||||
*/
|
||||
static fts(options?: Partial<FtsOptions>) {
|
||||
return new Index(LanceDbIndex.fts(options?.withPosition));
|
||||
return new Index(
|
||||
LanceDbIndex.fts(
|
||||
options?.withPosition,
|
||||
options?.baseTokenizer,
|
||||
options?.language,
|
||||
options?.maxTokenLength,
|
||||
options?.lowercase,
|
||||
options?.stem,
|
||||
options?.removeStopWords,
|
||||
options?.asciiFolding,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -16,6 +16,8 @@ import {
|
||||
Table as ArrowTable,
|
||||
type IntoVector,
|
||||
RecordBatch,
|
||||
fromBufferToRecordBatch,
|
||||
fromRecordBatchToBuffer,
|
||||
tableFromIPC,
|
||||
} from "./arrow";
|
||||
import { type IvfPqOptions } from "./indices";
|
||||
@@ -25,6 +27,7 @@ import {
|
||||
Table as NativeTable,
|
||||
VectorQuery as NativeVectorQuery,
|
||||
} from "./native";
|
||||
import { Reranker } from "./rerankers";
|
||||
export class RecordBatchIterator implements AsyncIterator<RecordBatch> {
|
||||
private promisedInner?: Promise<NativeBatchIterator>;
|
||||
private inner?: NativeBatchIterator;
|
||||
@@ -542,6 +545,27 @@ export class VectorQuery extends QueryBase<NativeVectorQuery> {
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
||||
rerank(reranker: Reranker): VectorQuery {
|
||||
super.doCall((inner) =>
|
||||
inner.rerank({
|
||||
rerankHybrid: async (_, args) => {
|
||||
const vecResults = await fromBufferToRecordBatch(args.vecResults);
|
||||
const ftsResults = await fromBufferToRecordBatch(args.ftsResults);
|
||||
const result = await reranker.rerankHybrid(
|
||||
args.query,
|
||||
vecResults as RecordBatch,
|
||||
ftsResults as RecordBatch,
|
||||
);
|
||||
|
||||
const buffer = fromRecordBatchToBuffer(result);
|
||||
return buffer;
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
||||
/** A builder for LanceDB queries. */
|
||||
|
||||
17
nodejs/lancedb/rerankers/index.ts
Normal file
17
nodejs/lancedb/rerankers/index.ts
Normal file
@@ -0,0 +1,17 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import { RecordBatch } from "apache-arrow";
|
||||
|
||||
export * from "./rrf";
|
||||
|
||||
// Interface for a reranker. A reranker is used to rerank the results from a
|
||||
// vector and FTS search. This is useful for combining the results from both
|
||||
// search methods.
|
||||
export interface Reranker {
|
||||
rerankHybrid(
|
||||
query: string,
|
||||
vecResults: RecordBatch,
|
||||
ftsResults: RecordBatch,
|
||||
): Promise<RecordBatch>;
|
||||
}
|
||||
40
nodejs/lancedb/rerankers/rrf.ts
Normal file
40
nodejs/lancedb/rerankers/rrf.ts
Normal file
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import { RecordBatch } from "apache-arrow";
|
||||
import { fromBufferToRecordBatch, fromRecordBatchToBuffer } from "../arrow";
|
||||
import { RrfReranker as NativeRRFReranker } from "../native";
|
||||
|
||||
/**
|
||||
* Reranks the results using the Reciprocal Rank Fusion (RRF) algorithm.
|
||||
*
|
||||
* Internally this uses the Rust implementation
|
||||
*/
|
||||
export class RRFReranker {
|
||||
private inner: NativeRRFReranker;
|
||||
|
||||
constructor(inner: NativeRRFReranker) {
|
||||
this.inner = inner;
|
||||
}
|
||||
|
||||
public static async create(k: number = 60) {
|
||||
return new RRFReranker(
|
||||
await NativeRRFReranker.tryNew(new Float32Array([k])),
|
||||
);
|
||||
}
|
||||
|
||||
async rerankHybrid(
|
||||
query: string,
|
||||
vecResults: RecordBatch,
|
||||
ftsResults: RecordBatch,
|
||||
): Promise<RecordBatch> {
|
||||
const buffer = await this.inner.rerankHybrid(
|
||||
query,
|
||||
await fromRecordBatchToBuffer(vecResults),
|
||||
await fromRecordBatchToBuffer(ftsResults),
|
||||
);
|
||||
const recordBatch = await fromBufferToRecordBatch(buffer);
|
||||
|
||||
return recordBatch as RecordBatch;
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-arm64",
|
||||
"version": "0.14.1-beta.1",
|
||||
"version": "0.14.1",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.darwin-arm64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-x64",
|
||||
"version": "0.14.1-beta.1",
|
||||
"version": "0.14.1",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.darwin-x64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||
"version": "0.14.1-beta.1",
|
||||
"version": "0.14.1",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||
"version": "0.14.1-beta.1",
|
||||
"version": "0.14.1",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||
"version": "0.14.1-beta.1",
|
||||
"version": "0.14.1",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||
"version": "0.14.1-beta.1",
|
||||
"version": "0.14.1",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||
"version": "0.14.1-beta.1",
|
||||
"version": "0.14.1",
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||
"version": "0.14.1-beta.1",
|
||||
"version": "0.14.1",
|
||||
"os": ["win32"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.win32-x64-msvc.node",
|
||||
|
||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.14.0",
|
||||
"version": "0.14.1",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.14.0",
|
||||
"version": "0.14.1",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
"ann"
|
||||
],
|
||||
"private": false,
|
||||
"version": "0.14.1-beta.1",
|
||||
"version": "0.14.1",
|
||||
"main": "dist/index.js",
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
|
||||
@@ -96,11 +96,45 @@ impl Index {
|
||||
}
|
||||
|
||||
#[napi(factory)]
|
||||
pub fn fts(with_position: Option<bool>) -> Self {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn fts(
|
||||
with_position: Option<bool>,
|
||||
base_tokenizer: Option<String>,
|
||||
language: Option<String>,
|
||||
max_token_length: Option<u32>,
|
||||
lower_case: Option<bool>,
|
||||
stem: Option<bool>,
|
||||
remove_stop_words: Option<bool>,
|
||||
ascii_folding: Option<bool>,
|
||||
) -> Self {
|
||||
let mut opts = FtsIndexBuilder::default();
|
||||
let mut tokenizer_configs = opts.tokenizer_configs.clone();
|
||||
if let Some(with_position) = with_position {
|
||||
opts = opts.with_position(with_position);
|
||||
}
|
||||
if let Some(base_tokenizer) = base_tokenizer {
|
||||
tokenizer_configs = tokenizer_configs.base_tokenizer(base_tokenizer);
|
||||
}
|
||||
if let Some(language) = language {
|
||||
tokenizer_configs = tokenizer_configs.language(&language).unwrap();
|
||||
}
|
||||
if let Some(max_token_length) = max_token_length {
|
||||
tokenizer_configs = tokenizer_configs.max_token_length(Some(max_token_length as usize));
|
||||
}
|
||||
if let Some(lower_case) = lower_case {
|
||||
tokenizer_configs = tokenizer_configs.lower_case(lower_case);
|
||||
}
|
||||
if let Some(stem) = stem {
|
||||
tokenizer_configs = tokenizer_configs.stem(stem);
|
||||
}
|
||||
if let Some(remove_stop_words) = remove_stop_words {
|
||||
tokenizer_configs = tokenizer_configs.remove_stop_words(remove_stop_words);
|
||||
}
|
||||
if let Some(ascii_folding) = ascii_folding {
|
||||
tokenizer_configs = tokenizer_configs.ascii_folding(ascii_folding);
|
||||
}
|
||||
opts.tokenizer_configs = tokenizer_configs;
|
||||
|
||||
Self {
|
||||
inner: Mutex::new(Some(LanceDbIndex::FTS(opts))),
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ mod iterator;
|
||||
pub mod merge;
|
||||
mod query;
|
||||
pub mod remote;
|
||||
mod rerankers;
|
||||
mod table;
|
||||
mod util;
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use lancedb::index::scalar::FullTextSearchQuery;
|
||||
use lancedb::query::ExecutableQuery;
|
||||
use lancedb::query::Query as LanceDbQuery;
|
||||
@@ -25,6 +27,8 @@ use napi_derive::napi;
|
||||
use crate::error::convert_error;
|
||||
use crate::error::NapiErrorExt;
|
||||
use crate::iterator::RecordBatchIterator;
|
||||
use crate::rerankers::Reranker;
|
||||
use crate::rerankers::RerankerCallbacks;
|
||||
use crate::util::parse_distance_type;
|
||||
|
||||
#[napi]
|
||||
@@ -218,6 +222,14 @@ impl VectorQuery {
|
||||
self.inner = self.inner.clone().with_row_id();
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn rerank(&mut self, callbacks: RerankerCallbacks) {
|
||||
self.inner = self
|
||||
.inner
|
||||
.clone()
|
||||
.rerank(Arc::new(Reranker::new(callbacks)));
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn execute(
|
||||
&self,
|
||||
|
||||
147
nodejs/src/rerankers.rs
Normal file
147
nodejs/src/rerankers.rs
Normal file
@@ -0,0 +1,147 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use arrow_array::RecordBatch;
|
||||
use async_trait::async_trait;
|
||||
use napi::{
|
||||
bindgen_prelude::*,
|
||||
threadsafe_function::{ErrorStrategy, ThreadsafeFunction},
|
||||
};
|
||||
use napi_derive::napi;
|
||||
|
||||
use lancedb::ipc::batches_to_ipc_file;
|
||||
use lancedb::rerankers::Reranker as LanceDBReranker;
|
||||
use lancedb::{error::Error, ipc::ipc_file_to_batches};
|
||||
|
||||
use crate::error::NapiErrorExt;
|
||||
|
||||
/// Reranker implementation that "wraps" a NodeJS Reranker implementation.
|
||||
/// This contains references to the callbacks that can be used to invoke the
|
||||
/// reranking methods on the NodeJS implementation and handles serializing the
|
||||
/// record batches to Arrow IPC buffers.
|
||||
#[napi]
|
||||
pub struct Reranker {
|
||||
/// callback to the Javascript which will call the rerankHybrid method of
|
||||
/// some Reranker implementation
|
||||
rerank_hybrid: ThreadsafeFunction<RerankHybridCallbackArgs, ErrorStrategy::CalleeHandled>,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl Reranker {
|
||||
#[napi]
|
||||
pub fn new(callbacks: RerankerCallbacks) -> Self {
|
||||
let rerank_hybrid = callbacks
|
||||
.rerank_hybrid
|
||||
.create_threadsafe_function(0, move |ctx| Ok(vec![ctx.value]))
|
||||
.unwrap();
|
||||
|
||||
Self { rerank_hybrid }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl lancedb::rerankers::Reranker for Reranker {
|
||||
async fn rerank_hybrid(
|
||||
&self,
|
||||
query: &str,
|
||||
vector_results: RecordBatch,
|
||||
fts_results: RecordBatch,
|
||||
) -> lancedb::error::Result<RecordBatch> {
|
||||
let callback_args = RerankHybridCallbackArgs {
|
||||
query: query.to_string(),
|
||||
vec_results: batches_to_ipc_file(&[vector_results])?,
|
||||
fts_results: batches_to_ipc_file(&[fts_results])?,
|
||||
};
|
||||
let promised_buffer: Promise<Buffer> = self
|
||||
.rerank_hybrid
|
||||
.call_async(Ok(callback_args))
|
||||
.await
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("napi error status={}, reason={}", e.status, e.reason),
|
||||
})?;
|
||||
let buffer = promised_buffer.await.map_err(|e| Error::Runtime {
|
||||
message: format!("napi error status={}, reason={}", e.status, e.reason),
|
||||
})?;
|
||||
let mut reader = ipc_file_to_batches(buffer.to_vec())?;
|
||||
let result = reader.next().ok_or(Error::Runtime {
|
||||
message: "reranker result deserialization failed".to_string(),
|
||||
})??;
|
||||
|
||||
return Ok(result);
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Reranker {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str("NodeJSRerankerWrapper")
|
||||
}
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
pub struct RerankerCallbacks {
|
||||
pub rerank_hybrid: JsFunction,
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
pub struct RerankHybridCallbackArgs {
|
||||
pub query: String,
|
||||
pub vec_results: Vec<u8>,
|
||||
pub fts_results: Vec<u8>,
|
||||
}
|
||||
|
||||
fn buffer_to_record_batch(buffer: Buffer) -> Result<RecordBatch> {
|
||||
let mut reader = ipc_file_to_batches(buffer.to_vec()).default_error()?;
|
||||
reader
|
||||
.next()
|
||||
.ok_or(Error::InvalidInput {
|
||||
message: "expected buffer containing record batch".to_string(),
|
||||
})
|
||||
.default_error()?
|
||||
.map_err(Error::from)
|
||||
.default_error()
|
||||
}
|
||||
|
||||
/// Wrapper around rust RRFReranker
|
||||
#[napi]
|
||||
pub struct RRFReranker {
|
||||
inner: lancedb::rerankers::rrf::RRFReranker,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl RRFReranker {
|
||||
#[napi]
|
||||
pub async fn try_new(k: &[f32]) -> Result<Self> {
|
||||
let k = k
|
||||
.first()
|
||||
.copied()
|
||||
.ok_or(Error::InvalidInput {
|
||||
message: "must supply RRF Reranker constructor arg 'k'".to_string(),
|
||||
})
|
||||
.default_error()?;
|
||||
|
||||
Ok(Self {
|
||||
inner: lancedb::rerankers::rrf::RRFReranker::new(k),
|
||||
})
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub async fn rerank_hybrid(
|
||||
&self,
|
||||
query: String,
|
||||
vec_results: Buffer,
|
||||
fts_results: Buffer,
|
||||
) -> Result<Buffer> {
|
||||
let vec_results = buffer_to_record_batch(vec_results)?;
|
||||
let fts_results = buffer_to_record_batch(fts_results)?;
|
||||
|
||||
let result = self
|
||||
.inner
|
||||
.rerank_hybrid(&query, vec_results, fts_results)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let result_buff = batches_to_ipc_file(&[result]).default_error()?;
|
||||
|
||||
Ok(Buffer::from(result_buff.as_ref()))
|
||||
}
|
||||
}
|
||||
@@ -5,8 +5,9 @@ pub fn parse_distance_type(distance_type: impl AsRef<str>) -> napi::Result<Dista
|
||||
"l2" => Ok(DistanceType::L2),
|
||||
"cosine" => Ok(DistanceType::Cosine),
|
||||
"dot" => Ok(DistanceType::Dot),
|
||||
"hamming" => Ok(DistanceType::Hamming),
|
||||
_ => Err(napi::Error::from_reason(format!(
|
||||
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
|
||||
"Invalid distance type '{}'. Must be one of l2, cosine, dot, or hamming",
|
||||
distance_type.as_ref()
|
||||
))),
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.17.1-beta.2"
|
||||
current_version = "0.17.2-beta.1"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.17.1-beta.2"
|
||||
version = "0.17.2-beta.1"
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
license.workspace = true
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
[project]
|
||||
name = "lancedb"
|
||||
# version in Cargo.toml
|
||||
dynamic = ["version"]
|
||||
dependencies = [
|
||||
"deprecation",
|
||||
"pylance==0.20.0",
|
||||
"pylance==0.21.1b1",
|
||||
"tqdm>=4.27.0",
|
||||
"pydantic>=1.10",
|
||||
"packaging",
|
||||
|
||||
@@ -70,7 +70,7 @@ def connect(
|
||||
default configuration is used.
|
||||
storage_options: dict, optional
|
||||
Additional options for the storage backend. See available options at
|
||||
https://lancedb.github.io/lancedb/guides/storage/
|
||||
<https://lancedb.github.io/lancedb/guides/storage/>
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -82,11 +82,13 @@ def connect(
|
||||
|
||||
For object storage, use a URI prefix:
|
||||
|
||||
>>> db = lancedb.connect("s3://my-bucket/lancedb")
|
||||
>>> db = lancedb.connect("s3://my-bucket/lancedb",
|
||||
... storage_options={"aws_access_key_id": "***"})
|
||||
|
||||
Connect to LanceDB cloud:
|
||||
|
||||
>>> db = lancedb.connect("db://my_database", api_key="ldb_...")
|
||||
>>> db = lancedb.connect("db://my_database", api_key="ldb_...",
|
||||
... client_config={"retry_config": {"retries": 5}})
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -164,7 +166,7 @@ async def connect_async(
|
||||
default configuration is used.
|
||||
storage_options: dict, optional
|
||||
Additional options for the storage backend. See available options at
|
||||
https://lancedb.github.io/lancedb/guides/storage/
|
||||
<https://lancedb.github.io/lancedb/guides/storage/>
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
@@ -2,19 +2,8 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
class Index:
|
||||
@staticmethod
|
||||
def ivf_pq(
|
||||
distance_type: Optional[str],
|
||||
num_partitions: Optional[int],
|
||||
num_sub_vectors: Optional[int],
|
||||
max_iterations: Optional[int],
|
||||
sample_rate: Optional[int],
|
||||
) -> Index: ...
|
||||
@staticmethod
|
||||
def btree() -> Index: ...
|
||||
|
||||
class Connection(object):
|
||||
uri: str
|
||||
async def table_names(
|
||||
self, start_after: Optional[str], limit: Optional[int]
|
||||
) -> list[str]: ...
|
||||
@@ -46,11 +35,9 @@ class Table:
|
||||
async def add(self, data: pa.RecordBatchReader, mode: str) -> None: ...
|
||||
async def update(self, updates: Dict[str, str], where: Optional[str]) -> None: ...
|
||||
async def count_rows(self, filter: Optional[str]) -> int: ...
|
||||
async def create_index(
|
||||
self, column: str, config: Optional[Index], replace: Optional[bool]
|
||||
): ...
|
||||
async def create_index(self, column: str, config, replace: Optional[bool]): ...
|
||||
async def version(self) -> int: ...
|
||||
async def checkout(self, version): ...
|
||||
async def checkout(self, version: int): ...
|
||||
async def checkout_latest(self): ...
|
||||
async def restore(self): ...
|
||||
async def list_indices(self) -> List[IndexConfig]: ...
|
||||
|
||||
@@ -23,3 +23,6 @@ class BackgroundEventLoop:
|
||||
|
||||
def run(self, future):
|
||||
return asyncio.run_coroutine_threadsafe(future, self.loop).result()
|
||||
|
||||
|
||||
LOOP = BackgroundEventLoop()
|
||||
|
||||
@@ -17,12 +17,13 @@ from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Union
|
||||
|
||||
from overrides import EnforceOverrides, override
|
||||
from lancedb.embeddings.registry import EmbeddingFunctionRegistry
|
||||
from overrides import EnforceOverrides, override # type: ignore
|
||||
|
||||
from lancedb.common import data_to_reader, sanitize_uri, validate_schema
|
||||
from lancedb.background_loop import BackgroundEventLoop
|
||||
from lancedb.background_loop import LOOP
|
||||
|
||||
from ._lancedb import connect as lancedb_connect
|
||||
from ._lancedb import connect as lancedb_connect # type: ignore
|
||||
from .table import (
|
||||
AsyncTable,
|
||||
LanceTable,
|
||||
@@ -43,8 +44,6 @@ if TYPE_CHECKING:
|
||||
from .common import DATA, URI
|
||||
from .embeddings import EmbeddingFunctionConfig
|
||||
|
||||
LOOP = BackgroundEventLoop()
|
||||
|
||||
|
||||
class DBConnection(EnforceOverrides):
|
||||
"""An active LanceDB connection interface."""
|
||||
@@ -82,6 +81,10 @@ class DBConnection(EnforceOverrides):
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
*,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
data_storage_version: Optional[str] = None,
|
||||
enable_v2_manifest_paths: Optional[bool] = None,
|
||||
) -> Table:
|
||||
"""Create a [Table][lancedb.table.Table] in the database.
|
||||
|
||||
@@ -119,6 +122,24 @@ class DBConnection(EnforceOverrides):
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
storage_options: dict, optional
|
||||
Additional options for the storage backend. Options already set on the
|
||||
connection will be inherited by the table, but can be overridden here.
|
||||
See available options at
|
||||
<https://lancedb.github.io/lancedb/guides/storage/>
|
||||
data_storage_version: optional, str, default "stable"
|
||||
The version of the data storage format to use. Newer versions are more
|
||||
efficient but require newer versions of lance to read. The default is
|
||||
"stable" which will use the legacy v2 version. See the user guide
|
||||
for more details.
|
||||
enable_v2_manifest_paths: bool, optional, default False
|
||||
Use the new V2 manifest paths. These paths provide more efficient
|
||||
opening of datasets with many versions on object stores. WARNING:
|
||||
turning this on will make the dataset unreadable for older versions
|
||||
of LanceDB (prior to 0.13.0). To migrate an existing dataset, instead
|
||||
use the
|
||||
[Table.migrate_manifest_paths_v2][lancedb.table.Table.migrate_v2_manifest_paths]
|
||||
method.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -140,7 +161,7 @@ class DBConnection(EnforceOverrides):
|
||||
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
|
||||
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
|
||||
>>> db.create_table("my_table", data)
|
||||
LanceTable(connection=..., name="my_table")
|
||||
LanceTable(name='my_table', version=1, ...)
|
||||
>>> db["my_table"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
@@ -161,7 +182,7 @@ class DBConnection(EnforceOverrides):
|
||||
... "long": [-122.7, -74.1]
|
||||
... })
|
||||
>>> db.create_table("table2", data)
|
||||
LanceTable(connection=..., name="table2")
|
||||
LanceTable(name='table2', version=1, ...)
|
||||
>>> db["table2"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
@@ -184,7 +205,7 @@ class DBConnection(EnforceOverrides):
|
||||
... pa.field("long", pa.float32())
|
||||
... ])
|
||||
>>> db.create_table("table3", data, schema = custom_schema)
|
||||
LanceTable(connection=..., name="table3")
|
||||
LanceTable(name='table3', version=1, ...)
|
||||
>>> db["table3"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
@@ -218,7 +239,7 @@ class DBConnection(EnforceOverrides):
|
||||
... pa.field("price", pa.float32()),
|
||||
... ])
|
||||
>>> db.create_table("table4", make_batches(), schema=schema)
|
||||
LanceTable(connection=..., name="table4")
|
||||
LanceTable(name='table4', version=1, ...)
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -226,7 +247,13 @@ class DBConnection(EnforceOverrides):
|
||||
def __getitem__(self, name: str) -> LanceTable:
|
||||
return self.open_table(name)
|
||||
|
||||
def open_table(self, name: str, *, index_cache_size: Optional[int] = None) -> Table:
|
||||
def open_table(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
) -> Table:
|
||||
"""Open a Lance Table in the database.
|
||||
|
||||
Parameters
|
||||
@@ -243,6 +270,11 @@ class DBConnection(EnforceOverrides):
|
||||
This cache applies to the entire opened table, across all indices.
|
||||
Setting this value higher will increase performance on larger datasets
|
||||
at the expense of more RAM
|
||||
storage_options: dict, optional
|
||||
Additional options for the storage backend. Options already set on the
|
||||
connection will be inherited by the table, but can be overridden here.
|
||||
See available options at
|
||||
<https://lancedb.github.io/lancedb/guides/storage/>
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -309,15 +341,15 @@ class LanceDBConnection(DBConnection):
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2},
|
||||
... {"vector": [0.5, 1.3], "b": 4}])
|
||||
LanceTable(connection=..., name="my_table")
|
||||
LanceTable(name='my_table', version=1, ...)
|
||||
>>> db.create_table("another_table", data=[{"vector": [0.4, 0.4], "b": 6}])
|
||||
LanceTable(connection=..., name="another_table")
|
||||
LanceTable(name='another_table', version=1, ...)
|
||||
>>> sorted(db.table_names())
|
||||
['another_table', 'my_table']
|
||||
>>> len(db)
|
||||
2
|
||||
>>> db["my_table"]
|
||||
LanceTable(connection=..., name="my_table")
|
||||
LanceTable(name='my_table', version=1, ...)
|
||||
>>> "my_table" in db
|
||||
True
|
||||
>>> db.drop_table("my_table")
|
||||
@@ -363,7 +395,7 @@ class LanceDBConnection(DBConnection):
|
||||
self._conn = AsyncConnection(LOOP.run(do_connect()))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
val = f"{self.__class__.__name__}({self._uri}"
|
||||
val = f"{self.__class__.__name__}(uri={self._uri!r}"
|
||||
if self.read_consistency_interval is not None:
|
||||
val += f", read_consistency_interval={repr(self.read_consistency_interval)}"
|
||||
val += ")"
|
||||
@@ -403,6 +435,10 @@ class LanceDBConnection(DBConnection):
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
*,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
data_storage_version: Optional[str] = None,
|
||||
enable_v2_manifest_paths: Optional[bool] = None,
|
||||
) -> LanceTable:
|
||||
"""Create a table in the database.
|
||||
|
||||
@@ -424,12 +460,19 @@ class LanceDBConnection(DBConnection):
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
embedding_functions=embedding_functions,
|
||||
storage_options=storage_options,
|
||||
data_storage_version=data_storage_version,
|
||||
enable_v2_manifest_paths=enable_v2_manifest_paths,
|
||||
)
|
||||
return tbl
|
||||
|
||||
@override
|
||||
def open_table(
|
||||
self, name: str, *, index_cache_size: Optional[int] = None
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
) -> LanceTable:
|
||||
"""Open a table in the database.
|
||||
|
||||
@@ -442,7 +485,12 @@ class LanceDBConnection(DBConnection):
|
||||
-------
|
||||
A LanceTable object representing the table.
|
||||
"""
|
||||
return LanceTable.open(self, name, index_cache_size=index_cache_size)
|
||||
return LanceTable.open(
|
||||
self,
|
||||
name,
|
||||
storage_options=storage_options,
|
||||
index_cache_size=index_cache_size,
|
||||
)
|
||||
|
||||
@override
|
||||
def drop_table(self, name: str, ignore_missing: bool = False):
|
||||
@@ -455,13 +503,7 @@ class LanceDBConnection(DBConnection):
|
||||
ignore_missing: bool, default False
|
||||
If True, ignore if the table does not exist.
|
||||
"""
|
||||
try:
|
||||
LOOP.run(self._conn.drop_table(name))
|
||||
except ValueError as e:
|
||||
if not ignore_missing:
|
||||
raise e
|
||||
if f"Table '{name}' was not found" not in str(e):
|
||||
raise e
|
||||
LOOP.run(self._conn.drop_table(name, ignore_missing=ignore_missing))
|
||||
|
||||
@override
|
||||
def drop_database(self):
|
||||
@@ -524,6 +566,10 @@ class AsyncConnection(object):
|
||||
Any attempt to use the connection after it is closed will result in an error."""
|
||||
self._inner.close()
|
||||
|
||||
@property
|
||||
def uri(self) -> str:
|
||||
return self._inner.uri
|
||||
|
||||
async def table_names(
|
||||
self, *, start_after: Optional[str] = None, limit: Optional[int] = None
|
||||
) -> Iterable[str]:
|
||||
@@ -557,6 +603,7 @@ class AsyncConnection(object):
|
||||
fill_value: Optional[float] = None,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
*,
|
||||
embedding_functions: List[EmbeddingFunctionConfig] = None,
|
||||
data_storage_version: Optional[str] = None,
|
||||
use_legacy_format: Optional[bool] = None,
|
||||
enable_v2_manifest_paths: Optional[bool] = None,
|
||||
@@ -601,7 +648,7 @@ class AsyncConnection(object):
|
||||
Additional options for the storage backend. Options already set on the
|
||||
connection will be inherited by the table, but can be overridden here.
|
||||
See available options at
|
||||
https://lancedb.github.io/lancedb/guides/storage/
|
||||
<https://lancedb.github.io/lancedb/guides/storage/>
|
||||
data_storage_version: optional, str, default "stable"
|
||||
The version of the data storage format to use. Newer versions are more
|
||||
efficient but require newer versions of lance to read. The default is
|
||||
@@ -730,6 +777,17 @@ class AsyncConnection(object):
|
||||
"""
|
||||
metadata = None
|
||||
|
||||
if embedding_functions is not None:
|
||||
# If we passed in embedding functions explicitly
|
||||
# then we'll override any schema metadata that
|
||||
# may was implicitly specified by the LanceModel schema
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
metadata = registry.get_table_metadata(embedding_functions)
|
||||
|
||||
data, schema = sanitize_create_table(
|
||||
data, schema, metadata, on_bad_vectors, fill_value
|
||||
)
|
||||
|
||||
# Defining defaults here and not in function prototype. In the future
|
||||
# these defaults will move into rust so better to keep them as None.
|
||||
if on_bad_vectors is None:
|
||||
@@ -791,7 +849,7 @@ class AsyncConnection(object):
|
||||
Additional options for the storage backend. Options already set on the
|
||||
connection will be inherited by the table, but can be overridden here.
|
||||
See available options at
|
||||
https://lancedb.github.io/lancedb/guides/storage/
|
||||
<https://lancedb.github.io/lancedb/guides/storage/>
|
||||
index_cache_size: int, default 256
|
||||
Set the size of the index cache, specified as a number of entries
|
||||
|
||||
@@ -822,15 +880,23 @@ class AsyncConnection(object):
|
||||
"""
|
||||
await self._inner.rename_table(old_name, new_name)
|
||||
|
||||
async def drop_table(self, name: str):
|
||||
async def drop_table(self, name: str, *, ignore_missing: bool = False):
|
||||
"""Drop a table from the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
ignore_missing: bool, default False
|
||||
If True, ignore if the table does not exist.
|
||||
"""
|
||||
await self._inner.drop_table(name)
|
||||
try:
|
||||
await self._inner.drop_table(name)
|
||||
except ValueError as e:
|
||||
if not ignore_missing:
|
||||
raise e
|
||||
if f"Table '{name}' was not found" not in str(e):
|
||||
raise e
|
||||
|
||||
async def drop_database(self):
|
||||
"""
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Optional
|
||||
|
||||
from ._lancedb import (
|
||||
Index as LanceDbIndex,
|
||||
)
|
||||
from ._lancedb import (
|
||||
IndexConfig,
|
||||
)
|
||||
@@ -29,6 +27,7 @@ lang_mapping = {
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BTree:
|
||||
"""Describes a btree index configuration
|
||||
|
||||
@@ -50,10 +49,10 @@ class BTree:
|
||||
the block size may be added in the future.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._inner = LanceDbIndex.btree()
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Bitmap:
|
||||
"""Describe a Bitmap index configuration.
|
||||
|
||||
@@ -73,10 +72,10 @@ class Bitmap:
|
||||
requires 128 / 8 * 1Bi bytes on disk.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._inner = LanceDbIndex.bitmap()
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class LabelList:
|
||||
"""Describe a LabelList index configuration.
|
||||
|
||||
@@ -87,41 +86,57 @@ class LabelList:
|
||||
For example, it works with `tags`, `categories`, `keywords`, etc.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._inner = LanceDbIndex.label_list()
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class FTS:
|
||||
"""Describe a FTS index configuration.
|
||||
|
||||
`FTS` is a full-text search index that can be used on `String` columns
|
||||
|
||||
For example, it works with `title`, `description`, `content`, etc.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
with_position : bool, default True
|
||||
Whether to store the position of the token in the document. Setting this
|
||||
to False can reduce the size of the index and improve indexing speed,
|
||||
but it will disable support for phrase queries.
|
||||
base_tokenizer : str, default "simple"
|
||||
The base tokenizer to use for tokenization. Options are:
|
||||
- "simple": Splits text by whitespace and punctuation.
|
||||
- "whitespace": Split text by whitespace, but not punctuation.
|
||||
- "raw": No tokenization. The entire text is treated as a single token.
|
||||
language : str, default "English"
|
||||
The language to use for tokenization.
|
||||
max_token_length : int, default 40
|
||||
The maximum token length to index. Tokens longer than this length will be
|
||||
ignored.
|
||||
lower_case : bool, default True
|
||||
Whether to convert the token to lower case. This makes queries case-insensitive.
|
||||
stem : bool, default False
|
||||
Whether to stem the token. Stemming reduces words to their root form.
|
||||
For example, in English "running" and "runs" would both be reduced to "run".
|
||||
remove_stop_words : bool, default False
|
||||
Whether to remove stop words. Stop words are common words that are often
|
||||
removed from text before indexing. For example, in English "the" and "and".
|
||||
ascii_folding : bool, default False
|
||||
Whether to fold ASCII characters. This converts accented characters to
|
||||
their ASCII equivalent. For example, "café" would be converted to "cafe".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
with_position: bool = True,
|
||||
base_tokenizer: str = "simple",
|
||||
language: str = "English",
|
||||
max_token_length: Optional[int] = 40,
|
||||
lower_case: bool = True,
|
||||
stem: bool = False,
|
||||
remove_stop_words: bool = False,
|
||||
ascii_folding: bool = False,
|
||||
):
|
||||
self._inner = LanceDbIndex.fts(
|
||||
with_position=with_position,
|
||||
base_tokenizer=base_tokenizer,
|
||||
language=language,
|
||||
max_token_length=max_token_length,
|
||||
lower_case=lower_case,
|
||||
stem=stem,
|
||||
remove_stop_words=remove_stop_words,
|
||||
ascii_folding=ascii_folding,
|
||||
)
|
||||
with_position: bool = True
|
||||
base_tokenizer: Literal["simple", "raw", "whitespace"] = "simple"
|
||||
language: str = "English"
|
||||
max_token_length: Optional[int] = 40
|
||||
lower_case: bool = True
|
||||
stem: bool = False
|
||||
remove_stop_words: bool = False
|
||||
ascii_folding: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class HnswPq:
|
||||
"""Describe a HNSW-PQ index configuration.
|
||||
|
||||
@@ -232,30 +247,17 @@ class HnswPq:
|
||||
search phase.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
distance_type: Optional[str] = None,
|
||||
num_partitions: Optional[int] = None,
|
||||
num_sub_vectors: Optional[int] = None,
|
||||
num_bits: Optional[int] = None,
|
||||
max_iterations: Optional[int] = None,
|
||||
sample_rate: Optional[int] = None,
|
||||
m: Optional[int] = None,
|
||||
ef_construction: Optional[int] = None,
|
||||
):
|
||||
self._inner = LanceDbIndex.hnsw_pq(
|
||||
distance_type=distance_type,
|
||||
num_partitions=num_partitions,
|
||||
num_sub_vectors=num_sub_vectors,
|
||||
num_bits=num_bits,
|
||||
max_iterations=max_iterations,
|
||||
sample_rate=sample_rate,
|
||||
m=m,
|
||||
ef_construction=ef_construction,
|
||||
)
|
||||
distance_type: Literal["l2", "cosine", "dot"] = "l2"
|
||||
num_partitions: Optional[int] = None
|
||||
num_sub_vectors: Optional[int] = None
|
||||
num_bits: int = 8
|
||||
max_iterations: int = 50
|
||||
sample_rate: int = 256
|
||||
m: int = 20
|
||||
ef_construction: int = 300
|
||||
|
||||
|
||||
@dataclass
|
||||
class HnswSq:
|
||||
"""Describe a HNSW-SQ index configuration.
|
||||
|
||||
@@ -345,26 +347,106 @@ class HnswSq:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
distance_type: Optional[str] = None,
|
||||
num_partitions: Optional[int] = None,
|
||||
max_iterations: Optional[int] = None,
|
||||
sample_rate: Optional[int] = None,
|
||||
m: Optional[int] = None,
|
||||
ef_construction: Optional[int] = None,
|
||||
):
|
||||
self._inner = LanceDbIndex.hnsw_sq(
|
||||
distance_type=distance_type,
|
||||
num_partitions=num_partitions,
|
||||
max_iterations=max_iterations,
|
||||
sample_rate=sample_rate,
|
||||
m=m,
|
||||
ef_construction=ef_construction,
|
||||
)
|
||||
distance_type: Literal["l2", "cosine", "dot"] = "l2"
|
||||
num_partitions: Optional[int] = None
|
||||
max_iterations: int = 50
|
||||
sample_rate: int = 256
|
||||
m: int = 20
|
||||
ef_construction: int = 300
|
||||
|
||||
|
||||
@dataclass
|
||||
class IvfFlat:
|
||||
"""Describes an IVF Flat Index
|
||||
|
||||
This index stores raw vectors.
|
||||
These vectors are grouped into partitions of similar vectors.
|
||||
Each partition keeps track of a centroid which is
|
||||
the average value of all vectors in the group.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
distance_type: str, default "L2"
|
||||
The distance metric used to train the index
|
||||
|
||||
This is used when training the index to calculate the IVF partitions
|
||||
(vectors are grouped in partitions with similar vectors according to this
|
||||
distance type) and to calculate a subvector's code during quantization.
|
||||
|
||||
The distance type used to train an index MUST match the distance type used
|
||||
to search the index. Failure to do so will yield inaccurate results.
|
||||
|
||||
The following distance types are available:
|
||||
|
||||
"l2" - Euclidean distance. This is a very common distance metric that
|
||||
accounts for both magnitude and direction when determining the distance
|
||||
between vectors. L2 distance has a range of [0, ∞).
|
||||
|
||||
"cosine" - Cosine distance. Cosine distance is a distance metric
|
||||
calculated from the cosine similarity between two vectors. Cosine
|
||||
similarity is a measure of similarity between two non-zero vectors of an
|
||||
inner product space. It is defined to equal the cosine of the angle
|
||||
between them. Unlike L2, the cosine distance is not affected by the
|
||||
magnitude of the vectors. Cosine distance has a range of [0, 2].
|
||||
|
||||
Note: the cosine distance is undefined when one (or both) of the vectors
|
||||
are all zeros (there is no direction). These vectors are invalid and may
|
||||
never be returned from a vector search.
|
||||
|
||||
"dot" - Dot product. Dot distance is the dot product of two vectors. Dot
|
||||
distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their
|
||||
L2 norm is 1), then dot distance is equivalent to the cosine distance.
|
||||
|
||||
"hamming" - Hamming distance. Hamming distance is a distance metric
|
||||
calculated as the number of positions at which the corresponding bits are
|
||||
different. Hamming distance has a range of [0, vector dimension].
|
||||
|
||||
num_partitions: int, default sqrt(num_rows)
|
||||
The number of IVF partitions to create.
|
||||
|
||||
This value should generally scale with the number of rows in the dataset.
|
||||
By default the number of partitions is the square root of the number of
|
||||
rows.
|
||||
|
||||
If this value is too large then the first part of the search (picking the
|
||||
right partition) will be slow. If this value is too small then the second
|
||||
part of the search (searching within a partition) will be slow.
|
||||
|
||||
max_iterations: int, default 50
|
||||
Max iteration to train kmeans.
|
||||
|
||||
When training an IVF PQ index we use kmeans to calculate the partitions.
|
||||
This parameter controls how many iterations of kmeans to run.
|
||||
|
||||
Increasing this might improve the quality of the index but in most cases
|
||||
these extra iterations have diminishing returns.
|
||||
|
||||
The default value is 50.
|
||||
sample_rate: int, default 256
|
||||
The rate used to calculate the number of training vectors for kmeans.
|
||||
|
||||
When an IVF PQ index is trained, we need to calculate partitions. These
|
||||
are groups of vectors that are similar to each other. To do this we use an
|
||||
algorithm called kmeans.
|
||||
|
||||
Running kmeans on a large dataset can be slow. To speed this up we run
|
||||
kmeans on a random sample of the data. This parameter controls the size of
|
||||
the sample. The total number of vectors used to train the index is
|
||||
`sample_rate * num_partitions`.
|
||||
|
||||
Increasing this value might improve the quality of the index but in most
|
||||
cases the default should be sufficient.
|
||||
|
||||
The default value is 256.
|
||||
"""
|
||||
|
||||
distance_type: Literal["l2", "cosine", "dot", "hamming"] = "l2"
|
||||
num_partitions: Optional[int] = None
|
||||
max_iterations: int = 50
|
||||
sample_rate: int = 256
|
||||
|
||||
|
||||
@dataclass
|
||||
class IvfPq:
|
||||
"""Describes an IVF PQ Index
|
||||
|
||||
@@ -387,120 +469,113 @@ class IvfPq:
|
||||
|
||||
Note that training an IVF PQ index on a large dataset is a slow operation and
|
||||
currently is also a memory intensive operation.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
distance_type: str, default "L2"
|
||||
The distance metric used to train the index
|
||||
|
||||
This is used when training the index to calculate the IVF partitions
|
||||
(vectors are grouped in partitions with similar vectors according to this
|
||||
distance type) and to calculate a subvector's code during quantization.
|
||||
|
||||
The distance type used to train an index MUST match the distance type used
|
||||
to search the index. Failure to do so will yield inaccurate results.
|
||||
|
||||
The following distance types are available:
|
||||
|
||||
"l2" - Euclidean distance. This is a very common distance metric that
|
||||
accounts for both magnitude and direction when determining the distance
|
||||
between vectors. L2 distance has a range of [0, ∞).
|
||||
|
||||
"cosine" - Cosine distance. Cosine distance is a distance metric
|
||||
calculated from the cosine similarity between two vectors. Cosine
|
||||
similarity is a measure of similarity between two non-zero vectors of an
|
||||
inner product space. It is defined to equal the cosine of the angle
|
||||
between them. Unlike L2, the cosine distance is not affected by the
|
||||
magnitude of the vectors. Cosine distance has a range of [0, 2].
|
||||
|
||||
Note: the cosine distance is undefined when one (or both) of the vectors
|
||||
are all zeros (there is no direction). These vectors are invalid and may
|
||||
never be returned from a vector search.
|
||||
|
||||
"dot" - Dot product. Dot distance is the dot product of two vectors. Dot
|
||||
distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their
|
||||
L2 norm is 1), then dot distance is equivalent to the cosine distance.
|
||||
num_partitions: int, default sqrt(num_rows)
|
||||
The number of IVF partitions to create.
|
||||
|
||||
This value should generally scale with the number of rows in the dataset.
|
||||
By default the number of partitions is the square root of the number of
|
||||
rows.
|
||||
|
||||
If this value is too large then the first part of the search (picking the
|
||||
right partition) will be slow. If this value is too small then the second
|
||||
part of the search (searching within a partition) will be slow.
|
||||
num_sub_vectors: int, default is vector dimension / 16
|
||||
Number of sub-vectors of PQ.
|
||||
|
||||
This value controls how much the vector is compressed during the
|
||||
quantization step. The more sub vectors there are the less the vector is
|
||||
compressed. The default is the dimension of the vector divided by 16. If
|
||||
the dimension is not evenly divisible by 16 we use the dimension divded by
|
||||
8.
|
||||
|
||||
The above two cases are highly preferred. Having 8 or 16 values per
|
||||
subvector allows us to use efficient SIMD instructions.
|
||||
|
||||
If the dimension is not visible by 8 then we use 1 subvector. This is not
|
||||
ideal and will likely result in poor performance.
|
||||
num_bits: int, default 8
|
||||
Number of bits to encode each sub-vector.
|
||||
|
||||
This value controls how much the sub-vectors are compressed. The more bits
|
||||
the more accurate the index but the slower search. The default is 8
|
||||
bits. Only 4 and 8 are supported.
|
||||
max_iterations: int, default 50
|
||||
Max iteration to train kmeans.
|
||||
|
||||
When training an IVF PQ index we use kmeans to calculate the partitions.
|
||||
This parameter controls how many iterations of kmeans to run.
|
||||
|
||||
Increasing this might improve the quality of the index but in most cases
|
||||
these extra iterations have diminishing returns.
|
||||
|
||||
The default value is 50.
|
||||
sample_rate: int, default 256
|
||||
The rate used to calculate the number of training vectors for kmeans.
|
||||
|
||||
When an IVF PQ index is trained, we need to calculate partitions. These
|
||||
are groups of vectors that are similar to each other. To do this we use an
|
||||
algorithm called kmeans.
|
||||
|
||||
Running kmeans on a large dataset can be slow. To speed this up we run
|
||||
kmeans on a random sample of the data. This parameter controls the size of
|
||||
the sample. The total number of vectors used to train the index is
|
||||
`sample_rate * num_partitions`.
|
||||
|
||||
Increasing this value might improve the quality of the index but in most
|
||||
cases the default should be sufficient.
|
||||
|
||||
The default value is 256.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
distance_type: Optional[str] = None,
|
||||
num_partitions: Optional[int] = None,
|
||||
num_sub_vectors: Optional[int] = None,
|
||||
num_bits: Optional[int] = None,
|
||||
max_iterations: Optional[int] = None,
|
||||
sample_rate: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Create an IVF PQ index config
|
||||
|
||||
Parameters
|
||||
----------
|
||||
distance_type: str, default "L2"
|
||||
The distance metric used to train the index
|
||||
|
||||
This is used when training the index to calculate the IVF partitions
|
||||
(vectors are grouped in partitions with similar vectors according to this
|
||||
distance type) and to calculate a subvector's code during quantization.
|
||||
|
||||
The distance type used to train an index MUST match the distance type used
|
||||
to search the index. Failure to do so will yield inaccurate results.
|
||||
|
||||
The following distance types are available:
|
||||
|
||||
"l2" - Euclidean distance. This is a very common distance metric that
|
||||
accounts for both magnitude and direction when determining the distance
|
||||
between vectors. L2 distance has a range of [0, ∞).
|
||||
|
||||
"cosine" - Cosine distance. Cosine distance is a distance metric
|
||||
calculated from the cosine similarity between two vectors. Cosine
|
||||
similarity is a measure of similarity between two non-zero vectors of an
|
||||
inner product space. It is defined to equal the cosine of the angle
|
||||
between them. Unlike L2, the cosine distance is not affected by the
|
||||
magnitude of the vectors. Cosine distance has a range of [0, 2].
|
||||
|
||||
Note: the cosine distance is undefined when one (or both) of the vectors
|
||||
are all zeros (there is no direction). These vectors are invalid and may
|
||||
never be returned from a vector search.
|
||||
|
||||
"dot" - Dot product. Dot distance is the dot product of two vectors. Dot
|
||||
distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their
|
||||
L2 norm is 1), then dot distance is equivalent to the cosine distance.
|
||||
num_partitions: int, default sqrt(num_rows)
|
||||
The number of IVF partitions to create.
|
||||
|
||||
This value should generally scale with the number of rows in the dataset.
|
||||
By default the number of partitions is the square root of the number of
|
||||
rows.
|
||||
|
||||
If this value is too large then the first part of the search (picking the
|
||||
right partition) will be slow. If this value is too small then the second
|
||||
part of the search (searching within a partition) will be slow.
|
||||
num_sub_vectors: int, default is vector dimension / 16
|
||||
Number of sub-vectors of PQ.
|
||||
|
||||
This value controls how much the vector is compressed during the
|
||||
quantization step. The more sub vectors there are the less the vector is
|
||||
compressed. The default is the dimension of the vector divided by 16. If
|
||||
the dimension is not evenly divisible by 16 we use the dimension divded by
|
||||
8.
|
||||
|
||||
The above two cases are highly preferred. Having 8 or 16 values per
|
||||
subvector allows us to use efficient SIMD instructions.
|
||||
|
||||
If the dimension is not visible by 8 then we use 1 subvector. This is not
|
||||
ideal and will likely result in poor performance.
|
||||
num_bits: int, default 8
|
||||
Number of bits to encode each sub-vector.
|
||||
|
||||
This value controls how much the sub-vectors are compressed. The more bits
|
||||
the more accurate the index but the slower search. The default is 8
|
||||
bits. Only 4 and 8 are supported.
|
||||
max_iterations: int, default 50
|
||||
Max iteration to train kmeans.
|
||||
|
||||
When training an IVF PQ index we use kmeans to calculate the partitions.
|
||||
This parameter controls how many iterations of kmeans to run.
|
||||
|
||||
Increasing this might improve the quality of the index but in most cases
|
||||
these extra iterations have diminishing returns.
|
||||
|
||||
The default value is 50.
|
||||
sample_rate: int, default 256
|
||||
The rate used to calculate the number of training vectors for kmeans.
|
||||
|
||||
When an IVF PQ index is trained, we need to calculate partitions. These
|
||||
are groups of vectors that are similar to each other. To do this we use an
|
||||
algorithm called kmeans.
|
||||
|
||||
Running kmeans on a large dataset can be slow. To speed this up we run
|
||||
kmeans on a random sample of the data. This parameter controls the size of
|
||||
the sample. The total number of vectors used to train the index is
|
||||
`sample_rate * num_partitions`.
|
||||
|
||||
Increasing this value might improve the quality of the index but in most
|
||||
cases the default should be sufficient.
|
||||
|
||||
The default value is 256.
|
||||
"""
|
||||
if distance_type is not None:
|
||||
distance_type = distance_type.lower()
|
||||
self._inner = LanceDbIndex.ivf_pq(
|
||||
distance_type=distance_type,
|
||||
num_partitions=num_partitions,
|
||||
num_sub_vectors=num_sub_vectors,
|
||||
num_bits=num_bits,
|
||||
max_iterations=max_iterations,
|
||||
sample_rate=sample_rate,
|
||||
)
|
||||
distance_type: Literal["l2", "cosine", "dot"] = "l2"
|
||||
num_partitions: Optional[int] = None
|
||||
num_sub_vectors: Optional[int] = None
|
||||
num_bits: int = 8
|
||||
max_iterations: int = 50
|
||||
sample_rate: int = 256
|
||||
|
||||
|
||||
__all__ = ["BTree", "IvfPq", "IndexConfig"]
|
||||
__all__ = [
|
||||
"BTree",
|
||||
"IvfPq",
|
||||
"IvfFlat",
|
||||
"HnswPq",
|
||||
"HnswSq",
|
||||
"IndexConfig",
|
||||
"FTS",
|
||||
"Bitmap",
|
||||
"LabelList",
|
||||
]
|
||||
|
||||
@@ -126,6 +126,9 @@ class Query(pydantic.BaseModel):
|
||||
|
||||
ef: Optional[int] = None
|
||||
|
||||
# Default is true. Set to false to enforce a brute force search.
|
||||
use_index: bool = True
|
||||
|
||||
|
||||
class LanceQueryBuilder(ABC):
|
||||
"""An abstract query builder. Subclasses are defined for vector search,
|
||||
@@ -253,6 +256,7 @@ class LanceQueryBuilder(ABC):
|
||||
self._vector = None
|
||||
self._text = None
|
||||
self._ef = None
|
||||
self._use_index = True
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.3.1",
|
||||
@@ -511,6 +515,7 @@ class LanceQueryBuilder(ABC):
|
||||
"metric": self._metric,
|
||||
"nprobes": self._nprobes,
|
||||
"refine_factor": self._refine_factor,
|
||||
"use_index": self._use_index,
|
||||
},
|
||||
prefilter=self._prefilter,
|
||||
filter=self._str_query,
|
||||
@@ -729,6 +734,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
offset=self._offset,
|
||||
fast_search=self._fast_search,
|
||||
ef=self._ef,
|
||||
use_index=self._use_index,
|
||||
)
|
||||
result_set = self._table._execute_query(query, batch_size)
|
||||
if self._reranker is not None:
|
||||
@@ -802,6 +808,24 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
self._str_query = query_string if query_string is not None else self._str_query
|
||||
return self
|
||||
|
||||
def bypass_vector_index(self) -> LanceVectorQueryBuilder:
|
||||
"""
|
||||
If this is called then any vector index is skipped
|
||||
|
||||
An exhaustive (flat) search will be performed. The query vector will
|
||||
be compared to every vector in the table. At high scales this can be
|
||||
expensive. However, this is often still useful. For example, skipping
|
||||
the vector index can give you ground truth results which you can use to
|
||||
calculate your recall to select an appropriate value for nprobes.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceVectorQueryBuilder
|
||||
The LanceVectorQueryBuilder object.
|
||||
"""
|
||||
self._use_index = False
|
||||
return self
|
||||
|
||||
|
||||
class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
"""A builder for full text search for LanceDB."""
|
||||
@@ -1108,6 +1132,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._vector_query.refine_factor(self._refine_factor)
|
||||
if self._ef:
|
||||
self._vector_query.ef(self._ef)
|
||||
if not self._use_index:
|
||||
self._vector_query.bypass_vector_index()
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow)
|
||||
@@ -1323,6 +1349,24 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._text = text
|
||||
return self
|
||||
|
||||
def bypass_vector_index(self) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
If this is called then any vector index is skipped
|
||||
|
||||
An exhaustive (flat) search will be performed. The query vector will
|
||||
be compared to every vector in the table. At high scales this can be
|
||||
expensive. However, this is often still useful. For example, skipping
|
||||
the vector index can give you ground truth results which you can use to
|
||||
calculate your recall to select an appropriate value for nprobes.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._use_index = False
|
||||
return self
|
||||
|
||||
|
||||
class AsyncQueryBase(object):
|
||||
def __init__(self, inner: Union[LanceQuery | LanceVectorQuery]):
|
||||
|
||||
@@ -121,7 +121,13 @@ class RemoteDBConnection(DBConnection):
|
||||
return LOOP.run(self._conn.table_names(start_after=page_token, limit=limit))
|
||||
|
||||
@override
|
||||
def open_table(self, name: str, *, index_cache_size: Optional[int] = None) -> Table:
|
||||
def open_table(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
) -> Table:
|
||||
"""Open a Lance Table in the database.
|
||||
|
||||
Parameters
|
||||
|
||||
@@ -1,22 +1,15 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
from functools import cached_property
|
||||
from typing import Dict, Iterable, List, Optional, Union, Literal
|
||||
import warnings
|
||||
|
||||
from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfPq, LabelList
|
||||
from lancedb._lancedb import IndexConfig
|
||||
from lancedb.embeddings.base import EmbeddingFunctionConfig
|
||||
from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfFlat, IvfPq, LabelList
|
||||
from lancedb.remote.db import LOOP
|
||||
import pyarrow as pa
|
||||
|
||||
@@ -25,7 +18,7 @@ from lancedb.merge import LanceMergeInsertBuilder
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
|
||||
from ..query import LanceVectorQueryBuilder, LanceQueryBuilder
|
||||
from ..table import AsyncTable, Query, Table
|
||||
from ..table import AsyncTable, IndexStatistics, Query, Table
|
||||
|
||||
|
||||
class RemoteTable(Table):
|
||||
@@ -62,7 +55,7 @@ class RemoteTable(Table):
|
||||
return LOOP.run(self._table.version())
|
||||
|
||||
@cached_property
|
||||
def embedding_functions(self) -> dict:
|
||||
def embedding_functions(self) -> Dict[str, EmbeddingFunctionConfig]:
|
||||
"""
|
||||
Get the embedding functions for the table
|
||||
|
||||
@@ -88,17 +81,17 @@ class RemoteTable(Table):
|
||||
"""to_pandas() is not yet supported on LanceDB cloud."""
|
||||
return NotImplementedError("to_pandas() is not yet supported on LanceDB cloud.")
|
||||
|
||||
def checkout(self, version):
|
||||
def checkout(self, version: int):
|
||||
return LOOP.run(self._table.checkout(version))
|
||||
|
||||
def checkout_latest(self):
|
||||
return LOOP.run(self._table.checkout_latest())
|
||||
|
||||
def list_indices(self):
|
||||
def list_indices(self) -> Iterable[IndexConfig]:
|
||||
"""List all the indices on the table"""
|
||||
return LOOP.run(self._table.list_indices())
|
||||
|
||||
def index_stats(self, index_uuid: str):
|
||||
def index_stats(self, index_uuid: str) -> Optional[IndexStatistics]:
|
||||
"""List all the stats of a specified index"""
|
||||
return LOOP.run(self._table.index_stats(index_uuid))
|
||||
|
||||
@@ -232,10 +225,12 @@ class RemoteTable(Table):
|
||||
config = HnswPq(distance_type=metric)
|
||||
elif index_type == "IVF_HNSW_SQ":
|
||||
config = HnswSq(distance_type=metric)
|
||||
elif index_type == "IVF_FLAT":
|
||||
config = IvfFlat(distance_type=metric)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown vector index type: {index_type}. Valid options are"
|
||||
" 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
|
||||
" 'IVF_FLAT', 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
|
||||
)
|
||||
|
||||
LOOP.run(self._table.create_index(vector_column_name, config=config))
|
||||
@@ -479,16 +474,28 @@ class RemoteTable(Table):
|
||||
)
|
||||
|
||||
def cleanup_old_versions(self, *_):
|
||||
"""cleanup_old_versions() is not supported on the LanceDB cloud"""
|
||||
raise NotImplementedError(
|
||||
"cleanup_old_versions() is not supported on the LanceDB cloud"
|
||||
"""
|
||||
cleanup_old_versions() is a no-op on LanceDB Cloud.
|
||||
|
||||
Tables are automatically cleaned up and optimized.
|
||||
"""
|
||||
warnings.warn(
|
||||
"cleanup_old_versions() is a no-op on LanceDB Cloud. "
|
||||
"Tables are automatically cleaned up and optimized."
|
||||
)
|
||||
pass
|
||||
|
||||
def compact_files(self, *_):
|
||||
"""compact_files() is not supported on the LanceDB cloud"""
|
||||
raise NotImplementedError(
|
||||
"compact_files() is not supported on the LanceDB cloud"
|
||||
"""
|
||||
compact_files() is a no-op on LanceDB Cloud.
|
||||
|
||||
Tables are automatically compacted and optimized.
|
||||
"""
|
||||
warnings.warn(
|
||||
"compact_files() is a no-op on LanceDB Cloud. "
|
||||
"Tables are automatically compacted and optimized."
|
||||
)
|
||||
pass
|
||||
|
||||
def optimize(
|
||||
self,
|
||||
@@ -496,12 +503,16 @@ class RemoteTable(Table):
|
||||
cleanup_older_than: Optional[timedelta] = None,
|
||||
delete_unverified: bool = False,
|
||||
):
|
||||
"""optimize() is not supported on the LanceDB cloud.
|
||||
Indices are optimized automatically."""
|
||||
raise NotImplementedError(
|
||||
"optimize() is not supported on the LanceDB cloud. "
|
||||
"""
|
||||
optimize() is a no-op on LanceDB Cloud.
|
||||
|
||||
Indices are optimized automatically.
|
||||
"""
|
||||
warnings.warn(
|
||||
"optimize() is a no-op on LanceDB Cloud. "
|
||||
"Indices are optimized automatically."
|
||||
)
|
||||
pass
|
||||
|
||||
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
return LOOP.run(self._table.count_rows(filter))
|
||||
@@ -515,6 +526,16 @@ class RemoteTable(Table):
|
||||
def drop_columns(self, columns: Iterable[str]):
|
||||
return LOOP.run(self._table.drop_columns(columns))
|
||||
|
||||
def uses_v2_manifest_paths(self) -> bool:
|
||||
raise NotImplementedError(
|
||||
"uses_v2_manifest_paths() is not supported on the LanceDB Cloud"
|
||||
)
|
||||
|
||||
def migrate_v2_manifest_paths(self):
|
||||
raise NotImplementedError(
|
||||
"migrate_v2_manifest_paths() is not supported on the LanceDB Cloud"
|
||||
)
|
||||
|
||||
|
||||
def add_index(tbl: pa.Table, i: int) -> pa.Table:
|
||||
return tbl.add_column(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -314,3 +314,15 @@ def deprecated(func):
|
||||
def validate_table_name(name: str):
|
||||
"""Verify the table name is valid."""
|
||||
native_validate_table_name(name)
|
||||
|
||||
|
||||
def add_note(base_exception: BaseException, note: str):
|
||||
if hasattr(base_exception, "add_note"):
|
||||
base_exception.add_note(note)
|
||||
elif isinstance(base_exception.args[0], str):
|
||||
base_exception.args = (
|
||||
base_exception.args[0] + "\n" + note,
|
||||
*base_exception.args[1:],
|
||||
)
|
||||
else:
|
||||
raise ValueError("Cannot add note to exception")
|
||||
|
||||
32
python/python/tests/conftest.py
Normal file
32
python/python/tests/conftest.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from datetime import timedelta
|
||||
from lancedb.db import AsyncConnection, DBConnection
|
||||
import lancedb
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
|
||||
# Use an in-memory database for most tests.
|
||||
@pytest.fixture
|
||||
def mem_db() -> DBConnection:
|
||||
return lancedb.connect("memory://")
|
||||
|
||||
|
||||
# Use a temporary directory when we need to inspect the database files.
|
||||
@pytest.fixture
|
||||
def tmp_db(tmp_path) -> DBConnection:
|
||||
return lancedb.connect(tmp_path)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def mem_db_async() -> AsyncConnection:
|
||||
return await lancedb.connect_async("memory://")
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def tmp_db_async(tmp_path) -> AsyncConnection:
|
||||
return await lancedb.connect_async(
|
||||
tmp_path, read_consistency_interval=timedelta(seconds=0)
|
||||
)
|
||||
@@ -75,6 +75,22 @@ def test_quickstart():
|
||||
for _ in range(1000)
|
||||
]
|
||||
)
|
||||
# --8<-- [start:add_columns]
|
||||
tbl.add_columns({"double_price": "cast((price * 2) as float)"})
|
||||
# --8<-- [end:add_columns]
|
||||
# --8<-- [start:alter_columns]
|
||||
tbl.alter_columns(
|
||||
{
|
||||
"path": "double_price",
|
||||
"rename": "dbl_price",
|
||||
"data_type": pa.float64(),
|
||||
"nullable": True,
|
||||
}
|
||||
)
|
||||
# --8<-- [end:alter_columns]
|
||||
# --8<-- [start:drop_columns]
|
||||
tbl.drop_columns(["dbl_price"])
|
||||
# --8<-- [end:drop_columns]
|
||||
# --8<-- [start:create_index]
|
||||
# Synchronous client
|
||||
tbl.create_index(num_sub_vectors=1)
|
||||
|
||||
44
python/python/tests/docs/test_binary_vector.py
Normal file
44
python/python/tests/docs/test_binary_vector.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import shutil
|
||||
|
||||
# --8<-- [start:imports]
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pytest
|
||||
# --8<-- [end:imports]
|
||||
|
||||
shutil.rmtree("data/binary_lancedb", ignore_errors=True)
|
||||
|
||||
|
||||
def test_binary_vector():
|
||||
# --8<-- [start:sync_binary_vector]
|
||||
db = lancedb.connect("data/binary_lancedb")
|
||||
data = [
|
||||
{
|
||||
"id": i,
|
||||
"vector": np.random.randint(0, 256, size=16),
|
||||
}
|
||||
for i in range(1024)
|
||||
]
|
||||
tbl = db.create_table("my_binary_vectors", data=data)
|
||||
query = np.random.randint(0, 256, size=16)
|
||||
tbl.search(query).to_arrow()
|
||||
# --8<-- [end:sync_binary_vector]
|
||||
db.drop_table("my_binary_vectors")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_binary_vector_async():
|
||||
# --8<-- [start:async_binary_vector]
|
||||
db = await lancedb.connect_async("data/binary_lancedb")
|
||||
data = [
|
||||
{
|
||||
"id": i,
|
||||
"vector": np.random.randint(0, 256, size=16),
|
||||
}
|
||||
for i in range(1024)
|
||||
]
|
||||
tbl = await db.create_table("my_binary_vectors", data=data)
|
||||
query = np.random.randint(0, 256, size=16)
|
||||
await tbl.query().nearest_to(query).to_arrow()
|
||||
# --8<-- [end:async_binary_vector]
|
||||
await db.drop_table("my_binary_vectors")
|
||||
@@ -98,7 +98,7 @@ def test_ingest_pd(tmp_path):
|
||||
assert db.open_table("test").name == db["test"].name
|
||||
|
||||
|
||||
def test_ingest_iterator(tmp_path):
|
||||
def test_ingest_iterator(mem_db: lancedb.DBConnection):
|
||||
class PydanticSchema(LanceModel):
|
||||
vector: Vector(2)
|
||||
item: str
|
||||
@@ -156,8 +156,7 @@ def test_ingest_iterator(tmp_path):
|
||||
]
|
||||
|
||||
def run_tests(schema):
|
||||
db = lancedb.connect(tmp_path)
|
||||
tbl = db.create_table("table2", make_batches(), schema=schema, mode="overwrite")
|
||||
tbl = mem_db.create_table("table2", make_batches(), schema=schema)
|
||||
tbl.to_pandas()
|
||||
assert tbl.search([3.1, 4.1]).limit(1).to_pandas()["_distance"][0] == 0.0
|
||||
assert tbl.search([5.9, 26.5]).limit(1).to_pandas()["_distance"][0] == 0.0
|
||||
@@ -165,15 +164,14 @@ def test_ingest_iterator(tmp_path):
|
||||
tbl.add(make_batches())
|
||||
assert tbl_len == 50
|
||||
assert len(tbl) == tbl_len * 2
|
||||
assert len(tbl.list_versions()) == 3
|
||||
db.drop_database()
|
||||
assert len(tbl.list_versions()) == 2
|
||||
mem_db.drop_database()
|
||||
|
||||
run_tests(arrow_schema)
|
||||
run_tests(PydanticSchema)
|
||||
|
||||
|
||||
def test_table_names(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
def test_table_names(tmp_db: lancedb.DBConnection):
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
@@ -181,10 +179,10 @@ def test_table_names(tmp_path):
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
db.create_table("test2", data=data)
|
||||
db.create_table("test1", data=data)
|
||||
db.create_table("test3", data=data)
|
||||
assert db.table_names() == ["test1", "test2", "test3"]
|
||||
tmp_db.create_table("test2", data=data)
|
||||
tmp_db.create_table("test1", data=data)
|
||||
tmp_db.create_table("test3", data=data)
|
||||
assert tmp_db.table_names() == ["test1", "test2", "test3"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -209,8 +207,7 @@ async def test_table_names_async(tmp_path):
|
||||
assert await db.table_names(start_after="test1") == ["test2", "test3"]
|
||||
|
||||
|
||||
def test_create_mode(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
def test_create_mode(tmp_db: lancedb.DBConnection):
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
@@ -218,10 +215,10 @@ def test_create_mode(tmp_path):
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
db.create_table("test", data=data)
|
||||
tmp_db.create_table("test", data=data)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
db.create_table("test", data=data)
|
||||
tmp_db.create_table("test", data=data)
|
||||
|
||||
new_data = pd.DataFrame(
|
||||
{
|
||||
@@ -230,13 +227,11 @@ def test_create_mode(tmp_path):
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
tbl = db.create_table("test", data=new_data, mode="overwrite")
|
||||
tbl = tmp_db.create_table("test", data=new_data, mode="overwrite")
|
||||
assert tbl.to_pandas().item.tolist() == ["fizz", "buzz"]
|
||||
|
||||
|
||||
def test_create_table_from_iterator(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
|
||||
def test_create_table_from_iterator(mem_db: lancedb.DBConnection):
|
||||
def gen_data():
|
||||
for _ in range(10):
|
||||
yield pa.RecordBatch.from_arrays(
|
||||
@@ -248,14 +243,12 @@ def test_create_table_from_iterator(tmp_path):
|
||||
["vector", "item", "price"],
|
||||
)
|
||||
|
||||
table = db.create_table("test", data=gen_data())
|
||||
table = mem_db.create_table("test", data=gen_data())
|
||||
assert table.count_rows() == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_table_from_iterator_async(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
|
||||
async def test_create_table_from_iterator_async(mem_db_async: lancedb.AsyncConnection):
|
||||
def gen_data():
|
||||
for _ in range(10):
|
||||
yield pa.RecordBatch.from_arrays(
|
||||
@@ -267,12 +260,11 @@ async def test_create_table_from_iterator_async(tmp_path):
|
||||
["vector", "item", "price"],
|
||||
)
|
||||
|
||||
table = await db.create_table("test", data=gen_data())
|
||||
table = await mem_db_async.create_table("test", data=gen_data())
|
||||
assert await table.count_rows() == 10
|
||||
|
||||
|
||||
def test_create_exist_ok(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
def test_create_exist_ok(tmp_db: lancedb.DBConnection):
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
@@ -280,13 +272,13 @@ def test_create_exist_ok(tmp_path):
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
tbl = db.create_table("test", data=data)
|
||||
tbl = tmp_db.create_table("test", data=data)
|
||||
|
||||
with pytest.raises(OSError):
|
||||
db.create_table("test", data=data)
|
||||
with pytest.raises(ValueError):
|
||||
tmp_db.create_table("test", data=data)
|
||||
|
||||
# open the table but don't add more rows
|
||||
tbl2 = db.create_table("test", data=data, exist_ok=True)
|
||||
tbl2 = tmp_db.create_table("test", data=data, exist_ok=True)
|
||||
assert tbl.name == tbl2.name
|
||||
assert tbl.schema == tbl2.schema
|
||||
assert len(tbl) == len(tbl2)
|
||||
@@ -298,7 +290,7 @@ def test_create_exist_ok(tmp_path):
|
||||
pa.field("price", pa.float64()),
|
||||
]
|
||||
)
|
||||
tbl3 = db.create_table("test", schema=schema, exist_ok=True)
|
||||
tbl3 = tmp_db.create_table("test", schema=schema, exist_ok=True)
|
||||
assert tbl3.schema == schema
|
||||
|
||||
bad_schema = pa.schema(
|
||||
@@ -310,7 +302,7 @@ def test_create_exist_ok(tmp_path):
|
||||
]
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
db.create_table("test", schema=bad_schema, exist_ok=True)
|
||||
tmp_db.create_table("test", schema=bad_schema, exist_ok=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -325,26 +317,24 @@ async def test_connect(tmp_path):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
assert db.is_open()
|
||||
db.close()
|
||||
assert not db.is_open()
|
||||
async def test_close(mem_db_async: lancedb.AsyncConnection):
|
||||
assert mem_db_async.is_open()
|
||||
mem_db_async.close()
|
||||
assert not mem_db_async.is_open()
|
||||
|
||||
with pytest.raises(RuntimeError, match="is closed"):
|
||||
await db.table_names()
|
||||
await mem_db_async.table_names()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_manager(tmp_path):
|
||||
with await lancedb.connect_async(tmp_path) as db:
|
||||
async def test_context_manager():
|
||||
with await lancedb.connect_async("memory://") as db:
|
||||
assert db.is_open()
|
||||
assert not db.is_open()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_mode_async(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
async def test_create_mode_async(tmp_db_async: lancedb.AsyncConnection):
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
@@ -352,10 +342,10 @@ async def test_create_mode_async(tmp_path):
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
await db.create_table("test", data=data)
|
||||
await tmp_db_async.create_table("test", data=data)
|
||||
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
await db.create_table("test", data=data)
|
||||
await tmp_db_async.create_table("test", data=data)
|
||||
|
||||
new_data = pd.DataFrame(
|
||||
{
|
||||
@@ -364,15 +354,14 @@ async def test_create_mode_async(tmp_path):
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
_tbl = await db.create_table("test", data=new_data, mode="overwrite")
|
||||
_tbl = await tmp_db_async.create_table("test", data=new_data, mode="overwrite")
|
||||
|
||||
# MIGRATION: to_pandas() is not available in async
|
||||
# assert tbl.to_pandas().item.tolist() == ["fizz", "buzz"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_exist_ok_async(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
async def test_create_exist_ok_async(tmp_db_async: lancedb.AsyncConnection):
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
@@ -380,13 +369,13 @@ async def test_create_exist_ok_async(tmp_path):
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
tbl = await db.create_table("test", data=data)
|
||||
tbl = await tmp_db_async.create_table("test", data=data)
|
||||
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
await db.create_table("test", data=data)
|
||||
await tmp_db_async.create_table("test", data=data)
|
||||
|
||||
# open the table but don't add more rows
|
||||
tbl2 = await db.create_table("test", data=data, exist_ok=True)
|
||||
tbl2 = await tmp_db_async.create_table("test", data=data, exist_ok=True)
|
||||
assert tbl.name == tbl2.name
|
||||
assert await tbl.schema() == await tbl2.schema()
|
||||
|
||||
@@ -397,7 +386,7 @@ async def test_create_exist_ok_async(tmp_path):
|
||||
pa.field("price", pa.float64()),
|
||||
]
|
||||
)
|
||||
tbl3 = await db.create_table("test", schema=schema, exist_ok=True)
|
||||
tbl3 = await tmp_db_async.create_table("test", schema=schema, exist_ok=True)
|
||||
assert await tbl3.schema() == schema
|
||||
|
||||
# Migration: When creating a table, but the table already exists, but
|
||||
@@ -448,13 +437,12 @@ async def test_create_table_v2_manifest_paths_async(tmp_path):
|
||||
assert re.match(r"\d{20}\.manifest", manifest)
|
||||
|
||||
|
||||
def test_open_table_sync(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
db.create_table("test", data=[{"id": 0}])
|
||||
assert db.open_table("test").count_rows() == 1
|
||||
assert db.open_table("test", index_cache_size=0).count_rows() == 1
|
||||
with pytest.raises(FileNotFoundError, match="does not exist"):
|
||||
db.open_table("does_not_exist")
|
||||
def test_open_table_sync(tmp_db: lancedb.DBConnection):
|
||||
tmp_db.create_table("test", data=[{"id": 0}])
|
||||
assert tmp_db.open_table("test").count_rows() == 1
|
||||
assert tmp_db.open_table("test", index_cache_size=0).count_rows() == 1
|
||||
with pytest.raises(ValueError, match="Table 'does_not_exist' was not found"):
|
||||
tmp_db.open_table("does_not_exist")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -494,8 +482,7 @@ async def test_open_table(tmp_path):
|
||||
await db.open_table("does_not_exist")
|
||||
|
||||
|
||||
def test_delete_table(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
def test_delete_table(tmp_db: lancedb.DBConnection):
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
@@ -503,26 +490,51 @@ def test_delete_table(tmp_path):
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
db.create_table("test", data=data)
|
||||
tmp_db.create_table("test", data=data)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
db.create_table("test", data=data)
|
||||
tmp_db.create_table("test", data=data)
|
||||
|
||||
assert db.table_names() == ["test"]
|
||||
assert tmp_db.table_names() == ["test"]
|
||||
|
||||
db.drop_table("test")
|
||||
assert db.table_names() == []
|
||||
tmp_db.drop_table("test")
|
||||
assert tmp_db.table_names() == []
|
||||
|
||||
db.create_table("test", data=data)
|
||||
assert db.table_names() == ["test"]
|
||||
tmp_db.create_table("test", data=data)
|
||||
assert tmp_db.table_names() == ["test"]
|
||||
|
||||
# dropping a table that does not exist should pass
|
||||
# if ignore_missing=True
|
||||
db.drop_table("does_not_exist", ignore_missing=True)
|
||||
tmp_db.drop_table("does_not_exist", ignore_missing=True)
|
||||
|
||||
|
||||
def test_drop_database(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_table_async(tmp_db: lancedb.DBConnection):
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
"item": ["foo", "bar"],
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
|
||||
tmp_db.create_table("test", data=data)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
tmp_db.create_table("test", data=data)
|
||||
|
||||
assert tmp_db.table_names() == ["test"]
|
||||
|
||||
tmp_db.drop_table("test")
|
||||
assert tmp_db.table_names() == []
|
||||
|
||||
tmp_db.create_table("test", data=data)
|
||||
assert tmp_db.table_names() == ["test"]
|
||||
|
||||
tmp_db.drop_table("does_not_exist", ignore_missing=True)
|
||||
|
||||
|
||||
def test_drop_database(tmp_db: lancedb.DBConnection):
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
@@ -537,51 +549,50 @@ def test_drop_database(tmp_path):
|
||||
"price": [12.0, 17.0],
|
||||
}
|
||||
)
|
||||
db.create_table("test", data=data)
|
||||
tmp_db.create_table("test", data=data)
|
||||
with pytest.raises(Exception):
|
||||
db.create_table("test", data=data)
|
||||
tmp_db.create_table("test", data=data)
|
||||
|
||||
assert db.table_names() == ["test"]
|
||||
assert tmp_db.table_names() == ["test"]
|
||||
|
||||
db.create_table("new_test", data=new_data)
|
||||
db.drop_database()
|
||||
assert db.table_names() == []
|
||||
tmp_db.create_table("new_test", data=new_data)
|
||||
tmp_db.drop_database()
|
||||
assert tmp_db.table_names() == []
|
||||
|
||||
# it should pass when no tables are present
|
||||
db.create_table("test", data=new_data)
|
||||
db.drop_table("test")
|
||||
assert db.table_names() == []
|
||||
db.drop_database()
|
||||
assert db.table_names() == []
|
||||
tmp_db.create_table("test", data=new_data)
|
||||
tmp_db.drop_table("test")
|
||||
assert tmp_db.table_names() == []
|
||||
tmp_db.drop_database()
|
||||
assert tmp_db.table_names() == []
|
||||
|
||||
# creating an empty database with schema
|
||||
schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), list_size=2))])
|
||||
db.create_table("empty_table", schema=schema)
|
||||
tmp_db.create_table("empty_table", schema=schema)
|
||||
# dropping a empty database should pass
|
||||
db.drop_database()
|
||||
assert db.table_names() == []
|
||||
tmp_db.drop_database()
|
||||
assert tmp_db.table_names() == []
|
||||
|
||||
|
||||
def test_empty_or_nonexistent_table(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
def test_empty_or_nonexistent_table(mem_db: lancedb.DBConnection):
|
||||
with pytest.raises(Exception):
|
||||
db.create_table("test_with_no_data")
|
||||
mem_db.create_table("test_with_no_data")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
db.open_table("does_not_exist")
|
||||
mem_db.open_table("does_not_exist")
|
||||
|
||||
schema = pa.schema([pa.field("a", pa.int64(), nullable=False)])
|
||||
test = db.create_table("test", schema=schema)
|
||||
test = mem_db.create_table("test", schema=schema)
|
||||
|
||||
class TestModel(LanceModel):
|
||||
a: int
|
||||
|
||||
test2 = db.create_table("test2", schema=TestModel)
|
||||
test2 = mem_db.create_table("test2", schema=TestModel)
|
||||
assert test.schema == test2.schema
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_in_v2_mode(tmp_path):
|
||||
async def test_create_in_v2_mode(mem_db_async: lancedb.AsyncConnection):
|
||||
def make_data():
|
||||
for i in range(10):
|
||||
yield pa.record_batch([pa.array([x for x in range(1024)])], names=["x"])
|
||||
@@ -591,10 +602,8 @@ async def test_create_in_v2_mode(tmp_path):
|
||||
|
||||
schema = pa.schema([pa.field("x", pa.int64())])
|
||||
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
|
||||
# Create table in v1 mode
|
||||
tbl = await db.create_table(
|
||||
tbl = await mem_db_async.create_table(
|
||||
"test", data=make_data(), schema=schema, data_storage_version="legacy"
|
||||
)
|
||||
|
||||
@@ -610,7 +619,7 @@ async def test_create_in_v2_mode(tmp_path):
|
||||
assert not await is_in_v2_mode(tbl)
|
||||
|
||||
# Create table in v2 mode
|
||||
tbl = await db.create_table(
|
||||
tbl = await mem_db_async.create_table(
|
||||
"test_v2", data=make_data(), schema=schema, use_legacy_format=False
|
||||
)
|
||||
|
||||
@@ -622,7 +631,7 @@ async def test_create_in_v2_mode(tmp_path):
|
||||
assert await is_in_v2_mode(tbl)
|
||||
|
||||
# Create empty table in v2 mode and add data
|
||||
tbl = await db.create_table(
|
||||
tbl = await mem_db_async.create_table(
|
||||
"test_empty_v2", data=None, schema=schema, use_legacy_format=False
|
||||
)
|
||||
await tbl.add(make_table())
|
||||
@@ -630,7 +639,7 @@ async def test_create_in_v2_mode(tmp_path):
|
||||
assert await is_in_v2_mode(tbl)
|
||||
|
||||
# Create empty table uses v1 mode by default
|
||||
tbl = await db.create_table(
|
||||
tbl = await mem_db_async.create_table(
|
||||
"test_empty_v2_default", data=None, schema=schema, data_storage_version="legacy"
|
||||
)
|
||||
await tbl.add(make_table())
|
||||
@@ -638,18 +647,17 @@ async def test_create_in_v2_mode(tmp_path):
|
||||
assert not await is_in_v2_mode(tbl)
|
||||
|
||||
|
||||
def test_replace_index(tmp_path):
|
||||
db = lancedb.connect(uri=tmp_path)
|
||||
table = db.create_table(
|
||||
def test_replace_index(mem_db: lancedb.DBConnection):
|
||||
table = mem_db.create_table(
|
||||
"test",
|
||||
[
|
||||
{"vector": np.random.rand(128), "item": "foo", "price": float(i)}
|
||||
for i in range(1000)
|
||||
{"vector": np.random.rand(32), "item": "foo", "price": float(i)}
|
||||
for i in range(512)
|
||||
],
|
||||
)
|
||||
table.create_index(
|
||||
num_partitions=2,
|
||||
num_sub_vectors=4,
|
||||
num_sub_vectors=2,
|
||||
)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
@@ -660,27 +668,26 @@ def test_replace_index(tmp_path):
|
||||
)
|
||||
|
||||
table.create_index(
|
||||
num_partitions=2,
|
||||
num_sub_vectors=4,
|
||||
num_partitions=1,
|
||||
num_sub_vectors=2,
|
||||
replace=True,
|
||||
index_cache_size=10,
|
||||
)
|
||||
|
||||
|
||||
def test_prefilter_with_index(tmp_path):
|
||||
db = lancedb.connect(uri=tmp_path)
|
||||
def test_prefilter_with_index(mem_db: lancedb.DBConnection):
|
||||
data = [
|
||||
{"vector": np.random.rand(128), "item": "foo", "price": float(i)}
|
||||
for i in range(1000)
|
||||
{"vector": np.random.rand(32), "item": "foo", "price": float(i)}
|
||||
for i in range(512)
|
||||
]
|
||||
sample_key = data[100]["vector"]
|
||||
table = db.create_table(
|
||||
table = mem_db.create_table(
|
||||
"test",
|
||||
data,
|
||||
)
|
||||
table.create_index(
|
||||
num_partitions=2,
|
||||
num_sub_vectors=4,
|
||||
num_sub_vectors=2,
|
||||
)
|
||||
table = (
|
||||
table.search(sample_key)
|
||||
@@ -691,13 +698,34 @@ def test_prefilter_with_index(tmp_path):
|
||||
assert table.num_rows == 1
|
||||
|
||||
|
||||
def test_create_table_with_invalid_names(tmp_path):
|
||||
db = lancedb.connect(uri=tmp_path)
|
||||
def test_create_table_with_invalid_names(tmp_db: lancedb.DBConnection):
|
||||
data = [{"vector": np.random.rand(128), "item": "foo"} for i in range(10)]
|
||||
with pytest.raises(ValueError):
|
||||
db.create_table("foo/bar", data)
|
||||
tmp_db.create_table("foo/bar", data)
|
||||
with pytest.raises(ValueError):
|
||||
db.create_table("foo bar", data)
|
||||
tmp_db.create_table("foo bar", data)
|
||||
with pytest.raises(ValueError):
|
||||
db.create_table("foo$$bar", data)
|
||||
db.create_table("foo.bar", data)
|
||||
tmp_db.create_table("foo$$bar", data)
|
||||
tmp_db.create_table("foo.bar", data)
|
||||
|
||||
|
||||
def test_bypass_vector_index_sync(tmp_db: lancedb.DBConnection):
|
||||
data = [{"vector": np.random.rand(32)} for _ in range(512)]
|
||||
sample_key = data[100]["vector"]
|
||||
table = tmp_db.create_table(
|
||||
"test",
|
||||
data,
|
||||
)
|
||||
|
||||
table.create_index(
|
||||
num_partitions=2,
|
||||
num_sub_vectors=2,
|
||||
)
|
||||
|
||||
plan_with_index = table.search(sample_key).explain_plan(verbose=True)
|
||||
assert "ANN" in plan_with_index
|
||||
|
||||
plan_without_index = (
|
||||
table.search(sample_key).bypass_vector_index().explain_plan(verbose=True)
|
||||
)
|
||||
assert "KNN" in plan_without_index
|
||||
|
||||
@@ -15,10 +15,12 @@ import random
|
||||
from unittest import mock
|
||||
|
||||
import lancedb as ldb
|
||||
from lancedb.db import DBConnection
|
||||
from lancedb.index import FTS
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from utils import exception_output
|
||||
|
||||
pytest.importorskip("lancedb.fts")
|
||||
tantivy = pytest.importorskip("tantivy")
|
||||
@@ -458,3 +460,44 @@ def test_syntax(table):
|
||||
table.search('the cats OR dogs were not really "pets" at all').phrase_query().limit(
|
||||
10
|
||||
).to_list()
|
||||
|
||||
|
||||
def test_language(mem_db: DBConnection):
|
||||
sentences = [
|
||||
"Il n'y a que trois routes qui traversent la ville.",
|
||||
"Je veux prendre la route vers l'est.",
|
||||
"Je te retrouve au café au bout de la route.",
|
||||
]
|
||||
data = [{"text": s} for s in sentences]
|
||||
table = mem_db.create_table("test", data=data)
|
||||
|
||||
with pytest.raises(ValueError) as e:
|
||||
table.create_fts_index("text", use_tantivy=False, language="klingon")
|
||||
|
||||
assert exception_output(e) == (
|
||||
"ValueError: LanceDB does not support the requested language: 'klingon'\n"
|
||||
"Supported languages: Arabic, Danish, Dutch, English, Finnish, French, "
|
||||
"German, Greek, Hungarian, Italian, Norwegian, Portuguese, Romanian, "
|
||||
"Russian, Spanish, Swedish, Tamil, Turkish"
|
||||
)
|
||||
|
||||
table.create_fts_index(
|
||||
"text",
|
||||
use_tantivy=False,
|
||||
language="French",
|
||||
stem=True,
|
||||
ascii_folding=True,
|
||||
remove_stop_words=True,
|
||||
)
|
||||
|
||||
# Can get "routes" and "route" from the same root
|
||||
results = table.search("route", query_type="fts").limit(5).to_list()
|
||||
assert len(results) == 3
|
||||
|
||||
# Can find "café", without needing to provide accent
|
||||
results = table.search("cafe", query_type="fts").limit(5).to_list()
|
||||
assert len(results) == 1
|
||||
|
||||
# Stop words -> no results
|
||||
results = table.search("la", query_type="fts").limit(5).to_list()
|
||||
assert len(results) == 0
|
||||
|
||||
@@ -8,7 +8,7 @@ import pyarrow as pa
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from lancedb import AsyncConnection, AsyncTable, connect_async
|
||||
from lancedb.index import BTree, IvfPq, Bitmap, LabelList, HnswPq, HnswSq
|
||||
from lancedb.index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@@ -42,6 +42,27 @@ async def some_table(db_async):
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def binary_table(db_async):
|
||||
data = [
|
||||
{
|
||||
"id": i,
|
||||
"vector": [i] * 128,
|
||||
}
|
||||
for i in range(NROWS)
|
||||
]
|
||||
return await db_async.create_table(
|
||||
"binary_table",
|
||||
data,
|
||||
schema=pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field("vector", pa.list_(pa.uint8(), 128)),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_scalar_index(some_table: AsyncTable):
|
||||
# Can create
|
||||
@@ -143,3 +164,27 @@ async def test_create_hnswsq_index(some_table: AsyncTable):
|
||||
await some_table.create_index("vector", config=HnswSq(num_partitions=10))
|
||||
indices = await some_table.list_indices()
|
||||
assert len(indices) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_index_with_binary_vectors(binary_table: AsyncTable):
|
||||
await binary_table.create_index(
|
||||
"vector", config=IvfFlat(distance_type="hamming", num_partitions=10)
|
||||
)
|
||||
indices = await binary_table.list_indices()
|
||||
assert len(indices) == 1
|
||||
assert indices[0].index_type == "IvfFlat"
|
||||
assert indices[0].columns == ["vector"]
|
||||
assert indices[0].name == "vector_idx"
|
||||
|
||||
stats = await binary_table.index_stats("vector_idx")
|
||||
assert stats.index_type == "IVF_FLAT"
|
||||
assert stats.distance_type == "hamming"
|
||||
assert stats.num_indexed_rows == await binary_table.count_rows()
|
||||
assert stats.num_unindexed_rows == 0
|
||||
assert stats.num_indices == 1
|
||||
|
||||
# the dataset contains vectors with all values from 0 to 255
|
||||
for v in range(256):
|
||||
res = await binary_table.query().nearest_to([v] * 128).to_arrow()
|
||||
assert res["id"][0].as_py() == v
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
11
python/python/tests/utils.py
Normal file
11
python/python/tests/utils.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
import pytest
|
||||
|
||||
|
||||
def exception_output(e_info: pytest.ExceptionInfo):
|
||||
import traceback
|
||||
|
||||
# skip traceback part, since it's not worth checking in tests
|
||||
lines = traceback.format_exception_only(e_info.type, e_info.value)
|
||||
return "".join(lines).strip()
|
||||
@@ -58,6 +58,11 @@ impl Connection {
|
||||
self.inner.take();
|
||||
}
|
||||
|
||||
#[getter]
|
||||
pub fn uri(&self) -> PyResult<String> {
|
||||
self.get_inner().map(|inner| inner.uri().to_string())
|
||||
}
|
||||
|
||||
#[pyo3(signature = (start_after=None, limit=None))]
|
||||
pub fn table_names(
|
||||
self_: PyRef<'_, Self>,
|
||||
|
||||
@@ -12,224 +12,174 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Mutex;
|
||||
|
||||
use lancedb::index::scalar::FtsIndexBuilder;
|
||||
use lancedb::{
|
||||
index::{
|
||||
scalar::BTreeIndexBuilder,
|
||||
vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
|
||||
Index as LanceDbIndex,
|
||||
},
|
||||
DistanceType,
|
||||
use lancedb::index::vector::IvfFlatIndexBuilder;
|
||||
use lancedb::index::{
|
||||
scalar::{BTreeIndexBuilder, FtsIndexBuilder, TokenizerConfig},
|
||||
vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
|
||||
Index as LanceDbIndex,
|
||||
};
|
||||
use pyo3::{
|
||||
exceptions::{PyKeyError, PyRuntimeError, PyValueError},
|
||||
pyclass, pymethods, IntoPy, PyObject, PyResult, Python,
|
||||
exceptions::{PyKeyError, PyValueError},
|
||||
intern, pyclass, pymethods,
|
||||
types::PyAnyMethods,
|
||||
Bound, FromPyObject, IntoPy, PyAny, PyObject, PyResult, Python,
|
||||
};
|
||||
|
||||
use crate::util::parse_distance_type;
|
||||
|
||||
#[pyclass]
|
||||
pub struct Index {
|
||||
inner: Mutex<Option<LanceDbIndex>>,
|
||||
}
|
||||
|
||||
impl Index {
|
||||
pub fn consume(&self) -> PyResult<LanceDbIndex> {
|
||||
self.inner
|
||||
.lock()
|
||||
.unwrap()
|
||||
.take()
|
||||
.ok_or_else(|| PyRuntimeError::new_err("cannot use an Index more than once"))
|
||||
pub fn class_name<'a>(ob: &'a Bound<'_, PyAny>) -> PyResult<&'a str> {
|
||||
let full_name: &str = ob
|
||||
.getattr(intern!(ob.py(), "__class__"))?
|
||||
.getattr(intern!(ob.py(), "__name__"))?
|
||||
.extract()?;
|
||||
match full_name.rsplit_once('.') {
|
||||
Some((_, name)) => Ok(name),
|
||||
None => Ok(full_name),
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl Index {
|
||||
#[pyo3(signature = (distance_type=None, num_partitions=None, num_sub_vectors=None,num_bits=None, max_iterations=None, sample_rate=None))]
|
||||
#[staticmethod]
|
||||
pub fn ivf_pq(
|
||||
distance_type: Option<String>,
|
||||
num_partitions: Option<u32>,
|
||||
num_sub_vectors: Option<u32>,
|
||||
num_bits: Option<u32>,
|
||||
max_iterations: Option<u32>,
|
||||
sample_rate: Option<u32>,
|
||||
) -> PyResult<Self> {
|
||||
let mut ivf_pq_builder = IvfPqIndexBuilder::default();
|
||||
if let Some(distance_type) = distance_type {
|
||||
let distance_type = match distance_type.as_str() {
|
||||
"l2" => Ok(DistanceType::L2),
|
||||
"cosine" => Ok(DistanceType::Cosine),
|
||||
"dot" => Ok(DistanceType::Dot),
|
||||
_ => Err(PyValueError::new_err(format!(
|
||||
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
|
||||
distance_type
|
||||
))),
|
||||
}?;
|
||||
ivf_pq_builder = ivf_pq_builder.distance_type(distance_type);
|
||||
pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<LanceDbIndex> {
|
||||
if let Some(source) = source {
|
||||
match class_name(source)? {
|
||||
"BTree" => Ok(LanceDbIndex::BTree(BTreeIndexBuilder::default())),
|
||||
"Bitmap" => Ok(LanceDbIndex::Bitmap(Default::default())),
|
||||
"LabelList" => Ok(LanceDbIndex::LabelList(Default::default())),
|
||||
"FTS" => {
|
||||
let params = source.extract::<FtsParams>()?;
|
||||
let inner_opts = TokenizerConfig::default()
|
||||
.base_tokenizer(params.base_tokenizer)
|
||||
.language(¶ms.language)
|
||||
.map_err(|_| PyValueError::new_err(format!("LanceDB does not support the requested language: '{}'", params.language)))?
|
||||
.lower_case(params.lower_case)
|
||||
.max_token_length(params.max_token_length)
|
||||
.remove_stop_words(params.remove_stop_words)
|
||||
.stem(params.stem)
|
||||
.ascii_folding(params.ascii_folding);
|
||||
let mut opts = FtsIndexBuilder::default()
|
||||
.with_position(params.with_position);
|
||||
opts.tokenizer_configs = inner_opts;
|
||||
Ok(LanceDbIndex::FTS(opts))
|
||||
},
|
||||
"IvfFlat" => {
|
||||
let params = source.extract::<IvfFlatParams>()?;
|
||||
let distance_type = parse_distance_type(params.distance_type)?;
|
||||
let mut ivf_flat_builder = IvfFlatIndexBuilder::default()
|
||||
.distance_type(distance_type)
|
||||
.max_iterations(params.max_iterations)
|
||||
.sample_rate(params.sample_rate);
|
||||
if let Some(num_partitions) = params.num_partitions {
|
||||
ivf_flat_builder = ivf_flat_builder.num_partitions(num_partitions);
|
||||
}
|
||||
Ok(LanceDbIndex::IvfFlat(ivf_flat_builder))
|
||||
},
|
||||
"IvfPq" => {
|
||||
let params = source.extract::<IvfPqParams>()?;
|
||||
let distance_type = parse_distance_type(params.distance_type)?;
|
||||
let mut ivf_pq_builder = IvfPqIndexBuilder::default()
|
||||
.distance_type(distance_type)
|
||||
.max_iterations(params.max_iterations)
|
||||
.sample_rate(params.sample_rate)
|
||||
.num_bits(params.num_bits);
|
||||
if let Some(num_partitions) = params.num_partitions {
|
||||
ivf_pq_builder = ivf_pq_builder.num_partitions(num_partitions);
|
||||
}
|
||||
if let Some(num_sub_vectors) = params.num_sub_vectors {
|
||||
ivf_pq_builder = ivf_pq_builder.num_sub_vectors(num_sub_vectors);
|
||||
}
|
||||
Ok(LanceDbIndex::IvfPq(ivf_pq_builder))
|
||||
},
|
||||
"HnswPq" => {
|
||||
let params = source.extract::<IvfHnswPqParams>()?;
|
||||
let distance_type = parse_distance_type(params.distance_type)?;
|
||||
let mut hnsw_pq_builder = IvfHnswPqIndexBuilder::default()
|
||||
.distance_type(distance_type)
|
||||
.max_iterations(params.max_iterations)
|
||||
.sample_rate(params.sample_rate)
|
||||
.num_edges(params.m)
|
||||
.ef_construction(params.ef_construction)
|
||||
.num_bits(params.num_bits);
|
||||
if let Some(num_partitions) = params.num_partitions {
|
||||
hnsw_pq_builder = hnsw_pq_builder.num_partitions(num_partitions);
|
||||
}
|
||||
if let Some(num_sub_vectors) = params.num_sub_vectors {
|
||||
hnsw_pq_builder = hnsw_pq_builder.num_sub_vectors(num_sub_vectors);
|
||||
}
|
||||
Ok(LanceDbIndex::IvfHnswPq(hnsw_pq_builder))
|
||||
},
|
||||
"HnswSq" => {
|
||||
let params = source.extract::<IvfHnswSqParams>()?;
|
||||
let distance_type = parse_distance_type(params.distance_type)?;
|
||||
let mut hnsw_sq_builder = IvfHnswSqIndexBuilder::default()
|
||||
.distance_type(distance_type)
|
||||
.max_iterations(params.max_iterations)
|
||||
.sample_rate(params.sample_rate)
|
||||
.num_edges(params.m)
|
||||
.ef_construction(params.ef_construction);
|
||||
if let Some(num_partitions) = params.num_partitions {
|
||||
hnsw_sq_builder = hnsw_sq_builder.num_partitions(num_partitions);
|
||||
}
|
||||
Ok(LanceDbIndex::IvfHnswSq(hnsw_sq_builder))
|
||||
},
|
||||
not_supported => Err(PyValueError::new_err(format!(
|
||||
"Invalid index type '{}'. Must be one of BTree, Bitmap, LabelList, FTS, IvfPq, IvfHnswPq, or IvfHnswSq",
|
||||
not_supported
|
||||
))),
|
||||
}
|
||||
if let Some(num_partitions) = num_partitions {
|
||||
ivf_pq_builder = ivf_pq_builder.num_partitions(num_partitions);
|
||||
}
|
||||
if let Some(num_sub_vectors) = num_sub_vectors {
|
||||
ivf_pq_builder = ivf_pq_builder.num_sub_vectors(num_sub_vectors);
|
||||
}
|
||||
if let Some(num_bits) = num_bits {
|
||||
ivf_pq_builder = ivf_pq_builder.num_bits(num_bits);
|
||||
}
|
||||
if let Some(max_iterations) = max_iterations {
|
||||
ivf_pq_builder = ivf_pq_builder.max_iterations(max_iterations);
|
||||
}
|
||||
if let Some(sample_rate) = sample_rate {
|
||||
ivf_pq_builder = ivf_pq_builder.sample_rate(sample_rate);
|
||||
}
|
||||
Ok(Self {
|
||||
inner: Mutex::new(Some(LanceDbIndex::IvfPq(ivf_pq_builder))),
|
||||
})
|
||||
} else {
|
||||
Ok(LanceDbIndex::Auto)
|
||||
}
|
||||
}
|
||||
|
||||
#[staticmethod]
|
||||
pub fn btree() -> PyResult<Self> {
|
||||
Ok(Self {
|
||||
inner: Mutex::new(Some(LanceDbIndex::BTree(BTreeIndexBuilder::default()))),
|
||||
})
|
||||
}
|
||||
#[derive(FromPyObject)]
|
||||
struct FtsParams {
|
||||
with_position: bool,
|
||||
base_tokenizer: String,
|
||||
language: String,
|
||||
max_token_length: Option<usize>,
|
||||
lower_case: bool,
|
||||
stem: bool,
|
||||
remove_stop_words: bool,
|
||||
ascii_folding: bool,
|
||||
}
|
||||
|
||||
#[staticmethod]
|
||||
pub fn bitmap() -> PyResult<Self> {
|
||||
Ok(Self {
|
||||
inner: Mutex::new(Some(LanceDbIndex::Bitmap(Default::default()))),
|
||||
})
|
||||
}
|
||||
#[derive(FromPyObject)]
|
||||
struct IvfFlatParams {
|
||||
distance_type: String,
|
||||
num_partitions: Option<u32>,
|
||||
max_iterations: u32,
|
||||
sample_rate: u32,
|
||||
}
|
||||
|
||||
#[staticmethod]
|
||||
pub fn label_list() -> PyResult<Self> {
|
||||
Ok(Self {
|
||||
inner: Mutex::new(Some(LanceDbIndex::LabelList(Default::default()))),
|
||||
})
|
||||
}
|
||||
#[derive(FromPyObject)]
|
||||
struct IvfPqParams {
|
||||
distance_type: String,
|
||||
num_partitions: Option<u32>,
|
||||
num_sub_vectors: Option<u32>,
|
||||
num_bits: u32,
|
||||
max_iterations: u32,
|
||||
sample_rate: u32,
|
||||
}
|
||||
|
||||
#[pyo3(signature = (with_position=None, base_tokenizer=None, language=None, max_token_length=None, lower_case=None, stem=None, remove_stop_words=None, ascii_folding=None))]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[staticmethod]
|
||||
pub fn fts(
|
||||
with_position: Option<bool>,
|
||||
base_tokenizer: Option<String>,
|
||||
language: Option<String>,
|
||||
max_token_length: Option<usize>,
|
||||
lower_case: Option<bool>,
|
||||
stem: Option<bool>,
|
||||
remove_stop_words: Option<bool>,
|
||||
ascii_folding: Option<bool>,
|
||||
) -> Self {
|
||||
let mut opts = FtsIndexBuilder::default();
|
||||
if let Some(with_position) = with_position {
|
||||
opts = opts.with_position(with_position);
|
||||
}
|
||||
if let Some(base_tokenizer) = base_tokenizer {
|
||||
opts.tokenizer_configs = opts.tokenizer_configs.base_tokenizer(base_tokenizer);
|
||||
}
|
||||
if let Some(language) = language {
|
||||
opts.tokenizer_configs = opts.tokenizer_configs.language(&language).unwrap();
|
||||
}
|
||||
opts.tokenizer_configs = opts.tokenizer_configs.max_token_length(max_token_length);
|
||||
if let Some(lower_case) = lower_case {
|
||||
opts.tokenizer_configs = opts.tokenizer_configs.lower_case(lower_case);
|
||||
}
|
||||
if let Some(stem) = stem {
|
||||
opts.tokenizer_configs = opts.tokenizer_configs.stem(stem);
|
||||
}
|
||||
if let Some(remove_stop_words) = remove_stop_words {
|
||||
opts.tokenizer_configs = opts.tokenizer_configs.remove_stop_words(remove_stop_words);
|
||||
}
|
||||
if let Some(ascii_folding) = ascii_folding {
|
||||
opts.tokenizer_configs = opts.tokenizer_configs.ascii_folding(ascii_folding);
|
||||
}
|
||||
Self {
|
||||
inner: Mutex::new(Some(LanceDbIndex::FTS(opts))),
|
||||
}
|
||||
}
|
||||
#[derive(FromPyObject)]
|
||||
struct IvfHnswPqParams {
|
||||
distance_type: String,
|
||||
num_partitions: Option<u32>,
|
||||
num_sub_vectors: Option<u32>,
|
||||
num_bits: u32,
|
||||
max_iterations: u32,
|
||||
sample_rate: u32,
|
||||
m: u32,
|
||||
ef_construction: u32,
|
||||
}
|
||||
|
||||
#[pyo3(signature = (distance_type=None, num_partitions=None, num_sub_vectors=None,num_bits=None, max_iterations=None, sample_rate=None, m=None, ef_construction=None))]
|
||||
#[staticmethod]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn hnsw_pq(
|
||||
distance_type: Option<String>,
|
||||
num_partitions: Option<u32>,
|
||||
num_sub_vectors: Option<u32>,
|
||||
num_bits: Option<u32>,
|
||||
max_iterations: Option<u32>,
|
||||
sample_rate: Option<u32>,
|
||||
m: Option<u32>,
|
||||
ef_construction: Option<u32>,
|
||||
) -> PyResult<Self> {
|
||||
let mut hnsw_pq_builder = IvfHnswPqIndexBuilder::default();
|
||||
if let Some(distance_type) = distance_type {
|
||||
let distance_type = parse_distance_type(distance_type)?;
|
||||
hnsw_pq_builder = hnsw_pq_builder.distance_type(distance_type);
|
||||
}
|
||||
if let Some(num_partitions) = num_partitions {
|
||||
hnsw_pq_builder = hnsw_pq_builder.num_partitions(num_partitions);
|
||||
}
|
||||
if let Some(num_sub_vectors) = num_sub_vectors {
|
||||
hnsw_pq_builder = hnsw_pq_builder.num_sub_vectors(num_sub_vectors);
|
||||
}
|
||||
if let Some(num_bits) = num_bits {
|
||||
hnsw_pq_builder = hnsw_pq_builder.num_bits(num_bits);
|
||||
}
|
||||
if let Some(max_iterations) = max_iterations {
|
||||
hnsw_pq_builder = hnsw_pq_builder.max_iterations(max_iterations);
|
||||
}
|
||||
if let Some(sample_rate) = sample_rate {
|
||||
hnsw_pq_builder = hnsw_pq_builder.sample_rate(sample_rate);
|
||||
}
|
||||
if let Some(m) = m {
|
||||
hnsw_pq_builder = hnsw_pq_builder.num_edges(m);
|
||||
}
|
||||
if let Some(ef_construction) = ef_construction {
|
||||
hnsw_pq_builder = hnsw_pq_builder.ef_construction(ef_construction);
|
||||
}
|
||||
Ok(Self {
|
||||
inner: Mutex::new(Some(LanceDbIndex::IvfHnswPq(hnsw_pq_builder))),
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (distance_type=None, num_partitions=None, max_iterations=None, sample_rate=None, m=None, ef_construction=None))]
|
||||
#[staticmethod]
|
||||
pub fn hnsw_sq(
|
||||
distance_type: Option<String>,
|
||||
num_partitions: Option<u32>,
|
||||
max_iterations: Option<u32>,
|
||||
sample_rate: Option<u32>,
|
||||
m: Option<u32>,
|
||||
ef_construction: Option<u32>,
|
||||
) -> PyResult<Self> {
|
||||
let mut hnsw_sq_builder = IvfHnswSqIndexBuilder::default();
|
||||
if let Some(distance_type) = distance_type {
|
||||
let distance_type = parse_distance_type(distance_type)?;
|
||||
hnsw_sq_builder = hnsw_sq_builder.distance_type(distance_type);
|
||||
}
|
||||
if let Some(num_partitions) = num_partitions {
|
||||
hnsw_sq_builder = hnsw_sq_builder.num_partitions(num_partitions);
|
||||
}
|
||||
if let Some(max_iterations) = max_iterations {
|
||||
hnsw_sq_builder = hnsw_sq_builder.max_iterations(max_iterations);
|
||||
}
|
||||
if let Some(sample_rate) = sample_rate {
|
||||
hnsw_sq_builder = hnsw_sq_builder.sample_rate(sample_rate);
|
||||
}
|
||||
if let Some(m) = m {
|
||||
hnsw_sq_builder = hnsw_sq_builder.num_edges(m);
|
||||
}
|
||||
if let Some(ef_construction) = ef_construction {
|
||||
hnsw_sq_builder = hnsw_sq_builder.ef_construction(ef_construction);
|
||||
}
|
||||
Ok(Self {
|
||||
inner: Mutex::new(Some(LanceDbIndex::IvfHnswSq(hnsw_sq_builder))),
|
||||
})
|
||||
}
|
||||
#[derive(FromPyObject)]
|
||||
struct IvfHnswSqParams {
|
||||
distance_type: String,
|
||||
num_partitions: Option<u32>,
|
||||
max_iterations: u32,
|
||||
sample_rate: u32,
|
||||
m: u32,
|
||||
ef_construction: u32,
|
||||
}
|
||||
|
||||
#[pyclass(get_all)]
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
use arrow::RecordBatchStream;
|
||||
use connection::{connect, Connection};
|
||||
use env_logger::Env;
|
||||
use index::{Index, IndexConfig};
|
||||
use index::IndexConfig;
|
||||
use pyo3::{
|
||||
pymodule,
|
||||
types::{PyModule, PyModuleMethods},
|
||||
@@ -40,7 +40,6 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
env_logger::init_from_env(env);
|
||||
m.add_class::<Connection>()?;
|
||||
m.add_class::<Table>()?;
|
||||
m.add_class::<Index>()?;
|
||||
m.add_class::<IndexConfig>()?;
|
||||
m.add_class::<Query>()?;
|
||||
m.add_class::<VectorQuery>()?;
|
||||
|
||||
@@ -19,7 +19,7 @@ use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
use crate::{
|
||||
error::PythonErrorExt,
|
||||
index::{Index, IndexConfig},
|
||||
index::{extract_index_params, IndexConfig},
|
||||
query::Query,
|
||||
};
|
||||
|
||||
@@ -177,14 +177,10 @@ impl Table {
|
||||
pub fn create_index<'a>(
|
||||
self_: PyRef<'a, Self>,
|
||||
column: String,
|
||||
index: Option<&Index>,
|
||||
index: Option<Bound<'_, PyAny>>,
|
||||
replace: Option<bool>,
|
||||
) -> PyResult<Bound<'a, PyAny>> {
|
||||
let index = if let Some(index) = index {
|
||||
index.consume()?
|
||||
} else {
|
||||
lancedb::index::Index::Auto
|
||||
};
|
||||
let index = extract_index_params(&index)?;
|
||||
let mut op = self_.inner_ref()?.create_index(&[column], index);
|
||||
if let Some(replace) = replace {
|
||||
op = op.replace(replace);
|
||||
|
||||
@@ -43,8 +43,9 @@ pub fn parse_distance_type(distance_type: impl AsRef<str>) -> PyResult<DistanceT
|
||||
"l2" => Ok(DistanceType::L2),
|
||||
"cosine" => Ok(DistanceType::Cosine),
|
||||
"dot" => Ok(DistanceType::Dot),
|
||||
"hamming" => Ok(DistanceType::Hamming),
|
||||
_ => Err(PyValueError::new_err(format!(
|
||||
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
|
||||
"Invalid distance type '{}'. Must be one of l2, cosine, dot, or hamming",
|
||||
distance_type.as_ref()
|
||||
))),
|
||||
}
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
[toolchain]
|
||||
channel = "1.80.0"
|
||||
channel = "1.83.0"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-node"
|
||||
version = "0.14.1-beta.1"
|
||||
version = "0.14.1"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
[package]
|
||||
name = "lancedb"
|
||||
version = "0.14.1-beta.1"
|
||||
version = "0.14.1"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
rust-version = "1.75"
|
||||
rust-version.workspace = true
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
[dependencies]
|
||||
|
||||
@@ -1050,6 +1050,8 @@ impl ConnectionInternal for Database {
|
||||
write_params.enable_v2_manifest_paths =
|
||||
options.enable_v2_manifest_paths.unwrap_or_default();
|
||||
|
||||
let data_schema = data.schema();
|
||||
|
||||
match NativeTable::create(
|
||||
&table_uri,
|
||||
&options.name,
|
||||
@@ -1069,7 +1071,18 @@ impl ConnectionInternal for Database {
|
||||
CreateTableMode::ExistOk(callback) => {
|
||||
let builder = OpenTableBuilder::new(options.parent, options.name);
|
||||
let builder = (callback)(builder);
|
||||
builder.execute().await
|
||||
let table = builder.execute().await?;
|
||||
|
||||
let table_schema = table.schema().await?;
|
||||
|
||||
if table_schema != data_schema {
|
||||
return Err(Error::Schema {
|
||||
message: "Provided schema does not match existing table schema"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(table)
|
||||
}
|
||||
CreateTableMode::Overwrite => unreachable!(),
|
||||
},
|
||||
|
||||
@@ -17,6 +17,7 @@ use std::sync::Arc;
|
||||
use scalar::FtsIndexBuilder;
|
||||
use serde::Deserialize;
|
||||
use serde_with::skip_serializing_none;
|
||||
use vector::IvfFlatIndexBuilder;
|
||||
|
||||
use crate::{table::TableInternal, DistanceType, Error, Result};
|
||||
|
||||
@@ -56,6 +57,9 @@ pub enum Index {
|
||||
/// Full text search index using bm25.
|
||||
FTS(FtsIndexBuilder),
|
||||
|
||||
/// IVF index
|
||||
IvfFlat(IvfFlatIndexBuilder),
|
||||
|
||||
/// IVF index with Product Quantization
|
||||
IvfPq(IvfPqIndexBuilder),
|
||||
|
||||
@@ -106,6 +110,8 @@ impl IndexBuilder {
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub enum IndexType {
|
||||
// Vector
|
||||
#[serde(alias = "IVF_FLAT")]
|
||||
IvfFlat,
|
||||
#[serde(alias = "IVF_PQ")]
|
||||
IvfPq,
|
||||
#[serde(alias = "IVF_HNSW_PQ")]
|
||||
@@ -127,6 +133,7 @@ pub enum IndexType {
|
||||
impl std::fmt::Display for IndexType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::IvfFlat => write!(f, "IVF_FLAT"),
|
||||
Self::IvfPq => write!(f, "IVF_PQ"),
|
||||
Self::IvfHnswPq => write!(f, "IVF_HNSW_PQ"),
|
||||
Self::IvfHnswSq => write!(f, "IVF_HNSW_SQ"),
|
||||
@@ -147,6 +154,7 @@ impl std::str::FromStr for IndexType {
|
||||
"BITMAP" => Ok(Self::Bitmap),
|
||||
"LABEL_LIST" | "LABELLIST" => Ok(Self::LabelList),
|
||||
"FTS" | "INVERTED" => Ok(Self::FTS),
|
||||
"IVF_FLAT" => Ok(Self::IvfFlat),
|
||||
"IVF_PQ" => Ok(Self::IvfPq),
|
||||
"IVF_HNSW_PQ" => Ok(Self::IvfHnswPq),
|
||||
"IVF_HNSW_SQ" => Ok(Self::IvfHnswSq),
|
||||
|
||||
@@ -77,5 +77,5 @@ impl FtsIndexBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
use lance_index::scalar::inverted::TokenizerConfig;
|
||||
pub use lance_index::scalar::inverted::TokenizerConfig;
|
||||
pub use lance_index::scalar::FullTextSearchQuery;
|
||||
|
||||
@@ -162,6 +162,43 @@ macro_rules! impl_hnsw_params_setter {
|
||||
};
|
||||
}
|
||||
|
||||
/// Builder for an IVF Flat index.
|
||||
///
|
||||
/// This index stores raw vectors. These vectors are grouped into partitions of similar vectors.
|
||||
/// Each partition keeps track of a centroid which is the average value of all vectors in the group.
|
||||
///
|
||||
/// During a query the centroids are compared with the query vector to find the closest partitions.
|
||||
/// The raw vectors in these partitions are then searched to find the closest vectors.
|
||||
///
|
||||
/// The partitioning process is called IVF and the `num_partitions` parameter controls how many groups to create.
|
||||
///
|
||||
/// Note that training an IVF Flat index on a large dataset is a slow operation and currently is also a memory intensive operation.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IvfFlatIndexBuilder {
|
||||
pub(crate) distance_type: DistanceType,
|
||||
|
||||
// IVF
|
||||
pub(crate) num_partitions: Option<u32>,
|
||||
pub(crate) sample_rate: u32,
|
||||
pub(crate) max_iterations: u32,
|
||||
}
|
||||
|
||||
impl Default for IvfFlatIndexBuilder {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
distance_type: DistanceType::L2,
|
||||
num_partitions: None,
|
||||
sample_rate: 256,
|
||||
max_iterations: 50,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IvfFlatIndexBuilder {
|
||||
impl_distance_type_setter!();
|
||||
impl_ivf_params_setter!();
|
||||
}
|
||||
|
||||
/// Builder for an IVF PQ index.
|
||||
///
|
||||
/// This index stores a compressed (quantized) copy of every vector. These vectors
|
||||
|
||||
@@ -214,6 +214,7 @@ mod polars_arrow_convertors;
|
||||
pub mod query;
|
||||
#[cfg(feature = "remote")]
|
||||
pub mod remote;
|
||||
pub mod rerankers;
|
||||
pub mod table;
|
||||
pub mod utils;
|
||||
|
||||
|
||||
@@ -15,19 +15,31 @@
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow::compute::concat_batches;
|
||||
use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array};
|
||||
use arrow_schema::DataType;
|
||||
use datafusion_physical_plan::ExecutionPlan;
|
||||
use futures::{stream, try_join, FutureExt, TryStreamExt};
|
||||
use half::f16;
|
||||
use lance::dataset::scanner::DatasetRecordBatchStream;
|
||||
use lance::{
|
||||
arrow::RecordBatchExt,
|
||||
dataset::{scanner::DatasetRecordBatchStream, ROW_ID},
|
||||
};
|
||||
use lance_datafusion::exec::execute_plan;
|
||||
use lance_index::scalar::inverted::SCORE_COL;
|
||||
use lance_index::scalar::FullTextSearchQuery;
|
||||
use lance_index::vector::DIST_COL;
|
||||
use lance_io::stream::RecordBatchStreamAdapter;
|
||||
|
||||
use crate::arrow::SendableRecordBatchStream;
|
||||
use crate::error::{Error, Result};
|
||||
use crate::rerankers::rrf::RRFReranker;
|
||||
use crate::rerankers::{check_reranker_result, NormalizeMethod, Reranker};
|
||||
use crate::table::TableInternal;
|
||||
use crate::DistanceType;
|
||||
|
||||
mod hybrid;
|
||||
|
||||
pub(crate) const DEFAULT_TOP_K: usize = 10;
|
||||
|
||||
/// Which columns should be retrieved from the database
|
||||
@@ -339,7 +351,7 @@ pub trait QueryBase {
|
||||
fn limit(self, limit: usize) -> Self;
|
||||
|
||||
/// Set the offset of the query.
|
||||
|
||||
///
|
||||
/// By default, it fetches starting with the first row.
|
||||
/// This method can be used to skip the first `offset` rows.
|
||||
fn offset(self, offset: usize) -> Self;
|
||||
@@ -435,6 +447,16 @@ pub trait QueryBase {
|
||||
|
||||
/// Return the `_rowid` meta column from the Table.
|
||||
fn with_row_id(self) -> Self;
|
||||
|
||||
/// Rerank the results using the specified reranker.
|
||||
///
|
||||
/// This is currently only supported for Hybrid Search.
|
||||
fn rerank(self, reranker: Arc<dyn Reranker>) -> Self;
|
||||
|
||||
/// The method to normalize the scores. Can be "rank" or "Score". If "Rank",
|
||||
/// the scores are converted to ranks and then normalized. If "Score", the
|
||||
/// scores are normalized directly.
|
||||
fn norm(self, norm: NormalizeMethod) -> Self;
|
||||
}
|
||||
|
||||
pub trait HasQuery {
|
||||
@@ -481,6 +503,16 @@ impl<T: HasQuery> QueryBase for T {
|
||||
self.mut_query().with_row_id = true;
|
||||
self
|
||||
}
|
||||
|
||||
fn rerank(mut self, reranker: Arc<dyn Reranker>) -> Self {
|
||||
self.mut_query().reranker = Some(reranker);
|
||||
self
|
||||
}
|
||||
|
||||
fn norm(mut self, norm: NormalizeMethod) -> Self {
|
||||
self.mut_query().norm = Some(norm);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Options for controlling the execution of a query
|
||||
@@ -600,6 +632,13 @@ pub struct Query {
|
||||
|
||||
/// If set to false, the filter will be applied after the vector search.
|
||||
pub(crate) prefilter: bool,
|
||||
|
||||
/// Implementation of reranker that can be used to reorder or combine query
|
||||
/// results, especially if using hybrid search
|
||||
pub(crate) reranker: Option<Arc<dyn Reranker>>,
|
||||
|
||||
/// Configure how query results are normalized when doing hybrid search
|
||||
pub(crate) norm: Option<NormalizeMethod>,
|
||||
}
|
||||
|
||||
impl Query {
|
||||
@@ -614,6 +653,8 @@ impl Query {
|
||||
fast_search: false,
|
||||
with_row_id: false,
|
||||
prefilter: true,
|
||||
reranker: None,
|
||||
norm: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -862,6 +903,65 @@ impl VectorQuery {
|
||||
self.use_index = false;
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn execute_hybrid(&self) -> Result<SendableRecordBatchStream> {
|
||||
// clone query and specify we want to include row IDs, which can be needed for reranking
|
||||
let fts_query = self.base.clone().with_row_id();
|
||||
let mut vector_query = self.clone().with_row_id();
|
||||
|
||||
vector_query.base.full_text_search = None;
|
||||
let (fts_results, vec_results) = try_join!(fts_query.execute(), vector_query.execute())?;
|
||||
|
||||
let (fts_results, vec_results) = try_join!(
|
||||
fts_results.try_collect::<Vec<_>>(),
|
||||
vec_results.try_collect::<Vec<_>>()
|
||||
)?;
|
||||
|
||||
// try to get the schema to use when combining batches.
|
||||
// if either
|
||||
let (fts_schema, vec_schema) = hybrid::query_schemas(&fts_results, &vec_results);
|
||||
|
||||
// concatenate all the batches together
|
||||
let mut fts_results = concat_batches(&fts_schema, fts_results.iter())?;
|
||||
let mut vec_results = concat_batches(&vec_schema, vec_results.iter())?;
|
||||
|
||||
if matches!(self.base.norm, Some(NormalizeMethod::Rank)) {
|
||||
vec_results = hybrid::rank(vec_results, DIST_COL, None)?;
|
||||
fts_results = hybrid::rank(fts_results, SCORE_COL, None)?;
|
||||
}
|
||||
|
||||
vec_results = hybrid::normalize_scores(vec_results, DIST_COL, None)?;
|
||||
fts_results = hybrid::normalize_scores(fts_results, SCORE_COL, None)?;
|
||||
|
||||
let reranker = self
|
||||
.base
|
||||
.reranker
|
||||
.clone()
|
||||
.unwrap_or(Arc::new(RRFReranker::default()));
|
||||
|
||||
let fts_query = self.base.full_text_search.as_ref().ok_or(Error::Runtime {
|
||||
message: "there should be an FTS search".to_string(),
|
||||
})?;
|
||||
|
||||
let mut results = reranker
|
||||
.rerank_hybrid(&fts_query.query, vec_results, fts_results)
|
||||
.await?;
|
||||
|
||||
check_reranker_result(&results)?;
|
||||
|
||||
let limit = self.base.limit.unwrap_or(DEFAULT_TOP_K);
|
||||
if results.num_rows() > limit {
|
||||
results = results.slice(0, limit);
|
||||
}
|
||||
|
||||
if !self.base.with_row_id {
|
||||
results = results.drop_column(ROW_ID)?;
|
||||
}
|
||||
|
||||
Ok(SendableRecordBatchStream::from(
|
||||
RecordBatchStreamAdapter::new(results.schema(), stream::iter([Ok(results)])),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutableQuery for VectorQuery {
|
||||
@@ -873,6 +973,11 @@ impl ExecutableQuery for VectorQuery {
|
||||
&self,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
if self.base.full_text_search.is_some() {
|
||||
let hybrid_result = async move { self.execute_hybrid().await }.boxed().await?;
|
||||
return Ok(hybrid_result);
|
||||
}
|
||||
|
||||
Ok(SendableRecordBatchStream::from(
|
||||
DatasetRecordBatchStream::new(execute_plan(
|
||||
self.create_plan(options).await?,
|
||||
@@ -894,20 +999,20 @@ impl HasQuery for VectorQuery {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
use std::{collections::HashSet, sync::Arc};
|
||||
|
||||
use super::*;
|
||||
use arrow::{compute::concat_batches, datatypes::Int32Type};
|
||||
use arrow::{array::downcast_array, compute::concat_batches, datatypes::Int32Type};
|
||||
use arrow_array::{
|
||||
cast::AsArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator,
|
||||
RecordBatchReader,
|
||||
cast::AsArray, types::Float32Type, FixedSizeListArray, Float32Array, Int32Array,
|
||||
RecordBatch, RecordBatchIterator, RecordBatchReader, StringArray,
|
||||
};
|
||||
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
|
||||
use futures::{StreamExt, TryStreamExt};
|
||||
use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector};
|
||||
use tempfile::tempdir;
|
||||
|
||||
use crate::{connect, Table};
|
||||
use crate::{connect, connection::CreateTableMode, Table};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_setters_getters() {
|
||||
@@ -1274,4 +1379,156 @@ mod tests {
|
||||
assert!(query_index.values().contains(&0));
|
||||
assert!(query_index.values().contains(&1));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hybrid_search() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let dataset_path = tmp_dir.path();
|
||||
let conn = connect(dataset_path.to_str().unwrap())
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let dims = 2;
|
||||
let schema = Arc::new(ArrowSchema::new(vec![
|
||||
ArrowField::new("text", DataType::Utf8, false),
|
||||
ArrowField::new(
|
||||
"vector",
|
||||
DataType::FixedSizeList(
|
||||
Arc::new(ArrowField::new("item", DataType::Float32, true)),
|
||||
dims,
|
||||
),
|
||||
false,
|
||||
),
|
||||
]));
|
||||
|
||||
let text = StringArray::from(vec!["dog", "cat", "a", "b"]);
|
||||
let vectors = vec![
|
||||
Some(vec![Some(0.0), Some(0.0)]),
|
||||
Some(vec![Some(-2.0), Some(-2.0)]),
|
||||
Some(vec![Some(50.0), Some(50.0)]),
|
||||
Some(vec![Some(-30.0), Some(-30.0)]),
|
||||
];
|
||||
let vector = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(vectors, dims);
|
||||
|
||||
let record_batch =
|
||||
RecordBatch::try_new(schema.clone(), vec![Arc::new(text), Arc::new(vector)]).unwrap();
|
||||
let record_batch_iter =
|
||||
RecordBatchIterator::new(vec![record_batch].into_iter().map(Ok), schema.clone());
|
||||
let table = conn
|
||||
.create_table("my_table", record_batch_iter)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
table
|
||||
.create_index(&["text"], crate::index::Index::FTS(Default::default()))
|
||||
.replace(true)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let fts_query = FullTextSearchQuery::new("b".to_string());
|
||||
let results = table
|
||||
.query()
|
||||
.full_text_search(fts_query)
|
||||
.limit(2)
|
||||
.nearest_to(&[-10.0, -10.0])
|
||||
.unwrap()
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let batch = &results[0];
|
||||
|
||||
let texts: StringArray = downcast_array(batch.column_by_name("text").unwrap());
|
||||
let texts = texts.iter().map(|e| e.unwrap()).collect::<HashSet<_>>();
|
||||
assert!(texts.contains("cat")); // should be close by vector search
|
||||
assert!(texts.contains("b")); // should be close by fts search
|
||||
|
||||
// ensure that this works correctly if there are no matching FTS results
|
||||
let fts_query = FullTextSearchQuery::new("z".to_string());
|
||||
table
|
||||
.query()
|
||||
.full_text_search(fts_query)
|
||||
.limit(2)
|
||||
.nearest_to(&[-10.0, -10.0])
|
||||
.unwrap()
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hybrid_search_empty_table() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let dataset_path = tmp_dir.path();
|
||||
let conn = connect(dataset_path.to_str().unwrap())
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let dims = 2;
|
||||
|
||||
let schema = Arc::new(ArrowSchema::new(vec![
|
||||
ArrowField::new("text", DataType::Utf8, false),
|
||||
ArrowField::new(
|
||||
"vector",
|
||||
DataType::FixedSizeList(
|
||||
Arc::new(ArrowField::new("item", DataType::Float32, true)),
|
||||
dims,
|
||||
),
|
||||
false,
|
||||
),
|
||||
]));
|
||||
|
||||
// ensure hybrid search is also supported on a fully empty table
|
||||
let vectors: Vec<Option<Vec<Option<f32>>>> = Vec::new();
|
||||
let record_batch = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(StringArray::from(Vec::<&str>::new())),
|
||||
Arc::new(
|
||||
FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(vectors, dims),
|
||||
),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let record_batch_iter =
|
||||
RecordBatchIterator::new(vec![record_batch].into_iter().map(Ok), schema.clone());
|
||||
let table = conn
|
||||
.create_table("my_table", record_batch_iter)
|
||||
.mode(CreateTableMode::Overwrite)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
table
|
||||
.create_index(&["text"], crate::index::Index::FTS(Default::default()))
|
||||
.replace(true)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
let fts_query = FullTextSearchQuery::new("b".to_string());
|
||||
let results = table
|
||||
.query()
|
||||
.full_text_search(fts_query)
|
||||
.limit(2)
|
||||
.nearest_to(&[-10.0, -10.0])
|
||||
.unwrap()
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
let batch = &results[0];
|
||||
assert_eq!(0, batch.num_rows());
|
||||
assert_eq!(2, batch.num_columns());
|
||||
}
|
||||
}
|
||||
|
||||
346
rust/lancedb/src/query/hybrid.rs
Normal file
346
rust/lancedb/src/query/hybrid.rs
Normal file
@@ -0,0 +1,346 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use arrow::compute::{
|
||||
kernels::numeric::{div, sub},
|
||||
max, min,
|
||||
};
|
||||
use arrow_array::{cast::downcast_array, Float32Array, RecordBatch};
|
||||
use arrow_schema::{DataType, Field, Schema, SortOptions};
|
||||
use lance::dataset::ROW_ID;
|
||||
use lance_index::{scalar::inverted::SCORE_COL, vector::DIST_COL};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
|
||||
/// Converts results's score column to a rank.
|
||||
///
|
||||
/// Expects the `column` argument to be type Float32 and will panic if it's not
|
||||
pub fn rank(results: RecordBatch, column: &str, ascending: Option<bool>) -> Result<RecordBatch> {
|
||||
let scores = results.column_by_name(column).ok_or(Error::InvalidInput {
|
||||
message: format!(
|
||||
"expected column {} not found in rank. found columns {:?}",
|
||||
column,
|
||||
results
|
||||
.schema()
|
||||
.fields()
|
||||
.iter()
|
||||
.map(|f| f.name())
|
||||
.collect::<Vec<_>>(),
|
||||
),
|
||||
})?;
|
||||
|
||||
if results.num_rows() == 0 {
|
||||
return Ok(results);
|
||||
}
|
||||
|
||||
let scores: Float32Array = downcast_array(scores);
|
||||
let ranks = Float32Array::from_iter_values(
|
||||
arrow::compute::kernels::rank::rank(
|
||||
&scores,
|
||||
Some(SortOptions {
|
||||
descending: !ascending.unwrap_or(true),
|
||||
..Default::default()
|
||||
}),
|
||||
)?
|
||||
.iter()
|
||||
.map(|i| *i as f32),
|
||||
);
|
||||
|
||||
let schema = results.schema();
|
||||
let (column_idx, _) = schema.column_with_name(column).unwrap();
|
||||
let mut columns = results.columns().to_vec();
|
||||
columns[column_idx] = Arc::new(ranks);
|
||||
|
||||
let results = RecordBatch::try_new(results.schema(), columns)?;
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Get the query schemas needed when combining the search results.
|
||||
///
|
||||
/// If either of the record batches are empty, then we create a schema from the
|
||||
/// other record batch, and replace the score/distance column. If both record
|
||||
/// batches are empty, create empty schemas.
|
||||
pub fn query_schemas(
|
||||
fts_results: &[RecordBatch],
|
||||
vec_results: &[RecordBatch],
|
||||
) -> (Arc<Schema>, Arc<Schema>) {
|
||||
let (fts_schema, vec_schema) = match (
|
||||
fts_results.first().map(|r| r.schema()),
|
||||
vec_results.first().map(|r| r.schema()),
|
||||
) {
|
||||
(Some(fts_schema), Some(vec_schema)) => (fts_schema, vec_schema),
|
||||
(None, Some(vec_schema)) => {
|
||||
let fts_schema = with_field_name_replaced(&vec_schema, DIST_COL, SCORE_COL);
|
||||
(Arc::new(fts_schema), vec_schema)
|
||||
}
|
||||
(Some(fts_schema), None) => {
|
||||
let vec_schema = with_field_name_replaced(&fts_schema, DIST_COL, SCORE_COL);
|
||||
(fts_schema, Arc::new(vec_schema))
|
||||
}
|
||||
(None, None) => (Arc::new(empty_fts_schema()), Arc::new(empty_vec_schema())),
|
||||
};
|
||||
|
||||
(fts_schema, vec_schema)
|
||||
}
|
||||
|
||||
pub fn empty_fts_schema() -> Schema {
|
||||
Schema::new(vec![
|
||||
Arc::new(Field::new(SCORE_COL, DataType::Float32, false)),
|
||||
Arc::new(Field::new(ROW_ID, DataType::UInt64, false)),
|
||||
])
|
||||
}
|
||||
|
||||
pub fn empty_vec_schema() -> Schema {
|
||||
Schema::new(vec![
|
||||
Arc::new(Field::new(DIST_COL, DataType::Float32, false)),
|
||||
Arc::new(Field::new(ROW_ID, DataType::UInt64, false)),
|
||||
])
|
||||
}
|
||||
|
||||
pub fn with_field_name_replaced(schema: &Schema, target: &str, replacement: &str) -> Schema {
|
||||
let field_idx = schema.fields().iter().enumerate().find_map(|(i, field)| {
|
||||
if field.name() == target {
|
||||
Some(i)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
|
||||
let mut fields = schema.fields().to_vec();
|
||||
if let Some(idx) = field_idx {
|
||||
let new_field = (*fields[idx]).clone().with_name(replacement);
|
||||
fields[idx] = Arc::new(new_field);
|
||||
}
|
||||
|
||||
Schema::new(fields)
|
||||
}
|
||||
|
||||
/// Normalize the scores column to have values between 0 and 1.
|
||||
///
|
||||
/// Expects the `column` argument to be type Float32 and will panic if it's not
|
||||
pub fn normalize_scores(
|
||||
results: RecordBatch,
|
||||
column: &str,
|
||||
invert: Option<bool>,
|
||||
) -> Result<RecordBatch> {
|
||||
let scores = results.column_by_name(column).ok_or(Error::InvalidInput {
|
||||
message: format!(
|
||||
"expected column {} not found in rank. found columns {:?}",
|
||||
column,
|
||||
results
|
||||
.schema()
|
||||
.fields()
|
||||
.iter()
|
||||
.map(|f| f.name())
|
||||
.collect::<Vec<_>>(),
|
||||
),
|
||||
})?;
|
||||
|
||||
if results.num_rows() == 0 {
|
||||
return Ok(results);
|
||||
}
|
||||
let mut scores: Float32Array = downcast_array(scores);
|
||||
|
||||
let max = max(&scores).unwrap_or(0.0);
|
||||
let min = min(&scores).unwrap_or(0.0);
|
||||
|
||||
// this is equivalent to np.isclose which is used in python
|
||||
let rng = if max - min < 10e-5 { max } else { max - min };
|
||||
|
||||
// if rng is 0, then min and max are both 0 so we just leave the scores as is
|
||||
if rng != 0.0 {
|
||||
let tmp = div(
|
||||
&sub(&scores, &Float32Array::new_scalar(min))?,
|
||||
&Float32Array::new_scalar(rng),
|
||||
)?;
|
||||
scores = downcast_array(&tmp);
|
||||
}
|
||||
|
||||
if invert.unwrap_or(false) {
|
||||
let tmp = sub(&Float32Array::new_scalar(1.0), &scores)?;
|
||||
scores = downcast_array(&tmp);
|
||||
}
|
||||
|
||||
let schema = results.schema();
|
||||
let (column_idx, _) = schema.column_with_name(column).unwrap();
|
||||
let mut columns = results.columns().to_vec();
|
||||
columns[column_idx] = Arc::new(scores);
|
||||
|
||||
let results = RecordBatch::try_new(results.schema(), columns).unwrap();
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use arrow_array::StringArray;
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
|
||||
#[test]
|
||||
fn test_rank() {
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Arc::new(Field::new("name", DataType::Utf8, false)),
|
||||
Arc::new(Field::new("score", DataType::Float32, false)),
|
||||
]));
|
||||
|
||||
let names = StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"]);
|
||||
let scores = Float32Array::from(vec![0.2, 0.4, 0.1, 0.6, 0.45]);
|
||||
|
||||
let batch =
|
||||
RecordBatch::try_new(schema.clone(), vec![Arc::new(names), Arc::new(scores)]).unwrap();
|
||||
|
||||
let result = rank(batch.clone(), "score", Some(false)).unwrap();
|
||||
assert_eq!(2, result.schema().fields().len());
|
||||
assert_eq!("name", result.schema().field(0).name());
|
||||
assert_eq!("score", result.schema().field(1).name());
|
||||
|
||||
let names: StringArray = downcast_array(result.column(0));
|
||||
assert_eq!(
|
||||
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec!["foo", "bar", "baz", "bean", "dog"]
|
||||
);
|
||||
let scores: Float32Array = downcast_array(result.column(1));
|
||||
assert_eq!(
|
||||
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec![4.0, 3.0, 5.0, 1.0, 2.0]
|
||||
);
|
||||
|
||||
// check sort ascending
|
||||
let result = rank(batch.clone(), "score", Some(true)).unwrap();
|
||||
let names: StringArray = downcast_array(result.column(0));
|
||||
assert_eq!(
|
||||
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec!["foo", "bar", "baz", "bean", "dog"]
|
||||
);
|
||||
let scores: Float32Array = downcast_array(result.column(1));
|
||||
assert_eq!(
|
||||
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec![2.0, 3.0, 1.0, 5.0, 4.0]
|
||||
);
|
||||
|
||||
// ensure default sort is ascending
|
||||
let result = rank(batch.clone(), "score", None).unwrap();
|
||||
let names: StringArray = downcast_array(result.column(0));
|
||||
assert_eq!(
|
||||
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec!["foo", "bar", "baz", "bean", "dog"]
|
||||
);
|
||||
let scores: Float32Array = downcast_array(result.column(1));
|
||||
assert_eq!(
|
||||
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec![2.0, 3.0, 1.0, 5.0, 4.0]
|
||||
);
|
||||
|
||||
// check it can handle an empty batch
|
||||
let batch = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(StringArray::from(Vec::<&str>::new())),
|
||||
Arc::new(Float32Array::from(Vec::<f32>::new())),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let result = rank(batch.clone(), "score", None).unwrap();
|
||||
assert_eq!(0, result.num_rows());
|
||||
assert_eq!(2, result.schema().fields().len());
|
||||
assert_eq!("name", result.schema().field(0).name());
|
||||
assert_eq!("score", result.schema().field(1).name());
|
||||
|
||||
// check it returns the expected error when there's no column
|
||||
let result = rank(batch.clone(), "bad_col", None);
|
||||
match result {
|
||||
Err(Error::InvalidInput { message }) => {
|
||||
assert_eq!("expected column bad_col not found in rank. found columns [\"name\", \"score\"]", message);
|
||||
}
|
||||
_ => {
|
||||
panic!("expected invalid input error, received {:?}", result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_scores() {
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Arc::new(Field::new("name", DataType::Utf8, false)),
|
||||
Arc::new(Field::new("score", DataType::Float32, false)),
|
||||
]));
|
||||
|
||||
let names = Arc::new(StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"]));
|
||||
let scores = Arc::new(Float32Array::from(vec![-4.0, 2.0, 0.0, 3.0, 6.0]));
|
||||
|
||||
let batch =
|
||||
RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap();
|
||||
|
||||
let result = normalize_scores(batch.clone(), "score", Some(false)).unwrap();
|
||||
let names: StringArray = downcast_array(result.column(0));
|
||||
assert_eq!(
|
||||
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec!["foo", "bar", "baz", "bean", "dog"]
|
||||
);
|
||||
let scores: Float32Array = downcast_array(result.column(1));
|
||||
assert_eq!(
|
||||
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec![0.0, 0.6, 0.4, 0.7, 1.0]
|
||||
);
|
||||
|
||||
// check it can invert the normalization
|
||||
let result = normalize_scores(batch.clone(), "score", Some(true)).unwrap();
|
||||
let scores: Float32Array = downcast_array(result.column(1));
|
||||
assert_eq!(
|
||||
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec![1.0, 1.0 - 0.6, 0.6, 0.3, 0.0]
|
||||
);
|
||||
|
||||
// check that the default is not inverted
|
||||
let result = normalize_scores(batch.clone(), "score", None).unwrap();
|
||||
let scores: Float32Array = downcast_array(result.column(1));
|
||||
assert_eq!(
|
||||
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec![0.0, 0.6, 0.4, 0.7, 1.0]
|
||||
);
|
||||
|
||||
// check that it will function correctly if all the values are the same
|
||||
let names = Arc::new(StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"]));
|
||||
let scores = Arc::new(Float32Array::from(vec![2.1, 2.1, 2.1, 2.1, 2.1]));
|
||||
let batch =
|
||||
RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap();
|
||||
let result = normalize_scores(batch.clone(), "score", None).unwrap();
|
||||
let scores: Float32Array = downcast_array(result.column(1));
|
||||
assert_eq!(
|
||||
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec![0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
);
|
||||
|
||||
// check it keeps floating point rounding errors for same score normalized the same
|
||||
// e.g., the behaviour is consistent with python
|
||||
let scores = Arc::new(Float32Array::from(vec![1.0, 1.0, 1.0, 1.0, 0.9999999]));
|
||||
let batch =
|
||||
RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap();
|
||||
let result = normalize_scores(batch.clone(), "score", None).unwrap();
|
||||
let scores: Float32Array = downcast_array(result.column(1));
|
||||
assert_eq!(
|
||||
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec![
|
||||
1.0 - 0.9999999,
|
||||
1.0 - 0.9999999,
|
||||
1.0 - 0.9999999,
|
||||
1.0 - 0.9999999,
|
||||
0.0
|
||||
]
|
||||
);
|
||||
|
||||
// check that it can handle if all the scores are 0
|
||||
let scores = Arc::new(Float32Array::from(vec![0.0, 0.0, 0.0, 0.0, 0.0]));
|
||||
let batch =
|
||||
RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap();
|
||||
let result = normalize_scores(batch.clone(), "score", None).unwrap();
|
||||
let scores: Float32Array = downcast_array(result.column(1));
|
||||
assert_eq!(
|
||||
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec![0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -563,6 +563,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
let (index_type, distance_type) = match index.index {
|
||||
// TODO: Should we pass the actual index parameters? SaaS does not
|
||||
// yet support them.
|
||||
Index::IvfFlat(index) => ("IVF_FLAT", Some(index.distance_type)),
|
||||
Index::IvfPq(index) => ("IVF_PQ", Some(index.distance_type)),
|
||||
Index::IvfHnswSq(index) => ("IVF_HNSW_SQ", Some(index.distance_type)),
|
||||
Index::BTree(_) => ("BTREE", None),
|
||||
@@ -873,6 +874,7 @@ mod tests {
|
||||
use lance_index::scalar::FullTextSearchQuery;
|
||||
use reqwest::Body;
|
||||
|
||||
use crate::index::vector::IvfFlatIndexBuilder;
|
||||
use crate::{
|
||||
index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType},
|
||||
query::{ExecutableQuery, QueryBase},
|
||||
@@ -1489,6 +1491,11 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_create_index() {
|
||||
let cases = [
|
||||
(
|
||||
"IVF_FLAT",
|
||||
Some("hamming"),
|
||||
Index::IvfFlat(IvfFlatIndexBuilder::default().distance_type(DistanceType::Hamming)),
|
||||
),
|
||||
("IVF_PQ", Some("l2"), Index::IvfPq(Default::default())),
|
||||
(
|
||||
"IVF_PQ",
|
||||
|
||||
87
rust/lancedb/src/rerankers.rs
Normal file
87
rust/lancedb/src/rerankers.rs
Normal file
@@ -0,0 +1,87 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::collections::BTreeSet;
|
||||
|
||||
use arrow::{
|
||||
array::downcast_array,
|
||||
compute::{concat_batches, filter_record_batch},
|
||||
};
|
||||
use arrow_array::{BooleanArray, RecordBatch, UInt64Array};
|
||||
use async_trait::async_trait;
|
||||
use lance::dataset::ROW_ID;
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
|
||||
pub mod rrf;
|
||||
|
||||
/// column name for reranker relevance score
|
||||
const RELEVANCE_SCORE: &str = "_relevance_score";
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum NormalizeMethod {
|
||||
Score,
|
||||
Rank,
|
||||
}
|
||||
|
||||
/// Interface for a reranker. A reranker is used to rerank the results from a
|
||||
/// vector and FTS search. This is useful for combining the results from both
|
||||
/// search methods.
|
||||
#[async_trait]
|
||||
pub trait Reranker: std::fmt::Debug + Sync + Send {
|
||||
// TODO support vector reranking and FTS reranking. Currently only hybrid reranking is supported.
|
||||
|
||||
/// Rerank function receives the individual results from the vector and FTS search
|
||||
/// results. You can choose to use any of the results to generate the final results,
|
||||
/// allowing maximum flexibility.
|
||||
async fn rerank_hybrid(
|
||||
&self,
|
||||
query: &str,
|
||||
vector_results: RecordBatch,
|
||||
fts_results: RecordBatch,
|
||||
) -> Result<RecordBatch>;
|
||||
|
||||
fn merge_results(
|
||||
&self,
|
||||
vector_results: RecordBatch,
|
||||
fts_results: RecordBatch,
|
||||
) -> Result<RecordBatch> {
|
||||
let combined = concat_batches(&fts_results.schema(), [vector_results, fts_results].iter())?;
|
||||
|
||||
let mut mask = BooleanArray::builder(combined.num_rows());
|
||||
let mut unique_ids = BTreeSet::new();
|
||||
let row_ids = combined.column_by_name(ROW_ID).ok_or(Error::InvalidInput {
|
||||
message: format!(
|
||||
"could not find expected column {} while merging results. found columns {:?}",
|
||||
ROW_ID,
|
||||
combined
|
||||
.schema()
|
||||
.fields()
|
||||
.iter()
|
||||
.map(|f| f.name())
|
||||
.collect::<Vec<_>>()
|
||||
),
|
||||
})?;
|
||||
let row_ids: UInt64Array = downcast_array(row_ids);
|
||||
row_ids.values().iter().for_each(|id| {
|
||||
mask.append_value(unique_ids.insert(id));
|
||||
});
|
||||
|
||||
let combined = filter_record_batch(&combined, &mask.finish())?;
|
||||
|
||||
Ok(combined)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_reranker_result(result: &RecordBatch) -> Result<()> {
|
||||
if result.schema().column_with_name(RELEVANCE_SCORE).is_none() {
|
||||
return Err(Error::Schema {
|
||||
message: format!(
|
||||
"rerank_hybrid must return a RecordBatch with a column named {}",
|
||||
RELEVANCE_SCORE
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
223
rust/lancedb/src/rerankers/rrf.rs
Normal file
223
rust/lancedb/src/rerankers/rrf.rs
Normal file
@@ -0,0 +1,223 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow::{
|
||||
array::downcast_array,
|
||||
compute::{sort_to_indices, take},
|
||||
};
|
||||
use arrow_array::{Float32Array, RecordBatch, UInt64Array};
|
||||
use arrow_schema::{DataType, Field, Schema, SortOptions};
|
||||
use async_trait::async_trait;
|
||||
use lance::dataset::ROW_ID;
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::rerankers::{Reranker, RELEVANCE_SCORE};
|
||||
|
||||
/// Reranks the results using Reciprocal Rank Fusion(RRF) algorithm based
|
||||
/// on the scores of vector and FTS search.
|
||||
///
|
||||
#[derive(Debug)]
|
||||
pub struct RRFReranker {
|
||||
k: f32,
|
||||
}
|
||||
|
||||
impl RRFReranker {
|
||||
/// Create a new RRFReranker
|
||||
///
|
||||
/// The parameter k is a constant used in the RRF formula (default is 60).
|
||||
/// Experiments indicate that k = 60 was near-optimal, but that the choice
|
||||
/// is not critical. See paper:
|
||||
/// https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf
|
||||
pub fn new(k: f32) -> Self {
|
||||
Self { k }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RRFReranker {
|
||||
fn default() -> Self {
|
||||
Self { k: 60.0 }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Reranker for RRFReranker {
|
||||
async fn rerank_hybrid(
|
||||
&self,
|
||||
_query: &str,
|
||||
vector_results: RecordBatch,
|
||||
fts_results: RecordBatch,
|
||||
) -> Result<RecordBatch> {
|
||||
let vector_ids = vector_results
|
||||
.column_by_name(ROW_ID)
|
||||
.ok_or(Error::InvalidInput {
|
||||
message: format!(
|
||||
"expected column {} not found in vector_results. found columns {:?}",
|
||||
ROW_ID,
|
||||
vector_results
|
||||
.schema()
|
||||
.fields()
|
||||
.iter()
|
||||
.map(|f| f.name())
|
||||
.collect::<Vec<_>>()
|
||||
),
|
||||
})?;
|
||||
let fts_ids = fts_results
|
||||
.column_by_name(ROW_ID)
|
||||
.ok_or(Error::InvalidInput {
|
||||
message: format!(
|
||||
"expected column {} not found in fts_results. found columns {:?}",
|
||||
ROW_ID,
|
||||
fts_results
|
||||
.schema()
|
||||
.fields()
|
||||
.iter()
|
||||
.map(|f| f.name())
|
||||
.collect::<Vec<_>>()
|
||||
),
|
||||
})?;
|
||||
|
||||
let vector_ids: UInt64Array = downcast_array(&vector_ids);
|
||||
let fts_ids: UInt64Array = downcast_array(&fts_ids);
|
||||
|
||||
let mut rrf_score_map = BTreeMap::new();
|
||||
let mut update_score_map = |(i, result_id)| {
|
||||
let score = 1.0 / (i as f32 + self.k);
|
||||
rrf_score_map
|
||||
.entry(result_id)
|
||||
.and_modify(|e| *e += score)
|
||||
.or_insert(score);
|
||||
};
|
||||
vector_ids
|
||||
.values()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.for_each(&mut update_score_map);
|
||||
fts_ids
|
||||
.values()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.for_each(&mut update_score_map);
|
||||
|
||||
let combined_results = self.merge_results(vector_results, fts_results)?;
|
||||
|
||||
let combined_row_ids: UInt64Array =
|
||||
downcast_array(combined_results.column_by_name(ROW_ID).unwrap());
|
||||
let relevance_scores = Float32Array::from_iter_values(
|
||||
combined_row_ids
|
||||
.values()
|
||||
.iter()
|
||||
.map(|row_id| rrf_score_map.get(row_id).unwrap())
|
||||
.copied(),
|
||||
);
|
||||
|
||||
// keep track of indices sorted by the relevance column
|
||||
let sort_indices = sort_to_indices(
|
||||
&relevance_scores,
|
||||
Some(SortOptions {
|
||||
descending: true,
|
||||
..Default::default()
|
||||
}),
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// add relevance scores to columns
|
||||
let mut columns = combined_results.columns().to_vec();
|
||||
columns.push(Arc::new(relevance_scores));
|
||||
|
||||
// sort by the relevance scores
|
||||
let columns = columns
|
||||
.iter()
|
||||
.map(|c| take(c, &sort_indices, None).unwrap())
|
||||
.collect();
|
||||
|
||||
// add relevance score to schema
|
||||
let mut fields = combined_results.schema().fields().to_vec();
|
||||
fields.push(Arc::new(Field::new(
|
||||
RELEVANCE_SCORE,
|
||||
DataType::Float32,
|
||||
false,
|
||||
)));
|
||||
let schema = Schema::new(fields);
|
||||
|
||||
let combined_results = RecordBatch::try_new(Arc::new(schema), columns)?;
|
||||
|
||||
Ok(combined_results)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test {
|
||||
use super::*;
|
||||
use arrow_array::StringArray;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rrf_reranker() {
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("name", DataType::Utf8, false),
|
||||
Field::new(ROW_ID, DataType::UInt64, false),
|
||||
]));
|
||||
|
||||
let vec_results = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"])),
|
||||
Arc::new(UInt64Array::from(vec![1, 4, 2, 5, 3])),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let fts_results = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(StringArray::from(vec!["bar", "bean", "dog"])),
|
||||
Arc::new(UInt64Array::from(vec![4, 5, 3])),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// scores should be calculated as:
|
||||
// - foo = 1/1 = 1.0
|
||||
// - bar = 1/2 + 1/1 = 1.5
|
||||
// - baz = 1/3 = 0.333
|
||||
// - bean = 1/4 + 1/2 = 0.75
|
||||
// - dog = 1/5 + 1/3 = 0.533
|
||||
// then we should get the result ranked in descending order
|
||||
|
||||
let reranker = RRFReranker::new(1.0);
|
||||
|
||||
let result = reranker
|
||||
.rerank_hybrid("", vec_results, fts_results)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(3, result.schema().fields().len());
|
||||
assert_eq!("name", result.schema().fields().first().unwrap().name());
|
||||
assert_eq!(ROW_ID, result.schema().fields().get(1).unwrap().name());
|
||||
assert_eq!(
|
||||
RELEVANCE_SCORE,
|
||||
result.schema().fields().get(2).unwrap().name()
|
||||
);
|
||||
|
||||
let names: StringArray = downcast_array(result.column(0));
|
||||
assert_eq!(
|
||||
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec!["bar", "foo", "bean", "dog", "baz"]
|
||||
);
|
||||
|
||||
let ids: UInt64Array = downcast_array(result.column(1));
|
||||
assert_eq!(
|
||||
ids.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec![4, 1, 5, 3, 2]
|
||||
);
|
||||
|
||||
let scores: Float32Array = downcast_array(result.column(2));
|
||||
assert_eq!(
|
||||
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||
vec![1.5, 1.0, 0.75, 1.0 / 5.0 + 1.0 / 3.0, 1.0 / 3.0]
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -18,9 +18,9 @@ use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow::array::AsArray;
|
||||
use arrow::datatypes::Float32Type;
|
||||
use arrow::datatypes::{Float32Type, UInt8Type};
|
||||
use arrow_array::{RecordBatchIterator, RecordBatchReader};
|
||||
use arrow_schema::{Field, Schema, SchemaRef};
|
||||
use arrow_schema::{DataType, Field, Schema, SchemaRef};
|
||||
use async_trait::async_trait;
|
||||
use datafusion_physical_plan::display::DisplayableExecutionPlan;
|
||||
use datafusion_physical_plan::projection::ProjectionExec;
|
||||
@@ -58,8 +58,8 @@ use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, M
|
||||
use crate::error::{Error, Result};
|
||||
use crate::index::scalar::FtsIndexBuilder;
|
||||
use crate::index::vector::{
|
||||
suggested_num_partitions_for_hnsw, IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder,
|
||||
IvfPqIndexBuilder, VectorIndex,
|
||||
suggested_num_partitions_for_hnsw, IvfFlatIndexBuilder, IvfHnswPqIndexBuilder,
|
||||
IvfHnswSqIndexBuilder, IvfPqIndexBuilder, VectorIndex,
|
||||
};
|
||||
use crate::index::IndexStatistics;
|
||||
use crate::index::{
|
||||
@@ -1306,6 +1306,44 @@ impl NativeTable {
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn create_ivf_flat_index(
|
||||
&self,
|
||||
index: IvfFlatIndexBuilder,
|
||||
field: &Field,
|
||||
replace: bool,
|
||||
) -> Result<()> {
|
||||
if !supported_vector_data_type(field.data_type()) {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"An IVF Flat index cannot be created on the column `{}` which has data type {}",
|
||||
field.name(),
|
||||
field.data_type()
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
let num_partitions = if let Some(n) = index.num_partitions {
|
||||
n
|
||||
} else {
|
||||
suggested_num_partitions(self.count_rows(None).await?)
|
||||
};
|
||||
let mut dataset = self.dataset.get_mut().await?;
|
||||
let lance_idx_params = lance::index::vector::VectorIndexParams::ivf_flat(
|
||||
num_partitions as usize,
|
||||
index.distance_type.into(),
|
||||
);
|
||||
dataset
|
||||
.create_index(
|
||||
&[field.name()],
|
||||
IndexType::Vector,
|
||||
None,
|
||||
&lance_idx_params,
|
||||
replace,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_ivf_pq_index(
|
||||
&self,
|
||||
index: IvfPqIndexBuilder,
|
||||
@@ -1778,6 +1816,10 @@ impl TableInternal for NativeTable {
|
||||
Index::Bitmap(_) => self.create_bitmap_index(field, opts).await,
|
||||
Index::LabelList(_) => self.create_label_list_index(field, opts).await,
|
||||
Index::FTS(fts_opts) => self.create_fts_index(field, fts_opts, opts.replace).await,
|
||||
Index::IvfFlat(ivf_flat) => {
|
||||
self.create_ivf_flat_index(ivf_flat, field, opts.replace)
|
||||
.await
|
||||
}
|
||||
Index::IvfPq(ivf_pq) => self.create_ivf_pq_index(ivf_pq, field, opts.replace).await,
|
||||
Index::IvfHnswPq(ivf_hnsw_pq) => {
|
||||
self.create_ivf_hnsw_pq_index(ivf_hnsw_pq, field, opts.replace)
|
||||
@@ -1848,14 +1890,21 @@ impl TableInternal for NativeTable {
|
||||
message: format!("Column {} not found in dataset schema", column),
|
||||
})?;
|
||||
|
||||
if let arrow_schema::DataType::FixedSizeList(f, dim) = field.data_type() {
|
||||
if !f.data_type().is_floating() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"The data type of the vector column '{}' is not a floating point type",
|
||||
column
|
||||
),
|
||||
});
|
||||
let mut is_binary = false;
|
||||
if let arrow_schema::DataType::FixedSizeList(element, dim) = field.data_type() {
|
||||
match element.data_type() {
|
||||
e_type if e_type.is_floating() => {}
|
||||
e_type if *e_type == DataType::UInt8 => {
|
||||
is_binary = true;
|
||||
}
|
||||
_ => {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"The data type of the vector column '{}' is not a floating point type",
|
||||
column
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
if dim != query_vector.len() as i32 {
|
||||
return Err(Error::InvalidInput {
|
||||
@@ -1870,12 +1919,22 @@ impl TableInternal for NativeTable {
|
||||
}
|
||||
}
|
||||
|
||||
let query_vector = query_vector.as_primitive::<Float32Type>();
|
||||
scanner.nearest(
|
||||
&column,
|
||||
query_vector,
|
||||
query.base.limit.unwrap_or(DEFAULT_TOP_K),
|
||||
)?;
|
||||
if is_binary {
|
||||
let query_vector = arrow::compute::cast(&query_vector, &DataType::UInt8)?;
|
||||
let query_vector = query_vector.as_primitive::<UInt8Type>();
|
||||
scanner.nearest(
|
||||
&column,
|
||||
query_vector,
|
||||
query.base.limit.unwrap_or(DEFAULT_TOP_K),
|
||||
)?;
|
||||
} else {
|
||||
let query_vector = query_vector.as_primitive::<Float32Type>();
|
||||
scanner.nearest(
|
||||
&column,
|
||||
query_vector,
|
||||
query.base.limit.unwrap_or(DEFAULT_TOP_K),
|
||||
)?;
|
||||
}
|
||||
}
|
||||
scanner.limit(
|
||||
query.base.limit.map(|limit| limit as i64),
|
||||
|
||||
@@ -110,7 +110,7 @@ pub(crate) fn default_vector_column(schema: &Schema, dim: Option<i32>) -> Result
|
||||
.iter()
|
||||
.filter_map(|field| match field.data_type() {
|
||||
arrow_schema::DataType::FixedSizeList(f, d)
|
||||
if f.data_type().is_floating()
|
||||
if (f.data_type().is_floating() || f.data_type() == &DataType::UInt8)
|
||||
&& dim.map(|expect| *d == expect).unwrap_or(true) =>
|
||||
{
|
||||
Some(field.name())
|
||||
@@ -171,7 +171,9 @@ pub fn supported_fts_data_type(dtype: &DataType) -> bool {
|
||||
|
||||
pub fn supported_vector_data_type(dtype: &DataType) -> bool {
|
||||
match dtype {
|
||||
DataType::FixedSizeList(inner, _) => DataType::is_floating(inner.data_type()),
|
||||
DataType::FixedSizeList(inner, _) => {
|
||||
DataType::is_floating(inner.data_type()) || *inner.data_type() == DataType::UInt8
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user