mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 13:29:57 +00:00
Compare commits
16 Commits
python-v0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
92d845fa72 | ||
|
|
397813f6a4 | ||
|
|
50c30c5d34 | ||
|
|
c9f248b058 | ||
|
|
0cb6da6b7e | ||
|
|
aec8332eb5 | ||
|
|
46061070e6 | ||
|
|
dae8334d0b | ||
|
|
8c81968b59 | ||
|
|
16cf2990f3 | ||
|
|
0a0f667bbd | ||
|
|
03753fd84b | ||
|
|
55cceaa309 | ||
|
|
c3797eb834 | ||
|
|
c0d0f38494 | ||
|
|
6a8ab78d0a |
@@ -1,5 +1,5 @@
|
|||||||
[tool.bumpversion]
|
[tool.bumpversion]
|
||||||
current_version = "0.14.1-beta.6"
|
current_version = "0.14.1"
|
||||||
parse = """(?x)
|
parse = """(?x)
|
||||||
(?P<major>0|[1-9]\\d*)\\.
|
(?P<major>0|[1-9]\\d*)\\.
|
||||||
(?P<minor>0|[1-9]\\d*)\\.
|
(?P<minor>0|[1-9]\\d*)\\.
|
||||||
|
|||||||
196
.github/workflows/npm-publish.yml
vendored
196
.github/workflows/npm-publish.yml
vendored
@@ -159,7 +159,7 @@ jobs:
|
|||||||
- name: Install common dependencies
|
- name: Install common dependencies
|
||||||
run: |
|
run: |
|
||||||
apk add protobuf-dev curl clang mold grep npm bash
|
apk add protobuf-dev curl clang mold grep npm bash
|
||||||
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y --default-toolchain 1.80.0
|
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y
|
||||||
echo "source $HOME/.cargo/env" >> saved_env
|
echo "source $HOME/.cargo/env" >> saved_env
|
||||||
echo "export CC=clang" >> saved_env
|
echo "export CC=clang" >> saved_env
|
||||||
echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=-crt-static,+avx2,+fma,+f16c -Clinker=clang -Clink-arg=-fuse-ld=mold'" >> saved_env
|
echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=-crt-static,+avx2,+fma,+f16c -Clinker=clang -Clink-arg=-fuse-ld=mold'" >> saved_env
|
||||||
@@ -167,7 +167,7 @@ jobs:
|
|||||||
if: ${{ matrix.config.arch == 'aarch64' }}
|
if: ${{ matrix.config.arch == 'aarch64' }}
|
||||||
run: |
|
run: |
|
||||||
source "$HOME/.cargo/env"
|
source "$HOME/.cargo/env"
|
||||||
rustup target add aarch64-unknown-linux-musl --toolchain 1.80.0
|
rustup target add aarch64-unknown-linux-musl
|
||||||
crt=$(realpath $(dirname $(rustup which rustc))/../lib/rustlib/aarch64-unknown-linux-musl/lib/self-contained)
|
crt=$(realpath $(dirname $(rustup which rustc))/../lib/rustlib/aarch64-unknown-linux-musl/lib/self-contained)
|
||||||
sysroot_lib=/usr/aarch64-unknown-linux-musl/usr/lib
|
sysroot_lib=/usr/aarch64-unknown-linux-musl/usr/lib
|
||||||
apk_url=https://dl-cdn.alpinelinux.org/alpine/latest-stable/main/aarch64/
|
apk_url=https://dl-cdn.alpinelinux.org/alpine/latest-stable/main/aarch64/
|
||||||
@@ -262,7 +262,7 @@ jobs:
|
|||||||
- name: Install common dependencies
|
- name: Install common dependencies
|
||||||
run: |
|
run: |
|
||||||
apk add protobuf-dev curl clang mold grep npm bash openssl-dev openssl-libs-static
|
apk add protobuf-dev curl clang mold grep npm bash openssl-dev openssl-libs-static
|
||||||
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y --default-toolchain 1.80.0
|
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y
|
||||||
echo "source $HOME/.cargo/env" >> saved_env
|
echo "source $HOME/.cargo/env" >> saved_env
|
||||||
echo "export CC=clang" >> saved_env
|
echo "export CC=clang" >> saved_env
|
||||||
echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=-crt-static,+avx2,+fma,+f16c -Clinker=clang -Clink-arg=-fuse-ld=mold'" >> saved_env
|
echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=-crt-static,+avx2,+fma,+f16c -Clinker=clang -Clink-arg=-fuse-ld=mold'" >> saved_env
|
||||||
@@ -272,7 +272,7 @@ jobs:
|
|||||||
if: ${{ matrix.config.arch == 'aarch64' }}
|
if: ${{ matrix.config.arch == 'aarch64' }}
|
||||||
run: |
|
run: |
|
||||||
source "$HOME/.cargo/env"
|
source "$HOME/.cargo/env"
|
||||||
rustup target add aarch64-unknown-linux-musl --toolchain 1.80.0
|
rustup target add aarch64-unknown-linux-musl
|
||||||
crt=$(realpath $(dirname $(rustup which rustc))/../lib/rustlib/aarch64-unknown-linux-musl/lib/self-contained)
|
crt=$(realpath $(dirname $(rustup which rustc))/../lib/rustlib/aarch64-unknown-linux-musl/lib/self-contained)
|
||||||
sysroot_lib=/usr/aarch64-unknown-linux-musl/usr/lib
|
sysroot_lib=/usr/aarch64-unknown-linux-musl/usr/lib
|
||||||
apk_url=https://dl-cdn.alpinelinux.org/alpine/latest-stable/main/aarch64/
|
apk_url=https://dl-cdn.alpinelinux.org/alpine/latest-stable/main/aarch64/
|
||||||
@@ -334,50 +334,51 @@ jobs:
|
|||||||
path: |
|
path: |
|
||||||
node/dist/lancedb-vectordb-win32*.tgz
|
node/dist/lancedb-vectordb-win32*.tgz
|
||||||
|
|
||||||
node-windows-arm64:
|
# TODO: https://github.com/lancedb/lancedb/issues/1975
|
||||||
name: vectordb ${{ matrix.config.arch }}-pc-windows-msvc
|
# node-windows-arm64:
|
||||||
if: startsWith(github.ref, 'refs/tags/v')
|
# name: vectordb ${{ matrix.config.arch }}-pc-windows-msvc
|
||||||
runs-on: ubuntu-latest
|
# # if: startsWith(github.ref, 'refs/tags/v')
|
||||||
container: alpine:edge
|
# runs-on: ubuntu-latest
|
||||||
strategy:
|
# container: alpine:edge
|
||||||
fail-fast: false
|
# strategy:
|
||||||
matrix:
|
# fail-fast: false
|
||||||
config:
|
# matrix:
|
||||||
# - arch: x86_64
|
# config:
|
||||||
- arch: aarch64
|
# # - arch: x86_64
|
||||||
steps:
|
# - arch: aarch64
|
||||||
- name: Checkout
|
# steps:
|
||||||
uses: actions/checkout@v4
|
# - name: Checkout
|
||||||
- name: Install dependencies
|
# uses: actions/checkout@v4
|
||||||
run: |
|
# - name: Install dependencies
|
||||||
apk add protobuf-dev curl clang lld llvm19 grep npm bash msitools sed
|
# run: |
|
||||||
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y --default-toolchain 1.80.0
|
# apk add protobuf-dev curl clang lld llvm19 grep npm bash msitools sed
|
||||||
echo "source $HOME/.cargo/env" >> saved_env
|
# curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y
|
||||||
echo "export CC=clang" >> saved_env
|
# echo "source $HOME/.cargo/env" >> saved_env
|
||||||
echo "export AR=llvm-ar" >> saved_env
|
# echo "export CC=clang" >> saved_env
|
||||||
source "$HOME/.cargo/env"
|
# echo "export AR=llvm-ar" >> saved_env
|
||||||
rustup target add ${{ matrix.config.arch }}-pc-windows-msvc --toolchain 1.80.0
|
# source "$HOME/.cargo/env"
|
||||||
(mkdir -p sysroot && cd sysroot && sh ../ci/sysroot-${{ matrix.config.arch }}-pc-windows-msvc.sh)
|
# rustup target add ${{ matrix.config.arch }}-pc-windows-msvc
|
||||||
echo "export C_INCLUDE_PATH=/usr/${{ matrix.config.arch }}-pc-windows-msvc/usr/include" >> saved_env
|
# (mkdir -p sysroot && cd sysroot && sh ../ci/sysroot-${{ matrix.config.arch }}-pc-windows-msvc.sh)
|
||||||
echo "export CARGO_BUILD_TARGET=${{ matrix.config.arch }}-pc-windows-msvc" >> saved_env
|
# echo "export C_INCLUDE_PATH=/usr/${{ matrix.config.arch }}-pc-windows-msvc/usr/include" >> saved_env
|
||||||
- name: Configure x86_64 build
|
# echo "export CARGO_BUILD_TARGET=${{ matrix.config.arch }}-pc-windows-msvc" >> saved_env
|
||||||
if: ${{ matrix.config.arch == 'x86_64' }}
|
# - name: Configure x86_64 build
|
||||||
run: |
|
# if: ${{ matrix.config.arch == 'x86_64' }}
|
||||||
echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=+crt-static,+avx2,+fma,+f16c -Clinker=lld -Clink-arg=/LIBPATH:/usr/x86_64-pc-windows-msvc/usr/lib'" >> saved_env
|
# run: |
|
||||||
- name: Configure aarch64 build
|
# echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=+crt-static,+avx2,+fma,+f16c -Clinker=lld -Clink-arg=/LIBPATH:/usr/x86_64-pc-windows-msvc/usr/lib'" >> saved_env
|
||||||
if: ${{ matrix.config.arch == 'aarch64' }}
|
# - name: Configure aarch64 build
|
||||||
run: |
|
# if: ${{ matrix.config.arch == 'aarch64' }}
|
||||||
echo "export RUSTFLAGS='-Ctarget-feature=+crt-static,+neon,+fp16,+fhm,+dotprod -Clinker=lld -Clink-arg=/LIBPATH:/usr/aarch64-pc-windows-msvc/usr/lib -Clink-arg=arm64rt.lib'" >> saved_env
|
# run: |
|
||||||
- name: Build Windows Artifacts
|
# echo "export RUSTFLAGS='-Ctarget-feature=+crt-static,+neon,+fp16,+fhm,+dotprod -Clinker=lld -Clink-arg=/LIBPATH:/usr/aarch64-pc-windows-msvc/usr/lib -Clink-arg=arm64rt.lib'" >> saved_env
|
||||||
run: |
|
# - name: Build Windows Artifacts
|
||||||
source ./saved_env
|
# run: |
|
||||||
bash ci/manylinux_node/build_vectordb.sh ${{ matrix.config.arch }} ${{ matrix.config.arch }}-pc-windows-msvc
|
# source ./saved_env
|
||||||
- name: Upload Windows Artifacts
|
# bash ci/manylinux_node/build_vectordb.sh ${{ matrix.config.arch }} ${{ matrix.config.arch }}-pc-windows-msvc
|
||||||
uses: actions/upload-artifact@v4
|
# - name: Upload Windows Artifacts
|
||||||
with:
|
# uses: actions/upload-artifact@v4
|
||||||
name: node-native-windows-${{ matrix.config.arch }}
|
# with:
|
||||||
path: |
|
# name: node-native-windows-${{ matrix.config.arch }}
|
||||||
node/dist/lancedb-vectordb-win32*.tgz
|
# path: |
|
||||||
|
# node/dist/lancedb-vectordb-win32*.tgz
|
||||||
|
|
||||||
nodejs-windows:
|
nodejs-windows:
|
||||||
name: lancedb ${{ matrix.target }}
|
name: lancedb ${{ matrix.target }}
|
||||||
@@ -413,57 +414,58 @@ jobs:
|
|||||||
path: |
|
path: |
|
||||||
nodejs/dist/*.node
|
nodejs/dist/*.node
|
||||||
|
|
||||||
nodejs-windows-arm64:
|
# TODO: https://github.com/lancedb/lancedb/issues/1975
|
||||||
name: lancedb ${{ matrix.config.arch }}-pc-windows-msvc
|
# nodejs-windows-arm64:
|
||||||
# Only runs on tags that matches the make-release action
|
# name: lancedb ${{ matrix.config.arch }}-pc-windows-msvc
|
||||||
if: startsWith(github.ref, 'refs/tags/v')
|
# # Only runs on tags that matches the make-release action
|
||||||
runs-on: ubuntu-latest
|
# # if: startsWith(github.ref, 'refs/tags/v')
|
||||||
container: alpine:edge
|
# runs-on: ubuntu-latest
|
||||||
strategy:
|
# container: alpine:edge
|
||||||
fail-fast: false
|
# strategy:
|
||||||
matrix:
|
# fail-fast: false
|
||||||
config:
|
# matrix:
|
||||||
# - arch: x86_64
|
# config:
|
||||||
- arch: aarch64
|
# # - arch: x86_64
|
||||||
steps:
|
# - arch: aarch64
|
||||||
- name: Checkout
|
# steps:
|
||||||
uses: actions/checkout@v4
|
# - name: Checkout
|
||||||
- name: Install dependencies
|
# uses: actions/checkout@v4
|
||||||
run: |
|
# - name: Install dependencies
|
||||||
apk add protobuf-dev curl clang lld llvm19 grep npm bash msitools sed
|
# run: |
|
||||||
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y --default-toolchain 1.80.0
|
# apk add protobuf-dev curl clang lld llvm19 grep npm bash msitools sed
|
||||||
echo "source $HOME/.cargo/env" >> saved_env
|
# curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y
|
||||||
echo "export CC=clang" >> saved_env
|
# echo "source $HOME/.cargo/env" >> saved_env
|
||||||
echo "export AR=llvm-ar" >> saved_env
|
# echo "export CC=clang" >> saved_env
|
||||||
source "$HOME/.cargo/env"
|
# echo "export AR=llvm-ar" >> saved_env
|
||||||
rustup target add ${{ matrix.config.arch }}-pc-windows-msvc --toolchain 1.80.0
|
# source "$HOME/.cargo/env"
|
||||||
(mkdir -p sysroot && cd sysroot && sh ../ci/sysroot-${{ matrix.config.arch }}-pc-windows-msvc.sh)
|
# rustup target add ${{ matrix.config.arch }}-pc-windows-msvc
|
||||||
echo "export C_INCLUDE_PATH=/usr/${{ matrix.config.arch }}-pc-windows-msvc/usr/include" >> saved_env
|
# (mkdir -p sysroot && cd sysroot && sh ../ci/sysroot-${{ matrix.config.arch }}-pc-windows-msvc.sh)
|
||||||
echo "export CARGO_BUILD_TARGET=${{ matrix.config.arch }}-pc-windows-msvc" >> saved_env
|
# echo "export C_INCLUDE_PATH=/usr/${{ matrix.config.arch }}-pc-windows-msvc/usr/include" >> saved_env
|
||||||
printf '#!/bin/sh\ncargo "$@"' > $HOME/.cargo/bin/cargo-xwin
|
# echo "export CARGO_BUILD_TARGET=${{ matrix.config.arch }}-pc-windows-msvc" >> saved_env
|
||||||
chmod u+x $HOME/.cargo/bin/cargo-xwin
|
# printf '#!/bin/sh\ncargo "$@"' > $HOME/.cargo/bin/cargo-xwin
|
||||||
- name: Configure x86_64 build
|
# chmod u+x $HOME/.cargo/bin/cargo-xwin
|
||||||
if: ${{ matrix.config.arch == 'x86_64' }}
|
# - name: Configure x86_64 build
|
||||||
run: |
|
# if: ${{ matrix.config.arch == 'x86_64' }}
|
||||||
echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=+crt-static,+avx2,+fma,+f16c -Clinker=lld -Clink-arg=/LIBPATH:/usr/x86_64-pc-windows-msvc/usr/lib'" >> saved_env
|
# run: |
|
||||||
- name: Configure aarch64 build
|
# echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=+crt-static,+avx2,+fma,+f16c -Clinker=lld -Clink-arg=/LIBPATH:/usr/x86_64-pc-windows-msvc/usr/lib'" >> saved_env
|
||||||
if: ${{ matrix.config.arch == 'aarch64' }}
|
# - name: Configure aarch64 build
|
||||||
run: |
|
# if: ${{ matrix.config.arch == 'aarch64' }}
|
||||||
echo "export RUSTFLAGS='-Ctarget-feature=+crt-static,+neon,+fp16,+fhm,+dotprod -Clinker=lld -Clink-arg=/LIBPATH:/usr/aarch64-pc-windows-msvc/usr/lib -Clink-arg=arm64rt.lib'" >> saved_env
|
# run: |
|
||||||
- name: Build Windows Artifacts
|
# echo "export RUSTFLAGS='-Ctarget-feature=+crt-static,+neon,+fp16,+fhm,+dotprod -Clinker=lld -Clink-arg=/LIBPATH:/usr/aarch64-pc-windows-msvc/usr/lib -Clink-arg=arm64rt.lib'" >> saved_env
|
||||||
run: |
|
# - name: Build Windows Artifacts
|
||||||
source ./saved_env
|
# run: |
|
||||||
bash ci/manylinux_node/build_lancedb.sh ${{ matrix.config.arch }}
|
# source ./saved_env
|
||||||
- name: Upload Windows Artifacts
|
# bash ci/manylinux_node/build_lancedb.sh ${{ matrix.config.arch }}
|
||||||
uses: actions/upload-artifact@v4
|
# - name: Upload Windows Artifacts
|
||||||
with:
|
# uses: actions/upload-artifact@v4
|
||||||
name: nodejs-native-windows-${{ matrix.config.arch }}
|
# with:
|
||||||
path: |
|
# name: nodejs-native-windows-${{ matrix.config.arch }}
|
||||||
nodejs/dist/*.node
|
# path: |
|
||||||
|
# nodejs/dist/*.node
|
||||||
|
|
||||||
release:
|
release:
|
||||||
name: vectordb NPM Publish
|
name: vectordb NPM Publish
|
||||||
needs: [node, node-macos, node-linux-gnu, node-linux-musl, node-windows, node-windows-arm64]
|
needs: [node, node-macos, node-linux-gnu, node-linux-musl, node-windows]
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
# Only runs on tags that matches the make-release action
|
# Only runs on tags that matches the make-release action
|
||||||
if: startsWith(github.ref, 'refs/tags/v')
|
if: startsWith(github.ref, 'refs/tags/v')
|
||||||
@@ -503,7 +505,7 @@ jobs:
|
|||||||
|
|
||||||
release-nodejs:
|
release-nodejs:
|
||||||
name: lancedb NPM Publish
|
name: lancedb NPM Publish
|
||||||
needs: [nodejs-macos, nodejs-linux-gnu, nodejs-linux-musl, nodejs-windows, nodejs-windows-arm64]
|
needs: [nodejs-macos, nodejs-linux-gnu, nodejs-linux-musl, nodejs-windows]
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
# Only runs on tags that matches the make-release action
|
# Only runs on tags that matches the make-release action
|
||||||
if: startsWith(github.ref, 'refs/tags/v')
|
if: startsWith(github.ref, 'refs/tags/v')
|
||||||
|
|||||||
18
Cargo.toml
18
Cargo.toml
@@ -21,16 +21,16 @@ categories = ["database-implementations"]
|
|||||||
rust-version = "1.78.0"
|
rust-version = "1.78.0"
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
lance = { "version" = "=0.21.0", "features" = [
|
lance = { "version" = "=0.21.1", "features" = [
|
||||||
"dynamodb",
|
"dynamodb",
|
||||||
], git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.5" }
|
], git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||||
lance-io = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.5" }
|
lance-io = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||||
lance-index = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.5" }
|
lance-index = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||||
lance-linalg = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.5" }
|
lance-linalg = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||||
lance-table = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.5" }
|
lance-table = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||||
lance-testing = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.5" }
|
lance-testing = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||||
lance-datafusion = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.5" }
|
lance-datafusion = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||||
lance-encoding = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.5" }
|
lance-encoding = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||||
# Note that this one does not include pyarrow
|
# Note that this one does not include pyarrow
|
||||||
arrow = { version = "53.2", optional = false }
|
arrow = { version = "53.2", optional = false }
|
||||||
arrow-array = "53.2"
|
arrow-array = "53.2"
|
||||||
|
|||||||
@@ -133,6 +133,10 @@ lists the indices that LanceDb supports.
|
|||||||
|
|
||||||
::: lancedb.index.IvfPq
|
::: lancedb.index.IvfPq
|
||||||
|
|
||||||
|
::: lancedb.index.HnswPq
|
||||||
|
|
||||||
|
::: lancedb.index.HnswSq
|
||||||
|
|
||||||
::: lancedb.index.IvfFlat
|
::: lancedb.index.IvfFlat
|
||||||
|
|
||||||
## Querying (Asynchronous)
|
## Querying (Asynchronous)
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
<parent>
|
<parent>
|
||||||
<groupId>com.lancedb</groupId>
|
<groupId>com.lancedb</groupId>
|
||||||
<artifactId>lancedb-parent</artifactId>
|
<artifactId>lancedb-parent</artifactId>
|
||||||
<version>0.14.1-beta.6</version>
|
<version>0.14.1-final.0</version>
|
||||||
<relativePath>../pom.xml</relativePath>
|
<relativePath>../pom.xml</relativePath>
|
||||||
</parent>
|
</parent>
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
<groupId>com.lancedb</groupId>
|
<groupId>com.lancedb</groupId>
|
||||||
<artifactId>lancedb-parent</artifactId>
|
<artifactId>lancedb-parent</artifactId>
|
||||||
<version>0.14.1-beta.6</version>
|
<version>0.14.1-final.0</version>
|
||||||
<packaging>pom</packaging>
|
<packaging>pom</packaging>
|
||||||
|
|
||||||
<name>LanceDB Parent</name>
|
<name>LanceDB Parent</name>
|
||||||
|
|||||||
111
node/package-lock.json
generated
111
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.14.1-beta.6",
|
"version": "0.14.1",
|
||||||
"lockfileVersion": 3,
|
"lockfileVersion": 3,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
"": {
|
"": {
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.14.1-beta.6",
|
"version": "0.14.1",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64",
|
"x64",
|
||||||
"arm64"
|
"arm64"
|
||||||
@@ -52,14 +52,14 @@
|
|||||||
"uuid": "^9.0.0"
|
"uuid": "^9.0.0"
|
||||||
},
|
},
|
||||||
"optionalDependencies": {
|
"optionalDependencies": {
|
||||||
"@lancedb/vectordb-darwin-arm64": "0.14.1-beta.6",
|
"@lancedb/vectordb-darwin-arm64": "0.14.1",
|
||||||
"@lancedb/vectordb-darwin-x64": "0.14.1-beta.6",
|
"@lancedb/vectordb-darwin-x64": "0.14.1",
|
||||||
"@lancedb/vectordb-linux-arm64-gnu": "0.14.1-beta.6",
|
"@lancedb/vectordb-linux-arm64-gnu": "0.14.1",
|
||||||
"@lancedb/vectordb-linux-arm64-musl": "0.14.1-beta.6",
|
"@lancedb/vectordb-linux-arm64-musl": "0.14.1",
|
||||||
"@lancedb/vectordb-linux-x64-gnu": "0.14.1-beta.6",
|
"@lancedb/vectordb-linux-x64-gnu": "0.14.1",
|
||||||
"@lancedb/vectordb-linux-x64-musl": "0.14.1-beta.6",
|
"@lancedb/vectordb-linux-x64-musl": "0.14.1",
|
||||||
"@lancedb/vectordb-win32-arm64-msvc": "0.14.1-beta.6",
|
"@lancedb/vectordb-win32-arm64-msvc": "0.14.1",
|
||||||
"@lancedb/vectordb-win32-x64-msvc": "0.14.1-beta.6"
|
"@lancedb/vectordb-win32-x64-msvc": "0.14.1"
|
||||||
},
|
},
|
||||||
"peerDependencies": {
|
"peerDependencies": {
|
||||||
"@apache-arrow/ts": "^14.0.2",
|
"@apache-arrow/ts": "^14.0.2",
|
||||||
@@ -329,6 +329,97 @@
|
|||||||
"@jridgewell/sourcemap-codec": "^1.4.10"
|
"@jridgewell/sourcemap-codec": "^1.4.10"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
||||||
|
"version": "0.14.1",
|
||||||
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.14.1.tgz",
|
||||||
|
"integrity": "sha512-6t7XHR7dBjDmAS/kz5wbe7LPhKW+WkFA16ZPyh0lmuxfnss4VvN3LE6qQBHjzYzB9U6Nu/4ktQ50xZGEPTnc5A==",
|
||||||
|
"cpu": [
|
||||||
|
"arm64"
|
||||||
|
],
|
||||||
|
"license": "Apache-2.0",
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"darwin"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"node_modules/@lancedb/vectordb-darwin-x64": {
|
||||||
|
"version": "0.14.1",
|
||||||
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.14.1.tgz",
|
||||||
|
"integrity": "sha512-8q6Kd6XnNPKN8wqj75pHVQ4KFl6z9BaI6lWDiEaCNcO3bjPZkcLFNosJq4raxZ9iUi50Yl0qFJ6qR0XFVTwnnw==",
|
||||||
|
"cpu": [
|
||||||
|
"x64"
|
||||||
|
],
|
||||||
|
"license": "Apache-2.0",
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"darwin"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
||||||
|
"version": "0.14.1",
|
||||||
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.14.1.tgz",
|
||||||
|
"integrity": "sha512-4djEMmeNb+p6nW/C4xb8wdMwnIbWfO8fYAwiplOxzxeOpPaUC9rhwUUDCbrJDCpMa8RP5ED4/jC6yT8epaDMDw==",
|
||||||
|
"cpu": [
|
||||||
|
"arm64"
|
||||||
|
],
|
||||||
|
"license": "Apache-2.0",
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"linux"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"node_modules/@lancedb/vectordb-linux-arm64-musl": {
|
||||||
|
"version": "0.14.1",
|
||||||
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-musl/-/vectordb-linux-arm64-musl-0.14.1.tgz",
|
||||||
|
"integrity": "sha512-c33hSsp16pnC58plzx1OXuifp9Rachx/MshE/L/OReoutt74fFdrRJwUjE4UCAysyY5QdvTrNm9OhDjopQK2Bw==",
|
||||||
|
"cpu": [
|
||||||
|
"arm64"
|
||||||
|
],
|
||||||
|
"license": "Apache-2.0",
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"linux"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
||||||
|
"version": "0.14.1",
|
||||||
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.14.1.tgz",
|
||||||
|
"integrity": "sha512-psu6cH9iLiSbUEZD1EWbOA4THGYSwJvS2XICO9yN7A6D41AP/ynYMRZNKWo1fpdi2Fjb0xNQwiNhQyqwbi5gzA==",
|
||||||
|
"cpu": [
|
||||||
|
"x64"
|
||||||
|
],
|
||||||
|
"license": "Apache-2.0",
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"linux"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"node_modules/@lancedb/vectordb-linux-x64-musl": {
|
||||||
|
"version": "0.14.1",
|
||||||
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-musl/-/vectordb-linux-x64-musl-0.14.1.tgz",
|
||||||
|
"integrity": "sha512-Rg4VWW80HaTFmR7EvNSu+nfRQQM8beO/otBn/Nus5mj5zFw/7cacGRmiEYhDnk5iAn8nauV+Jsi9j2U+C2hp5w==",
|
||||||
|
"cpu": [
|
||||||
|
"x64"
|
||||||
|
],
|
||||||
|
"license": "Apache-2.0",
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"linux"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
||||||
|
"version": "0.14.1",
|
||||||
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.14.1.tgz",
|
||||||
|
"integrity": "sha512-XbifasmMbQIt3V9P0AtQND6M3XFiIAc1ZIgmjzBjOmxwqw4sQUwHMyJGIGOzKFZTK3fPJIGRHId7jAzXuBgfQg==",
|
||||||
|
"cpu": [
|
||||||
|
"x64"
|
||||||
|
],
|
||||||
|
"license": "Apache-2.0",
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"win32"
|
||||||
|
]
|
||||||
|
},
|
||||||
"node_modules/@neon-rs/cli": {
|
"node_modules/@neon-rs/cli": {
|
||||||
"version": "0.0.160",
|
"version": "0.0.160",
|
||||||
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",
|
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.14.1-beta.6",
|
"version": "0.14.1",
|
||||||
"description": " Serverless, low-latency vector database for AI applications",
|
"description": " Serverless, low-latency vector database for AI applications",
|
||||||
"private": false,
|
"private": false,
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
@@ -92,13 +92,13 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"optionalDependencies": {
|
"optionalDependencies": {
|
||||||
"@lancedb/vectordb-darwin-x64": "0.14.1-beta.6",
|
"@lancedb/vectordb-darwin-x64": "0.14.1",
|
||||||
"@lancedb/vectordb-darwin-arm64": "0.14.1-beta.6",
|
"@lancedb/vectordb-darwin-arm64": "0.14.1",
|
||||||
"@lancedb/vectordb-linux-x64-gnu": "0.14.1-beta.6",
|
"@lancedb/vectordb-linux-x64-gnu": "0.14.1",
|
||||||
"@lancedb/vectordb-linux-arm64-gnu": "0.14.1-beta.6",
|
"@lancedb/vectordb-linux-arm64-gnu": "0.14.1",
|
||||||
"@lancedb/vectordb-linux-x64-musl": "0.14.1-beta.6",
|
"@lancedb/vectordb-linux-x64-musl": "0.14.1",
|
||||||
"@lancedb/vectordb-linux-arm64-musl": "0.14.1-beta.6",
|
"@lancedb/vectordb-linux-arm64-musl": "0.14.1",
|
||||||
"@lancedb/vectordb-win32-x64-msvc": "0.14.1-beta.6",
|
"@lancedb/vectordb-win32-x64-msvc": "0.14.1",
|
||||||
"@lancedb/vectordb-win32-arm64-msvc": "0.14.1-beta.6"
|
"@lancedb/vectordb-win32-arm64-msvc": "0.14.1"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb-nodejs"
|
name = "lancedb-nodejs"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
version = "0.14.1-beta.6"
|
version = "0.14.1"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
description.workspace = true
|
description.workspace = true
|
||||||
repository.workspace = true
|
repository.workspace = true
|
||||||
@@ -12,7 +12,10 @@ categories.workspace = true
|
|||||||
crate-type = ["cdylib"]
|
crate-type = ["cdylib"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
async-trait.workspace = true
|
||||||
arrow-ipc.workspace = true
|
arrow-ipc.workspace = true
|
||||||
|
arrow-array.workspace = true
|
||||||
|
arrow-schema.workspace = true
|
||||||
env_logger.workspace = true
|
env_logger.workspace = true
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
lancedb = { path = "../rust/lancedb", features = ["remote"] }
|
lancedb = { path = "../rust/lancedb", features = ["remote"] }
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ import * as arrow18 from "apache-arrow-18";
|
|||||||
|
|
||||||
import {
|
import {
|
||||||
convertToTable,
|
convertToTable,
|
||||||
|
fromBufferToRecordBatch,
|
||||||
|
fromRecordBatchToBuffer,
|
||||||
fromTableToBuffer,
|
fromTableToBuffer,
|
||||||
makeArrowTable,
|
makeArrowTable,
|
||||||
makeEmptyTable,
|
makeEmptyTable,
|
||||||
@@ -553,5 +555,28 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe("converting record batches to buffers", function () {
|
||||||
|
it("can convert to buffered record batch and back again", async function () {
|
||||||
|
const records = [
|
||||||
|
{ text: "dog", vector: [0.1, 0.2] },
|
||||||
|
{ text: "cat", vector: [0.3, 0.4] },
|
||||||
|
];
|
||||||
|
const table = await convertToTable(records);
|
||||||
|
const batch = table.batches[0];
|
||||||
|
|
||||||
|
const buffer = await fromRecordBatchToBuffer(batch);
|
||||||
|
const result = await fromBufferToRecordBatch(buffer);
|
||||||
|
|
||||||
|
expect(JSON.stringify(batch.toArray())).toEqual(
|
||||||
|
JSON.stringify(result?.toArray()),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("converting from buffer returns null if buffer has no record batches", async function () {
|
||||||
|
const result = await fromBufferToRecordBatch(Buffer.from([0x01, 0x02])); // bad data
|
||||||
|
expect(result).toEqual(null);
|
||||||
|
});
|
||||||
|
});
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|||||||
79
nodejs/__test__/rerankers.test.ts
Normal file
79
nodejs/__test__/rerankers.test.ts
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
import { RecordBatch } from "apache-arrow";
|
||||||
|
import * as tmp from "tmp";
|
||||||
|
import { Connection, Index, Table, connect, makeArrowTable } from "../lancedb";
|
||||||
|
import { RRFReranker } from "../lancedb/rerankers";
|
||||||
|
|
||||||
|
describe("rerankers", function () {
|
||||||
|
let tmpDir: tmp.DirResult;
|
||||||
|
let conn: Connection;
|
||||||
|
let table: Table;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
tmpDir = tmp.dirSync({ unsafeCleanup: true });
|
||||||
|
conn = await connect(tmpDir.name);
|
||||||
|
table = await conn.createTable("mytable", [
|
||||||
|
{ vector: [0.1, 0.1], text: "dog" },
|
||||||
|
{ vector: [0.2, 0.2], text: "cat" },
|
||||||
|
]);
|
||||||
|
await table.createIndex("text", {
|
||||||
|
config: Index.fts(),
|
||||||
|
replace: true,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it("will query with the custom reranker", async function () {
|
||||||
|
const expectedResult = [
|
||||||
|
{
|
||||||
|
text: "albert",
|
||||||
|
// biome-ignore lint/style/useNamingConvention: this is the lance field name
|
||||||
|
_relevance_score: 0.99,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
class MyCustomReranker {
|
||||||
|
async rerankHybrid(
|
||||||
|
_query: string,
|
||||||
|
_vecResults: RecordBatch,
|
||||||
|
_ftsResults: RecordBatch,
|
||||||
|
): Promise<RecordBatch> {
|
||||||
|
// no reranker logic, just return some static data
|
||||||
|
const table = makeArrowTable(expectedResult);
|
||||||
|
return table.batches[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = await table
|
||||||
|
.query()
|
||||||
|
.nearestTo([0.1, 0.1])
|
||||||
|
.fullTextSearch("dog")
|
||||||
|
.rerank(new MyCustomReranker())
|
||||||
|
.select(["text"])
|
||||||
|
.limit(5)
|
||||||
|
.toArray();
|
||||||
|
|
||||||
|
result = JSON.parse(JSON.stringify(result)); // convert StructRow to Object
|
||||||
|
expect(result).toEqual([
|
||||||
|
{
|
||||||
|
text: "albert",
|
||||||
|
// biome-ignore lint/style/useNamingConvention: this is the lance field name
|
||||||
|
_relevance_score: 0.99,
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("will query with RRFReranker", async function () {
|
||||||
|
// smoke test to see if the Rust wrapping Typescript is wired up correctly
|
||||||
|
const result = await table
|
||||||
|
.query()
|
||||||
|
.nearestTo([0.1, 0.1])
|
||||||
|
.fullTextSearch("dog")
|
||||||
|
.rerank(await RRFReranker.create())
|
||||||
|
.select(["text"])
|
||||||
|
.limit(5)
|
||||||
|
.toArray();
|
||||||
|
|
||||||
|
expect(result).toHaveLength(2);
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -27,7 +27,9 @@ import {
|
|||||||
List,
|
List,
|
||||||
Null,
|
Null,
|
||||||
RecordBatch,
|
RecordBatch,
|
||||||
|
RecordBatchFileReader,
|
||||||
RecordBatchFileWriter,
|
RecordBatchFileWriter,
|
||||||
|
RecordBatchReader,
|
||||||
RecordBatchStreamWriter,
|
RecordBatchStreamWriter,
|
||||||
Schema,
|
Schema,
|
||||||
Struct,
|
Struct,
|
||||||
@@ -810,6 +812,30 @@ export async function fromDataToBuffer(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Read a single record batch from a buffer.
|
||||||
|
*
|
||||||
|
* Returns null if the buffer does not contain a record batch
|
||||||
|
*/
|
||||||
|
export async function fromBufferToRecordBatch(
|
||||||
|
data: Buffer,
|
||||||
|
): Promise<RecordBatch | null> {
|
||||||
|
const iter = await RecordBatchFileReader.readAll(Buffer.from(data)).next()
|
||||||
|
.value;
|
||||||
|
const recordBatch = iter?.next().value;
|
||||||
|
return recordBatch || null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a buffer containing a single record batch
|
||||||
|
*/
|
||||||
|
export async function fromRecordBatchToBuffer(
|
||||||
|
batch: RecordBatch,
|
||||||
|
): Promise<Buffer> {
|
||||||
|
const writer = new RecordBatchFileWriter().writeAll([batch]);
|
||||||
|
return Buffer.from(await writer.toUint8Array());
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Serialize an Arrow Table into a buffer using the Arrow IPC Stream serialization
|
* Serialize an Arrow Table into a buffer using the Arrow IPC Stream serialization
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ export { Index, IndexOptions, IvfPqOptions } from "./indices";
|
|||||||
export { Table, AddDataOptions, UpdateOptions, OptimizeOptions } from "./table";
|
export { Table, AddDataOptions, UpdateOptions, OptimizeOptions } from "./table";
|
||||||
|
|
||||||
export * as embedding from "./embedding";
|
export * as embedding from "./embedding";
|
||||||
|
export * as rerankers from "./rerankers";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Connect to a LanceDB instance at the given URI.
|
* Connect to a LanceDB instance at the given URI.
|
||||||
|
|||||||
@@ -16,6 +16,8 @@ import {
|
|||||||
Table as ArrowTable,
|
Table as ArrowTable,
|
||||||
type IntoVector,
|
type IntoVector,
|
||||||
RecordBatch,
|
RecordBatch,
|
||||||
|
fromBufferToRecordBatch,
|
||||||
|
fromRecordBatchToBuffer,
|
||||||
tableFromIPC,
|
tableFromIPC,
|
||||||
} from "./arrow";
|
} from "./arrow";
|
||||||
import { type IvfPqOptions } from "./indices";
|
import { type IvfPqOptions } from "./indices";
|
||||||
@@ -25,6 +27,7 @@ import {
|
|||||||
Table as NativeTable,
|
Table as NativeTable,
|
||||||
VectorQuery as NativeVectorQuery,
|
VectorQuery as NativeVectorQuery,
|
||||||
} from "./native";
|
} from "./native";
|
||||||
|
import { Reranker } from "./rerankers";
|
||||||
export class RecordBatchIterator implements AsyncIterator<RecordBatch> {
|
export class RecordBatchIterator implements AsyncIterator<RecordBatch> {
|
||||||
private promisedInner?: Promise<NativeBatchIterator>;
|
private promisedInner?: Promise<NativeBatchIterator>;
|
||||||
private inner?: NativeBatchIterator;
|
private inner?: NativeBatchIterator;
|
||||||
@@ -542,6 +545,27 @@ export class VectorQuery extends QueryBase<NativeVectorQuery> {
|
|||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rerank(reranker: Reranker): VectorQuery {
|
||||||
|
super.doCall((inner) =>
|
||||||
|
inner.rerank({
|
||||||
|
rerankHybrid: async (_, args) => {
|
||||||
|
const vecResults = await fromBufferToRecordBatch(args.vecResults);
|
||||||
|
const ftsResults = await fromBufferToRecordBatch(args.ftsResults);
|
||||||
|
const result = await reranker.rerankHybrid(
|
||||||
|
args.query,
|
||||||
|
vecResults as RecordBatch,
|
||||||
|
ftsResults as RecordBatch,
|
||||||
|
);
|
||||||
|
|
||||||
|
const buffer = fromRecordBatchToBuffer(result);
|
||||||
|
return buffer;
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
return this;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** A builder for LanceDB queries. */
|
/** A builder for LanceDB queries. */
|
||||||
|
|||||||
17
nodejs/lancedb/rerankers/index.ts
Normal file
17
nodejs/lancedb/rerankers/index.ts
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
import { RecordBatch } from "apache-arrow";
|
||||||
|
|
||||||
|
export * from "./rrf";
|
||||||
|
|
||||||
|
// Interface for a reranker. A reranker is used to rerank the results from a
|
||||||
|
// vector and FTS search. This is useful for combining the results from both
|
||||||
|
// search methods.
|
||||||
|
export interface Reranker {
|
||||||
|
rerankHybrid(
|
||||||
|
query: string,
|
||||||
|
vecResults: RecordBatch,
|
||||||
|
ftsResults: RecordBatch,
|
||||||
|
): Promise<RecordBatch>;
|
||||||
|
}
|
||||||
40
nodejs/lancedb/rerankers/rrf.ts
Normal file
40
nodejs/lancedb/rerankers/rrf.ts
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
import { RecordBatch } from "apache-arrow";
|
||||||
|
import { fromBufferToRecordBatch, fromRecordBatchToBuffer } from "../arrow";
|
||||||
|
import { RrfReranker as NativeRRFReranker } from "../native";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reranks the results using the Reciprocal Rank Fusion (RRF) algorithm.
|
||||||
|
*
|
||||||
|
* Internally this uses the Rust implementation
|
||||||
|
*/
|
||||||
|
export class RRFReranker {
|
||||||
|
private inner: NativeRRFReranker;
|
||||||
|
|
||||||
|
constructor(inner: NativeRRFReranker) {
|
||||||
|
this.inner = inner;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static async create(k: number = 60) {
|
||||||
|
return new RRFReranker(
|
||||||
|
await NativeRRFReranker.tryNew(new Float32Array([k])),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
async rerankHybrid(
|
||||||
|
query: string,
|
||||||
|
vecResults: RecordBatch,
|
||||||
|
ftsResults: RecordBatch,
|
||||||
|
): Promise<RecordBatch> {
|
||||||
|
const buffer = await this.inner.rerankHybrid(
|
||||||
|
query,
|
||||||
|
await fromRecordBatchToBuffer(vecResults),
|
||||||
|
await fromRecordBatchToBuffer(ftsResults),
|
||||||
|
);
|
||||||
|
const recordBatch = await fromBufferToRecordBatch(buffer);
|
||||||
|
|
||||||
|
return recordBatch as RecordBatch;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-darwin-arm64",
|
"name": "@lancedb/lancedb-darwin-arm64",
|
||||||
"version": "0.14.1-beta.6",
|
"version": "0.14.1",
|
||||||
"os": ["darwin"],
|
"os": ["darwin"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.darwin-arm64.node",
|
"main": "lancedb.darwin-arm64.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-darwin-x64",
|
"name": "@lancedb/lancedb-darwin-x64",
|
||||||
"version": "0.14.1-beta.6",
|
"version": "0.14.1",
|
||||||
"os": ["darwin"],
|
"os": ["darwin"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.darwin-x64.node",
|
"main": "lancedb.darwin-x64.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||||
"version": "0.14.1-beta.6",
|
"version": "0.14.1",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.linux-arm64-gnu.node",
|
"main": "lancedb.linux-arm64-gnu.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||||
"version": "0.14.1-beta.6",
|
"version": "0.14.1",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.linux-arm64-musl.node",
|
"main": "lancedb.linux-arm64-musl.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||||
"version": "0.14.1-beta.6",
|
"version": "0.14.1",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.linux-x64-gnu.node",
|
"main": "lancedb.linux-x64-gnu.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||||
"version": "0.14.1-beta.6",
|
"version": "0.14.1",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.linux-x64-musl.node",
|
"main": "lancedb.linux-x64-musl.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||||
"version": "0.14.1-beta.6",
|
"version": "0.14.1",
|
||||||
"os": [
|
"os": [
|
||||||
"win32"
|
"win32"
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||||
"version": "0.14.1-beta.6",
|
"version": "0.14.1",
|
||||||
"os": ["win32"],
|
"os": ["win32"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.win32-x64-msvc.node",
|
"main": "lancedb.win32-x64-msvc.node",
|
||||||
|
|||||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb",
|
"name": "@lancedb/lancedb",
|
||||||
"version": "0.14.1-beta.6",
|
"version": "0.14.1",
|
||||||
"lockfileVersion": 3,
|
"lockfileVersion": 3,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
"": {
|
"": {
|
||||||
"name": "@lancedb/lancedb",
|
"name": "@lancedb/lancedb",
|
||||||
"version": "0.14.1-beta.6",
|
"version": "0.14.1",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64",
|
"x64",
|
||||||
"arm64"
|
"arm64"
|
||||||
|
|||||||
@@ -11,7 +11,7 @@
|
|||||||
"ann"
|
"ann"
|
||||||
],
|
],
|
||||||
"private": false,
|
"private": false,
|
||||||
"version": "0.14.1-beta.6",
|
"version": "0.14.1",
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
"exports": {
|
"exports": {
|
||||||
".": "./dist/index.js",
|
".": "./dist/index.js",
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ mod iterator;
|
|||||||
pub mod merge;
|
pub mod merge;
|
||||||
mod query;
|
mod query;
|
||||||
pub mod remote;
|
pub mod remote;
|
||||||
|
mod rerankers;
|
||||||
mod table;
|
mod table;
|
||||||
mod util;
|
mod util;
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,8 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use lancedb::index::scalar::FullTextSearchQuery;
|
use lancedb::index::scalar::FullTextSearchQuery;
|
||||||
use lancedb::query::ExecutableQuery;
|
use lancedb::query::ExecutableQuery;
|
||||||
use lancedb::query::Query as LanceDbQuery;
|
use lancedb::query::Query as LanceDbQuery;
|
||||||
@@ -25,6 +27,8 @@ use napi_derive::napi;
|
|||||||
use crate::error::convert_error;
|
use crate::error::convert_error;
|
||||||
use crate::error::NapiErrorExt;
|
use crate::error::NapiErrorExt;
|
||||||
use crate::iterator::RecordBatchIterator;
|
use crate::iterator::RecordBatchIterator;
|
||||||
|
use crate::rerankers::Reranker;
|
||||||
|
use crate::rerankers::RerankerCallbacks;
|
||||||
use crate::util::parse_distance_type;
|
use crate::util::parse_distance_type;
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
@@ -218,6 +222,14 @@ impl VectorQuery {
|
|||||||
self.inner = self.inner.clone().with_row_id();
|
self.inner = self.inner.clone().with_row_id();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub fn rerank(&mut self, callbacks: RerankerCallbacks) {
|
||||||
|
self.inner = self
|
||||||
|
.inner
|
||||||
|
.clone()
|
||||||
|
.rerank(Arc::new(Reranker::new(callbacks)));
|
||||||
|
}
|
||||||
|
|
||||||
#[napi(catch_unwind)]
|
#[napi(catch_unwind)]
|
||||||
pub async fn execute(
|
pub async fn execute(
|
||||||
&self,
|
&self,
|
||||||
|
|||||||
147
nodejs/src/rerankers.rs
Normal file
147
nodejs/src/rerankers.rs
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
use arrow_array::RecordBatch;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use napi::{
|
||||||
|
bindgen_prelude::*,
|
||||||
|
threadsafe_function::{ErrorStrategy, ThreadsafeFunction},
|
||||||
|
};
|
||||||
|
use napi_derive::napi;
|
||||||
|
|
||||||
|
use lancedb::ipc::batches_to_ipc_file;
|
||||||
|
use lancedb::rerankers::Reranker as LanceDBReranker;
|
||||||
|
use lancedb::{error::Error, ipc::ipc_file_to_batches};
|
||||||
|
|
||||||
|
use crate::error::NapiErrorExt;
|
||||||
|
|
||||||
|
/// Reranker implementation that "wraps" a NodeJS Reranker implementation.
|
||||||
|
/// This contains references to the callbacks that can be used to invoke the
|
||||||
|
/// reranking methods on the NodeJS implementation and handles serializing the
|
||||||
|
/// record batches to Arrow IPC buffers.
|
||||||
|
#[napi]
|
||||||
|
pub struct Reranker {
|
||||||
|
/// callback to the Javascript which will call the rerankHybrid method of
|
||||||
|
/// some Reranker implementation
|
||||||
|
rerank_hybrid: ThreadsafeFunction<RerankHybridCallbackArgs, ErrorStrategy::CalleeHandled>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
impl Reranker {
|
||||||
|
#[napi]
|
||||||
|
pub fn new(callbacks: RerankerCallbacks) -> Self {
|
||||||
|
let rerank_hybrid = callbacks
|
||||||
|
.rerank_hybrid
|
||||||
|
.create_threadsafe_function(0, move |ctx| Ok(vec![ctx.value]))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
Self { rerank_hybrid }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl lancedb::rerankers::Reranker for Reranker {
|
||||||
|
async fn rerank_hybrid(
|
||||||
|
&self,
|
||||||
|
query: &str,
|
||||||
|
vector_results: RecordBatch,
|
||||||
|
fts_results: RecordBatch,
|
||||||
|
) -> lancedb::error::Result<RecordBatch> {
|
||||||
|
let callback_args = RerankHybridCallbackArgs {
|
||||||
|
query: query.to_string(),
|
||||||
|
vec_results: batches_to_ipc_file(&[vector_results])?,
|
||||||
|
fts_results: batches_to_ipc_file(&[fts_results])?,
|
||||||
|
};
|
||||||
|
let promised_buffer: Promise<Buffer> = self
|
||||||
|
.rerank_hybrid
|
||||||
|
.call_async(Ok(callback_args))
|
||||||
|
.await
|
||||||
|
.map_err(|e| Error::Runtime {
|
||||||
|
message: format!("napi error status={}, reason={}", e.status, e.reason),
|
||||||
|
})?;
|
||||||
|
let buffer = promised_buffer.await.map_err(|e| Error::Runtime {
|
||||||
|
message: format!("napi error status={}, reason={}", e.status, e.reason),
|
||||||
|
})?;
|
||||||
|
let mut reader = ipc_file_to_batches(buffer.to_vec())?;
|
||||||
|
let result = reader.next().ok_or(Error::Runtime {
|
||||||
|
message: "reranker result deserialization failed".to_string(),
|
||||||
|
})??;
|
||||||
|
|
||||||
|
return Ok(result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for Reranker {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.write_str("NodeJSRerankerWrapper")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi(object)]
|
||||||
|
pub struct RerankerCallbacks {
|
||||||
|
pub rerank_hybrid: JsFunction,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi(object)]
|
||||||
|
pub struct RerankHybridCallbackArgs {
|
||||||
|
pub query: String,
|
||||||
|
pub vec_results: Vec<u8>,
|
||||||
|
pub fts_results: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn buffer_to_record_batch(buffer: Buffer) -> Result<RecordBatch> {
|
||||||
|
let mut reader = ipc_file_to_batches(buffer.to_vec()).default_error()?;
|
||||||
|
reader
|
||||||
|
.next()
|
||||||
|
.ok_or(Error::InvalidInput {
|
||||||
|
message: "expected buffer containing record batch".to_string(),
|
||||||
|
})
|
||||||
|
.default_error()?
|
||||||
|
.map_err(Error::from)
|
||||||
|
.default_error()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wrapper around rust RRFReranker
|
||||||
|
#[napi]
|
||||||
|
pub struct RRFReranker {
|
||||||
|
inner: lancedb::rerankers::rrf::RRFReranker,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
impl RRFReranker {
|
||||||
|
#[napi]
|
||||||
|
pub async fn try_new(k: &[f32]) -> Result<Self> {
|
||||||
|
let k = k
|
||||||
|
.first()
|
||||||
|
.copied()
|
||||||
|
.ok_or(Error::InvalidInput {
|
||||||
|
message: "must supply RRF Reranker constructor arg 'k'".to_string(),
|
||||||
|
})
|
||||||
|
.default_error()?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
inner: lancedb::rerankers::rrf::RRFReranker::new(k),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub async fn rerank_hybrid(
|
||||||
|
&self,
|
||||||
|
query: String,
|
||||||
|
vec_results: Buffer,
|
||||||
|
fts_results: Buffer,
|
||||||
|
) -> Result<Buffer> {
|
||||||
|
let vec_results = buffer_to_record_batch(vec_results)?;
|
||||||
|
let fts_results = buffer_to_record_batch(fts_results)?;
|
||||||
|
|
||||||
|
let result = self
|
||||||
|
.inner
|
||||||
|
.rerank_hybrid(&query, vec_results, fts_results)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let result_buff = batches_to_ipc_file(&[result]).default_error()?;
|
||||||
|
|
||||||
|
Ok(Buffer::from(result_buff.as_ref()))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
[tool.bumpversion]
|
[tool.bumpversion]
|
||||||
current_version = "0.17.1"
|
current_version = "0.17.2-beta.1"
|
||||||
parse = """(?x)
|
parse = """(?x)
|
||||||
(?P<major>0|[1-9]\\d*)\\.
|
(?P<major>0|[1-9]\\d*)\\.
|
||||||
(?P<minor>0|[1-9]\\d*)\\.
|
(?P<minor>0|[1-9]\\d*)\\.
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb-python"
|
name = "lancedb-python"
|
||||||
version = "0.17.1"
|
version = "0.17.2-beta.1"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
description = "Python bindings for LanceDB"
|
description = "Python bindings for LanceDB"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
# version in Cargo.toml
|
# version in Cargo.toml
|
||||||
|
dynamic = ["version"]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"deprecation",
|
"deprecation",
|
||||||
"pylance==0.21.0b5",
|
"pylance==0.21.1b1",
|
||||||
"tqdm>=4.27.0",
|
"tqdm>=4.27.0",
|
||||||
"pydantic>=1.10",
|
"pydantic>=1.10",
|
||||||
"packaging",
|
"packaging",
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ class Table:
|
|||||||
async def count_rows(self, filter: Optional[str]) -> int: ...
|
async def count_rows(self, filter: Optional[str]) -> int: ...
|
||||||
async def create_index(self, column: str, config, replace: Optional[bool]): ...
|
async def create_index(self, column: str, config, replace: Optional[bool]): ...
|
||||||
async def version(self) -> int: ...
|
async def version(self) -> int: ...
|
||||||
async def checkout(self, version): ...
|
async def checkout(self, version: int): ...
|
||||||
async def checkout_latest(self): ...
|
async def checkout_latest(self): ...
|
||||||
async def restore(self): ...
|
async def restore(self): ...
|
||||||
async def list_indices(self) -> List[IndexConfig]: ...
|
async def list_indices(self) -> List[IndexConfig]: ...
|
||||||
|
|||||||
@@ -568,4 +568,14 @@ class IvfPq:
|
|||||||
sample_rate: int = 256
|
sample_rate: int = 256
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["BTree", "IvfFlat", "IvfPq", "HnswPq", "HnswSq", "IndexConfig"]
|
__all__ = [
|
||||||
|
"BTree",
|
||||||
|
"IvfPq",
|
||||||
|
"IvfFlat",
|
||||||
|
"HnswPq",
|
||||||
|
"HnswSq",
|
||||||
|
"IndexConfig",
|
||||||
|
"FTS",
|
||||||
|
"Bitmap",
|
||||||
|
"LabelList",
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,15 +1,5 @@
|
|||||||
# Copyright 2023 LanceDB Developers
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
#
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
import logging
|
import logging
|
||||||
@@ -19,7 +9,7 @@ import warnings
|
|||||||
|
|
||||||
from lancedb._lancedb import IndexConfig
|
from lancedb._lancedb import IndexConfig
|
||||||
from lancedb.embeddings.base import EmbeddingFunctionConfig
|
from lancedb.embeddings.base import EmbeddingFunctionConfig
|
||||||
from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfPq, LabelList
|
from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfFlat, IvfPq, LabelList
|
||||||
from lancedb.remote.db import LOOP
|
from lancedb.remote.db import LOOP
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
|
||||||
@@ -91,7 +81,7 @@ class RemoteTable(Table):
|
|||||||
"""to_pandas() is not yet supported on LanceDB cloud."""
|
"""to_pandas() is not yet supported on LanceDB cloud."""
|
||||||
return NotImplementedError("to_pandas() is not yet supported on LanceDB cloud.")
|
return NotImplementedError("to_pandas() is not yet supported on LanceDB cloud.")
|
||||||
|
|
||||||
def checkout(self, version):
|
def checkout(self, version: int):
|
||||||
return LOOP.run(self._table.checkout(version))
|
return LOOP.run(self._table.checkout(version))
|
||||||
|
|
||||||
def checkout_latest(self):
|
def checkout_latest(self):
|
||||||
@@ -235,10 +225,12 @@ class RemoteTable(Table):
|
|||||||
config = HnswPq(distance_type=metric)
|
config = HnswPq(distance_type=metric)
|
||||||
elif index_type == "IVF_HNSW_SQ":
|
elif index_type == "IVF_HNSW_SQ":
|
||||||
config = HnswSq(distance_type=metric)
|
config = HnswSq(distance_type=metric)
|
||||||
|
elif index_type == "IVF_FLAT":
|
||||||
|
config = IvfFlat(distance_type=metric)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown vector index type: {index_type}. Valid options are"
|
f"Unknown vector index type: {index_type}. Valid options are"
|
||||||
" 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
|
" 'IVF_FLAT', 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
|
||||||
)
|
)
|
||||||
|
|
||||||
LOOP.run(self._table.create_index(vector_column_name, config=config))
|
LOOP.run(self._table.create_index(vector_column_name, config=config))
|
||||||
|
|||||||
@@ -1092,7 +1092,7 @@ class Table(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def checkout(self):
|
def checkout(self, version: int):
|
||||||
"""
|
"""
|
||||||
Checks out a specific version of the Table
|
Checks out a specific version of the Table
|
||||||
|
|
||||||
@@ -3049,7 +3049,7 @@ class AsyncTable:
|
|||||||
|
|
||||||
return versions
|
return versions
|
||||||
|
|
||||||
async def checkout(self, version):
|
async def checkout(self, version: int):
|
||||||
"""
|
"""
|
||||||
Checks out a specific version of the Table
|
Checks out a specific version of the Table
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb-node"
|
name = "lancedb-node"
|
||||||
version = "0.14.1-beta.6"
|
version = "0.14.1"
|
||||||
description = "Serverless, low-latency vector database for AI applications"
|
description = "Serverless, low-latency vector database for AI applications"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
version = "0.14.1-beta.6"
|
version = "0.14.1"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
|||||||
@@ -214,6 +214,7 @@ mod polars_arrow_convertors;
|
|||||||
pub mod query;
|
pub mod query;
|
||||||
#[cfg(feature = "remote")]
|
#[cfg(feature = "remote")]
|
||||||
pub mod remote;
|
pub mod remote;
|
||||||
|
pub mod rerankers;
|
||||||
pub mod table;
|
pub mod table;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
|
|
||||||
|
|||||||
@@ -15,19 +15,31 @@
|
|||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use arrow::compute::concat_batches;
|
||||||
use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array};
|
use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array};
|
||||||
use arrow_schema::DataType;
|
use arrow_schema::DataType;
|
||||||
use datafusion_physical_plan::ExecutionPlan;
|
use datafusion_physical_plan::ExecutionPlan;
|
||||||
|
use futures::{stream, try_join, FutureExt, TryStreamExt};
|
||||||
use half::f16;
|
use half::f16;
|
||||||
use lance::dataset::scanner::DatasetRecordBatchStream;
|
use lance::{
|
||||||
|
arrow::RecordBatchExt,
|
||||||
|
dataset::{scanner::DatasetRecordBatchStream, ROW_ID},
|
||||||
|
};
|
||||||
use lance_datafusion::exec::execute_plan;
|
use lance_datafusion::exec::execute_plan;
|
||||||
|
use lance_index::scalar::inverted::SCORE_COL;
|
||||||
use lance_index::scalar::FullTextSearchQuery;
|
use lance_index::scalar::FullTextSearchQuery;
|
||||||
|
use lance_index::vector::DIST_COL;
|
||||||
|
use lance_io::stream::RecordBatchStreamAdapter;
|
||||||
|
|
||||||
use crate::arrow::SendableRecordBatchStream;
|
use crate::arrow::SendableRecordBatchStream;
|
||||||
use crate::error::{Error, Result};
|
use crate::error::{Error, Result};
|
||||||
|
use crate::rerankers::rrf::RRFReranker;
|
||||||
|
use crate::rerankers::{check_reranker_result, NormalizeMethod, Reranker};
|
||||||
use crate::table::TableInternal;
|
use crate::table::TableInternal;
|
||||||
use crate::DistanceType;
|
use crate::DistanceType;
|
||||||
|
|
||||||
|
mod hybrid;
|
||||||
|
|
||||||
pub(crate) const DEFAULT_TOP_K: usize = 10;
|
pub(crate) const DEFAULT_TOP_K: usize = 10;
|
||||||
|
|
||||||
/// Which columns should be retrieved from the database
|
/// Which columns should be retrieved from the database
|
||||||
@@ -435,6 +447,16 @@ pub trait QueryBase {
|
|||||||
|
|
||||||
/// Return the `_rowid` meta column from the Table.
|
/// Return the `_rowid` meta column from the Table.
|
||||||
fn with_row_id(self) -> Self;
|
fn with_row_id(self) -> Self;
|
||||||
|
|
||||||
|
/// Rerank the results using the specified reranker.
|
||||||
|
///
|
||||||
|
/// This is currently only supported for Hybrid Search.
|
||||||
|
fn rerank(self, reranker: Arc<dyn Reranker>) -> Self;
|
||||||
|
|
||||||
|
/// The method to normalize the scores. Can be "rank" or "Score". If "Rank",
|
||||||
|
/// the scores are converted to ranks and then normalized. If "Score", the
|
||||||
|
/// scores are normalized directly.
|
||||||
|
fn norm(self, norm: NormalizeMethod) -> Self;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait HasQuery {
|
pub trait HasQuery {
|
||||||
@@ -481,6 +503,16 @@ impl<T: HasQuery> QueryBase for T {
|
|||||||
self.mut_query().with_row_id = true;
|
self.mut_query().with_row_id = true;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn rerank(mut self, reranker: Arc<dyn Reranker>) -> Self {
|
||||||
|
self.mut_query().reranker = Some(reranker);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
fn norm(mut self, norm: NormalizeMethod) -> Self {
|
||||||
|
self.mut_query().norm = Some(norm);
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Options for controlling the execution of a query
|
/// Options for controlling the execution of a query
|
||||||
@@ -600,6 +632,13 @@ pub struct Query {
|
|||||||
|
|
||||||
/// If set to false, the filter will be applied after the vector search.
|
/// If set to false, the filter will be applied after the vector search.
|
||||||
pub(crate) prefilter: bool,
|
pub(crate) prefilter: bool,
|
||||||
|
|
||||||
|
/// Implementation of reranker that can be used to reorder or combine query
|
||||||
|
/// results, especially if using hybrid search
|
||||||
|
pub(crate) reranker: Option<Arc<dyn Reranker>>,
|
||||||
|
|
||||||
|
/// Configure how query results are normalized when doing hybrid search
|
||||||
|
pub(crate) norm: Option<NormalizeMethod>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Query {
|
impl Query {
|
||||||
@@ -614,6 +653,8 @@ impl Query {
|
|||||||
fast_search: false,
|
fast_search: false,
|
||||||
with_row_id: false,
|
with_row_id: false,
|
||||||
prefilter: true,
|
prefilter: true,
|
||||||
|
reranker: None,
|
||||||
|
norm: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -862,6 +903,65 @@ impl VectorQuery {
|
|||||||
self.use_index = false;
|
self.use_index = false;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn execute_hybrid(&self) -> Result<SendableRecordBatchStream> {
|
||||||
|
// clone query and specify we want to include row IDs, which can be needed for reranking
|
||||||
|
let fts_query = self.base.clone().with_row_id();
|
||||||
|
let mut vector_query = self.clone().with_row_id();
|
||||||
|
|
||||||
|
vector_query.base.full_text_search = None;
|
||||||
|
let (fts_results, vec_results) = try_join!(fts_query.execute(), vector_query.execute())?;
|
||||||
|
|
||||||
|
let (fts_results, vec_results) = try_join!(
|
||||||
|
fts_results.try_collect::<Vec<_>>(),
|
||||||
|
vec_results.try_collect::<Vec<_>>()
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// try to get the schema to use when combining batches.
|
||||||
|
// if either
|
||||||
|
let (fts_schema, vec_schema) = hybrid::query_schemas(&fts_results, &vec_results);
|
||||||
|
|
||||||
|
// concatenate all the batches together
|
||||||
|
let mut fts_results = concat_batches(&fts_schema, fts_results.iter())?;
|
||||||
|
let mut vec_results = concat_batches(&vec_schema, vec_results.iter())?;
|
||||||
|
|
||||||
|
if matches!(self.base.norm, Some(NormalizeMethod::Rank)) {
|
||||||
|
vec_results = hybrid::rank(vec_results, DIST_COL, None)?;
|
||||||
|
fts_results = hybrid::rank(fts_results, SCORE_COL, None)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
vec_results = hybrid::normalize_scores(vec_results, DIST_COL, None)?;
|
||||||
|
fts_results = hybrid::normalize_scores(fts_results, SCORE_COL, None)?;
|
||||||
|
|
||||||
|
let reranker = self
|
||||||
|
.base
|
||||||
|
.reranker
|
||||||
|
.clone()
|
||||||
|
.unwrap_or(Arc::new(RRFReranker::default()));
|
||||||
|
|
||||||
|
let fts_query = self.base.full_text_search.as_ref().ok_or(Error::Runtime {
|
||||||
|
message: "there should be an FTS search".to_string(),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let mut results = reranker
|
||||||
|
.rerank_hybrid(&fts_query.query, vec_results, fts_results)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
check_reranker_result(&results)?;
|
||||||
|
|
||||||
|
let limit = self.base.limit.unwrap_or(DEFAULT_TOP_K);
|
||||||
|
if results.num_rows() > limit {
|
||||||
|
results = results.slice(0, limit);
|
||||||
|
}
|
||||||
|
|
||||||
|
if !self.base.with_row_id {
|
||||||
|
results = results.drop_column(ROW_ID)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(SendableRecordBatchStream::from(
|
||||||
|
RecordBatchStreamAdapter::new(results.schema(), stream::iter([Ok(results)])),
|
||||||
|
))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ExecutableQuery for VectorQuery {
|
impl ExecutableQuery for VectorQuery {
|
||||||
@@ -873,6 +973,11 @@ impl ExecutableQuery for VectorQuery {
|
|||||||
&self,
|
&self,
|
||||||
options: QueryExecutionOptions,
|
options: QueryExecutionOptions,
|
||||||
) -> Result<SendableRecordBatchStream> {
|
) -> Result<SendableRecordBatchStream> {
|
||||||
|
if self.base.full_text_search.is_some() {
|
||||||
|
let hybrid_result = async move { self.execute_hybrid().await }.boxed().await?;
|
||||||
|
return Ok(hybrid_result);
|
||||||
|
}
|
||||||
|
|
||||||
Ok(SendableRecordBatchStream::from(
|
Ok(SendableRecordBatchStream::from(
|
||||||
DatasetRecordBatchStream::new(execute_plan(
|
DatasetRecordBatchStream::new(execute_plan(
|
||||||
self.create_plan(options).await?,
|
self.create_plan(options).await?,
|
||||||
@@ -894,20 +999,20 @@ impl HasQuery for VectorQuery {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::sync::Arc;
|
use std::{collections::HashSet, sync::Arc};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use arrow::{compute::concat_batches, datatypes::Int32Type};
|
use arrow::{array::downcast_array, compute::concat_batches, datatypes::Int32Type};
|
||||||
use arrow_array::{
|
use arrow_array::{
|
||||||
cast::AsArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator,
|
cast::AsArray, types::Float32Type, FixedSizeListArray, Float32Array, Int32Array,
|
||||||
RecordBatchReader,
|
RecordBatch, RecordBatchIterator, RecordBatchReader, StringArray,
|
||||||
};
|
};
|
||||||
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
|
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
|
||||||
use futures::{StreamExt, TryStreamExt};
|
use futures::{StreamExt, TryStreamExt};
|
||||||
use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector};
|
use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector};
|
||||||
use tempfile::tempdir;
|
use tempfile::tempdir;
|
||||||
|
|
||||||
use crate::{connect, Table};
|
use crate::{connect, connection::CreateTableMode, Table};
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_setters_getters() {
|
async fn test_setters_getters() {
|
||||||
@@ -1274,4 +1379,156 @@ mod tests {
|
|||||||
assert!(query_index.values().contains(&0));
|
assert!(query_index.values().contains(&0));
|
||||||
assert!(query_index.values().contains(&1));
|
assert!(query_index.values().contains(&1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_hybrid_search() {
|
||||||
|
let tmp_dir = tempdir().unwrap();
|
||||||
|
let dataset_path = tmp_dir.path();
|
||||||
|
let conn = connect(dataset_path.to_str().unwrap())
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let dims = 2;
|
||||||
|
let schema = Arc::new(ArrowSchema::new(vec![
|
||||||
|
ArrowField::new("text", DataType::Utf8, false),
|
||||||
|
ArrowField::new(
|
||||||
|
"vector",
|
||||||
|
DataType::FixedSizeList(
|
||||||
|
Arc::new(ArrowField::new("item", DataType::Float32, true)),
|
||||||
|
dims,
|
||||||
|
),
|
||||||
|
false,
|
||||||
|
),
|
||||||
|
]));
|
||||||
|
|
||||||
|
let text = StringArray::from(vec!["dog", "cat", "a", "b"]);
|
||||||
|
let vectors = vec![
|
||||||
|
Some(vec![Some(0.0), Some(0.0)]),
|
||||||
|
Some(vec![Some(-2.0), Some(-2.0)]),
|
||||||
|
Some(vec![Some(50.0), Some(50.0)]),
|
||||||
|
Some(vec![Some(-30.0), Some(-30.0)]),
|
||||||
|
];
|
||||||
|
let vector = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(vectors, dims);
|
||||||
|
|
||||||
|
let record_batch =
|
||||||
|
RecordBatch::try_new(schema.clone(), vec![Arc::new(text), Arc::new(vector)]).unwrap();
|
||||||
|
let record_batch_iter =
|
||||||
|
RecordBatchIterator::new(vec![record_batch].into_iter().map(Ok), schema.clone());
|
||||||
|
let table = conn
|
||||||
|
.create_table("my_table", record_batch_iter)
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
table
|
||||||
|
.create_index(&["text"], crate::index::Index::FTS(Default::default()))
|
||||||
|
.replace(true)
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let fts_query = FullTextSearchQuery::new("b".to_string());
|
||||||
|
let results = table
|
||||||
|
.query()
|
||||||
|
.full_text_search(fts_query)
|
||||||
|
.limit(2)
|
||||||
|
.nearest_to(&[-10.0, -10.0])
|
||||||
|
.unwrap()
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.try_collect::<Vec<_>>()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let batch = &results[0];
|
||||||
|
|
||||||
|
let texts: StringArray = downcast_array(batch.column_by_name("text").unwrap());
|
||||||
|
let texts = texts.iter().map(|e| e.unwrap()).collect::<HashSet<_>>();
|
||||||
|
assert!(texts.contains("cat")); // should be close by vector search
|
||||||
|
assert!(texts.contains("b")); // should be close by fts search
|
||||||
|
|
||||||
|
// ensure that this works correctly if there are no matching FTS results
|
||||||
|
let fts_query = FullTextSearchQuery::new("z".to_string());
|
||||||
|
table
|
||||||
|
.query()
|
||||||
|
.full_text_search(fts_query)
|
||||||
|
.limit(2)
|
||||||
|
.nearest_to(&[-10.0, -10.0])
|
||||||
|
.unwrap()
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.try_collect::<Vec<_>>()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_hybrid_search_empty_table() {
|
||||||
|
let tmp_dir = tempdir().unwrap();
|
||||||
|
let dataset_path = tmp_dir.path();
|
||||||
|
let conn = connect(dataset_path.to_str().unwrap())
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let dims = 2;
|
||||||
|
|
||||||
|
let schema = Arc::new(ArrowSchema::new(vec![
|
||||||
|
ArrowField::new("text", DataType::Utf8, false),
|
||||||
|
ArrowField::new(
|
||||||
|
"vector",
|
||||||
|
DataType::FixedSizeList(
|
||||||
|
Arc::new(ArrowField::new("item", DataType::Float32, true)),
|
||||||
|
dims,
|
||||||
|
),
|
||||||
|
false,
|
||||||
|
),
|
||||||
|
]));
|
||||||
|
|
||||||
|
// ensure hybrid search is also supported on a fully empty table
|
||||||
|
let vectors: Vec<Option<Vec<Option<f32>>>> = Vec::new();
|
||||||
|
let record_batch = RecordBatch::try_new(
|
||||||
|
schema.clone(),
|
||||||
|
vec![
|
||||||
|
Arc::new(StringArray::from(Vec::<&str>::new())),
|
||||||
|
Arc::new(
|
||||||
|
FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(vectors, dims),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let record_batch_iter =
|
||||||
|
RecordBatchIterator::new(vec![record_batch].into_iter().map(Ok), schema.clone());
|
||||||
|
let table = conn
|
||||||
|
.create_table("my_table", record_batch_iter)
|
||||||
|
.mode(CreateTableMode::Overwrite)
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
table
|
||||||
|
.create_index(&["text"], crate::index::Index::FTS(Default::default()))
|
||||||
|
.replace(true)
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let fts_query = FullTextSearchQuery::new("b".to_string());
|
||||||
|
let results = table
|
||||||
|
.query()
|
||||||
|
.full_text_search(fts_query)
|
||||||
|
.limit(2)
|
||||||
|
.nearest_to(&[-10.0, -10.0])
|
||||||
|
.unwrap()
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.try_collect::<Vec<_>>()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let batch = &results[0];
|
||||||
|
assert_eq!(0, batch.num_rows());
|
||||||
|
assert_eq!(2, batch.num_columns());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
346
rust/lancedb/src/query/hybrid.rs
Normal file
346
rust/lancedb/src/query/hybrid.rs
Normal file
@@ -0,0 +1,346 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
use arrow::compute::{
|
||||||
|
kernels::numeric::{div, sub},
|
||||||
|
max, min,
|
||||||
|
};
|
||||||
|
use arrow_array::{cast::downcast_array, Float32Array, RecordBatch};
|
||||||
|
use arrow_schema::{DataType, Field, Schema, SortOptions};
|
||||||
|
use lance::dataset::ROW_ID;
|
||||||
|
use lance_index::{scalar::inverted::SCORE_COL, vector::DIST_COL};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use crate::error::{Error, Result};
|
||||||
|
|
||||||
|
/// Converts results's score column to a rank.
|
||||||
|
///
|
||||||
|
/// Expects the `column` argument to be type Float32 and will panic if it's not
|
||||||
|
pub fn rank(results: RecordBatch, column: &str, ascending: Option<bool>) -> Result<RecordBatch> {
|
||||||
|
let scores = results.column_by_name(column).ok_or(Error::InvalidInput {
|
||||||
|
message: format!(
|
||||||
|
"expected column {} not found in rank. found columns {:?}",
|
||||||
|
column,
|
||||||
|
results
|
||||||
|
.schema()
|
||||||
|
.fields()
|
||||||
|
.iter()
|
||||||
|
.map(|f| f.name())
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if results.num_rows() == 0 {
|
||||||
|
return Ok(results);
|
||||||
|
}
|
||||||
|
|
||||||
|
let scores: Float32Array = downcast_array(scores);
|
||||||
|
let ranks = Float32Array::from_iter_values(
|
||||||
|
arrow::compute::kernels::rank::rank(
|
||||||
|
&scores,
|
||||||
|
Some(SortOptions {
|
||||||
|
descending: !ascending.unwrap_or(true),
|
||||||
|
..Default::default()
|
||||||
|
}),
|
||||||
|
)?
|
||||||
|
.iter()
|
||||||
|
.map(|i| *i as f32),
|
||||||
|
);
|
||||||
|
|
||||||
|
let schema = results.schema();
|
||||||
|
let (column_idx, _) = schema.column_with_name(column).unwrap();
|
||||||
|
let mut columns = results.columns().to_vec();
|
||||||
|
columns[column_idx] = Arc::new(ranks);
|
||||||
|
|
||||||
|
let results = RecordBatch::try_new(results.schema(), columns)?;
|
||||||
|
|
||||||
|
Ok(results)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the query schemas needed when combining the search results.
|
||||||
|
///
|
||||||
|
/// If either of the record batches are empty, then we create a schema from the
|
||||||
|
/// other record batch, and replace the score/distance column. If both record
|
||||||
|
/// batches are empty, create empty schemas.
|
||||||
|
pub fn query_schemas(
|
||||||
|
fts_results: &[RecordBatch],
|
||||||
|
vec_results: &[RecordBatch],
|
||||||
|
) -> (Arc<Schema>, Arc<Schema>) {
|
||||||
|
let (fts_schema, vec_schema) = match (
|
||||||
|
fts_results.first().map(|r| r.schema()),
|
||||||
|
vec_results.first().map(|r| r.schema()),
|
||||||
|
) {
|
||||||
|
(Some(fts_schema), Some(vec_schema)) => (fts_schema, vec_schema),
|
||||||
|
(None, Some(vec_schema)) => {
|
||||||
|
let fts_schema = with_field_name_replaced(&vec_schema, DIST_COL, SCORE_COL);
|
||||||
|
(Arc::new(fts_schema), vec_schema)
|
||||||
|
}
|
||||||
|
(Some(fts_schema), None) => {
|
||||||
|
let vec_schema = with_field_name_replaced(&fts_schema, DIST_COL, SCORE_COL);
|
||||||
|
(fts_schema, Arc::new(vec_schema))
|
||||||
|
}
|
||||||
|
(None, None) => (Arc::new(empty_fts_schema()), Arc::new(empty_vec_schema())),
|
||||||
|
};
|
||||||
|
|
||||||
|
(fts_schema, vec_schema)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn empty_fts_schema() -> Schema {
|
||||||
|
Schema::new(vec![
|
||||||
|
Arc::new(Field::new(SCORE_COL, DataType::Float32, false)),
|
||||||
|
Arc::new(Field::new(ROW_ID, DataType::UInt64, false)),
|
||||||
|
])
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn empty_vec_schema() -> Schema {
|
||||||
|
Schema::new(vec![
|
||||||
|
Arc::new(Field::new(DIST_COL, DataType::Float32, false)),
|
||||||
|
Arc::new(Field::new(ROW_ID, DataType::UInt64, false)),
|
||||||
|
])
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_field_name_replaced(schema: &Schema, target: &str, replacement: &str) -> Schema {
|
||||||
|
let field_idx = schema.fields().iter().enumerate().find_map(|(i, field)| {
|
||||||
|
if field.name() == target {
|
||||||
|
Some(i)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut fields = schema.fields().to_vec();
|
||||||
|
if let Some(idx) = field_idx {
|
||||||
|
let new_field = (*fields[idx]).clone().with_name(replacement);
|
||||||
|
fields[idx] = Arc::new(new_field);
|
||||||
|
}
|
||||||
|
|
||||||
|
Schema::new(fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Normalize the scores column to have values between 0 and 1.
|
||||||
|
///
|
||||||
|
/// Expects the `column` argument to be type Float32 and will panic if it's not
|
||||||
|
pub fn normalize_scores(
|
||||||
|
results: RecordBatch,
|
||||||
|
column: &str,
|
||||||
|
invert: Option<bool>,
|
||||||
|
) -> Result<RecordBatch> {
|
||||||
|
let scores = results.column_by_name(column).ok_or(Error::InvalidInput {
|
||||||
|
message: format!(
|
||||||
|
"expected column {} not found in rank. found columns {:?}",
|
||||||
|
column,
|
||||||
|
results
|
||||||
|
.schema()
|
||||||
|
.fields()
|
||||||
|
.iter()
|
||||||
|
.map(|f| f.name())
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if results.num_rows() == 0 {
|
||||||
|
return Ok(results);
|
||||||
|
}
|
||||||
|
let mut scores: Float32Array = downcast_array(scores);
|
||||||
|
|
||||||
|
let max = max(&scores).unwrap_or(0.0);
|
||||||
|
let min = min(&scores).unwrap_or(0.0);
|
||||||
|
|
||||||
|
// this is equivalent to np.isclose which is used in python
|
||||||
|
let rng = if max - min < 10e-5 { max } else { max - min };
|
||||||
|
|
||||||
|
// if rng is 0, then min and max are both 0 so we just leave the scores as is
|
||||||
|
if rng != 0.0 {
|
||||||
|
let tmp = div(
|
||||||
|
&sub(&scores, &Float32Array::new_scalar(min))?,
|
||||||
|
&Float32Array::new_scalar(rng),
|
||||||
|
)?;
|
||||||
|
scores = downcast_array(&tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
if invert.unwrap_or(false) {
|
||||||
|
let tmp = sub(&Float32Array::new_scalar(1.0), &scores)?;
|
||||||
|
scores = downcast_array(&tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
let schema = results.schema();
|
||||||
|
let (column_idx, _) = schema.column_with_name(column).unwrap();
|
||||||
|
let mut columns = results.columns().to_vec();
|
||||||
|
columns[column_idx] = Arc::new(scores);
|
||||||
|
|
||||||
|
let results = RecordBatch::try_new(results.schema(), columns).unwrap();
|
||||||
|
|
||||||
|
Ok(results)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use super::*;
|
||||||
|
use arrow_array::StringArray;
|
||||||
|
use arrow_schema::{DataType, Field, Schema};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_rank() {
|
||||||
|
let schema = Arc::new(Schema::new(vec![
|
||||||
|
Arc::new(Field::new("name", DataType::Utf8, false)),
|
||||||
|
Arc::new(Field::new("score", DataType::Float32, false)),
|
||||||
|
]));
|
||||||
|
|
||||||
|
let names = StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"]);
|
||||||
|
let scores = Float32Array::from(vec![0.2, 0.4, 0.1, 0.6, 0.45]);
|
||||||
|
|
||||||
|
let batch =
|
||||||
|
RecordBatch::try_new(schema.clone(), vec![Arc::new(names), Arc::new(scores)]).unwrap();
|
||||||
|
|
||||||
|
let result = rank(batch.clone(), "score", Some(false)).unwrap();
|
||||||
|
assert_eq!(2, result.schema().fields().len());
|
||||||
|
assert_eq!("name", result.schema().field(0).name());
|
||||||
|
assert_eq!("score", result.schema().field(1).name());
|
||||||
|
|
||||||
|
let names: StringArray = downcast_array(result.column(0));
|
||||||
|
assert_eq!(
|
||||||
|
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec!["foo", "bar", "baz", "bean", "dog"]
|
||||||
|
);
|
||||||
|
let scores: Float32Array = downcast_array(result.column(1));
|
||||||
|
assert_eq!(
|
||||||
|
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec![4.0, 3.0, 5.0, 1.0, 2.0]
|
||||||
|
);
|
||||||
|
|
||||||
|
// check sort ascending
|
||||||
|
let result = rank(batch.clone(), "score", Some(true)).unwrap();
|
||||||
|
let names: StringArray = downcast_array(result.column(0));
|
||||||
|
assert_eq!(
|
||||||
|
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec!["foo", "bar", "baz", "bean", "dog"]
|
||||||
|
);
|
||||||
|
let scores: Float32Array = downcast_array(result.column(1));
|
||||||
|
assert_eq!(
|
||||||
|
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec![2.0, 3.0, 1.0, 5.0, 4.0]
|
||||||
|
);
|
||||||
|
|
||||||
|
// ensure default sort is ascending
|
||||||
|
let result = rank(batch.clone(), "score", None).unwrap();
|
||||||
|
let names: StringArray = downcast_array(result.column(0));
|
||||||
|
assert_eq!(
|
||||||
|
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec!["foo", "bar", "baz", "bean", "dog"]
|
||||||
|
);
|
||||||
|
let scores: Float32Array = downcast_array(result.column(1));
|
||||||
|
assert_eq!(
|
||||||
|
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec![2.0, 3.0, 1.0, 5.0, 4.0]
|
||||||
|
);
|
||||||
|
|
||||||
|
// check it can handle an empty batch
|
||||||
|
let batch = RecordBatch::try_new(
|
||||||
|
schema.clone(),
|
||||||
|
vec![
|
||||||
|
Arc::new(StringArray::from(Vec::<&str>::new())),
|
||||||
|
Arc::new(Float32Array::from(Vec::<f32>::new())),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let result = rank(batch.clone(), "score", None).unwrap();
|
||||||
|
assert_eq!(0, result.num_rows());
|
||||||
|
assert_eq!(2, result.schema().fields().len());
|
||||||
|
assert_eq!("name", result.schema().field(0).name());
|
||||||
|
assert_eq!("score", result.schema().field(1).name());
|
||||||
|
|
||||||
|
// check it returns the expected error when there's no column
|
||||||
|
let result = rank(batch.clone(), "bad_col", None);
|
||||||
|
match result {
|
||||||
|
Err(Error::InvalidInput { message }) => {
|
||||||
|
assert_eq!("expected column bad_col not found in rank. found columns [\"name\", \"score\"]", message);
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
panic!("expected invalid input error, received {:?}", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_normalize_scores() {
|
||||||
|
let schema = Arc::new(Schema::new(vec![
|
||||||
|
Arc::new(Field::new("name", DataType::Utf8, false)),
|
||||||
|
Arc::new(Field::new("score", DataType::Float32, false)),
|
||||||
|
]));
|
||||||
|
|
||||||
|
let names = Arc::new(StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"]));
|
||||||
|
let scores = Arc::new(Float32Array::from(vec![-4.0, 2.0, 0.0, 3.0, 6.0]));
|
||||||
|
|
||||||
|
let batch =
|
||||||
|
RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap();
|
||||||
|
|
||||||
|
let result = normalize_scores(batch.clone(), "score", Some(false)).unwrap();
|
||||||
|
let names: StringArray = downcast_array(result.column(0));
|
||||||
|
assert_eq!(
|
||||||
|
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec!["foo", "bar", "baz", "bean", "dog"]
|
||||||
|
);
|
||||||
|
let scores: Float32Array = downcast_array(result.column(1));
|
||||||
|
assert_eq!(
|
||||||
|
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec![0.0, 0.6, 0.4, 0.7, 1.0]
|
||||||
|
);
|
||||||
|
|
||||||
|
// check it can invert the normalization
|
||||||
|
let result = normalize_scores(batch.clone(), "score", Some(true)).unwrap();
|
||||||
|
let scores: Float32Array = downcast_array(result.column(1));
|
||||||
|
assert_eq!(
|
||||||
|
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec![1.0, 1.0 - 0.6, 0.6, 0.3, 0.0]
|
||||||
|
);
|
||||||
|
|
||||||
|
// check that the default is not inverted
|
||||||
|
let result = normalize_scores(batch.clone(), "score", None).unwrap();
|
||||||
|
let scores: Float32Array = downcast_array(result.column(1));
|
||||||
|
assert_eq!(
|
||||||
|
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec![0.0, 0.6, 0.4, 0.7, 1.0]
|
||||||
|
);
|
||||||
|
|
||||||
|
// check that it will function correctly if all the values are the same
|
||||||
|
let names = Arc::new(StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"]));
|
||||||
|
let scores = Arc::new(Float32Array::from(vec![2.1, 2.1, 2.1, 2.1, 2.1]));
|
||||||
|
let batch =
|
||||||
|
RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap();
|
||||||
|
let result = normalize_scores(batch.clone(), "score", None).unwrap();
|
||||||
|
let scores: Float32Array = downcast_array(result.column(1));
|
||||||
|
assert_eq!(
|
||||||
|
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec![0.0, 0.0, 0.0, 0.0, 0.0]
|
||||||
|
);
|
||||||
|
|
||||||
|
// check it keeps floating point rounding errors for same score normalized the same
|
||||||
|
// e.g., the behaviour is consistent with python
|
||||||
|
let scores = Arc::new(Float32Array::from(vec![1.0, 1.0, 1.0, 1.0, 0.9999999]));
|
||||||
|
let batch =
|
||||||
|
RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap();
|
||||||
|
let result = normalize_scores(batch.clone(), "score", None).unwrap();
|
||||||
|
let scores: Float32Array = downcast_array(result.column(1));
|
||||||
|
assert_eq!(
|
||||||
|
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec![
|
||||||
|
1.0 - 0.9999999,
|
||||||
|
1.0 - 0.9999999,
|
||||||
|
1.0 - 0.9999999,
|
||||||
|
1.0 - 0.9999999,
|
||||||
|
0.0
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
// check that it can handle if all the scores are 0
|
||||||
|
let scores = Arc::new(Float32Array::from(vec![0.0, 0.0, 0.0, 0.0, 0.0]));
|
||||||
|
let batch =
|
||||||
|
RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap();
|
||||||
|
let result = normalize_scores(batch.clone(), "score", None).unwrap();
|
||||||
|
let scores: Float32Array = downcast_array(result.column(1));
|
||||||
|
assert_eq!(
|
||||||
|
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec![0.0, 0.0, 0.0, 0.0, 0.0]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -563,6 +563,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
|||||||
let (index_type, distance_type) = match index.index {
|
let (index_type, distance_type) = match index.index {
|
||||||
// TODO: Should we pass the actual index parameters? SaaS does not
|
// TODO: Should we pass the actual index parameters? SaaS does not
|
||||||
// yet support them.
|
// yet support them.
|
||||||
|
Index::IvfFlat(index) => ("IVF_FLAT", Some(index.distance_type)),
|
||||||
Index::IvfPq(index) => ("IVF_PQ", Some(index.distance_type)),
|
Index::IvfPq(index) => ("IVF_PQ", Some(index.distance_type)),
|
||||||
Index::IvfHnswSq(index) => ("IVF_HNSW_SQ", Some(index.distance_type)),
|
Index::IvfHnswSq(index) => ("IVF_HNSW_SQ", Some(index.distance_type)),
|
||||||
Index::BTree(_) => ("BTREE", None),
|
Index::BTree(_) => ("BTREE", None),
|
||||||
@@ -873,6 +874,7 @@ mod tests {
|
|||||||
use lance_index::scalar::FullTextSearchQuery;
|
use lance_index::scalar::FullTextSearchQuery;
|
||||||
use reqwest::Body;
|
use reqwest::Body;
|
||||||
|
|
||||||
|
use crate::index::vector::IvfFlatIndexBuilder;
|
||||||
use crate::{
|
use crate::{
|
||||||
index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType},
|
index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType},
|
||||||
query::{ExecutableQuery, QueryBase},
|
query::{ExecutableQuery, QueryBase},
|
||||||
@@ -1489,6 +1491,11 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_create_index() {
|
async fn test_create_index() {
|
||||||
let cases = [
|
let cases = [
|
||||||
|
(
|
||||||
|
"IVF_FLAT",
|
||||||
|
Some("hamming"),
|
||||||
|
Index::IvfFlat(IvfFlatIndexBuilder::default().distance_type(DistanceType::Hamming)),
|
||||||
|
),
|
||||||
("IVF_PQ", Some("l2"), Index::IvfPq(Default::default())),
|
("IVF_PQ", Some("l2"), Index::IvfPq(Default::default())),
|
||||||
(
|
(
|
||||||
"IVF_PQ",
|
"IVF_PQ",
|
||||||
|
|||||||
87
rust/lancedb/src/rerankers.rs
Normal file
87
rust/lancedb/src/rerankers.rs
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
use std::collections::BTreeSet;
|
||||||
|
|
||||||
|
use arrow::{
|
||||||
|
array::downcast_array,
|
||||||
|
compute::{concat_batches, filter_record_batch},
|
||||||
|
};
|
||||||
|
use arrow_array::{BooleanArray, RecordBatch, UInt64Array};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use lance::dataset::ROW_ID;
|
||||||
|
|
||||||
|
use crate::error::{Error, Result};
|
||||||
|
|
||||||
|
pub mod rrf;
|
||||||
|
|
||||||
|
/// column name for reranker relevance score
|
||||||
|
const RELEVANCE_SCORE: &str = "_relevance_score";
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub enum NormalizeMethod {
|
||||||
|
Score,
|
||||||
|
Rank,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Interface for a reranker. A reranker is used to rerank the results from a
|
||||||
|
/// vector and FTS search. This is useful for combining the results from both
|
||||||
|
/// search methods.
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Reranker: std::fmt::Debug + Sync + Send {
|
||||||
|
// TODO support vector reranking and FTS reranking. Currently only hybrid reranking is supported.
|
||||||
|
|
||||||
|
/// Rerank function receives the individual results from the vector and FTS search
|
||||||
|
/// results. You can choose to use any of the results to generate the final results,
|
||||||
|
/// allowing maximum flexibility.
|
||||||
|
async fn rerank_hybrid(
|
||||||
|
&self,
|
||||||
|
query: &str,
|
||||||
|
vector_results: RecordBatch,
|
||||||
|
fts_results: RecordBatch,
|
||||||
|
) -> Result<RecordBatch>;
|
||||||
|
|
||||||
|
fn merge_results(
|
||||||
|
&self,
|
||||||
|
vector_results: RecordBatch,
|
||||||
|
fts_results: RecordBatch,
|
||||||
|
) -> Result<RecordBatch> {
|
||||||
|
let combined = concat_batches(&fts_results.schema(), [vector_results, fts_results].iter())?;
|
||||||
|
|
||||||
|
let mut mask = BooleanArray::builder(combined.num_rows());
|
||||||
|
let mut unique_ids = BTreeSet::new();
|
||||||
|
let row_ids = combined.column_by_name(ROW_ID).ok_or(Error::InvalidInput {
|
||||||
|
message: format!(
|
||||||
|
"could not find expected column {} while merging results. found columns {:?}",
|
||||||
|
ROW_ID,
|
||||||
|
combined
|
||||||
|
.schema()
|
||||||
|
.fields()
|
||||||
|
.iter()
|
||||||
|
.map(|f| f.name())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
),
|
||||||
|
})?;
|
||||||
|
let row_ids: UInt64Array = downcast_array(row_ids);
|
||||||
|
row_ids.values().iter().for_each(|id| {
|
||||||
|
mask.append_value(unique_ids.insert(id));
|
||||||
|
});
|
||||||
|
|
||||||
|
let combined = filter_record_batch(&combined, &mask.finish())?;
|
||||||
|
|
||||||
|
Ok(combined)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn check_reranker_result(result: &RecordBatch) -> Result<()> {
|
||||||
|
if result.schema().column_with_name(RELEVANCE_SCORE).is_none() {
|
||||||
|
return Err(Error::Schema {
|
||||||
|
message: format!(
|
||||||
|
"rerank_hybrid must return a RecordBatch with a column named {}",
|
||||||
|
RELEVANCE_SCORE
|
||||||
|
),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
223
rust/lancedb/src/rerankers/rrf.rs
Normal file
223
rust/lancedb/src/rerankers/rrf.rs
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
use std::collections::BTreeMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use arrow::{
|
||||||
|
array::downcast_array,
|
||||||
|
compute::{sort_to_indices, take},
|
||||||
|
};
|
||||||
|
use arrow_array::{Float32Array, RecordBatch, UInt64Array};
|
||||||
|
use arrow_schema::{DataType, Field, Schema, SortOptions};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use lance::dataset::ROW_ID;
|
||||||
|
|
||||||
|
use crate::error::{Error, Result};
|
||||||
|
use crate::rerankers::{Reranker, RELEVANCE_SCORE};
|
||||||
|
|
||||||
|
/// Reranks the results using Reciprocal Rank Fusion(RRF) algorithm based
|
||||||
|
/// on the scores of vector and FTS search.
|
||||||
|
///
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct RRFReranker {
|
||||||
|
k: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RRFReranker {
|
||||||
|
/// Create a new RRFReranker
|
||||||
|
///
|
||||||
|
/// The parameter k is a constant used in the RRF formula (default is 60).
|
||||||
|
/// Experiments indicate that k = 60 was near-optimal, but that the choice
|
||||||
|
/// is not critical. See paper:
|
||||||
|
/// https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf
|
||||||
|
pub fn new(k: f32) -> Self {
|
||||||
|
Self { k }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for RRFReranker {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self { k: 60.0 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Reranker for RRFReranker {
|
||||||
|
async fn rerank_hybrid(
|
||||||
|
&self,
|
||||||
|
_query: &str,
|
||||||
|
vector_results: RecordBatch,
|
||||||
|
fts_results: RecordBatch,
|
||||||
|
) -> Result<RecordBatch> {
|
||||||
|
let vector_ids = vector_results
|
||||||
|
.column_by_name(ROW_ID)
|
||||||
|
.ok_or(Error::InvalidInput {
|
||||||
|
message: format!(
|
||||||
|
"expected column {} not found in vector_results. found columns {:?}",
|
||||||
|
ROW_ID,
|
||||||
|
vector_results
|
||||||
|
.schema()
|
||||||
|
.fields()
|
||||||
|
.iter()
|
||||||
|
.map(|f| f.name())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
),
|
||||||
|
})?;
|
||||||
|
let fts_ids = fts_results
|
||||||
|
.column_by_name(ROW_ID)
|
||||||
|
.ok_or(Error::InvalidInput {
|
||||||
|
message: format!(
|
||||||
|
"expected column {} not found in fts_results. found columns {:?}",
|
||||||
|
ROW_ID,
|
||||||
|
fts_results
|
||||||
|
.schema()
|
||||||
|
.fields()
|
||||||
|
.iter()
|
||||||
|
.map(|f| f.name())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let vector_ids: UInt64Array = downcast_array(&vector_ids);
|
||||||
|
let fts_ids: UInt64Array = downcast_array(&fts_ids);
|
||||||
|
|
||||||
|
let mut rrf_score_map = BTreeMap::new();
|
||||||
|
let mut update_score_map = |(i, result_id)| {
|
||||||
|
let score = 1.0 / (i as f32 + self.k);
|
||||||
|
rrf_score_map
|
||||||
|
.entry(result_id)
|
||||||
|
.and_modify(|e| *e += score)
|
||||||
|
.or_insert(score);
|
||||||
|
};
|
||||||
|
vector_ids
|
||||||
|
.values()
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.for_each(&mut update_score_map);
|
||||||
|
fts_ids
|
||||||
|
.values()
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.for_each(&mut update_score_map);
|
||||||
|
|
||||||
|
let combined_results = self.merge_results(vector_results, fts_results)?;
|
||||||
|
|
||||||
|
let combined_row_ids: UInt64Array =
|
||||||
|
downcast_array(combined_results.column_by_name(ROW_ID).unwrap());
|
||||||
|
let relevance_scores = Float32Array::from_iter_values(
|
||||||
|
combined_row_ids
|
||||||
|
.values()
|
||||||
|
.iter()
|
||||||
|
.map(|row_id| rrf_score_map.get(row_id).unwrap())
|
||||||
|
.copied(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// keep track of indices sorted by the relevance column
|
||||||
|
let sort_indices = sort_to_indices(
|
||||||
|
&relevance_scores,
|
||||||
|
Some(SortOptions {
|
||||||
|
descending: true,
|
||||||
|
..Default::default()
|
||||||
|
}),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// add relevance scores to columns
|
||||||
|
let mut columns = combined_results.columns().to_vec();
|
||||||
|
columns.push(Arc::new(relevance_scores));
|
||||||
|
|
||||||
|
// sort by the relevance scores
|
||||||
|
let columns = columns
|
||||||
|
.iter()
|
||||||
|
.map(|c| take(c, &sort_indices, None).unwrap())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// add relevance score to schema
|
||||||
|
let mut fields = combined_results.schema().fields().to_vec();
|
||||||
|
fields.push(Arc::new(Field::new(
|
||||||
|
RELEVANCE_SCORE,
|
||||||
|
DataType::Float32,
|
||||||
|
false,
|
||||||
|
)));
|
||||||
|
let schema = Schema::new(fields);
|
||||||
|
|
||||||
|
let combined_results = RecordBatch::try_new(Arc::new(schema), columns)?;
|
||||||
|
|
||||||
|
Ok(combined_results)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub mod test {
|
||||||
|
use super::*;
|
||||||
|
use arrow_array::StringArray;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_rrf_reranker() {
|
||||||
|
let schema = Arc::new(Schema::new(vec![
|
||||||
|
Field::new("name", DataType::Utf8, false),
|
||||||
|
Field::new(ROW_ID, DataType::UInt64, false),
|
||||||
|
]));
|
||||||
|
|
||||||
|
let vec_results = RecordBatch::try_new(
|
||||||
|
schema.clone(),
|
||||||
|
vec![
|
||||||
|
Arc::new(StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"])),
|
||||||
|
Arc::new(UInt64Array::from(vec![1, 4, 2, 5, 3])),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let fts_results = RecordBatch::try_new(
|
||||||
|
schema.clone(),
|
||||||
|
vec![
|
||||||
|
Arc::new(StringArray::from(vec!["bar", "bean", "dog"])),
|
||||||
|
Arc::new(UInt64Array::from(vec![4, 5, 3])),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// scores should be calculated as:
|
||||||
|
// - foo = 1/1 = 1.0
|
||||||
|
// - bar = 1/2 + 1/1 = 1.5
|
||||||
|
// - baz = 1/3 = 0.333
|
||||||
|
// - bean = 1/4 + 1/2 = 0.75
|
||||||
|
// - dog = 1/5 + 1/3 = 0.533
|
||||||
|
// then we should get the result ranked in descending order
|
||||||
|
|
||||||
|
let reranker = RRFReranker::new(1.0);
|
||||||
|
|
||||||
|
let result = reranker
|
||||||
|
.rerank_hybrid("", vec_results, fts_results)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(3, result.schema().fields().len());
|
||||||
|
assert_eq!("name", result.schema().fields().first().unwrap().name());
|
||||||
|
assert_eq!(ROW_ID, result.schema().fields().get(1).unwrap().name());
|
||||||
|
assert_eq!(
|
||||||
|
RELEVANCE_SCORE,
|
||||||
|
result.schema().fields().get(2).unwrap().name()
|
||||||
|
);
|
||||||
|
|
||||||
|
let names: StringArray = downcast_array(result.column(0));
|
||||||
|
assert_eq!(
|
||||||
|
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec!["bar", "foo", "bean", "dog", "baz"]
|
||||||
|
);
|
||||||
|
|
||||||
|
let ids: UInt64Array = downcast_array(result.column(1));
|
||||||
|
assert_eq!(
|
||||||
|
ids.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec![4, 1, 5, 3, 2]
|
||||||
|
);
|
||||||
|
|
||||||
|
let scores: Float32Array = downcast_array(result.column(2));
|
||||||
|
assert_eq!(
|
||||||
|
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
|
||||||
|
vec![1.5, 1.0, 0.75, 1.0 / 5.0 + 1.0 / 3.0, 1.0 / 3.0]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user