mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 05:19:58 +00:00
Compare commits
21 Commits
python-v0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a27c5cf12b | ||
|
|
f4dea72cc5 | ||
|
|
f76c4a5ce1 | ||
|
|
164ce397c2 | ||
|
|
445a312667 | ||
|
|
92d845fa72 | ||
|
|
397813f6a4 | ||
|
|
50c30c5d34 | ||
|
|
c9f248b058 | ||
|
|
0cb6da6b7e | ||
|
|
aec8332eb5 | ||
|
|
46061070e6 | ||
|
|
dae8334d0b | ||
|
|
8c81968b59 | ||
|
|
16cf2990f3 | ||
|
|
0a0f667bbd | ||
|
|
03753fd84b | ||
|
|
55cceaa309 | ||
|
|
c3797eb834 | ||
|
|
c0d0f38494 | ||
|
|
6a8ab78d0a |
@@ -1,5 +1,5 @@
|
|||||||
[tool.bumpversion]
|
[tool.bumpversion]
|
||||||
current_version = "0.14.1-beta.6"
|
current_version = "0.14.1"
|
||||||
parse = """(?x)
|
parse = """(?x)
|
||||||
(?P<major>0|[1-9]\\d*)\\.
|
(?P<major>0|[1-9]\\d*)\\.
|
||||||
(?P<minor>0|[1-9]\\d*)\\.
|
(?P<minor>0|[1-9]\\d*)\\.
|
||||||
|
|||||||
196
.github/workflows/npm-publish.yml
vendored
196
.github/workflows/npm-publish.yml
vendored
@@ -159,7 +159,7 @@ jobs:
|
|||||||
- name: Install common dependencies
|
- name: Install common dependencies
|
||||||
run: |
|
run: |
|
||||||
apk add protobuf-dev curl clang mold grep npm bash
|
apk add protobuf-dev curl clang mold grep npm bash
|
||||||
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y --default-toolchain 1.80.0
|
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y
|
||||||
echo "source $HOME/.cargo/env" >> saved_env
|
echo "source $HOME/.cargo/env" >> saved_env
|
||||||
echo "export CC=clang" >> saved_env
|
echo "export CC=clang" >> saved_env
|
||||||
echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=-crt-static,+avx2,+fma,+f16c -Clinker=clang -Clink-arg=-fuse-ld=mold'" >> saved_env
|
echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=-crt-static,+avx2,+fma,+f16c -Clinker=clang -Clink-arg=-fuse-ld=mold'" >> saved_env
|
||||||
@@ -167,7 +167,7 @@ jobs:
|
|||||||
if: ${{ matrix.config.arch == 'aarch64' }}
|
if: ${{ matrix.config.arch == 'aarch64' }}
|
||||||
run: |
|
run: |
|
||||||
source "$HOME/.cargo/env"
|
source "$HOME/.cargo/env"
|
||||||
rustup target add aarch64-unknown-linux-musl --toolchain 1.80.0
|
rustup target add aarch64-unknown-linux-musl
|
||||||
crt=$(realpath $(dirname $(rustup which rustc))/../lib/rustlib/aarch64-unknown-linux-musl/lib/self-contained)
|
crt=$(realpath $(dirname $(rustup which rustc))/../lib/rustlib/aarch64-unknown-linux-musl/lib/self-contained)
|
||||||
sysroot_lib=/usr/aarch64-unknown-linux-musl/usr/lib
|
sysroot_lib=/usr/aarch64-unknown-linux-musl/usr/lib
|
||||||
apk_url=https://dl-cdn.alpinelinux.org/alpine/latest-stable/main/aarch64/
|
apk_url=https://dl-cdn.alpinelinux.org/alpine/latest-stable/main/aarch64/
|
||||||
@@ -262,7 +262,7 @@ jobs:
|
|||||||
- name: Install common dependencies
|
- name: Install common dependencies
|
||||||
run: |
|
run: |
|
||||||
apk add protobuf-dev curl clang mold grep npm bash openssl-dev openssl-libs-static
|
apk add protobuf-dev curl clang mold grep npm bash openssl-dev openssl-libs-static
|
||||||
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y --default-toolchain 1.80.0
|
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y
|
||||||
echo "source $HOME/.cargo/env" >> saved_env
|
echo "source $HOME/.cargo/env" >> saved_env
|
||||||
echo "export CC=clang" >> saved_env
|
echo "export CC=clang" >> saved_env
|
||||||
echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=-crt-static,+avx2,+fma,+f16c -Clinker=clang -Clink-arg=-fuse-ld=mold'" >> saved_env
|
echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=-crt-static,+avx2,+fma,+f16c -Clinker=clang -Clink-arg=-fuse-ld=mold'" >> saved_env
|
||||||
@@ -272,7 +272,7 @@ jobs:
|
|||||||
if: ${{ matrix.config.arch == 'aarch64' }}
|
if: ${{ matrix.config.arch == 'aarch64' }}
|
||||||
run: |
|
run: |
|
||||||
source "$HOME/.cargo/env"
|
source "$HOME/.cargo/env"
|
||||||
rustup target add aarch64-unknown-linux-musl --toolchain 1.80.0
|
rustup target add aarch64-unknown-linux-musl
|
||||||
crt=$(realpath $(dirname $(rustup which rustc))/../lib/rustlib/aarch64-unknown-linux-musl/lib/self-contained)
|
crt=$(realpath $(dirname $(rustup which rustc))/../lib/rustlib/aarch64-unknown-linux-musl/lib/self-contained)
|
||||||
sysroot_lib=/usr/aarch64-unknown-linux-musl/usr/lib
|
sysroot_lib=/usr/aarch64-unknown-linux-musl/usr/lib
|
||||||
apk_url=https://dl-cdn.alpinelinux.org/alpine/latest-stable/main/aarch64/
|
apk_url=https://dl-cdn.alpinelinux.org/alpine/latest-stable/main/aarch64/
|
||||||
@@ -334,50 +334,51 @@ jobs:
|
|||||||
path: |
|
path: |
|
||||||
node/dist/lancedb-vectordb-win32*.tgz
|
node/dist/lancedb-vectordb-win32*.tgz
|
||||||
|
|
||||||
node-windows-arm64:
|
# TODO: https://github.com/lancedb/lancedb/issues/1975
|
||||||
name: vectordb ${{ matrix.config.arch }}-pc-windows-msvc
|
# node-windows-arm64:
|
||||||
if: startsWith(github.ref, 'refs/tags/v')
|
# name: vectordb ${{ matrix.config.arch }}-pc-windows-msvc
|
||||||
runs-on: ubuntu-latest
|
# # if: startsWith(github.ref, 'refs/tags/v')
|
||||||
container: alpine:edge
|
# runs-on: ubuntu-latest
|
||||||
strategy:
|
# container: alpine:edge
|
||||||
fail-fast: false
|
# strategy:
|
||||||
matrix:
|
# fail-fast: false
|
||||||
config:
|
# matrix:
|
||||||
# - arch: x86_64
|
# config:
|
||||||
- arch: aarch64
|
# # - arch: x86_64
|
||||||
steps:
|
# - arch: aarch64
|
||||||
- name: Checkout
|
# steps:
|
||||||
uses: actions/checkout@v4
|
# - name: Checkout
|
||||||
- name: Install dependencies
|
# uses: actions/checkout@v4
|
||||||
run: |
|
# - name: Install dependencies
|
||||||
apk add protobuf-dev curl clang lld llvm19 grep npm bash msitools sed
|
# run: |
|
||||||
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y --default-toolchain 1.80.0
|
# apk add protobuf-dev curl clang lld llvm19 grep npm bash msitools sed
|
||||||
echo "source $HOME/.cargo/env" >> saved_env
|
# curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y
|
||||||
echo "export CC=clang" >> saved_env
|
# echo "source $HOME/.cargo/env" >> saved_env
|
||||||
echo "export AR=llvm-ar" >> saved_env
|
# echo "export CC=clang" >> saved_env
|
||||||
source "$HOME/.cargo/env"
|
# echo "export AR=llvm-ar" >> saved_env
|
||||||
rustup target add ${{ matrix.config.arch }}-pc-windows-msvc --toolchain 1.80.0
|
# source "$HOME/.cargo/env"
|
||||||
(mkdir -p sysroot && cd sysroot && sh ../ci/sysroot-${{ matrix.config.arch }}-pc-windows-msvc.sh)
|
# rustup target add ${{ matrix.config.arch }}-pc-windows-msvc
|
||||||
echo "export C_INCLUDE_PATH=/usr/${{ matrix.config.arch }}-pc-windows-msvc/usr/include" >> saved_env
|
# (mkdir -p sysroot && cd sysroot && sh ../ci/sysroot-${{ matrix.config.arch }}-pc-windows-msvc.sh)
|
||||||
echo "export CARGO_BUILD_TARGET=${{ matrix.config.arch }}-pc-windows-msvc" >> saved_env
|
# echo "export C_INCLUDE_PATH=/usr/${{ matrix.config.arch }}-pc-windows-msvc/usr/include" >> saved_env
|
||||||
- name: Configure x86_64 build
|
# echo "export CARGO_BUILD_TARGET=${{ matrix.config.arch }}-pc-windows-msvc" >> saved_env
|
||||||
if: ${{ matrix.config.arch == 'x86_64' }}
|
# - name: Configure x86_64 build
|
||||||
run: |
|
# if: ${{ matrix.config.arch == 'x86_64' }}
|
||||||
echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=+crt-static,+avx2,+fma,+f16c -Clinker=lld -Clink-arg=/LIBPATH:/usr/x86_64-pc-windows-msvc/usr/lib'" >> saved_env
|
# run: |
|
||||||
- name: Configure aarch64 build
|
# echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=+crt-static,+avx2,+fma,+f16c -Clinker=lld -Clink-arg=/LIBPATH:/usr/x86_64-pc-windows-msvc/usr/lib'" >> saved_env
|
||||||
if: ${{ matrix.config.arch == 'aarch64' }}
|
# - name: Configure aarch64 build
|
||||||
run: |
|
# if: ${{ matrix.config.arch == 'aarch64' }}
|
||||||
echo "export RUSTFLAGS='-Ctarget-feature=+crt-static,+neon,+fp16,+fhm,+dotprod -Clinker=lld -Clink-arg=/LIBPATH:/usr/aarch64-pc-windows-msvc/usr/lib -Clink-arg=arm64rt.lib'" >> saved_env
|
# run: |
|
||||||
- name: Build Windows Artifacts
|
# echo "export RUSTFLAGS='-Ctarget-feature=+crt-static,+neon,+fp16,+fhm,+dotprod -Clinker=lld -Clink-arg=/LIBPATH:/usr/aarch64-pc-windows-msvc/usr/lib -Clink-arg=arm64rt.lib'" >> saved_env
|
||||||
run: |
|
# - name: Build Windows Artifacts
|
||||||
source ./saved_env
|
# run: |
|
||||||
bash ci/manylinux_node/build_vectordb.sh ${{ matrix.config.arch }} ${{ matrix.config.arch }}-pc-windows-msvc
|
# source ./saved_env
|
||||||
- name: Upload Windows Artifacts
|
# bash ci/manylinux_node/build_vectordb.sh ${{ matrix.config.arch }} ${{ matrix.config.arch }}-pc-windows-msvc
|
||||||
uses: actions/upload-artifact@v4
|
# - name: Upload Windows Artifacts
|
||||||
with:
|
# uses: actions/upload-artifact@v4
|
||||||
name: node-native-windows-${{ matrix.config.arch }}
|
# with:
|
||||||
path: |
|
# name: node-native-windows-${{ matrix.config.arch }}
|
||||||
node/dist/lancedb-vectordb-win32*.tgz
|
# path: |
|
||||||
|
# node/dist/lancedb-vectordb-win32*.tgz
|
||||||
|
|
||||||
nodejs-windows:
|
nodejs-windows:
|
||||||
name: lancedb ${{ matrix.target }}
|
name: lancedb ${{ matrix.target }}
|
||||||
@@ -413,57 +414,58 @@ jobs:
|
|||||||
path: |
|
path: |
|
||||||
nodejs/dist/*.node
|
nodejs/dist/*.node
|
||||||
|
|
||||||
nodejs-windows-arm64:
|
# TODO: https://github.com/lancedb/lancedb/issues/1975
|
||||||
name: lancedb ${{ matrix.config.arch }}-pc-windows-msvc
|
# nodejs-windows-arm64:
|
||||||
# Only runs on tags that matches the make-release action
|
# name: lancedb ${{ matrix.config.arch }}-pc-windows-msvc
|
||||||
if: startsWith(github.ref, 'refs/tags/v')
|
# # Only runs on tags that matches the make-release action
|
||||||
runs-on: ubuntu-latest
|
# # if: startsWith(github.ref, 'refs/tags/v')
|
||||||
container: alpine:edge
|
# runs-on: ubuntu-latest
|
||||||
strategy:
|
# container: alpine:edge
|
||||||
fail-fast: false
|
# strategy:
|
||||||
matrix:
|
# fail-fast: false
|
||||||
config:
|
# matrix:
|
||||||
# - arch: x86_64
|
# config:
|
||||||
- arch: aarch64
|
# # - arch: x86_64
|
||||||
steps:
|
# - arch: aarch64
|
||||||
- name: Checkout
|
# steps:
|
||||||
uses: actions/checkout@v4
|
# - name: Checkout
|
||||||
- name: Install dependencies
|
# uses: actions/checkout@v4
|
||||||
run: |
|
# - name: Install dependencies
|
||||||
apk add protobuf-dev curl clang lld llvm19 grep npm bash msitools sed
|
# run: |
|
||||||
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y --default-toolchain 1.80.0
|
# apk add protobuf-dev curl clang lld llvm19 grep npm bash msitools sed
|
||||||
echo "source $HOME/.cargo/env" >> saved_env
|
# curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y
|
||||||
echo "export CC=clang" >> saved_env
|
# echo "source $HOME/.cargo/env" >> saved_env
|
||||||
echo "export AR=llvm-ar" >> saved_env
|
# echo "export CC=clang" >> saved_env
|
||||||
source "$HOME/.cargo/env"
|
# echo "export AR=llvm-ar" >> saved_env
|
||||||
rustup target add ${{ matrix.config.arch }}-pc-windows-msvc --toolchain 1.80.0
|
# source "$HOME/.cargo/env"
|
||||||
(mkdir -p sysroot && cd sysroot && sh ../ci/sysroot-${{ matrix.config.arch }}-pc-windows-msvc.sh)
|
# rustup target add ${{ matrix.config.arch }}-pc-windows-msvc
|
||||||
echo "export C_INCLUDE_PATH=/usr/${{ matrix.config.arch }}-pc-windows-msvc/usr/include" >> saved_env
|
# (mkdir -p sysroot && cd sysroot && sh ../ci/sysroot-${{ matrix.config.arch }}-pc-windows-msvc.sh)
|
||||||
echo "export CARGO_BUILD_TARGET=${{ matrix.config.arch }}-pc-windows-msvc" >> saved_env
|
# echo "export C_INCLUDE_PATH=/usr/${{ matrix.config.arch }}-pc-windows-msvc/usr/include" >> saved_env
|
||||||
printf '#!/bin/sh\ncargo "$@"' > $HOME/.cargo/bin/cargo-xwin
|
# echo "export CARGO_BUILD_TARGET=${{ matrix.config.arch }}-pc-windows-msvc" >> saved_env
|
||||||
chmod u+x $HOME/.cargo/bin/cargo-xwin
|
# printf '#!/bin/sh\ncargo "$@"' > $HOME/.cargo/bin/cargo-xwin
|
||||||
- name: Configure x86_64 build
|
# chmod u+x $HOME/.cargo/bin/cargo-xwin
|
||||||
if: ${{ matrix.config.arch == 'x86_64' }}
|
# - name: Configure x86_64 build
|
||||||
run: |
|
# if: ${{ matrix.config.arch == 'x86_64' }}
|
||||||
echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=+crt-static,+avx2,+fma,+f16c -Clinker=lld -Clink-arg=/LIBPATH:/usr/x86_64-pc-windows-msvc/usr/lib'" >> saved_env
|
# run: |
|
||||||
- name: Configure aarch64 build
|
# echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=+crt-static,+avx2,+fma,+f16c -Clinker=lld -Clink-arg=/LIBPATH:/usr/x86_64-pc-windows-msvc/usr/lib'" >> saved_env
|
||||||
if: ${{ matrix.config.arch == 'aarch64' }}
|
# - name: Configure aarch64 build
|
||||||
run: |
|
# if: ${{ matrix.config.arch == 'aarch64' }}
|
||||||
echo "export RUSTFLAGS='-Ctarget-feature=+crt-static,+neon,+fp16,+fhm,+dotprod -Clinker=lld -Clink-arg=/LIBPATH:/usr/aarch64-pc-windows-msvc/usr/lib -Clink-arg=arm64rt.lib'" >> saved_env
|
# run: |
|
||||||
- name: Build Windows Artifacts
|
# echo "export RUSTFLAGS='-Ctarget-feature=+crt-static,+neon,+fp16,+fhm,+dotprod -Clinker=lld -Clink-arg=/LIBPATH:/usr/aarch64-pc-windows-msvc/usr/lib -Clink-arg=arm64rt.lib'" >> saved_env
|
||||||
run: |
|
# - name: Build Windows Artifacts
|
||||||
source ./saved_env
|
# run: |
|
||||||
bash ci/manylinux_node/build_lancedb.sh ${{ matrix.config.arch }}
|
# source ./saved_env
|
||||||
- name: Upload Windows Artifacts
|
# bash ci/manylinux_node/build_lancedb.sh ${{ matrix.config.arch }}
|
||||||
uses: actions/upload-artifact@v4
|
# - name: Upload Windows Artifacts
|
||||||
with:
|
# uses: actions/upload-artifact@v4
|
||||||
name: nodejs-native-windows-${{ matrix.config.arch }}
|
# with:
|
||||||
path: |
|
# name: nodejs-native-windows-${{ matrix.config.arch }}
|
||||||
nodejs/dist/*.node
|
# path: |
|
||||||
|
# nodejs/dist/*.node
|
||||||
|
|
||||||
release:
|
release:
|
||||||
name: vectordb NPM Publish
|
name: vectordb NPM Publish
|
||||||
needs: [node, node-macos, node-linux-gnu, node-linux-musl, node-windows, node-windows-arm64]
|
needs: [node, node-macos, node-linux-gnu, node-linux-musl, node-windows]
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
# Only runs on tags that matches the make-release action
|
# Only runs on tags that matches the make-release action
|
||||||
if: startsWith(github.ref, 'refs/tags/v')
|
if: startsWith(github.ref, 'refs/tags/v')
|
||||||
@@ -503,7 +505,7 @@ jobs:
|
|||||||
|
|
||||||
release-nodejs:
|
release-nodejs:
|
||||||
name: lancedb NPM Publish
|
name: lancedb NPM Publish
|
||||||
needs: [nodejs-macos, nodejs-linux-gnu, nodejs-linux-musl, nodejs-windows, nodejs-windows-arm64]
|
needs: [nodejs-macos, nodejs-linux-gnu, nodejs-linux-musl, nodejs-windows]
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
# Only runs on tags that matches the make-release action
|
# Only runs on tags that matches the make-release action
|
||||||
if: startsWith(github.ref, 'refs/tags/v')
|
if: startsWith(github.ref, 'refs/tags/v')
|
||||||
|
|||||||
4
.github/workflows/python.yml
vendored
4
.github/workflows/python.yml
vendored
@@ -30,10 +30,10 @@ jobs:
|
|||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.11"
|
python-version: "3.12"
|
||||||
- name: Install ruff
|
- name: Install ruff
|
||||||
run: |
|
run: |
|
||||||
pip install ruff==0.5.4
|
pip install ruff==0.8.4
|
||||||
- name: Format check
|
- name: Format check
|
||||||
run: ruff format --check .
|
run: ruff format --check .
|
||||||
- name: Lint
|
- name: Lint
|
||||||
|
|||||||
18
Cargo.toml
18
Cargo.toml
@@ -21,16 +21,16 @@ categories = ["database-implementations"]
|
|||||||
rust-version = "1.78.0"
|
rust-version = "1.78.0"
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
lance = { "version" = "=0.21.0", "features" = [
|
lance = { "version" = "=0.21.1", "features" = [
|
||||||
"dynamodb",
|
"dynamodb",
|
||||||
], git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.5" }
|
], git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||||
lance-io = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.5" }
|
lance-io = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||||
lance-index = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.5" }
|
lance-index = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||||
lance-linalg = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.5" }
|
lance-linalg = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||||
lance-table = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.5" }
|
lance-table = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||||
lance-testing = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.5" }
|
lance-testing = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||||
lance-datafusion = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.5" }
|
lance-datafusion = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||||
lance-encoding = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.5" }
|
lance-encoding = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||||
# Note that this one does not include pyarrow
|
# Note that this one does not include pyarrow
|
||||||
arrow = { version = "53.2", optional = false }
|
arrow = { version = "53.2", optional = false }
|
||||||
arrow-array = "53.2"
|
arrow-array = "53.2"
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ Consider that we have a LanceDB table named `my_table`, whose string column `tex
|
|||||||
});
|
});
|
||||||
|
|
||||||
await tbl
|
await tbl
|
||||||
.search("puppy", queryType="fts")
|
.search("puppy", "fts")
|
||||||
.select(["text"])
|
.select(["text"])
|
||||||
.limit(10)
|
.limit(10)
|
||||||
.toArray();
|
.toArray();
|
||||||
|
|||||||
@@ -133,6 +133,10 @@ lists the indices that LanceDb supports.
|
|||||||
|
|
||||||
::: lancedb.index.IvfPq
|
::: lancedb.index.IvfPq
|
||||||
|
|
||||||
|
::: lancedb.index.HnswPq
|
||||||
|
|
||||||
|
::: lancedb.index.HnswSq
|
||||||
|
|
||||||
::: lancedb.index.IvfFlat
|
::: lancedb.index.IvfFlat
|
||||||
|
|
||||||
## Querying (Asynchronous)
|
## Querying (Asynchronous)
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
<parent>
|
<parent>
|
||||||
<groupId>com.lancedb</groupId>
|
<groupId>com.lancedb</groupId>
|
||||||
<artifactId>lancedb-parent</artifactId>
|
<artifactId>lancedb-parent</artifactId>
|
||||||
<version>0.14.1-beta.6</version>
|
<version>0.14.1-final.0</version>
|
||||||
<relativePath>../pom.xml</relativePath>
|
<relativePath>../pom.xml</relativePath>
|
||||||
</parent>
|
</parent>
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
<groupId>com.lancedb</groupId>
|
<groupId>com.lancedb</groupId>
|
||||||
<artifactId>lancedb-parent</artifactId>
|
<artifactId>lancedb-parent</artifactId>
|
||||||
<version>0.14.1-beta.6</version>
|
<version>0.14.1-final.0</version>
|
||||||
<packaging>pom</packaging>
|
<packaging>pom</packaging>
|
||||||
|
|
||||||
<name>LanceDB Parent</name>
|
<name>LanceDB Parent</name>
|
||||||
|
|||||||
111
node/package-lock.json
generated
111
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.14.1-beta.6",
|
"version": "0.14.1",
|
||||||
"lockfileVersion": 3,
|
"lockfileVersion": 3,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
"": {
|
"": {
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.14.1-beta.6",
|
"version": "0.14.1",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64",
|
"x64",
|
||||||
"arm64"
|
"arm64"
|
||||||
@@ -52,14 +52,14 @@
|
|||||||
"uuid": "^9.0.0"
|
"uuid": "^9.0.0"
|
||||||
},
|
},
|
||||||
"optionalDependencies": {
|
"optionalDependencies": {
|
||||||
"@lancedb/vectordb-darwin-arm64": "0.14.1-beta.6",
|
"@lancedb/vectordb-darwin-arm64": "0.14.1",
|
||||||
"@lancedb/vectordb-darwin-x64": "0.14.1-beta.6",
|
"@lancedb/vectordb-darwin-x64": "0.14.1",
|
||||||
"@lancedb/vectordb-linux-arm64-gnu": "0.14.1-beta.6",
|
"@lancedb/vectordb-linux-arm64-gnu": "0.14.1",
|
||||||
"@lancedb/vectordb-linux-arm64-musl": "0.14.1-beta.6",
|
"@lancedb/vectordb-linux-arm64-musl": "0.14.1",
|
||||||
"@lancedb/vectordb-linux-x64-gnu": "0.14.1-beta.6",
|
"@lancedb/vectordb-linux-x64-gnu": "0.14.1",
|
||||||
"@lancedb/vectordb-linux-x64-musl": "0.14.1-beta.6",
|
"@lancedb/vectordb-linux-x64-musl": "0.14.1",
|
||||||
"@lancedb/vectordb-win32-arm64-msvc": "0.14.1-beta.6",
|
"@lancedb/vectordb-win32-arm64-msvc": "0.14.1",
|
||||||
"@lancedb/vectordb-win32-x64-msvc": "0.14.1-beta.6"
|
"@lancedb/vectordb-win32-x64-msvc": "0.14.1"
|
||||||
},
|
},
|
||||||
"peerDependencies": {
|
"peerDependencies": {
|
||||||
"@apache-arrow/ts": "^14.0.2",
|
"@apache-arrow/ts": "^14.0.2",
|
||||||
@@ -329,6 +329,97 @@
|
|||||||
"@jridgewell/sourcemap-codec": "^1.4.10"
|
"@jridgewell/sourcemap-codec": "^1.4.10"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
||||||
|
"version": "0.14.1",
|
||||||
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.14.1.tgz",
|
||||||
|
"integrity": "sha512-6t7XHR7dBjDmAS/kz5wbe7LPhKW+WkFA16ZPyh0lmuxfnss4VvN3LE6qQBHjzYzB9U6Nu/4ktQ50xZGEPTnc5A==",
|
||||||
|
"cpu": [
|
||||||
|
"arm64"
|
||||||
|
],
|
||||||
|
"license": "Apache-2.0",
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"darwin"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"node_modules/@lancedb/vectordb-darwin-x64": {
|
||||||
|
"version": "0.14.1",
|
||||||
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.14.1.tgz",
|
||||||
|
"integrity": "sha512-8q6Kd6XnNPKN8wqj75pHVQ4KFl6z9BaI6lWDiEaCNcO3bjPZkcLFNosJq4raxZ9iUi50Yl0qFJ6qR0XFVTwnnw==",
|
||||||
|
"cpu": [
|
||||||
|
"x64"
|
||||||
|
],
|
||||||
|
"license": "Apache-2.0",
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"darwin"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
||||||
|
"version": "0.14.1",
|
||||||
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.14.1.tgz",
|
||||||
|
"integrity": "sha512-4djEMmeNb+p6nW/C4xb8wdMwnIbWfO8fYAwiplOxzxeOpPaUC9rhwUUDCbrJDCpMa8RP5ED4/jC6yT8epaDMDw==",
|
||||||
|
"cpu": [
|
||||||
|
"arm64"
|
||||||
|
],
|
||||||
|
"license": "Apache-2.0",
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"linux"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"node_modules/@lancedb/vectordb-linux-arm64-musl": {
|
||||||
|
"version": "0.14.1",
|
||||||
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-musl/-/vectordb-linux-arm64-musl-0.14.1.tgz",
|
||||||
|
"integrity": "sha512-c33hSsp16pnC58plzx1OXuifp9Rachx/MshE/L/OReoutt74fFdrRJwUjE4UCAysyY5QdvTrNm9OhDjopQK2Bw==",
|
||||||
|
"cpu": [
|
||||||
|
"arm64"
|
||||||
|
],
|
||||||
|
"license": "Apache-2.0",
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"linux"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
||||||
|
"version": "0.14.1",
|
||||||
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.14.1.tgz",
|
||||||
|
"integrity": "sha512-psu6cH9iLiSbUEZD1EWbOA4THGYSwJvS2XICO9yN7A6D41AP/ynYMRZNKWo1fpdi2Fjb0xNQwiNhQyqwbi5gzA==",
|
||||||
|
"cpu": [
|
||||||
|
"x64"
|
||||||
|
],
|
||||||
|
"license": "Apache-2.0",
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"linux"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"node_modules/@lancedb/vectordb-linux-x64-musl": {
|
||||||
|
"version": "0.14.1",
|
||||||
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-musl/-/vectordb-linux-x64-musl-0.14.1.tgz",
|
||||||
|
"integrity": "sha512-Rg4VWW80HaTFmR7EvNSu+nfRQQM8beO/otBn/Nus5mj5zFw/7cacGRmiEYhDnk5iAn8nauV+Jsi9j2U+C2hp5w==",
|
||||||
|
"cpu": [
|
||||||
|
"x64"
|
||||||
|
],
|
||||||
|
"license": "Apache-2.0",
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"linux"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
||||||
|
"version": "0.14.1",
|
||||||
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.14.1.tgz",
|
||||||
|
"integrity": "sha512-XbifasmMbQIt3V9P0AtQND6M3XFiIAc1ZIgmjzBjOmxwqw4sQUwHMyJGIGOzKFZTK3fPJIGRHId7jAzXuBgfQg==",
|
||||||
|
"cpu": [
|
||||||
|
"x64"
|
||||||
|
],
|
||||||
|
"license": "Apache-2.0",
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"win32"
|
||||||
|
]
|
||||||
|
},
|
||||||
"node_modules/@neon-rs/cli": {
|
"node_modules/@neon-rs/cli": {
|
||||||
"version": "0.0.160",
|
"version": "0.0.160",
|
||||||
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",
|
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.14.1-beta.6",
|
"version": "0.14.1",
|
||||||
"description": " Serverless, low-latency vector database for AI applications",
|
"description": " Serverless, low-latency vector database for AI applications",
|
||||||
"private": false,
|
"private": false,
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
@@ -92,13 +92,13 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"optionalDependencies": {
|
"optionalDependencies": {
|
||||||
"@lancedb/vectordb-darwin-x64": "0.14.1-beta.6",
|
"@lancedb/vectordb-darwin-x64": "0.14.1",
|
||||||
"@lancedb/vectordb-darwin-arm64": "0.14.1-beta.6",
|
"@lancedb/vectordb-darwin-arm64": "0.14.1",
|
||||||
"@lancedb/vectordb-linux-x64-gnu": "0.14.1-beta.6",
|
"@lancedb/vectordb-linux-x64-gnu": "0.14.1",
|
||||||
"@lancedb/vectordb-linux-arm64-gnu": "0.14.1-beta.6",
|
"@lancedb/vectordb-linux-arm64-gnu": "0.14.1",
|
||||||
"@lancedb/vectordb-linux-x64-musl": "0.14.1-beta.6",
|
"@lancedb/vectordb-linux-x64-musl": "0.14.1",
|
||||||
"@lancedb/vectordb-linux-arm64-musl": "0.14.1-beta.6",
|
"@lancedb/vectordb-linux-arm64-musl": "0.14.1",
|
||||||
"@lancedb/vectordb-win32-x64-msvc": "0.14.1-beta.6",
|
"@lancedb/vectordb-win32-x64-msvc": "0.14.1",
|
||||||
"@lancedb/vectordb-win32-arm64-msvc": "0.14.1-beta.6"
|
"@lancedb/vectordb-win32-arm64-msvc": "0.14.1"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb-nodejs"
|
name = "lancedb-nodejs"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
version = "0.14.1-beta.6"
|
version = "0.14.1"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
description.workspace = true
|
description.workspace = true
|
||||||
repository.workspace = true
|
repository.workspace = true
|
||||||
@@ -12,7 +12,10 @@ categories.workspace = true
|
|||||||
crate-type = ["cdylib"]
|
crate-type = ["cdylib"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
async-trait.workspace = true
|
||||||
arrow-ipc.workspace = true
|
arrow-ipc.workspace = true
|
||||||
|
arrow-array.workspace = true
|
||||||
|
arrow-schema.workspace = true
|
||||||
env_logger.workspace = true
|
env_logger.workspace = true
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
lancedb = { path = "../rust/lancedb", features = ["remote"] }
|
lancedb = { path = "../rust/lancedb", features = ["remote"] }
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ import * as arrow18 from "apache-arrow-18";
|
|||||||
|
|
||||||
import {
|
import {
|
||||||
convertToTable,
|
convertToTable,
|
||||||
|
fromBufferToRecordBatch,
|
||||||
|
fromRecordBatchToBuffer,
|
||||||
fromTableToBuffer,
|
fromTableToBuffer,
|
||||||
makeArrowTable,
|
makeArrowTable,
|
||||||
makeEmptyTable,
|
makeEmptyTable,
|
||||||
@@ -553,5 +555,28 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe("converting record batches to buffers", function () {
|
||||||
|
it("can convert to buffered record batch and back again", async function () {
|
||||||
|
const records = [
|
||||||
|
{ text: "dog", vector: [0.1, 0.2] },
|
||||||
|
{ text: "cat", vector: [0.3, 0.4] },
|
||||||
|
];
|
||||||
|
const table = await convertToTable(records);
|
||||||
|
const batch = table.batches[0];
|
||||||
|
|
||||||
|
const buffer = await fromRecordBatchToBuffer(batch);
|
||||||
|
const result = await fromBufferToRecordBatch(buffer);
|
||||||
|
|
||||||
|
expect(JSON.stringify(batch.toArray())).toEqual(
|
||||||
|
JSON.stringify(result?.toArray()),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("converting from buffer returns null if buffer has no record batches", async function () {
|
||||||
|
const result = await fromBufferToRecordBatch(Buffer.from([0x01, 0x02])); // bad data
|
||||||
|
expect(result).toEqual(null);
|
||||||
|
});
|
||||||
|
});
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|||||||
79
nodejs/__test__/rerankers.test.ts
Normal file
79
nodejs/__test__/rerankers.test.ts
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
import { RecordBatch } from "apache-arrow";
|
||||||
|
import * as tmp from "tmp";
|
||||||
|
import { Connection, Index, Table, connect, makeArrowTable } from "../lancedb";
|
||||||
|
import { RRFReranker } from "../lancedb/rerankers";
|
||||||
|
|
||||||
|
describe("rerankers", function () {
|
||||||
|
let tmpDir: tmp.DirResult;
|
||||||
|
let conn: Connection;
|
||||||
|
let table: Table;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
tmpDir = tmp.dirSync({ unsafeCleanup: true });
|
||||||
|
conn = await connect(tmpDir.name);
|
||||||
|
table = await conn.createTable("mytable", [
|
||||||
|
{ vector: [0.1, 0.1], text: "dog" },
|
||||||
|
{ vector: [0.2, 0.2], text: "cat" },
|
||||||
|
]);
|
||||||
|
await table.createIndex("text", {
|
||||||
|
config: Index.fts(),
|
||||||
|
replace: true,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it("will query with the custom reranker", async function () {
|
||||||
|
const expectedResult = [
|
||||||
|
{
|
||||||
|
text: "albert",
|
||||||
|
// biome-ignore lint/style/useNamingConvention: this is the lance field name
|
||||||
|
_relevance_score: 0.99,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
class MyCustomReranker {
|
||||||
|
async rerankHybrid(
|
||||||
|
_query: string,
|
||||||
|
_vecResults: RecordBatch,
|
||||||
|
_ftsResults: RecordBatch,
|
||||||
|
): Promise<RecordBatch> {
|
||||||
|
// no reranker logic, just return some static data
|
||||||
|
const table = makeArrowTable(expectedResult);
|
||||||
|
return table.batches[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = await table
|
||||||
|
.query()
|
||||||
|
.nearestTo([0.1, 0.1])
|
||||||
|
.fullTextSearch("dog")
|
||||||
|
.rerank(new MyCustomReranker())
|
||||||
|
.select(["text"])
|
||||||
|
.limit(5)
|
||||||
|
.toArray();
|
||||||
|
|
||||||
|
result = JSON.parse(JSON.stringify(result)); // convert StructRow to Object
|
||||||
|
expect(result).toEqual([
|
||||||
|
{
|
||||||
|
text: "albert",
|
||||||
|
// biome-ignore lint/style/useNamingConvention: this is the lance field name
|
||||||
|
_relevance_score: 0.99,
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("will query with RRFReranker", async function () {
|
||||||
|
// smoke test to see if the Rust wrapping Typescript is wired up correctly
|
||||||
|
const result = await table
|
||||||
|
.query()
|
||||||
|
.nearestTo([0.1, 0.1])
|
||||||
|
.fullTextSearch("dog")
|
||||||
|
.rerank(await RRFReranker.create())
|
||||||
|
.select(["text"])
|
||||||
|
.limit(5)
|
||||||
|
.toArray();
|
||||||
|
|
||||||
|
expect(result).toHaveLength(2);
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -27,7 +27,9 @@ import {
|
|||||||
List,
|
List,
|
||||||
Null,
|
Null,
|
||||||
RecordBatch,
|
RecordBatch,
|
||||||
|
RecordBatchFileReader,
|
||||||
RecordBatchFileWriter,
|
RecordBatchFileWriter,
|
||||||
|
RecordBatchReader,
|
||||||
RecordBatchStreamWriter,
|
RecordBatchStreamWriter,
|
||||||
Schema,
|
Schema,
|
||||||
Struct,
|
Struct,
|
||||||
@@ -810,6 +812,30 @@ export async function fromDataToBuffer(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Read a single record batch from a buffer.
|
||||||
|
*
|
||||||
|
* Returns null if the buffer does not contain a record batch
|
||||||
|
*/
|
||||||
|
export async function fromBufferToRecordBatch(
|
||||||
|
data: Buffer,
|
||||||
|
): Promise<RecordBatch | null> {
|
||||||
|
const iter = await RecordBatchFileReader.readAll(Buffer.from(data)).next()
|
||||||
|
.value;
|
||||||
|
const recordBatch = iter?.next().value;
|
||||||
|
return recordBatch || null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a buffer containing a single record batch
|
||||||
|
*/
|
||||||
|
export async function fromRecordBatchToBuffer(
|
||||||
|
batch: RecordBatch,
|
||||||
|
): Promise<Buffer> {
|
||||||
|
const writer = new RecordBatchFileWriter().writeAll([batch]);
|
||||||
|
return Buffer.from(await writer.toUint8Array());
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Serialize an Arrow Table into a buffer using the Arrow IPC Stream serialization
|
* Serialize an Arrow Table into a buffer using the Arrow IPC Stream serialization
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ export { Index, IndexOptions, IvfPqOptions } from "./indices";
|
|||||||
export { Table, AddDataOptions, UpdateOptions, OptimizeOptions } from "./table";
|
export { Table, AddDataOptions, UpdateOptions, OptimizeOptions } from "./table";
|
||||||
|
|
||||||
export * as embedding from "./embedding";
|
export * as embedding from "./embedding";
|
||||||
|
export * as rerankers from "./rerankers";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Connect to a LanceDB instance at the given URI.
|
* Connect to a LanceDB instance at the given URI.
|
||||||
|
|||||||
@@ -16,6 +16,8 @@ import {
|
|||||||
Table as ArrowTable,
|
Table as ArrowTable,
|
||||||
type IntoVector,
|
type IntoVector,
|
||||||
RecordBatch,
|
RecordBatch,
|
||||||
|
fromBufferToRecordBatch,
|
||||||
|
fromRecordBatchToBuffer,
|
||||||
tableFromIPC,
|
tableFromIPC,
|
||||||
} from "./arrow";
|
} from "./arrow";
|
||||||
import { type IvfPqOptions } from "./indices";
|
import { type IvfPqOptions } from "./indices";
|
||||||
@@ -25,6 +27,7 @@ import {
|
|||||||
Table as NativeTable,
|
Table as NativeTable,
|
||||||
VectorQuery as NativeVectorQuery,
|
VectorQuery as NativeVectorQuery,
|
||||||
} from "./native";
|
} from "./native";
|
||||||
|
import { Reranker } from "./rerankers";
|
||||||
export class RecordBatchIterator implements AsyncIterator<RecordBatch> {
|
export class RecordBatchIterator implements AsyncIterator<RecordBatch> {
|
||||||
private promisedInner?: Promise<NativeBatchIterator>;
|
private promisedInner?: Promise<NativeBatchIterator>;
|
||||||
private inner?: NativeBatchIterator;
|
private inner?: NativeBatchIterator;
|
||||||
@@ -542,6 +545,27 @@ export class VectorQuery extends QueryBase<NativeVectorQuery> {
|
|||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rerank(reranker: Reranker): VectorQuery {
|
||||||
|
super.doCall((inner) =>
|
||||||
|
inner.rerank({
|
||||||
|
rerankHybrid: async (_, args) => {
|
||||||
|
const vecResults = await fromBufferToRecordBatch(args.vecResults);
|
||||||
|
const ftsResults = await fromBufferToRecordBatch(args.ftsResults);
|
||||||
|
const result = await reranker.rerankHybrid(
|
||||||
|
args.query,
|
||||||
|
vecResults as RecordBatch,
|
||||||
|
ftsResults as RecordBatch,
|
||||||
|
);
|
||||||
|
|
||||||
|
const buffer = fromRecordBatchToBuffer(result);
|
||||||
|
return buffer;
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
return this;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** A builder for LanceDB queries. */
|
/** A builder for LanceDB queries. */
|
||||||
|
|||||||
17
nodejs/lancedb/rerankers/index.ts
Normal file
17
nodejs/lancedb/rerankers/index.ts
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
import { RecordBatch } from "apache-arrow";
|
||||||
|
|
||||||
|
export * from "./rrf";
|
||||||
|
|
||||||
|
// Interface for a reranker. A reranker is used to rerank the results from a
|
||||||
|
// vector and FTS search. This is useful for combining the results from both
|
||||||
|
// search methods.
|
||||||
|
export interface Reranker {
|
||||||
|
rerankHybrid(
|
||||||
|
query: string,
|
||||||
|
vecResults: RecordBatch,
|
||||||
|
ftsResults: RecordBatch,
|
||||||
|
): Promise<RecordBatch>;
|
||||||
|
}
|
||||||
40
nodejs/lancedb/rerankers/rrf.ts
Normal file
40
nodejs/lancedb/rerankers/rrf.ts
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
import { RecordBatch } from "apache-arrow";
|
||||||
|
import { fromBufferToRecordBatch, fromRecordBatchToBuffer } from "../arrow";
|
||||||
|
import { RrfReranker as NativeRRFReranker } from "../native";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reranks the results using the Reciprocal Rank Fusion (RRF) algorithm.
|
||||||
|
*
|
||||||
|
* Internally this uses the Rust implementation
|
||||||
|
*/
|
||||||
|
export class RRFReranker {
|
||||||
|
private inner: NativeRRFReranker;
|
||||||
|
|
||||||
|
constructor(inner: NativeRRFReranker) {
|
||||||
|
this.inner = inner;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static async create(k: number = 60) {
|
||||||
|
return new RRFReranker(
|
||||||
|
await NativeRRFReranker.tryNew(new Float32Array([k])),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
async rerankHybrid(
|
||||||
|
query: string,
|
||||||
|
vecResults: RecordBatch,
|
||||||
|
ftsResults: RecordBatch,
|
||||||
|
): Promise<RecordBatch> {
|
||||||
|
const buffer = await this.inner.rerankHybrid(
|
||||||
|
query,
|
||||||
|
await fromRecordBatchToBuffer(vecResults),
|
||||||
|
await fromRecordBatchToBuffer(ftsResults),
|
||||||
|
);
|
||||||
|
const recordBatch = await fromBufferToRecordBatch(buffer);
|
||||||
|
|
||||||
|
return recordBatch as RecordBatch;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-darwin-arm64",
|
"name": "@lancedb/lancedb-darwin-arm64",
|
||||||
"version": "0.14.1-beta.6",
|
"version": "0.14.1",
|
||||||
"os": ["darwin"],
|
"os": ["darwin"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.darwin-arm64.node",
|
"main": "lancedb.darwin-arm64.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-darwin-x64",
|
"name": "@lancedb/lancedb-darwin-x64",
|
||||||
"version": "0.14.1-beta.6",
|
"version": "0.14.1",
|
||||||
"os": ["darwin"],
|
"os": ["darwin"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.darwin-x64.node",
|
"main": "lancedb.darwin-x64.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||||
"version": "0.14.1-beta.6",
|
"version": "0.14.1",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.linux-arm64-gnu.node",
|
"main": "lancedb.linux-arm64-gnu.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||||
"version": "0.14.1-beta.6",
|
"version": "0.14.1",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.linux-arm64-musl.node",
|
"main": "lancedb.linux-arm64-musl.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||||
"version": "0.14.1-beta.6",
|
"version": "0.14.1",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.linux-x64-gnu.node",
|
"main": "lancedb.linux-x64-gnu.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||||
"version": "0.14.1-beta.6",
|
"version": "0.14.1",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.linux-x64-musl.node",
|
"main": "lancedb.linux-x64-musl.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||||
"version": "0.14.1-beta.6",
|
"version": "0.14.1",
|
||||||
"os": [
|
"os": [
|
||||||
"win32"
|
"win32"
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||||
"version": "0.14.1-beta.6",
|
"version": "0.14.1",
|
||||||
"os": ["win32"],
|
"os": ["win32"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.win32-x64-msvc.node",
|
"main": "lancedb.win32-x64-msvc.node",
|
||||||
|
|||||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb",
|
"name": "@lancedb/lancedb",
|
||||||
"version": "0.14.1-beta.6",
|
"version": "0.14.1",
|
||||||
"lockfileVersion": 3,
|
"lockfileVersion": 3,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
"": {
|
"": {
|
||||||
"name": "@lancedb/lancedb",
|
"name": "@lancedb/lancedb",
|
||||||
"version": "0.14.1-beta.6",
|
"version": "0.14.1",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64",
|
"x64",
|
||||||
"arm64"
|
"arm64"
|
||||||
|
|||||||
@@ -11,7 +11,7 @@
|
|||||||
"ann"
|
"ann"
|
||||||
],
|
],
|
||||||
"private": false,
|
"private": false,
|
||||||
"version": "0.14.1-beta.6",
|
"version": "0.14.1",
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
"exports": {
|
"exports": {
|
||||||
".": "./dist/index.js",
|
".": "./dist/index.js",
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ mod iterator;
|
|||||||
pub mod merge;
|
pub mod merge;
|
||||||
mod query;
|
mod query;
|
||||||
pub mod remote;
|
pub mod remote;
|
||||||
|
mod rerankers;
|
||||||
mod table;
|
mod table;
|
||||||
mod util;
|
mod util;
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,8 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use lancedb::index::scalar::FullTextSearchQuery;
|
use lancedb::index::scalar::FullTextSearchQuery;
|
||||||
use lancedb::query::ExecutableQuery;
|
use lancedb::query::ExecutableQuery;
|
||||||
use lancedb::query::Query as LanceDbQuery;
|
use lancedb::query::Query as LanceDbQuery;
|
||||||
@@ -25,6 +27,8 @@ use napi_derive::napi;
|
|||||||
use crate::error::convert_error;
|
use crate::error::convert_error;
|
||||||
use crate::error::NapiErrorExt;
|
use crate::error::NapiErrorExt;
|
||||||
use crate::iterator::RecordBatchIterator;
|
use crate::iterator::RecordBatchIterator;
|
||||||
|
use crate::rerankers::Reranker;
|
||||||
|
use crate::rerankers::RerankerCallbacks;
|
||||||
use crate::util::parse_distance_type;
|
use crate::util::parse_distance_type;
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
@@ -218,6 +222,14 @@ impl VectorQuery {
|
|||||||
self.inner = self.inner.clone().with_row_id();
|
self.inner = self.inner.clone().with_row_id();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub fn rerank(&mut self, callbacks: RerankerCallbacks) {
|
||||||
|
self.inner = self
|
||||||
|
.inner
|
||||||
|
.clone()
|
||||||
|
.rerank(Arc::new(Reranker::new(callbacks)));
|
||||||
|
}
|
||||||
|
|
||||||
#[napi(catch_unwind)]
|
#[napi(catch_unwind)]
|
||||||
pub async fn execute(
|
pub async fn execute(
|
||||||
&self,
|
&self,
|
||||||
|
|||||||
147
nodejs/src/rerankers.rs
Normal file
147
nodejs/src/rerankers.rs
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
use arrow_array::RecordBatch;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use napi::{
|
||||||
|
bindgen_prelude::*,
|
||||||
|
threadsafe_function::{ErrorStrategy, ThreadsafeFunction},
|
||||||
|
};
|
||||||
|
use napi_derive::napi;
|
||||||
|
|
||||||
|
use lancedb::ipc::batches_to_ipc_file;
|
||||||
|
use lancedb::rerankers::Reranker as LanceDBReranker;
|
||||||
|
use lancedb::{error::Error, ipc::ipc_file_to_batches};
|
||||||
|
|
||||||
|
use crate::error::NapiErrorExt;
|
||||||
|
|
||||||
|
/// Reranker implementation that "wraps" a NodeJS Reranker implementation.
|
||||||
|
/// This contains references to the callbacks that can be used to invoke the
|
||||||
|
/// reranking methods on the NodeJS implementation and handles serializing the
|
||||||
|
/// record batches to Arrow IPC buffers.
|
||||||
|
#[napi]
|
||||||
|
pub struct Reranker {
|
||||||
|
/// callback to the Javascript which will call the rerankHybrid method of
|
||||||
|
/// some Reranker implementation
|
||||||
|
rerank_hybrid: ThreadsafeFunction<RerankHybridCallbackArgs, ErrorStrategy::CalleeHandled>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
impl Reranker {
|
||||||
|
#[napi]
|
||||||
|
pub fn new(callbacks: RerankerCallbacks) -> Self {
|
||||||
|
let rerank_hybrid = callbacks
|
||||||
|
.rerank_hybrid
|
||||||
|
.create_threadsafe_function(0, move |ctx| Ok(vec![ctx.value]))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
Self { rerank_hybrid }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl lancedb::rerankers::Reranker for Reranker {
|
||||||
|
async fn rerank_hybrid(
|
||||||
|
&self,
|
||||||
|
query: &str,
|
||||||
|
vector_results: RecordBatch,
|
||||||
|
fts_results: RecordBatch,
|
||||||
|
) -> lancedb::error::Result<RecordBatch> {
|
||||||
|
let callback_args = RerankHybridCallbackArgs {
|
||||||
|
query: query.to_string(),
|
||||||
|
vec_results: batches_to_ipc_file(&[vector_results])?,
|
||||||
|
fts_results: batches_to_ipc_file(&[fts_results])?,
|
||||||
|
};
|
||||||
|
let promised_buffer: Promise<Buffer> = self
|
||||||
|
.rerank_hybrid
|
||||||
|
.call_async(Ok(callback_args))
|
||||||
|
.await
|
||||||
|
.map_err(|e| Error::Runtime {
|
||||||
|
message: format!("napi error status={}, reason={}", e.status, e.reason),
|
||||||
|
})?;
|
||||||
|
let buffer = promised_buffer.await.map_err(|e| Error::Runtime {
|
||||||
|
message: format!("napi error status={}, reason={}", e.status, e.reason),
|
||||||
|
})?;
|
||||||
|
let mut reader = ipc_file_to_batches(buffer.to_vec())?;
|
||||||
|
let result = reader.next().ok_or(Error::Runtime {
|
||||||
|
message: "reranker result deserialization failed".to_string(),
|
||||||
|
})??;
|
||||||
|
|
||||||
|
return Ok(result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for Reranker {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.write_str("NodeJSRerankerWrapper")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi(object)]
|
||||||
|
pub struct RerankerCallbacks {
|
||||||
|
pub rerank_hybrid: JsFunction,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi(object)]
|
||||||
|
pub struct RerankHybridCallbackArgs {
|
||||||
|
pub query: String,
|
||||||
|
pub vec_results: Vec<u8>,
|
||||||
|
pub fts_results: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn buffer_to_record_batch(buffer: Buffer) -> Result<RecordBatch> {
|
||||||
|
let mut reader = ipc_file_to_batches(buffer.to_vec()).default_error()?;
|
||||||
|
reader
|
||||||
|
.next()
|
||||||
|
.ok_or(Error::InvalidInput {
|
||||||
|
message: "expected buffer containing record batch".to_string(),
|
||||||
|
})
|
||||||
|
.default_error()?
|
||||||
|
.map_err(Error::from)
|
||||||
|
.default_error()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wrapper around rust RRFReranker
|
||||||
|
#[napi]
|
||||||
|
pub struct RRFReranker {
|
||||||
|
inner: lancedb::rerankers::rrf::RRFReranker,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
impl RRFReranker {
|
||||||
|
#[napi]
|
||||||
|
pub async fn try_new(k: &[f32]) -> Result<Self> {
|
||||||
|
let k = k
|
||||||
|
.first()
|
||||||
|
.copied()
|
||||||
|
.ok_or(Error::InvalidInput {
|
||||||
|
message: "must supply RRF Reranker constructor arg 'k'".to_string(),
|
||||||
|
})
|
||||||
|
.default_error()?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
inner: lancedb::rerankers::rrf::RRFReranker::new(k),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub async fn rerank_hybrid(
|
||||||
|
&self,
|
||||||
|
query: String,
|
||||||
|
vec_results: Buffer,
|
||||||
|
fts_results: Buffer,
|
||||||
|
) -> Result<Buffer> {
|
||||||
|
let vec_results = buffer_to_record_batch(vec_results)?;
|
||||||
|
let fts_results = buffer_to_record_batch(fts_results)?;
|
||||||
|
|
||||||
|
let result = self
|
||||||
|
.inner
|
||||||
|
.rerank_hybrid(&query, vec_results, fts_results)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let result_buff = batches_to_ipc_file(&[result]).default_error()?;
|
||||||
|
|
||||||
|
Ok(Buffer::from(result_buff.as_ref()))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
[tool.bumpversion]
|
[tool.bumpversion]
|
||||||
current_version = "0.17.1"
|
current_version = "0.17.2-beta.2"
|
||||||
parse = """(?x)
|
parse = """(?x)
|
||||||
(?P<major>0|[1-9]\\d*)\\.
|
(?P<major>0|[1-9]\\d*)\\.
|
||||||
(?P<minor>0|[1-9]\\d*)\\.
|
(?P<minor>0|[1-9]\\d*)\\.
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb-python"
|
name = "lancedb-python"
|
||||||
version = "0.17.1"
|
version = "0.17.2-beta.2"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
description = "Python bindings for LanceDB"
|
description = "Python bindings for LanceDB"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
# version in Cargo.toml
|
# version in Cargo.toml
|
||||||
|
dynamic = ["version"]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"deprecation",
|
"deprecation",
|
||||||
"pylance==0.21.0b5",
|
"pylance==0.21.1b1",
|
||||||
"tqdm>=4.27.0",
|
"tqdm>=4.27.0",
|
||||||
"pydantic>=1.10",
|
"pydantic>=1.10",
|
||||||
"packaging",
|
"packaging",
|
||||||
@@ -52,8 +53,9 @@ tests = [
|
|||||||
"pytz",
|
"pytz",
|
||||||
"polars>=0.19, <=1.3.0",
|
"polars>=0.19, <=1.3.0",
|
||||||
"tantivy",
|
"tantivy",
|
||||||
|
"pyarrow-stubs"
|
||||||
]
|
]
|
||||||
dev = ["ruff", "pre-commit"]
|
dev = ["ruff", "pre-commit", "pyright"]
|
||||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||||
clip = ["torch", "pillow", "open-clip"]
|
clip = ["torch", "pillow", "open-clip"]
|
||||||
embeddings = [
|
embeddings = [
|
||||||
@@ -93,3 +95,7 @@ markers = [
|
|||||||
"asyncio",
|
"asyncio",
|
||||||
"s3_test",
|
"s3_test",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.pyright]
|
||||||
|
include = ["python/lancedb/table.py"]
|
||||||
|
pythonVersion = "3.12"
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple, Any, Union, Literal
|
||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
|
||||||
|
from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
|
||||||
|
|
||||||
class Connection(object):
|
class Connection(object):
|
||||||
uri: str
|
uri: str
|
||||||
async def table_names(
|
async def table_names(
|
||||||
@@ -31,16 +33,35 @@ class Connection(object):
|
|||||||
class Table:
|
class Table:
|
||||||
def name(self) -> str: ...
|
def name(self) -> str: ...
|
||||||
def __repr__(self) -> str: ...
|
def __repr__(self) -> str: ...
|
||||||
|
def is_open(self) -> bool: ...
|
||||||
|
def close(self) -> None: ...
|
||||||
async def schema(self) -> pa.Schema: ...
|
async def schema(self) -> pa.Schema: ...
|
||||||
async def add(self, data: pa.RecordBatchReader, mode: str) -> None: ...
|
async def add(
|
||||||
|
self, data: pa.RecordBatchReader, mode: Literal["append", "overwrite"]
|
||||||
|
) -> None: ...
|
||||||
async def update(self, updates: Dict[str, str], where: Optional[str]) -> None: ...
|
async def update(self, updates: Dict[str, str], where: Optional[str]) -> None: ...
|
||||||
async def count_rows(self, filter: Optional[str]) -> int: ...
|
async def count_rows(self, filter: Optional[str]) -> int: ...
|
||||||
async def create_index(self, column: str, config, replace: Optional[bool]): ...
|
async def create_index(
|
||||||
|
self,
|
||||||
|
column: str,
|
||||||
|
index: Union[IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS],
|
||||||
|
replace: Optional[bool],
|
||||||
|
): ...
|
||||||
|
async def list_versions(self) -> List[Dict[str, Any]]: ...
|
||||||
async def version(self) -> int: ...
|
async def version(self) -> int: ...
|
||||||
async def checkout(self, version): ...
|
async def checkout(self, version: int): ...
|
||||||
async def checkout_latest(self): ...
|
async def checkout_latest(self): ...
|
||||||
async def restore(self): ...
|
async def restore(self): ...
|
||||||
async def list_indices(self) -> List[IndexConfig]: ...
|
async def list_indices(self) -> list[IndexConfig]: ...
|
||||||
|
async def delete(self, filter: str): ...
|
||||||
|
async def add_columns(self, columns: list[tuple[str, str]]) -> None: ...
|
||||||
|
async def alter_columns(self, columns: list[dict[str, Any]]) -> None: ...
|
||||||
|
async def optimize(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
cleanup_since_ms: Optional[int] = None,
|
||||||
|
delete_unverified: Optional[bool] = None,
|
||||||
|
) -> OptimizeStats: ...
|
||||||
def query(self) -> Query: ...
|
def query(self) -> Query: ...
|
||||||
def vector_search(self) -> VectorQuery: ...
|
def vector_search(self) -> VectorQuery: ...
|
||||||
|
|
||||||
|
|||||||
@@ -603,7 +603,7 @@ class AsyncConnection(object):
|
|||||||
fill_value: Optional[float] = None,
|
fill_value: Optional[float] = None,
|
||||||
storage_options: Optional[Dict[str, str]] = None,
|
storage_options: Optional[Dict[str, str]] = None,
|
||||||
*,
|
*,
|
||||||
embedding_functions: List[EmbeddingFunctionConfig] = None,
|
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||||
data_storage_version: Optional[str] = None,
|
data_storage_version: Optional[str] = None,
|
||||||
use_legacy_format: Optional[bool] = None,
|
use_legacy_format: Optional[bool] = None,
|
||||||
enable_v2_manifest_paths: Optional[bool] = None,
|
enable_v2_manifest_paths: Optional[bool] = None,
|
||||||
|
|||||||
@@ -1,20 +1,10 @@
|
|||||||
# Copyright 2023 LanceDB Developers
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
#
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""Full text search index using tantivy-py"""
|
"""Full text search index using tantivy-py"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
|
||||||
@@ -31,7 +21,7 @@ from .table import LanceTable
|
|||||||
def create_index(
|
def create_index(
|
||||||
index_path: str,
|
index_path: str,
|
||||||
text_fields: List[str],
|
text_fields: List[str],
|
||||||
ordering_fields: List[str] = None,
|
ordering_fields: Optional[List[str]] = None,
|
||||||
tokenizer_name: str = "default",
|
tokenizer_name: str = "default",
|
||||||
) -> tantivy.Index:
|
) -> tantivy.Index:
|
||||||
"""
|
"""
|
||||||
@@ -75,8 +65,8 @@ def populate_index(
|
|||||||
index: tantivy.Index,
|
index: tantivy.Index,
|
||||||
table: LanceTable,
|
table: LanceTable,
|
||||||
fields: List[str],
|
fields: List[str],
|
||||||
writer_heap_size: int = 1024 * 1024 * 1024,
|
writer_heap_size: Optional[int] = None,
|
||||||
ordering_fields: List[str] = None,
|
ordering_fields: Optional[List[str]] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Populate an index with data from a LanceTable
|
Populate an index with data from a LanceTable
|
||||||
@@ -99,6 +89,7 @@ def populate_index(
|
|||||||
"""
|
"""
|
||||||
if ordering_fields is None:
|
if ordering_fields is None:
|
||||||
ordering_fields = []
|
ordering_fields = []
|
||||||
|
writer_heap_size = writer_heap_size or 1024 * 1024 * 1024
|
||||||
# first check the fields exist and are string or large string type
|
# first check the fields exist and are string or large string type
|
||||||
nested = []
|
nested = []
|
||||||
|
|
||||||
|
|||||||
@@ -568,4 +568,14 @@ class IvfPq:
|
|||||||
sample_rate: int = 256
|
sample_rate: int = 256
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["BTree", "IvfFlat", "IvfPq", "HnswPq", "HnswSq", "IndexConfig"]
|
__all__ = [
|
||||||
|
"BTree",
|
||||||
|
"IvfPq",
|
||||||
|
"IvfFlat",
|
||||||
|
"HnswPq",
|
||||||
|
"HnswSq",
|
||||||
|
"IndexConfig",
|
||||||
|
"FTS",
|
||||||
|
"Bitmap",
|
||||||
|
"LabelList",
|
||||||
|
]
|
||||||
|
|||||||
@@ -115,6 +115,9 @@ class Query(pydantic.BaseModel):
|
|||||||
# e.g. `{"nprobes": "10", "refine_factor": "10"}`
|
# e.g. `{"nprobes": "10", "refine_factor": "10"}`
|
||||||
nprobes: int = 10
|
nprobes: int = 10
|
||||||
|
|
||||||
|
lower_bound: Optional[float] = None
|
||||||
|
upper_bound: Optional[float] = None
|
||||||
|
|
||||||
# Refine factor.
|
# Refine factor.
|
||||||
refine_factor: Optional[int] = None
|
refine_factor: Optional[int] = None
|
||||||
|
|
||||||
@@ -604,6 +607,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
|||||||
self._query = query
|
self._query = query
|
||||||
self._metric = "L2"
|
self._metric = "L2"
|
||||||
self._nprobes = 20
|
self._nprobes = 20
|
||||||
|
self._lower_bound = None
|
||||||
|
self._upper_bound = None
|
||||||
self._refine_factor = None
|
self._refine_factor = None
|
||||||
self._vector_column = vector_column
|
self._vector_column = vector_column
|
||||||
self._prefilter = False
|
self._prefilter = False
|
||||||
@@ -649,6 +654,30 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
|||||||
self._nprobes = nprobes
|
self._nprobes = nprobes
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def distance_range(
|
||||||
|
self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None
|
||||||
|
) -> LanceVectorQueryBuilder:
|
||||||
|
"""Set the distance range to use.
|
||||||
|
|
||||||
|
Only rows with distances within range [lower_bound, upper_bound)
|
||||||
|
will be returned.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
lower: Optional[float]
|
||||||
|
The lower bound of the distance range.
|
||||||
|
upper_bound: Optional[float]
|
||||||
|
The upper bound of the distance range.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
LanceVectorQueryBuilder
|
||||||
|
The LanceQueryBuilder object.
|
||||||
|
"""
|
||||||
|
self._lower_bound = lower_bound
|
||||||
|
self._upper_bound = upper_bound
|
||||||
|
return self
|
||||||
|
|
||||||
def ef(self, ef: int) -> LanceVectorQueryBuilder:
|
def ef(self, ef: int) -> LanceVectorQueryBuilder:
|
||||||
"""Set the number of candidates to consider during search.
|
"""Set the number of candidates to consider during search.
|
||||||
|
|
||||||
@@ -728,6 +757,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
|||||||
metric=self._metric,
|
metric=self._metric,
|
||||||
columns=self._columns,
|
columns=self._columns,
|
||||||
nprobes=self._nprobes,
|
nprobes=self._nprobes,
|
||||||
|
lower_bound=self._lower_bound,
|
||||||
|
upper_bound=self._upper_bound,
|
||||||
refine_factor=self._refine_factor,
|
refine_factor=self._refine_factor,
|
||||||
vector_column=self._vector_column,
|
vector_column=self._vector_column,
|
||||||
with_row_id=self._with_row_id,
|
with_row_id=self._with_row_id,
|
||||||
@@ -1284,6 +1315,31 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
|||||||
self._nprobes = nprobes
|
self._nprobes = nprobes
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def distance_range(
|
||||||
|
self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None
|
||||||
|
) -> LanceHybridQueryBuilder:
|
||||||
|
"""
|
||||||
|
Set the distance range to use.
|
||||||
|
|
||||||
|
Only rows with distances within range [lower_bound, upper_bound)
|
||||||
|
will be returned.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
lower: Optional[float]
|
||||||
|
The lower bound of the distance range.
|
||||||
|
upper_bound: Optional[float]
|
||||||
|
The upper bound of the distance range.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
LanceHybridQueryBuilder
|
||||||
|
The LanceHybridQueryBuilder object.
|
||||||
|
"""
|
||||||
|
self._lower_bound = lower_bound
|
||||||
|
self._upper_bound = upper_bound
|
||||||
|
return self
|
||||||
|
|
||||||
def ef(self, ef: int) -> LanceHybridQueryBuilder:
|
def ef(self, ef: int) -> LanceHybridQueryBuilder:
|
||||||
"""
|
"""
|
||||||
Set the number of candidates to consider during search.
|
Set the number of candidates to consider during search.
|
||||||
@@ -1855,6 +1911,29 @@ class AsyncVectorQuery(AsyncQueryBase):
|
|||||||
self._inner.nprobes(nprobes)
|
self._inner.nprobes(nprobes)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def distance_range(
|
||||||
|
self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None
|
||||||
|
) -> AsyncVectorQuery:
|
||||||
|
"""Set the distance range to use.
|
||||||
|
|
||||||
|
Only rows with distances within range [lower_bound, upper_bound)
|
||||||
|
will be returned.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
lower: Optional[float]
|
||||||
|
The lower bound of the distance range.
|
||||||
|
upper_bound: Optional[float]
|
||||||
|
The upper bound of the distance range.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AsyncVectorQuery
|
||||||
|
The AsyncVectorQuery object.
|
||||||
|
"""
|
||||||
|
self._inner.distance_range(lower_bound, upper_bound)
|
||||||
|
return self
|
||||||
|
|
||||||
def ef(self, ef: int) -> AsyncVectorQuery:
|
def ef(self, ef: int) -> AsyncVectorQuery:
|
||||||
"""
|
"""
|
||||||
Set the number of candidates to consider during search
|
Set the number of candidates to consider during search
|
||||||
|
|||||||
@@ -1,15 +1,5 @@
|
|||||||
# Copyright 2023 LanceDB Developers
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
#
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
import logging
|
import logging
|
||||||
@@ -19,7 +9,7 @@ import warnings
|
|||||||
|
|
||||||
from lancedb._lancedb import IndexConfig
|
from lancedb._lancedb import IndexConfig
|
||||||
from lancedb.embeddings.base import EmbeddingFunctionConfig
|
from lancedb.embeddings.base import EmbeddingFunctionConfig
|
||||||
from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfPq, LabelList
|
from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfFlat, IvfPq, LabelList
|
||||||
from lancedb.remote.db import LOOP
|
from lancedb.remote.db import LOOP
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
|
||||||
@@ -91,7 +81,7 @@ class RemoteTable(Table):
|
|||||||
"""to_pandas() is not yet supported on LanceDB cloud."""
|
"""to_pandas() is not yet supported on LanceDB cloud."""
|
||||||
return NotImplementedError("to_pandas() is not yet supported on LanceDB cloud.")
|
return NotImplementedError("to_pandas() is not yet supported on LanceDB cloud.")
|
||||||
|
|
||||||
def checkout(self, version):
|
def checkout(self, version: int):
|
||||||
return LOOP.run(self._table.checkout(version))
|
return LOOP.run(self._table.checkout(version))
|
||||||
|
|
||||||
def checkout_latest(self):
|
def checkout_latest(self):
|
||||||
@@ -235,10 +225,12 @@ class RemoteTable(Table):
|
|||||||
config = HnswPq(distance_type=metric)
|
config = HnswPq(distance_type=metric)
|
||||||
elif index_type == "IVF_HNSW_SQ":
|
elif index_type == "IVF_HNSW_SQ":
|
||||||
config = HnswSq(distance_type=metric)
|
config = HnswSq(distance_type=metric)
|
||||||
|
elif index_type == "IVF_FLAT":
|
||||||
|
config = IvfFlat(distance_type=metric)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown vector index type: {index_type}. Valid options are"
|
f"Unknown vector index type: {index_type}. Valid options are"
|
||||||
" 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
|
" 'IVF_FLAT', 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
|
||||||
)
|
)
|
||||||
|
|
||||||
LOOP.run(self._table.create_index(vector_column_name, config=config))
|
LOOP.run(self._table.create_index(vector_column_name, config=config))
|
||||||
|
|||||||
@@ -61,11 +61,12 @@ from .index import lang_mapping
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import PIL
|
from ._lancedb import Table as LanceDBTable, OptimizeStats, CompactionStats
|
||||||
from lance.dataset import CleanupStats, ReaderLike
|
|
||||||
from ._lancedb import Table as LanceDBTable, OptimizeStats
|
|
||||||
from .db import LanceDBConnection
|
from .db import LanceDBConnection
|
||||||
from .index import IndexConfig
|
from .index import IndexConfig
|
||||||
|
from lance.dataset import CleanupStats, ReaderLike
|
||||||
|
import pandas
|
||||||
|
import PIL
|
||||||
|
|
||||||
pd = safe_import_pandas()
|
pd = safe_import_pandas()
|
||||||
pl = safe_import_polars()
|
pl = safe_import_polars()
|
||||||
@@ -84,7 +85,6 @@ def _pd_schema_without_embedding_funcs(
|
|||||||
)
|
)
|
||||||
if not embedding_functions:
|
if not embedding_functions:
|
||||||
return schema
|
return schema
|
||||||
columns = set(columns)
|
|
||||||
return pa.schema([field for field in schema if field.name in columns])
|
return pa.schema([field for field in schema if field.name in columns])
|
||||||
|
|
||||||
|
|
||||||
@@ -119,7 +119,7 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
|
|||||||
return pa.Table.from_batches(data, schema=schema)
|
return pa.Table.from_batches(data, schema=schema)
|
||||||
else:
|
else:
|
||||||
return pa.Table.from_pylist(data, schema=schema)
|
return pa.Table.from_pylist(data, schema=schema)
|
||||||
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame):
|
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame): # type: ignore
|
||||||
raw_schema = _pd_schema_without_embedding_funcs(schema, data.columns.to_list())
|
raw_schema = _pd_schema_without_embedding_funcs(schema, data.columns.to_list())
|
||||||
table = pa.Table.from_pandas(data, preserve_index=False, schema=raw_schema)
|
table = pa.Table.from_pandas(data, preserve_index=False, schema=raw_schema)
|
||||||
# Do not serialize Pandas metadata
|
# Do not serialize Pandas metadata
|
||||||
@@ -160,7 +160,7 @@ def _sanitize_data(
|
|||||||
metadata: Optional[dict] = None, # embedding metadata
|
metadata: Optional[dict] = None, # embedding metadata
|
||||||
on_bad_vectors: str = "error",
|
on_bad_vectors: str = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
):
|
) -> Tuple[pa.Table, pa.Schema]:
|
||||||
data = _coerce_to_table(data, schema)
|
data = _coerce_to_table(data, schema)
|
||||||
|
|
||||||
if metadata:
|
if metadata:
|
||||||
@@ -178,13 +178,17 @@ def _sanitize_data(
|
|||||||
|
|
||||||
|
|
||||||
def sanitize_create_table(
|
def sanitize_create_table(
|
||||||
data, schema, metadata=None, on_bad_vectors="error", fill_value=0.0
|
data,
|
||||||
|
schema: Union[pa.Schema, LanceModel],
|
||||||
|
metadata=None,
|
||||||
|
on_bad_vectors: str = "error",
|
||||||
|
fill_value: float = 0.0,
|
||||||
):
|
):
|
||||||
if inspect.isclass(schema) and issubclass(schema, LanceModel):
|
if inspect.isclass(schema) and issubclass(schema, LanceModel):
|
||||||
# convert LanceModel to pyarrow schema
|
# convert LanceModel to pyarrow schema
|
||||||
# note that it's possible this contains
|
# note that it's possible this contains
|
||||||
# embedding function metadata already
|
# embedding function metadata already
|
||||||
schema = schema.to_arrow_schema()
|
schema: pa.Schema = schema.to_arrow_schema()
|
||||||
|
|
||||||
if data is not None:
|
if data is not None:
|
||||||
if metadata is None and schema is not None:
|
if metadata is None and schema is not None:
|
||||||
@@ -272,41 +276,6 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def _generator_to_data_and_schema(
|
|
||||||
data: Iterable,
|
|
||||||
) -> Tuple[Iterable[pa.RecordBatch], pa.Schema]:
|
|
||||||
def _with_first_generator(first, data):
|
|
||||||
yield first
|
|
||||||
yield from data
|
|
||||||
|
|
||||||
first = next(data, None)
|
|
||||||
schema = None
|
|
||||||
if isinstance(first, pa.RecordBatch):
|
|
||||||
schema = first.schema
|
|
||||||
data = _with_first_generator(first, data)
|
|
||||||
elif isinstance(first, pa.Table):
|
|
||||||
schema = first.schema
|
|
||||||
data = _with_first_generator(first.to_batches(), data)
|
|
||||||
return data, schema
|
|
||||||
|
|
||||||
|
|
||||||
def _to_record_batch_generator(
|
|
||||||
data: Iterable,
|
|
||||||
schema,
|
|
||||||
metadata,
|
|
||||||
on_bad_vectors,
|
|
||||||
fill_value,
|
|
||||||
):
|
|
||||||
for batch in data:
|
|
||||||
# always convert to table because we need to sanitize the data
|
|
||||||
# and do things like add the vector column etc
|
|
||||||
if isinstance(batch, pa.RecordBatch):
|
|
||||||
batch = pa.Table.from_batches([batch])
|
|
||||||
batch, _ = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
|
|
||||||
for b in batch.to_batches():
|
|
||||||
yield b
|
|
||||||
|
|
||||||
|
|
||||||
def _table_path(base: str, table_name: str) -> str:
|
def _table_path(base: str, table_name: str) -> str:
|
||||||
"""
|
"""
|
||||||
Get a table path that can be used in PyArrow FS.
|
Get a table path that can be used in PyArrow FS.
|
||||||
@@ -404,7 +373,7 @@ class Table(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def to_pandas(self) -> "pd.DataFrame":
|
def to_pandas(self) -> "pandas.DataFrame":
|
||||||
"""Return the table as a pandas DataFrame.
|
"""Return the table as a pandas DataFrame.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@@ -537,8 +506,8 @@ class Table(ABC):
|
|||||||
def create_fts_index(
|
def create_fts_index(
|
||||||
self,
|
self,
|
||||||
field_names: Union[str, List[str]],
|
field_names: Union[str, List[str]],
|
||||||
ordering_field_names: Union[str, List[str]] = None,
|
|
||||||
*,
|
*,
|
||||||
|
ordering_field_names: Optional[Union[str, List[str]]] = None,
|
||||||
replace: bool = False,
|
replace: bool = False,
|
||||||
writer_heap_size: Optional[int] = 1024 * 1024 * 1024,
|
writer_heap_size: Optional[int] = 1024 * 1024 * 1024,
|
||||||
use_tantivy: bool = True,
|
use_tantivy: bool = True,
|
||||||
@@ -790,8 +759,7 @@ class Table(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _execute_query(
|
def _execute_query(
|
||||||
self, query: Query, batch_size: Optional[int] = None
|
self, query: Query, batch_size: Optional[int] = None
|
||||||
) -> pa.RecordBatchReader:
|
) -> pa.RecordBatchReader: ...
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _do_merge(
|
def _do_merge(
|
||||||
@@ -800,8 +768,7 @@ class Table(ABC):
|
|||||||
new_data: DATA,
|
new_data: DATA,
|
||||||
on_bad_vectors: str,
|
on_bad_vectors: str,
|
||||||
fill_value: float,
|
fill_value: float,
|
||||||
):
|
): ...
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete(self, where: str):
|
def delete(self, where: str):
|
||||||
@@ -1092,7 +1059,7 @@ class Table(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def checkout(self):
|
def checkout(self, version: int):
|
||||||
"""
|
"""
|
||||||
Checks out a specific version of the Table
|
Checks out a specific version of the Table
|
||||||
|
|
||||||
@@ -1121,7 +1088,7 @@ class Table(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def list_versions(self):
|
def list_versions(self) -> List[Dict[str, Any]]:
|
||||||
"""List all versions of the table"""
|
"""List all versions of the table"""
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
@@ -1244,7 +1211,7 @@ class LanceTable(Table):
|
|||||||
A PyArrow schema object."""
|
A PyArrow schema object."""
|
||||||
return LOOP.run(self._table.schema())
|
return LOOP.run(self._table.schema())
|
||||||
|
|
||||||
def list_versions(self):
|
def list_versions(self) -> List[Dict[str, Any]]:
|
||||||
"""List all versions of the table"""
|
"""List all versions of the table"""
|
||||||
return LOOP.run(self._table.list_versions())
|
return LOOP.run(self._table.list_versions())
|
||||||
|
|
||||||
@@ -1297,7 +1264,7 @@ class LanceTable(Table):
|
|||||||
"""
|
"""
|
||||||
LOOP.run(self._table.checkout_latest())
|
LOOP.run(self._table.checkout_latest())
|
||||||
|
|
||||||
def restore(self, version: int = None):
|
def restore(self, version: Optional[int] = None):
|
||||||
"""Restore a version of the table. This is an in-place operation.
|
"""Restore a version of the table. This is an in-place operation.
|
||||||
|
|
||||||
This creates a new version where the data is equivalent to the
|
This creates a new version where the data is equivalent to the
|
||||||
@@ -1338,7 +1305,7 @@ class LanceTable(Table):
|
|||||||
def count_rows(self, filter: Optional[str] = None) -> int:
|
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||||
return LOOP.run(self._table.count_rows(filter))
|
return LOOP.run(self._table.count_rows(filter))
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self) -> int:
|
||||||
return self.count_rows()
|
return self.count_rows()
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
@@ -1506,8 +1473,8 @@ class LanceTable(Table):
|
|||||||
def create_fts_index(
|
def create_fts_index(
|
||||||
self,
|
self,
|
||||||
field_names: Union[str, List[str]],
|
field_names: Union[str, List[str]],
|
||||||
ordering_field_names: Union[str, List[str]] = None,
|
|
||||||
*,
|
*,
|
||||||
|
ordering_field_names: Optional[Union[str, List[str]]] = None,
|
||||||
replace: bool = False,
|
replace: bool = False,
|
||||||
writer_heap_size: Optional[int] = 1024 * 1024 * 1024,
|
writer_heap_size: Optional[int] = 1024 * 1024 * 1024,
|
||||||
use_tantivy: bool = True,
|
use_tantivy: bool = True,
|
||||||
@@ -1594,6 +1561,7 @@ class LanceTable(Table):
|
|||||||
writer_heap_size=writer_heap_size,
|
writer_heap_size=writer_heap_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def infer_tokenizer_configs(tokenizer_name: str) -> dict:
|
def infer_tokenizer_configs(tokenizer_name: str) -> dict:
|
||||||
if tokenizer_name == "default":
|
if tokenizer_name == "default":
|
||||||
return {
|
return {
|
||||||
@@ -1759,7 +1727,7 @@ class LanceTable(Table):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def search(
|
def search( # type: ignore
|
||||||
self,
|
self,
|
||||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||||
vector_column_name: Optional[str] = None,
|
vector_column_name: Optional[str] = None,
|
||||||
@@ -1895,11 +1863,11 @@ class LanceTable(Table):
|
|||||||
name: str,
|
name: str,
|
||||||
data: Optional[DATA] = None,
|
data: Optional[DATA] = None,
|
||||||
schema: Optional[pa.Schema] = None,
|
schema: Optional[pa.Schema] = None,
|
||||||
mode: Literal["create", "overwrite", "append"] = "create",
|
mode: Literal["create", "overwrite"] = "create",
|
||||||
exist_ok: bool = False,
|
exist_ok: bool = False,
|
||||||
on_bad_vectors: str = "error",
|
on_bad_vectors: str = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
embedding_functions: List[EmbeddingFunctionConfig] = None,
|
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||||
*,
|
*,
|
||||||
storage_options: Optional[Dict[str, str]] = None,
|
storage_options: Optional[Dict[str, str]] = None,
|
||||||
data_storage_version: Optional[str] = None,
|
data_storage_version: Optional[str] = None,
|
||||||
@@ -2065,7 +2033,7 @@ class LanceTable(Table):
|
|||||||
older_than, delete_unverified=delete_unverified
|
older_than, delete_unverified=delete_unverified
|
||||||
)
|
)
|
||||||
|
|
||||||
def compact_files(self, *args, **kwargs):
|
def compact_files(self, *args, **kwargs) -> CompactionStats:
|
||||||
"""
|
"""
|
||||||
Run the compaction process on the table.
|
Run the compaction process on the table.
|
||||||
|
|
||||||
@@ -2450,7 +2418,7 @@ def _process_iterator(data: Iterable, schema: Optional[pa.Schema] = None) -> pa.
|
|||||||
if batch_table.schema != schema:
|
if batch_table.schema != schema:
|
||||||
try:
|
try:
|
||||||
batch_table = batch_table.cast(schema)
|
batch_table = batch_table.cast(schema)
|
||||||
except pa.lib.ArrowInvalid:
|
except pa.lib.ArrowInvalid: # type: ignore
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Input iterator yielded a batch with schema that "
|
f"Input iterator yielded a batch with schema that "
|
||||||
f"does not match the expected schema.\nExpected:\n{schema}\n"
|
f"does not match the expected schema.\nExpected:\n{schema}\n"
|
||||||
@@ -2710,16 +2678,17 @@ class AsyncTable:
|
|||||||
on_bad_vectors = "error"
|
on_bad_vectors = "error"
|
||||||
if fill_value is None:
|
if fill_value is None:
|
||||||
fill_value = 0.0
|
fill_value = 0.0
|
||||||
data, _ = _sanitize_data(
|
table_and_schema: Tuple[pa.Table, pa.Schema] = _sanitize_data(
|
||||||
data,
|
data,
|
||||||
schema,
|
schema,
|
||||||
metadata=schema.metadata,
|
metadata=schema.metadata,
|
||||||
on_bad_vectors=on_bad_vectors,
|
on_bad_vectors=on_bad_vectors,
|
||||||
fill_value=fill_value,
|
fill_value=fill_value,
|
||||||
)
|
)
|
||||||
if isinstance(data, pa.Table):
|
tbl, schema = table_and_schema
|
||||||
data = pa.RecordBatchReader.from_batches(data.schema, data.to_batches())
|
if isinstance(tbl, pa.Table):
|
||||||
await self._inner.add(data, mode)
|
data = pa.RecordBatchReader.from_batches(schema, tbl.to_batches())
|
||||||
|
await self._inner.add(data, mode or "append")
|
||||||
|
|
||||||
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
||||||
"""
|
"""
|
||||||
@@ -2817,6 +2786,7 @@ class AsyncTable:
|
|||||||
async_query.nearest_to(query.vector)
|
async_query.nearest_to(query.vector)
|
||||||
.distance_type(query.metric)
|
.distance_type(query.metric)
|
||||||
.nprobes(query.nprobes)
|
.nprobes(query.nprobes)
|
||||||
|
.distance_range(query.lower_bound, query.upper_bound)
|
||||||
)
|
)
|
||||||
if query.refine_factor:
|
if query.refine_factor:
|
||||||
async_query = async_query.refine_factor(query.refine_factor)
|
async_query = async_query.refine_factor(query.refine_factor)
|
||||||
@@ -2977,7 +2947,7 @@ class AsyncTable:
|
|||||||
|
|
||||||
return await self._inner.update(updates_sql, where)
|
return await self._inner.update(updates_sql, where)
|
||||||
|
|
||||||
async def add_columns(self, transforms: Dict[str, str]):
|
async def add_columns(self, transforms: dict[str, str]):
|
||||||
"""
|
"""
|
||||||
Add new columns with defined values.
|
Add new columns with defined values.
|
||||||
|
|
||||||
@@ -2990,7 +2960,7 @@ class AsyncTable:
|
|||||||
"""
|
"""
|
||||||
await self._inner.add_columns(list(transforms.items()))
|
await self._inner.add_columns(list(transforms.items()))
|
||||||
|
|
||||||
async def alter_columns(self, *alterations: Iterable[Dict[str, str]]):
|
async def alter_columns(self, *alterations: Iterable[dict[str, Any]]):
|
||||||
"""
|
"""
|
||||||
Alter column names and nullability.
|
Alter column names and nullability.
|
||||||
|
|
||||||
@@ -3049,7 +3019,7 @@ class AsyncTable:
|
|||||||
|
|
||||||
return versions
|
return versions
|
||||||
|
|
||||||
async def checkout(self, version):
|
async def checkout(self, version: int):
|
||||||
"""
|
"""
|
||||||
Checks out a specific version of the Table
|
Checks out a specific version of the Table
|
||||||
|
|
||||||
@@ -3148,9 +3118,12 @@ class AsyncTable:
|
|||||||
you have added or modified 100,000 or more records or run more than 20 data
|
you have added or modified 100,000 or more records or run more than 20 data
|
||||||
modification operations.
|
modification operations.
|
||||||
"""
|
"""
|
||||||
|
cleanup_since_ms: Optional[int] = None
|
||||||
if cleanup_older_than is not None:
|
if cleanup_older_than is not None:
|
||||||
cleanup_older_than = round(cleanup_older_than.total_seconds() * 1000)
|
cleanup_since_ms = round(cleanup_older_than.total_seconds() * 1000)
|
||||||
return await self._inner.optimize(cleanup_older_than, delete_unverified)
|
return await self._inner.optimize(
|
||||||
|
cleanup_since_ms=cleanup_since_ms, delete_unverified=delete_unverified
|
||||||
|
)
|
||||||
|
|
||||||
async def list_indices(self) -> Iterable[IndexConfig]:
|
async def list_indices(self) -> Iterable[IndexConfig]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -167,8 +167,24 @@ def test_search_index(tmp_path, table):
|
|||||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||||
def test_search_fts(table, use_tantivy):
|
def test_search_fts(table, use_tantivy):
|
||||||
table.create_fts_index("text", use_tantivy=use_tantivy)
|
table.create_fts_index("text", use_tantivy=use_tantivy)
|
||||||
results = table.search("puppy").limit(5).to_list()
|
results = table.search("puppy").select(["id", "text"]).limit(5).to_list()
|
||||||
assert len(results) == 5
|
assert len(results) == 5
|
||||||
|
assert len(results[0]) == 3 # id, text, _score
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fts_select_async(async_table):
|
||||||
|
tbl = await async_table
|
||||||
|
await tbl.create_index("text", config=FTS())
|
||||||
|
results = (
|
||||||
|
await tbl.query()
|
||||||
|
.nearest_to_text("puppy")
|
||||||
|
.select(["id", "text"])
|
||||||
|
.limit(5)
|
||||||
|
.to_list()
|
||||||
|
)
|
||||||
|
assert len(results) == 5
|
||||||
|
assert len(results[0]) == 3 # id, text, _score
|
||||||
|
|
||||||
|
|
||||||
def test_search_fts_phrase_query(table):
|
def test_search_fts_phrase_query(table):
|
||||||
|
|||||||
@@ -94,6 +94,73 @@ def test_with_row_id(table: lancedb.table.Table):
|
|||||||
assert rs["_rowid"].to_pylist() == [0, 1]
|
assert rs["_rowid"].to_pylist() == [0, 1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_distance_range(table: lancedb.table.Table):
|
||||||
|
q = [0, 0]
|
||||||
|
rs = table.search(q).to_arrow()
|
||||||
|
dists = rs["_distance"].to_pylist()
|
||||||
|
min_dist = dists[0]
|
||||||
|
max_dist = dists[-1]
|
||||||
|
|
||||||
|
res = table.search(q).distance_range(upper_bound=min_dist).to_arrow()
|
||||||
|
assert len(res) == 0
|
||||||
|
|
||||||
|
res = table.search(q).distance_range(lower_bound=max_dist).to_arrow()
|
||||||
|
assert len(res) == 1
|
||||||
|
assert res["_distance"].to_pylist() == [max_dist]
|
||||||
|
|
||||||
|
res = table.search(q).distance_range(upper_bound=max_dist).to_arrow()
|
||||||
|
assert len(res) == 1
|
||||||
|
assert res["_distance"].to_pylist() == [min_dist]
|
||||||
|
|
||||||
|
res = table.search(q).distance_range(lower_bound=min_dist).to_arrow()
|
||||||
|
assert len(res) == 2
|
||||||
|
assert res["_distance"].to_pylist() == [min_dist, max_dist]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_distance_range_async(table_async: AsyncTable):
|
||||||
|
q = [0, 0]
|
||||||
|
rs = await table_async.query().nearest_to(q).to_arrow()
|
||||||
|
dists = rs["_distance"].to_pylist()
|
||||||
|
min_dist = dists[0]
|
||||||
|
max_dist = dists[-1]
|
||||||
|
|
||||||
|
res = (
|
||||||
|
await table_async.query()
|
||||||
|
.nearest_to(q)
|
||||||
|
.distance_range(upper_bound=min_dist)
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
assert len(res) == 0
|
||||||
|
|
||||||
|
res = (
|
||||||
|
await table_async.query()
|
||||||
|
.nearest_to(q)
|
||||||
|
.distance_range(lower_bound=max_dist)
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
assert len(res) == 1
|
||||||
|
assert res["_distance"].to_pylist() == [max_dist]
|
||||||
|
|
||||||
|
res = (
|
||||||
|
await table_async.query()
|
||||||
|
.nearest_to(q)
|
||||||
|
.distance_range(upper_bound=max_dist)
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
assert len(res) == 1
|
||||||
|
assert res["_distance"].to_pylist() == [min_dist]
|
||||||
|
|
||||||
|
res = (
|
||||||
|
await table_async.query()
|
||||||
|
.nearest_to(q)
|
||||||
|
.distance_range(lower_bound=min_dist)
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
assert len(res) == 2
|
||||||
|
assert res["_distance"].to_pylist() == [min_dist, max_dist]
|
||||||
|
|
||||||
|
|
||||||
def test_vector_query_with_no_limit(table):
|
def test_vector_query_with_no_limit(table):
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(0).select(
|
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(0).select(
|
||||||
|
|||||||
@@ -306,6 +306,8 @@ def test_query_sync_minimal():
|
|||||||
"k": 10,
|
"k": 10,
|
||||||
"prefilter": False,
|
"prefilter": False,
|
||||||
"refine_factor": None,
|
"refine_factor": None,
|
||||||
|
"lower_bound": None,
|
||||||
|
"upper_bound": None,
|
||||||
"ef": None,
|
"ef": None,
|
||||||
"vector": [1.0, 2.0, 3.0],
|
"vector": [1.0, 2.0, 3.0],
|
||||||
"nprobes": 20,
|
"nprobes": 20,
|
||||||
@@ -348,6 +350,8 @@ def test_query_sync_maximal():
|
|||||||
"refine_factor": 10,
|
"refine_factor": 10,
|
||||||
"vector": [1.0, 2.0, 3.0],
|
"vector": [1.0, 2.0, 3.0],
|
||||||
"nprobes": 5,
|
"nprobes": 5,
|
||||||
|
"lower_bound": None,
|
||||||
|
"upper_bound": None,
|
||||||
"ef": None,
|
"ef": None,
|
||||||
"filter": "id > 0",
|
"filter": "id > 0",
|
||||||
"columns": ["id", "name"],
|
"columns": ["id", "name"],
|
||||||
@@ -449,6 +453,8 @@ def test_query_sync_hybrid():
|
|||||||
"refine_factor": None,
|
"refine_factor": None,
|
||||||
"vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
"vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||||
"nprobes": 20,
|
"nprobes": 20,
|
||||||
|
"lower_bound": None,
|
||||||
|
"upper_bound": None,
|
||||||
"ef": None,
|
"ef": None,
|
||||||
"with_row_id": True,
|
"with_row_id": True,
|
||||||
"version": None,
|
"version": None,
|
||||||
|
|||||||
@@ -152,6 +152,10 @@ impl FTSQuery {
|
|||||||
self.inner = self.inner.clone().select(Select::dynamic(&columns));
|
self.inner = self.inner.clone().select(Select::dynamic(&columns));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn select_columns(&mut self, columns: Vec<String>) {
|
||||||
|
self.inner = self.inner.clone().select(Select::columns(&columns));
|
||||||
|
}
|
||||||
|
|
||||||
pub fn limit(&mut self, limit: u32) {
|
pub fn limit(&mut self, limit: u32) {
|
||||||
self.inner = self.inner.clone().limit(limit as usize);
|
self.inner = self.inner.clone().limit(limit as usize);
|
||||||
}
|
}
|
||||||
@@ -280,6 +284,11 @@ impl VectorQuery {
|
|||||||
self.inner = self.inner.clone().nprobes(nprobe as usize);
|
self.inner = self.inner.clone().nprobes(nprobe as usize);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[pyo3(signature = (lower_bound=None, upper_bound=None))]
|
||||||
|
pub fn distance_range(&mut self, lower_bound: Option<f32>, upper_bound: Option<f32>) {
|
||||||
|
self.inner = self.inner.clone().distance_range(lower_bound, upper_bound);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn ef(&mut self, ef: u32) {
|
pub fn ef(&mut self, ef: u32) {
|
||||||
self.inner = self.inner.clone().ef(ef as usize);
|
self.inner = self.inner.clone().ef(ef as usize);
|
||||||
}
|
}
|
||||||
@@ -341,6 +350,11 @@ impl HybridQuery {
|
|||||||
self.inner_fts.select(columns);
|
self.inner_fts.select(columns);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn select_columns(&mut self, columns: Vec<String>) {
|
||||||
|
self.inner_vec.select_columns(columns.clone());
|
||||||
|
self.inner_fts.select_columns(columns);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn limit(&mut self, limit: u32) {
|
pub fn limit(&mut self, limit: u32) {
|
||||||
self.inner_vec.limit(limit);
|
self.inner_vec.limit(limit);
|
||||||
self.inner_fts.limit(limit);
|
self.inner_fts.limit(limit);
|
||||||
|
|||||||
@@ -97,10 +97,12 @@ impl Table {
|
|||||||
self.name.clone()
|
self.name.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns True if the table is open, False if it is closed.
|
||||||
pub fn is_open(&self) -> bool {
|
pub fn is_open(&self) -> bool {
|
||||||
self.inner.is_some()
|
self.inner.is_some()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Closes the table, releasing any resources associated with it.
|
||||||
pub fn close(&mut self) {
|
pub fn close(&mut self) {
|
||||||
self.inner.take();
|
self.inner.take();
|
||||||
}
|
}
|
||||||
@@ -301,6 +303,7 @@ impl Table {
|
|||||||
Query::new(self.inner_ref().unwrap().query())
|
Query::new(self.inner_ref().unwrap().query())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Optimize the on-disk data by compacting and pruning old data, for better performance.
|
||||||
#[pyo3(signature = (cleanup_since_ms=None, delete_unverified=None))]
|
#[pyo3(signature = (cleanup_since_ms=None, delete_unverified=None))]
|
||||||
pub fn optimize(
|
pub fn optimize(
|
||||||
self_: PyRef<'_, Self>,
|
self_: PyRef<'_, Self>,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb-node"
|
name = "lancedb-node"
|
||||||
version = "0.14.1-beta.6"
|
version = "0.14.1"
|
||||||
description = "Serverless, low-latency vector database for AI applications"
|
description = "Serverless, low-latency vector database for AI applications"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
version = "0.14.1-beta.6"
|
version = "0.14.1"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
|||||||
@@ -214,6 +214,7 @@ mod polars_arrow_convertors;
|
|||||||
pub mod query;
|
pub mod query;
|
||||||
#[cfg(feature = "remote")]
|
#[cfg(feature = "remote")]
|
||||||
pub mod remote;
|
pub mod remote;
|
||||||
|
pub mod rerankers;
|
||||||
pub mod table;
|
pub mod table;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
|
|
||||||
|
|||||||
@@ -15,19 +15,31 @@
|
|||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use arrow::compute::concat_batches;
|
||||||
use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array};
|
use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array};
|
||||||
use arrow_schema::DataType;
|
use arrow_schema::DataType;
|
||||||
use datafusion_physical_plan::ExecutionPlan;
|
use datafusion_physical_plan::ExecutionPlan;
|
||||||
|
use futures::{stream, try_join, FutureExt, TryStreamExt};
|
||||||
use half::f16;
|
use half::f16;
|
||||||
use lance::dataset::scanner::DatasetRecordBatchStream;
|
use lance::{
|
||||||
|
arrow::RecordBatchExt,
|
||||||
|
dataset::{scanner::DatasetRecordBatchStream, ROW_ID},
|
||||||
|
};
|
||||||
use lance_datafusion::exec::execute_plan;
|
use lance_datafusion::exec::execute_plan;
|
||||||
|
use lance_index::scalar::inverted::SCORE_COL;
|
||||||
use lance_index::scalar::FullTextSearchQuery;
|
use lance_index::scalar::FullTextSearchQuery;
|
||||||
|
use lance_index::vector::DIST_COL;
|
||||||
|
use lance_io::stream::RecordBatchStreamAdapter;
|
||||||
|
|
||||||
use crate::arrow::SendableRecordBatchStream;
|
use crate::arrow::SendableRecordBatchStream;
|
||||||
use crate::error::{Error, Result};
|
use crate::error::{Error, Result};
|
||||||
|
use crate::rerankers::rrf::RRFReranker;
|
||||||
|
use crate::rerankers::{check_reranker_result, NormalizeMethod, Reranker};
|
||||||
use crate::table::TableInternal;
|
use crate::table::TableInternal;
|
||||||
use crate::DistanceType;
|
use crate::DistanceType;
|
||||||
|
|
||||||
|
mod hybrid;
|
||||||
|
|
||||||
pub(crate) const DEFAULT_TOP_K: usize = 10;
|
pub(crate) const DEFAULT_TOP_K: usize = 10;
|
||||||
|
|
||||||
/// Which columns should be retrieved from the database
|
/// Which columns should be retrieved from the database
|
||||||
@@ -435,6 +447,16 @@ pub trait QueryBase {
|
|||||||
|
|
||||||
/// Return the `_rowid` meta column from the Table.
|
/// Return the `_rowid` meta column from the Table.
|
||||||
fn with_row_id(self) -> Self;
|
fn with_row_id(self) -> Self;
|
||||||
|
|
||||||
|
/// Rerank the results using the specified reranker.
|
||||||
|
///
|
||||||
|
/// This is currently only supported for Hybrid Search.
|
||||||
|
fn rerank(self, reranker: Arc<dyn Reranker>) -> Self;
|
||||||
|
|
||||||
|
/// The method to normalize the scores. Can be "rank" or "Score". If "Rank",
|
||||||
|
/// the scores are converted to ranks and then normalized. If "Score", the
|
||||||
|
/// scores are normalized directly.
|
||||||
|
fn norm(self, norm: NormalizeMethod) -> Self;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait HasQuery {
|
pub trait HasQuery {
|
||||||
@@ -481,6 +503,16 @@ impl<T: HasQuery> QueryBase for T {
|
|||||||
self.mut_query().with_row_id = true;
|
self.mut_query().with_row_id = true;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn rerank(mut self, reranker: Arc<dyn Reranker>) -> Self {
|
||||||
|
self.mut_query().reranker = Some(reranker);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
fn norm(mut self, norm: NormalizeMethod) -> Self {
|
||||||
|
self.mut_query().norm = Some(norm);
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Options for controlling the execution of a query
|
/// Options for controlling the execution of a query
|
||||||
@@ -600,6 +632,13 @@ pub struct Query {
|
|||||||
|
|
||||||
/// If set to false, the filter will be applied after the vector search.
|
/// If set to false, the filter will be applied after the vector search.
|
||||||
pub(crate) prefilter: bool,
|
pub(crate) prefilter: bool,
|
||||||
|
|
||||||
|
/// Implementation of reranker that can be used to reorder or combine query
|
||||||
|
/// results, especially if using hybrid search
|
||||||
|
pub(crate) reranker: Option<Arc<dyn Reranker>>,
|
||||||
|
|
||||||
|
/// Configure how query results are normalized when doing hybrid search
|
||||||
|
pub(crate) norm: Option<NormalizeMethod>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Query {
|
impl Query {
|
||||||
@@ -614,6 +653,8 @@ impl Query {
|
|||||||
fast_search: false,
|
fast_search: false,
|
||||||
with_row_id: false,
|
with_row_id: false,
|
||||||
prefilter: true,
|
prefilter: true,
|
||||||
|
reranker: None,
|
||||||
|
norm: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -714,6 +755,10 @@ pub struct VectorQuery {
|
|||||||
// IVF PQ - ANN search.
|
// IVF PQ - ANN search.
|
||||||
pub(crate) query_vector: Vec<Arc<dyn Array>>,
|
pub(crate) query_vector: Vec<Arc<dyn Array>>,
|
||||||
pub(crate) nprobes: usize,
|
pub(crate) nprobes: usize,
|
||||||
|
// The lower bound (inclusive) of the distance to search for.
|
||||||
|
pub(crate) lower_bound: Option<f32>,
|
||||||
|
// The upper bound (exclusive) of the distance to search for.
|
||||||
|
pub(crate) upper_bound: Option<f32>,
|
||||||
// The number of candidates to return during the refine step for HNSW,
|
// The number of candidates to return during the refine step for HNSW,
|
||||||
// defaults to 1.5 * limit.
|
// defaults to 1.5 * limit.
|
||||||
pub(crate) ef: Option<usize>,
|
pub(crate) ef: Option<usize>,
|
||||||
@@ -730,6 +775,8 @@ impl VectorQuery {
|
|||||||
column: None,
|
column: None,
|
||||||
query_vector: Vec::new(),
|
query_vector: Vec::new(),
|
||||||
nprobes: 20,
|
nprobes: 20,
|
||||||
|
lower_bound: None,
|
||||||
|
upper_bound: None,
|
||||||
ef: None,
|
ef: None,
|
||||||
refine_factor: None,
|
refine_factor: None,
|
||||||
distance_type: None,
|
distance_type: None,
|
||||||
@@ -790,6 +837,14 @@ impl VectorQuery {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set the distance range for vector search,
|
||||||
|
/// only rows with distances in the range [lower_bound, upper_bound) will be returned
|
||||||
|
pub fn distance_range(mut self, lower_bound: Option<f32>, upper_bound: Option<f32>) -> Self {
|
||||||
|
self.lower_bound = lower_bound;
|
||||||
|
self.upper_bound = upper_bound;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
/// Set the number of candidates to return during the refine step for HNSW
|
/// Set the number of candidates to return during the refine step for HNSW
|
||||||
///
|
///
|
||||||
/// This argument is only used when the vector column has an HNSW index.
|
/// This argument is only used when the vector column has an HNSW index.
|
||||||
@@ -862,6 +917,65 @@ impl VectorQuery {
|
|||||||
self.use_index = false;
|
self.use_index = false;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn execute_hybrid(&self) -> Result<SendableRecordBatchStream> {
|
||||||
|
// clone query and specify we want to include row IDs, which can be needed for reranking
|
||||||
|
let fts_query = self.base.clone().with_row_id();
|
||||||
|
let mut vector_query = self.clone().with_row_id();
|
||||||
|
|
||||||
|
vector_query.base.full_text_search = None;
|
||||||
|
let (fts_results, vec_results) = try_join!(fts_query.execute(), vector_query.execute())?;
|
||||||
|
|
||||||
|
let (fts_results, vec_results) = try_join!(
|
||||||
|
fts_results.try_collect::<Vec<_>>(),
|
||||||
|
vec_results.try_collect::<Vec<_>>()
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// try to get the schema to use when combining batches.
|
||||||
|
// if either
|
||||||
|
let (fts_schema, vec_schema) = hybrid::query_schemas(&fts_results, &vec_results);
|
||||||
|
|
||||||
|
// concatenate all the batches together
|
||||||
|
let mut fts_results = concat_batches(&fts_schema, fts_results.iter())?;
|
||||||
|
let mut vec_results = concat_batches(&vec_schema, vec_results.iter())?;
|
||||||
|
|
||||||
|
if matches!(self.base.norm, Some(NormalizeMethod::Rank)) {
|
||||||
|
vec_results = hybrid::rank(vec_results, DIST_COL, None)?;
|
||||||
|
fts_results = hybrid::rank(fts_results, SCORE_COL, None)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
vec_results = hybrid::normalize_scores(vec_results, DIST_COL, None)?;
|
||||||
|
fts_results = hybrid::normalize_scores(fts_results, SCORE_COL, None)?;
|
||||||
|
|
||||||
|
let reranker = self
|
||||||
|
.base
|
||||||
|
.reranker
|
||||||
|
.clone()
|
||||||
|
.unwrap_or(Arc::new(RRFReranker::default()));
|
||||||
|
|
||||||
|
let fts_query = self.base.full_text_search.as_ref().ok_or(Error::Runtime {
|
||||||
|
message: "there should be an FTS search".to_string(),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let mut results = reranker
|
||||||
|
.rerank_hybrid(&fts_query.query, vec_results, fts_results)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
check_reranker_result(&results)?;
|
||||||
|
|
||||||
|
let limit = self.base.limit.unwrap_or(DEFAULT_TOP_K);
|
||||||
|
if results.num_rows() > limit {
|
||||||
|
results = results.slice(0, limit);
|
||||||
|
}
|
||||||
|
|
||||||
|
if !self.base.with_row_id {
|
||||||
|
results = results.drop_column(ROW_ID)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(SendableRecordBatchStream::from(
|
||||||
|
RecordBatchStreamAdapter::new(results.schema(), stream::iter([Ok(results)])),
|
||||||
|
))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ExecutableQuery for VectorQuery {
|
impl ExecutableQuery for VectorQuery {
|
||||||
@@ -873,6 +987,11 @@ impl ExecutableQuery for VectorQuery {
|
|||||||
&self,
|
&self,
|
||||||
options: QueryExecutionOptions,
|
options: QueryExecutionOptions,
|
||||||
) -> Result<SendableRecordBatchStream> {
|
) -> Result<SendableRecordBatchStream> {
|
||||||
|
if self.base.full_text_search.is_some() {
|
||||||
|
let hybrid_result = async move { self.execute_hybrid().await }.boxed().await?;
|
||||||
|
return Ok(hybrid_result);
|
||||||
|
}
|
||||||
|
|
||||||
Ok(SendableRecordBatchStream::from(
|
Ok(SendableRecordBatchStream::from(
|
||||||
DatasetRecordBatchStream::new(execute_plan(
|
DatasetRecordBatchStream::new(execute_plan(
|
||||||
self.create_plan(options).await?,
|
self.create_plan(options).await?,
|
||||||
@@ -894,20 +1013,20 @@ impl HasQuery for VectorQuery {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::sync::Arc;
|
use std::{collections::HashSet, sync::Arc};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use arrow::{compute::concat_batches, datatypes::Int32Type};
|
use arrow::{array::downcast_array, compute::concat_batches, datatypes::Int32Type};
|
||||||
use arrow_array::{
|
use arrow_array::{
|
||||||
cast::AsArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator,
|
cast::AsArray, types::Float32Type, FixedSizeListArray, Float32Array, Int32Array,
|
||||||
RecordBatchReader,
|
RecordBatch, RecordBatchIterator, RecordBatchReader, StringArray,
|
||||||
};
|
};
|
||||||
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
|
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
|
||||||
use futures::{StreamExt, TryStreamExt};
|
use futures::{StreamExt, TryStreamExt};
|
||||||
use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector};
|
use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector};
|
||||||
use tempfile::tempdir;
|
use tempfile::tempdir;
|
||||||
|
|
||||||
use crate::{connect, Table};
|
use crate::{connect, connection::CreateTableMode, Table};
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_setters_getters() {
|
async fn test_setters_getters() {
|
||||||
@@ -1245,6 +1364,30 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_distance_range() {
|
||||||
|
let tmp_dir = tempdir().unwrap();
|
||||||
|
let table = make_test_table(&tmp_dir).await;
|
||||||
|
let results = table
|
||||||
|
.vector_search(&[0.1, 0.2, 0.3, 0.4])
|
||||||
|
.unwrap()
|
||||||
|
.distance_range(Some(0.0), Some(1.0))
|
||||||
|
.limit(10)
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.try_collect::<Vec<_>>()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
for batch in results {
|
||||||
|
let distances = batch["_distance"].as_primitive::<Float32Type>();
|
||||||
|
assert!(distances.iter().all(|d| {
|
||||||
|
let d = d.unwrap();
|
||||||
|
(0.0..1.0).contains(&d)
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_multiple_query_vectors() {
|
async fn test_multiple_query_vectors() {
|
||||||
let tmp_dir = tempdir().unwrap();
|
let tmp_dir = tempdir().unwrap();
|
||||||
@@ -1274,4 +1417,156 @@ mod tests {
|
|||||||
assert!(query_index.values().contains(&0));
|
assert!(query_index.values().contains(&0));
|
||||||
assert!(query_index.values().contains(&1));
|
assert!(query_index.values().contains(&1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_hybrid_search() {
|
||||||
|
let tmp_dir = tempdir().unwrap();
|
||||||
|
let dataset_path = tmp_dir.path();
|
||||||
|
let conn = connect(dataset_path.to_str().unwrap())
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let dims = 2;
|
||||||
|
let schema = Arc::new(ArrowSchema::new(vec![
|
||||||
|
ArrowField::new("text", DataType::Utf8, false),
|
||||||
|
ArrowField::new(
|
||||||
|
"vector",
|
||||||
|
DataType::FixedSizeList(
|
||||||
|
Arc::new(ArrowField::new("item", DataType::Float32, true)),
|
||||||
|
dims,
|
||||||
|
),
|
||||||
|
false,
|
||||||
|
),
|
||||||
|
]));
|
||||||
|
|
||||||
|
let text = StringArray::from(vec!["dog", "cat", "a", "b"]);
|
||||||
|
let vectors = vec![
|
||||||
|
Some(vec![Some(0.0), Some(0.0)]),
|
||||||
|
Some(vec![Some(-2.0), Some(-2.0)]),
|
||||||
|
Some(vec![Some(50.0), Some(50.0)]),
|
||||||
|
Some(vec![Some(-30.0), Some(-30.0)]),
|
||||||
|
];
|
||||||
|
let vector = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(vectors, dims);
|
||||||
|
|
||||||
|
let record_batch =
|
||||||
|
RecordBatch::try_new(schema.clone(), vec![Arc::new(text), Arc::new(vector)]).unwrap();
|
||||||
|
let record_batch_iter =
|
||||||
|
RecordBatchIterator::new(vec![record_batch].into_iter().map(Ok), schema.clone());
|
||||||
|
let table = conn
|
||||||
|
.create_table("my_table", record_batch_iter)
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
table
|
||||||
|
.create_index(&["text"], crate::index::Index::FTS(Default::default()))
|
||||||
|
.replace(true)
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let fts_query = FullTextSearchQuery::new("b".to_string());
|
||||||
|
let results = table
|
||||||
|
.query()
|
||||||
|
.full_text_search(fts_query)
|
||||||
|
.limit(2)
|
||||||
|
.nearest_to(&[-10.0, -10.0])
|
||||||
|
.unwrap()
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.try_collect::<Vec<_>>()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let batch = &results[0];
|
||||||
|
|
||||||
|
let texts: StringArray = downcast_array(batch.column_by_name("text").unwrap());
|
||||||
|
let texts = texts.iter().map(|e| e.unwrap()).collect::<HashSet<_>>();
|
||||||
|
assert!(texts.contains("cat")); // should be close by vector search
|
||||||
|
assert!(texts.contains("b")); // should be close by fts search
|
||||||
|
|
||||||
|
// ensure that this works correctly if there are no matching FTS results
|
||||||
|
let fts_query = FullTextSearchQuery::new("z".to_string());
|
||||||
|
table
|
||||||
|
.query()
|
||||||
|
.full_text_search(fts_query)
|
||||||
|
.limit(2)
|
||||||
|
.nearest_to(&[-10.0, -10.0])
|
||||||
|
.unwrap()
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.try_collect::<Vec<_>>()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_hybrid_search_empty_table() {
|
||||||
|
let tmp_dir = tempdir().unwrap();
|
||||||
|
let dataset_path = tmp_dir.path();
|
||||||
|
let conn = connect(dataset_path.to_str().unwrap())
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let dims = 2;
|
||||||
|
|
||||||
|
let schema = Arc::new(ArrowSchema::new(vec![
|
||||||
|
ArrowField::new("text", DataType::Utf8, false),
|
||||||
|
ArrowField::new(
|
||||||
|
"vector",
|
||||||
|
DataType::FixedSizeList(
|
||||||
|
Arc::new(ArrowField::new("item", DataType::Float32, true)),
|
||||||
|
dims,
|
||||||
|
),
|
||||||
|
false,
|
||||||
|
),
|
||||||
|
]));
|
||||||
|
|
||||||
|
// ensure hybrid search is also supported on a fully empty table
|
||||||
|
let vectors: Vec<Option<Vec<Option<f32>>>> = Vec::new();
|
||||||
|
let record_batch = RecordBatch::try_new(
|
||||||
|
schema.clone(),
|
||||||
|
vec![
|
||||||
|
Arc::new(StringArray::from(Vec::<&str>::new())),
|
||||||
|
Arc::new(
|
||||||
|
FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(vectors, dims),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let record_batch_iter =
|
||||||
|
RecordBatchIterator::new(vec![record_batch].into_iter().map(Ok), schema.clone());
|
||||||
|
let table = conn
|
||||||
|
.create_table("my_table", record_batch_iter)
|
||||||
|
.mode(CreateTableMode::Overwrite)
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
table
|
||||||
|
.create_index(&["text"], crate::index::Index::FTS(Default::default()))
|
||||||
|
.replace(true)
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let fts_query = FullTextSearchQuery::new("b".to_string());
|
||||||
|
let results = table
|
||||||
|
.query()
|
||||||
|
.full_text_search(fts_query)
|
||||||
|
.limit(2)
|
||||||
|
.nearest_to(&[-10.0, -10.0])
|
||||||
|
.unwrap()
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.try_collect::<Vec<_>>()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let batch = &results[0];
|
||||||
|
assert_eq!(0, batch.num_rows());
|
||||||
|
assert_eq!(2, batch.num_columns());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
346
rust/lancedb/src/query/hybrid.rs
Normal file
346
rust/lancedb/src/query/hybrid.rs
Normal file
@@ -0,0 +1,346 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
use arrow::compute::{
|
||||||
|
kernels::numeric::{div, sub},
|
||||||
|
max, min,
|
||||||
|
};
|
||||||
|
use arrow_array::{cast::downcast_array, Float32Array, RecordBatch};
|
||||||
|
use arrow_schema::{DataType, Field, Schema, SortOptions};
|
||||||
|
use lance::dataset::ROW_ID;
|
||||||
|
use lance_index::{scalar::inverted::SCORE_COL, vector::DIST_COL};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use crate::error::{Error, Result};
|
||||||
|
|
||||||
|
/// Converts results's score column to a rank.
|
||||||
|
///
|
||||||
|
/// Expects the `column` argument to be type Float32 and will panic if it's not
|
||||||
|
pub fn rank(results: RecordBatch, column: &str, ascending: Option<bool>) -> Result<RecordBatch> {
|
||||||
|
let scores = results.column_by_name(column).ok_or(Error::InvalidInput {
|
||||||
|
message: format!(
|
||||||
|
"expected column {} not found in rank. found columns {:?}",
|
||||||
|
column,
|
||||||
|
results
|
||||||
|
.schema()
|
||||||
|
.fields()
|
||||||
|
.iter()
|
||||||
|
.map(|f| f.name())
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if results.num_rows() == 0 {
|
||||||
|
return Ok(results);
|
||||||
|
}
|
||||||
|
|
||||||
|
let scores: Float32Array = downcast_array(scores);
|
||||||
|
let ranks = Float32Array::from_iter_values(
|
||||||
|
arrow::compute::kernels::rank::rank(
|
||||||
|
&scores,
|
||||||
|
Some(SortOptions {
|
||||||
|
descending: !ascending.unwrap_or(true),
|
||||||
|
..Default::default()
|
||||||
|
}),
|
||||||
|
)?
|
||||||
|
.iter()
|
||||||
|
.map(|i| *i as f32),
|
||||||
|
);
|
||||||
|
|
||||||
|
let schema = results.schema();
|
||||||
|
let (column_idx, _) = schema.column_with_name(column).unwrap();
|
||||||
|
let mut columns = results.columns().to_vec();
|
||||||
|
columns[column_idx] = Arc::new(ranks);
|
||||||
|
|
||||||
|
let results = RecordBatch::try_new(results.schema(), columns)?;
|
||||||
|
|
||||||
|
Ok(results)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the query schemas needed when combining the search results.
|
||||||
|
///
|
||||||
|
/// If either of the record batches are empty, then we create a schema from the
|
||||||
|
/// other record batch, and replace the score/distance column. If both record
|
||||||
|
/// batches are empty, create empty schemas.
|
||||||
|
pub fn query_schemas(
|
||||||
|
fts_results: &[RecordBatch],
|
||||||
|
vec_results: &[RecordBatch],
|
||||||
|
) -> (Arc<Schema>, Arc<Schema>) {
|
||||||
|
let (fts_schema, vec_schema) = match (
|
||||||
|
fts_results.first().map(|r| r.schema()),
|
||||||
|
vec_results.first().map(|r| r.schema()),
|
||||||
|
) {
|
||||||
|
(Some(fts_schema), Some(vec_schema)) => (fts_schema, vec_schema),
|
||||||
|
(None, Some(vec_schema)) => {
|
||||||
|
let fts_schema = with_field_name_replaced(&vec_schema, DIST_COL, SCORE_COL);
|
||||||
|
(Arc::new(fts_schema), vec_schema)
|
||||||
|
}
|
||||||
|
(Some(fts_schema), None) => {
|
||||||
|
let vec_schema = with_field_name_replaced(&fts_schema, DIST_COL, SCORE_COL);
|
||||||
|
(fts_schema, Arc::new(vec_schema))
|
||||||
|
}
|
||||||
|
(None, None) => (Arc::new(empty_fts_schema()), Arc::new(empty_vec_schema())),
|
||||||
|
};
|
||||||
|
|
||||||
|
(fts_schema, vec_schema)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn empty_fts_schema() -> Schema {
|
||||||
|
Schema::new(vec![
|
||||||
|
Arc::new(Field::new(SCORE_COL, DataType::Float32, false)),
|
||||||
|
Arc::new(Field::new(ROW_ID, DataType::UInt64, false)),
|
||||||
|
])
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn empty_vec_schema() -> Schema {
|
||||||
|
Schema::new(vec![
|
||||||
|
Arc::new(Field::new(DIST_COL, DataType::Float32, false)),
|
||||||
|
Arc::new(Field::new(ROW_ID, DataType::UInt64, false)),
|
||||||
|
])
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_field_name_replaced(schema: &Schema, target: &str, replacement: &str) -> Schema {
|
||||||
|
let field_idx = schema.fields().iter().enumerate().find_map(|(i, field)| {
|
||||||
|
if field.name() == target {
|
||||||
|
Some(i)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut fields = schema.fields().to_vec();
|
||||||
|
if let Some(idx) = field_idx {
|
||||||
|
let new_field = (*fields[idx]).clone().with_name(replacement);
|
||||||
|
fields[idx] = Arc::new(new_field);
|
||||||
|
}
|
||||||
|
|
||||||
|
Schema::new(fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Normalize the scores column to have values between 0 and 1.
|
||||||
|
///
|
||||||
|
/// Expects the `column` argument to be type Float32 and will panic if it's not
|
||||||
|
pub fn normalize_scores(
|
||||||
|
results: RecordBatch,
|
||||||
|
column: &str,
|
||||||
|
invert: Option<bool>,
|
||||||
|
) -> Result<RecordBatch> {
|
||||||
|
let scores = results.column_by_name(column).ok_or(Error::InvalidInput {
|
||||||
|
message: format!(
|
||||||
|
"expected column {} not found in rank. found columns {:?}",
|
||||||
|
column,
|
||||||
|
results
|
||||||
|
.schema()
|
||||||
|
.fields()
|
||||||
|
.iter()
|
||||||
|
.map(|f| f.name())
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if results.num_rows() == 0 {
|
||||||
|
return Ok(results);
|
||||||
|
}
|
||||||
|
let mut scores: Float32Array = downcast_array(scores);
|
||||||
|
|
||||||
|
let max = max(&scores).unwrap_or(0.0);
|
||||||
|
let min = min(&scores).unwrap_or(0.0);
|
||||||
|
|
||||||
|
// this is equivalent to np.isclose which is used in python
|
||||||
|
let rng = if max - min < 10e-5 { max } else { max - min };
|
||||||
|
|
||||||
|
// if rng is 0, then min and max are both 0 so we just leave the scores as is
|
||||||
|
if rng != 0.0 {
|
||||||
|
let tmp = div(
|
||||||
|
&sub(&scores, &Float32Array::new_scalar(min))?,
|
||||||
|
&Float32Array::new_scalar(rng),
|
||||||
|
)?;
|
||||||
|
scores = downcast_array(&tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
if invert.unwrap_or(false) {
|
||||||
|
let tmp = sub(&Float32Array::new_scalar(1.0), &scores)?;
|
||||||
|
scores = downcast_array(&tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
let schema = results.schema();
|
||||||
|
let (column_idx, _) = schema.column_with_name(column).unwrap();
|
||||||
|
let mut columns = results.columns().to_vec();
|
||||||
|
columns[column_idx] = Arc::new(scores);
|
||||||
|
|
||||||
|
let results = RecordBatch::try_new(results.schema(), columns).unwrap();
|
||||||
|
|
||||||
|
Ok(results)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use super::*;
|
||||||
|
use arrow_array::StringArray;
|
||||||
|
use arrow_schema::{DataType, Field, Schema};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_rank() {
|
||||||
|
let schema = Arc::new(Schema::new(vec![
|
||||||
|
Arc::new(Field::new("name", DataType::Utf8, false)),
|
||||||
|
Arc::new(Field::new("score", DataType::Float32, false)),
|
||||||
|
]));
|
||||||
|
|
||||||
|
let names = StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"]);
|
||||||
|
let scores = Float32Array::from(vec![0.2, 0.4, 0.1, 0.6, 0.45]);
|
||||||
|
|
||||||
|
let batch =
|
||||||
|
RecordBatch::try_new(schema.clone(), vec![Arc::new(names), Arc::new(scores)]).unwrap();
|
||||||
|
|
||||||
|
let result = rank(batch.clone(), "score", Some(false)).unwrap();
|
||||||
|
assert_eq!(2, result.schema().fields().len());
|
||||||
|
assert_eq!("name", result.schema().field(0).name());
|
||||||
|
assert_eq!("score", result.schema().field(1).name());
|
||||||
|
|
||||||
|
let names: StringArray = downcast_array(result.column(0));
|
||||||
|
assert_eq!(
|
||||||
|
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec!["foo", "bar", "baz", "bean", "dog"]
|
||||||
|
);
|
||||||
|
let scores: Float32Array = downcast_array(result.column(1));
|
||||||
|
assert_eq!(
|
||||||
|
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec![4.0, 3.0, 5.0, 1.0, 2.0]
|
||||||
|
);
|
||||||
|
|
||||||
|
// check sort ascending
|
||||||
|
let result = rank(batch.clone(), "score", Some(true)).unwrap();
|
||||||
|
let names: StringArray = downcast_array(result.column(0));
|
||||||
|
assert_eq!(
|
||||||
|
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec!["foo", "bar", "baz", "bean", "dog"]
|
||||||
|
);
|
||||||
|
let scores: Float32Array = downcast_array(result.column(1));
|
||||||
|
assert_eq!(
|
||||||
|
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec![2.0, 3.0, 1.0, 5.0, 4.0]
|
||||||
|
);
|
||||||
|
|
||||||
|
// ensure default sort is ascending
|
||||||
|
let result = rank(batch.clone(), "score", None).unwrap();
|
||||||
|
let names: StringArray = downcast_array(result.column(0));
|
||||||
|
assert_eq!(
|
||||||
|
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec!["foo", "bar", "baz", "bean", "dog"]
|
||||||
|
);
|
||||||
|
let scores: Float32Array = downcast_array(result.column(1));
|
||||||
|
assert_eq!(
|
||||||
|
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec![2.0, 3.0, 1.0, 5.0, 4.0]
|
||||||
|
);
|
||||||
|
|
||||||
|
// check it can handle an empty batch
|
||||||
|
let batch = RecordBatch::try_new(
|
||||||
|
schema.clone(),
|
||||||
|
vec![
|
||||||
|
Arc::new(StringArray::from(Vec::<&str>::new())),
|
||||||
|
Arc::new(Float32Array::from(Vec::<f32>::new())),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let result = rank(batch.clone(), "score", None).unwrap();
|
||||||
|
assert_eq!(0, result.num_rows());
|
||||||
|
assert_eq!(2, result.schema().fields().len());
|
||||||
|
assert_eq!("name", result.schema().field(0).name());
|
||||||
|
assert_eq!("score", result.schema().field(1).name());
|
||||||
|
|
||||||
|
// check it returns the expected error when there's no column
|
||||||
|
let result = rank(batch.clone(), "bad_col", None);
|
||||||
|
match result {
|
||||||
|
Err(Error::InvalidInput { message }) => {
|
||||||
|
assert_eq!("expected column bad_col not found in rank. found columns [\"name\", \"score\"]", message);
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
panic!("expected invalid input error, received {:?}", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_normalize_scores() {
|
||||||
|
let schema = Arc::new(Schema::new(vec![
|
||||||
|
Arc::new(Field::new("name", DataType::Utf8, false)),
|
||||||
|
Arc::new(Field::new("score", DataType::Float32, false)),
|
||||||
|
]));
|
||||||
|
|
||||||
|
let names = Arc::new(StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"]));
|
||||||
|
let scores = Arc::new(Float32Array::from(vec![-4.0, 2.0, 0.0, 3.0, 6.0]));
|
||||||
|
|
||||||
|
let batch =
|
||||||
|
RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap();
|
||||||
|
|
||||||
|
let result = normalize_scores(batch.clone(), "score", Some(false)).unwrap();
|
||||||
|
let names: StringArray = downcast_array(result.column(0));
|
||||||
|
assert_eq!(
|
||||||
|
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec!["foo", "bar", "baz", "bean", "dog"]
|
||||||
|
);
|
||||||
|
let scores: Float32Array = downcast_array(result.column(1));
|
||||||
|
assert_eq!(
|
||||||
|
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec![0.0, 0.6, 0.4, 0.7, 1.0]
|
||||||
|
);
|
||||||
|
|
||||||
|
// check it can invert the normalization
|
||||||
|
let result = normalize_scores(batch.clone(), "score", Some(true)).unwrap();
|
||||||
|
let scores: Float32Array = downcast_array(result.column(1));
|
||||||
|
assert_eq!(
|
||||||
|
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec![1.0, 1.0 - 0.6, 0.6, 0.3, 0.0]
|
||||||
|
);
|
||||||
|
|
||||||
|
// check that the default is not inverted
|
||||||
|
let result = normalize_scores(batch.clone(), "score", None).unwrap();
|
||||||
|
let scores: Float32Array = downcast_array(result.column(1));
|
||||||
|
assert_eq!(
|
||||||
|
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec![0.0, 0.6, 0.4, 0.7, 1.0]
|
||||||
|
);
|
||||||
|
|
||||||
|
// check that it will function correctly if all the values are the same
|
||||||
|
let names = Arc::new(StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"]));
|
||||||
|
let scores = Arc::new(Float32Array::from(vec![2.1, 2.1, 2.1, 2.1, 2.1]));
|
||||||
|
let batch =
|
||||||
|
RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap();
|
||||||
|
let result = normalize_scores(batch.clone(), "score", None).unwrap();
|
||||||
|
let scores: Float32Array = downcast_array(result.column(1));
|
||||||
|
assert_eq!(
|
||||||
|
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec![0.0, 0.0, 0.0, 0.0, 0.0]
|
||||||
|
);
|
||||||
|
|
||||||
|
// check it keeps floating point rounding errors for same score normalized the same
|
||||||
|
// e.g., the behaviour is consistent with python
|
||||||
|
let scores = Arc::new(Float32Array::from(vec![1.0, 1.0, 1.0, 1.0, 0.9999999]));
|
||||||
|
let batch =
|
||||||
|
RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap();
|
||||||
|
let result = normalize_scores(batch.clone(), "score", None).unwrap();
|
||||||
|
let scores: Float32Array = downcast_array(result.column(1));
|
||||||
|
assert_eq!(
|
||||||
|
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec![
|
||||||
|
1.0 - 0.9999999,
|
||||||
|
1.0 - 0.9999999,
|
||||||
|
1.0 - 0.9999999,
|
||||||
|
1.0 - 0.9999999,
|
||||||
|
0.0
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
// check that it can handle if all the scores are 0
|
||||||
|
let scores = Arc::new(Float32Array::from(vec![0.0, 0.0, 0.0, 0.0, 0.0]));
|
||||||
|
let batch =
|
||||||
|
RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap();
|
||||||
|
let result = normalize_scores(batch.clone(), "score", None).unwrap();
|
||||||
|
let scores: Float32Array = downcast_array(result.column(1));
|
||||||
|
assert_eq!(
|
||||||
|
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec![0.0, 0.0, 0.0, 0.0, 0.0]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -210,6 +210,8 @@ impl<S: HttpSend> RemoteTable<S> {
|
|||||||
body["prefilter"] = query.base.prefilter.into();
|
body["prefilter"] = query.base.prefilter.into();
|
||||||
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
|
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
|
||||||
body["nprobes"] = query.nprobes.into();
|
body["nprobes"] = query.nprobes.into();
|
||||||
|
body["lower_bound"] = query.lower_bound.into();
|
||||||
|
body["upper_bound"] = query.upper_bound.into();
|
||||||
body["ef"] = query.ef.into();
|
body["ef"] = query.ef.into();
|
||||||
body["refine_factor"] = query.refine_factor.into();
|
body["refine_factor"] = query.refine_factor.into();
|
||||||
if let Some(vector_column) = query.column.as_ref() {
|
if let Some(vector_column) = query.column.as_ref() {
|
||||||
@@ -563,6 +565,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
|||||||
let (index_type, distance_type) = match index.index {
|
let (index_type, distance_type) = match index.index {
|
||||||
// TODO: Should we pass the actual index parameters? SaaS does not
|
// TODO: Should we pass the actual index parameters? SaaS does not
|
||||||
// yet support them.
|
// yet support them.
|
||||||
|
Index::IvfFlat(index) => ("IVF_FLAT", Some(index.distance_type)),
|
||||||
Index::IvfPq(index) => ("IVF_PQ", Some(index.distance_type)),
|
Index::IvfPq(index) => ("IVF_PQ", Some(index.distance_type)),
|
||||||
Index::IvfHnswSq(index) => ("IVF_HNSW_SQ", Some(index.distance_type)),
|
Index::IvfHnswSq(index) => ("IVF_HNSW_SQ", Some(index.distance_type)),
|
||||||
Index::BTree(_) => ("BTREE", None),
|
Index::BTree(_) => ("BTREE", None),
|
||||||
@@ -873,6 +876,7 @@ mod tests {
|
|||||||
use lance_index::scalar::FullTextSearchQuery;
|
use lance_index::scalar::FullTextSearchQuery;
|
||||||
use reqwest::Body;
|
use reqwest::Body;
|
||||||
|
|
||||||
|
use crate::index::vector::IvfFlatIndexBuilder;
|
||||||
use crate::{
|
use crate::{
|
||||||
index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType},
|
index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType},
|
||||||
query::{ExecutableQuery, QueryBase},
|
query::{ExecutableQuery, QueryBase},
|
||||||
@@ -1302,6 +1306,8 @@ mod tests {
|
|||||||
"prefilter": true,
|
"prefilter": true,
|
||||||
"distance_type": "l2",
|
"distance_type": "l2",
|
||||||
"nprobes": 20,
|
"nprobes": 20,
|
||||||
|
"lower_bound": Option::<f32>::None,
|
||||||
|
"upper_bound": Option::<f32>::None,
|
||||||
"k": 10,
|
"k": 10,
|
||||||
"ef": Option::<usize>::None,
|
"ef": Option::<usize>::None,
|
||||||
"refine_factor": null,
|
"refine_factor": null,
|
||||||
@@ -1351,6 +1357,8 @@ mod tests {
|
|||||||
"bypass_vector_index": true,
|
"bypass_vector_index": true,
|
||||||
"columns": ["a", "b"],
|
"columns": ["a", "b"],
|
||||||
"nprobes": 12,
|
"nprobes": 12,
|
||||||
|
"lower_bound": Option::<f32>::None,
|
||||||
|
"upper_bound": Option::<f32>::None,
|
||||||
"ef": Option::<usize>::None,
|
"ef": Option::<usize>::None,
|
||||||
"refine_factor": 2,
|
"refine_factor": 2,
|
||||||
"version": null,
|
"version": null,
|
||||||
@@ -1489,6 +1497,11 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_create_index() {
|
async fn test_create_index() {
|
||||||
let cases = [
|
let cases = [
|
||||||
|
(
|
||||||
|
"IVF_FLAT",
|
||||||
|
Some("hamming"),
|
||||||
|
Index::IvfFlat(IvfFlatIndexBuilder::default().distance_type(DistanceType::Hamming)),
|
||||||
|
),
|
||||||
("IVF_PQ", Some("l2"), Index::IvfPq(Default::default())),
|
("IVF_PQ", Some("l2"), Index::IvfPq(Default::default())),
|
||||||
(
|
(
|
||||||
"IVF_PQ",
|
"IVF_PQ",
|
||||||
|
|||||||
87
rust/lancedb/src/rerankers.rs
Normal file
87
rust/lancedb/src/rerankers.rs
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
use std::collections::BTreeSet;
|
||||||
|
|
||||||
|
use arrow::{
|
||||||
|
array::downcast_array,
|
||||||
|
compute::{concat_batches, filter_record_batch},
|
||||||
|
};
|
||||||
|
use arrow_array::{BooleanArray, RecordBatch, UInt64Array};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use lance::dataset::ROW_ID;
|
||||||
|
|
||||||
|
use crate::error::{Error, Result};
|
||||||
|
|
||||||
|
pub mod rrf;
|
||||||
|
|
||||||
|
/// column name for reranker relevance score
|
||||||
|
const RELEVANCE_SCORE: &str = "_relevance_score";
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub enum NormalizeMethod {
|
||||||
|
Score,
|
||||||
|
Rank,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Interface for a reranker. A reranker is used to rerank the results from a
|
||||||
|
/// vector and FTS search. This is useful for combining the results from both
|
||||||
|
/// search methods.
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Reranker: std::fmt::Debug + Sync + Send {
|
||||||
|
// TODO support vector reranking and FTS reranking. Currently only hybrid reranking is supported.
|
||||||
|
|
||||||
|
/// Rerank function receives the individual results from the vector and FTS search
|
||||||
|
/// results. You can choose to use any of the results to generate the final results,
|
||||||
|
/// allowing maximum flexibility.
|
||||||
|
async fn rerank_hybrid(
|
||||||
|
&self,
|
||||||
|
query: &str,
|
||||||
|
vector_results: RecordBatch,
|
||||||
|
fts_results: RecordBatch,
|
||||||
|
) -> Result<RecordBatch>;
|
||||||
|
|
||||||
|
fn merge_results(
|
||||||
|
&self,
|
||||||
|
vector_results: RecordBatch,
|
||||||
|
fts_results: RecordBatch,
|
||||||
|
) -> Result<RecordBatch> {
|
||||||
|
let combined = concat_batches(&fts_results.schema(), [vector_results, fts_results].iter())?;
|
||||||
|
|
||||||
|
let mut mask = BooleanArray::builder(combined.num_rows());
|
||||||
|
let mut unique_ids = BTreeSet::new();
|
||||||
|
let row_ids = combined.column_by_name(ROW_ID).ok_or(Error::InvalidInput {
|
||||||
|
message: format!(
|
||||||
|
"could not find expected column {} while merging results. found columns {:?}",
|
||||||
|
ROW_ID,
|
||||||
|
combined
|
||||||
|
.schema()
|
||||||
|
.fields()
|
||||||
|
.iter()
|
||||||
|
.map(|f| f.name())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
),
|
||||||
|
})?;
|
||||||
|
let row_ids: UInt64Array = downcast_array(row_ids);
|
||||||
|
row_ids.values().iter().for_each(|id| {
|
||||||
|
mask.append_value(unique_ids.insert(id));
|
||||||
|
});
|
||||||
|
|
||||||
|
let combined = filter_record_batch(&combined, &mask.finish())?;
|
||||||
|
|
||||||
|
Ok(combined)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn check_reranker_result(result: &RecordBatch) -> Result<()> {
|
||||||
|
if result.schema().column_with_name(RELEVANCE_SCORE).is_none() {
|
||||||
|
return Err(Error::Schema {
|
||||||
|
message: format!(
|
||||||
|
"rerank_hybrid must return a RecordBatch with a column named {}",
|
||||||
|
RELEVANCE_SCORE
|
||||||
|
),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
223
rust/lancedb/src/rerankers/rrf.rs
Normal file
223
rust/lancedb/src/rerankers/rrf.rs
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
use std::collections::BTreeMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use arrow::{
|
||||||
|
array::downcast_array,
|
||||||
|
compute::{sort_to_indices, take},
|
||||||
|
};
|
||||||
|
use arrow_array::{Float32Array, RecordBatch, UInt64Array};
|
||||||
|
use arrow_schema::{DataType, Field, Schema, SortOptions};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use lance::dataset::ROW_ID;
|
||||||
|
|
||||||
|
use crate::error::{Error, Result};
|
||||||
|
use crate::rerankers::{Reranker, RELEVANCE_SCORE};
|
||||||
|
|
||||||
|
/// Reranks the results using Reciprocal Rank Fusion(RRF) algorithm based
|
||||||
|
/// on the scores of vector and FTS search.
|
||||||
|
///
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct RRFReranker {
|
||||||
|
k: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RRFReranker {
|
||||||
|
/// Create a new RRFReranker
|
||||||
|
///
|
||||||
|
/// The parameter k is a constant used in the RRF formula (default is 60).
|
||||||
|
/// Experiments indicate that k = 60 was near-optimal, but that the choice
|
||||||
|
/// is not critical. See paper:
|
||||||
|
/// https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf
|
||||||
|
pub fn new(k: f32) -> Self {
|
||||||
|
Self { k }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for RRFReranker {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self { k: 60.0 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Reranker for RRFReranker {
|
||||||
|
async fn rerank_hybrid(
|
||||||
|
&self,
|
||||||
|
_query: &str,
|
||||||
|
vector_results: RecordBatch,
|
||||||
|
fts_results: RecordBatch,
|
||||||
|
) -> Result<RecordBatch> {
|
||||||
|
let vector_ids = vector_results
|
||||||
|
.column_by_name(ROW_ID)
|
||||||
|
.ok_or(Error::InvalidInput {
|
||||||
|
message: format!(
|
||||||
|
"expected column {} not found in vector_results. found columns {:?}",
|
||||||
|
ROW_ID,
|
||||||
|
vector_results
|
||||||
|
.schema()
|
||||||
|
.fields()
|
||||||
|
.iter()
|
||||||
|
.map(|f| f.name())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
),
|
||||||
|
})?;
|
||||||
|
let fts_ids = fts_results
|
||||||
|
.column_by_name(ROW_ID)
|
||||||
|
.ok_or(Error::InvalidInput {
|
||||||
|
message: format!(
|
||||||
|
"expected column {} not found in fts_results. found columns {:?}",
|
||||||
|
ROW_ID,
|
||||||
|
fts_results
|
||||||
|
.schema()
|
||||||
|
.fields()
|
||||||
|
.iter()
|
||||||
|
.map(|f| f.name())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let vector_ids: UInt64Array = downcast_array(&vector_ids);
|
||||||
|
let fts_ids: UInt64Array = downcast_array(&fts_ids);
|
||||||
|
|
||||||
|
let mut rrf_score_map = BTreeMap::new();
|
||||||
|
let mut update_score_map = |(i, result_id)| {
|
||||||
|
let score = 1.0 / (i as f32 + self.k);
|
||||||
|
rrf_score_map
|
||||||
|
.entry(result_id)
|
||||||
|
.and_modify(|e| *e += score)
|
||||||
|
.or_insert(score);
|
||||||
|
};
|
||||||
|
vector_ids
|
||||||
|
.values()
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.for_each(&mut update_score_map);
|
||||||
|
fts_ids
|
||||||
|
.values()
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.for_each(&mut update_score_map);
|
||||||
|
|
||||||
|
let combined_results = self.merge_results(vector_results, fts_results)?;
|
||||||
|
|
||||||
|
let combined_row_ids: UInt64Array =
|
||||||
|
downcast_array(combined_results.column_by_name(ROW_ID).unwrap());
|
||||||
|
let relevance_scores = Float32Array::from_iter_values(
|
||||||
|
combined_row_ids
|
||||||
|
.values()
|
||||||
|
.iter()
|
||||||
|
.map(|row_id| rrf_score_map.get(row_id).unwrap())
|
||||||
|
.copied(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// keep track of indices sorted by the relevance column
|
||||||
|
let sort_indices = sort_to_indices(
|
||||||
|
&relevance_scores,
|
||||||
|
Some(SortOptions {
|
||||||
|
descending: true,
|
||||||
|
..Default::default()
|
||||||
|
}),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// add relevance scores to columns
|
||||||
|
let mut columns = combined_results.columns().to_vec();
|
||||||
|
columns.push(Arc::new(relevance_scores));
|
||||||
|
|
||||||
|
// sort by the relevance scores
|
||||||
|
let columns = columns
|
||||||
|
.iter()
|
||||||
|
.map(|c| take(c, &sort_indices, None).unwrap())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// add relevance score to schema
|
||||||
|
let mut fields = combined_results.schema().fields().to_vec();
|
||||||
|
fields.push(Arc::new(Field::new(
|
||||||
|
RELEVANCE_SCORE,
|
||||||
|
DataType::Float32,
|
||||||
|
false,
|
||||||
|
)));
|
||||||
|
let schema = Schema::new(fields);
|
||||||
|
|
||||||
|
let combined_results = RecordBatch::try_new(Arc::new(schema), columns)?;
|
||||||
|
|
||||||
|
Ok(combined_results)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub mod test {
|
||||||
|
use super::*;
|
||||||
|
use arrow_array::StringArray;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_rrf_reranker() {
|
||||||
|
let schema = Arc::new(Schema::new(vec![
|
||||||
|
Field::new("name", DataType::Utf8, false),
|
||||||
|
Field::new(ROW_ID, DataType::UInt64, false),
|
||||||
|
]));
|
||||||
|
|
||||||
|
let vec_results = RecordBatch::try_new(
|
||||||
|
schema.clone(),
|
||||||
|
vec![
|
||||||
|
Arc::new(StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"])),
|
||||||
|
Arc::new(UInt64Array::from(vec![1, 4, 2, 5, 3])),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let fts_results = RecordBatch::try_new(
|
||||||
|
schema.clone(),
|
||||||
|
vec![
|
||||||
|
Arc::new(StringArray::from(vec!["bar", "bean", "dog"])),
|
||||||
|
Arc::new(UInt64Array::from(vec![4, 5, 3])),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// scores should be calculated as:
|
||||||
|
// - foo = 1/1 = 1.0
|
||||||
|
// - bar = 1/2 + 1/1 = 1.5
|
||||||
|
// - baz = 1/3 = 0.333
|
||||||
|
// - bean = 1/4 + 1/2 = 0.75
|
||||||
|
// - dog = 1/5 + 1/3 = 0.533
|
||||||
|
// then we should get the result ranked in descending order
|
||||||
|
|
||||||
|
let reranker = RRFReranker::new(1.0);
|
||||||
|
|
||||||
|
let result = reranker
|
||||||
|
.rerank_hybrid("", vec_results, fts_results)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(3, result.schema().fields().len());
|
||||||
|
assert_eq!("name", result.schema().fields().first().unwrap().name());
|
||||||
|
assert_eq!(ROW_ID, result.schema().fields().get(1).unwrap().name());
|
||||||
|
assert_eq!(
|
||||||
|
RELEVANCE_SCORE,
|
||||||
|
result.schema().fields().get(2).unwrap().name()
|
||||||
|
);
|
||||||
|
|
||||||
|
let names: StringArray = downcast_array(result.column(0));
|
||||||
|
assert_eq!(
|
||||||
|
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec!["bar", "foo", "bean", "dog", "baz"]
|
||||||
|
);
|
||||||
|
|
||||||
|
let ids: UInt64Array = downcast_array(result.column(1));
|
||||||
|
assert_eq!(
|
||||||
|
ids.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec![4, 1, 5, 3, 2]
|
||||||
|
);
|
||||||
|
|
||||||
|
let scores: Float32Array = downcast_array(result.column(2));
|
||||||
|
assert_eq!(
|
||||||
|
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec![1.5, 1.0, 0.75, 1.0 / 5.0 + 1.0 / 3.0, 1.0 / 3.0]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1944,6 +1944,7 @@ impl TableInternal for NativeTable {
|
|||||||
if let Some(ef) = query.ef {
|
if let Some(ef) = query.ef {
|
||||||
scanner.ef(ef);
|
scanner.ef(ef);
|
||||||
}
|
}
|
||||||
|
scanner.distance_range(query.lower_bound, query.upper_bound);
|
||||||
scanner.use_index(query.use_index);
|
scanner.use_index(query.use_index);
|
||||||
scanner.prefilter(query.base.prefilter);
|
scanner.prefilter(query.base.prefilter);
|
||||||
match query.base.select {
|
match query.base.select {
|
||||||
|
|||||||
Reference in New Issue
Block a user