Compare commits

...

21 Commits

Author SHA1 Message Date
Lance Release
a27c5cf12b Bump version: 0.17.2-beta.1 → 0.17.2-beta.2 2025-01-06 05:34:27 +00:00
BubbleCal
f4dea72cc5 feat: support vector search with distance thresholds (#1993)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2025-01-06 13:23:39 +08:00
Lei Xu
f76c4a5ce1 chore: add pyright static type checking and fix some of the table interface (#1996)
* Enable `pyright` in the project
* Fixed some pyright typing errors in `table.py`
2025-01-04 15:24:58 -08:00
ahaapple
164ce397c2 docs: fix full-text search (Native FTS) TypeScript doc error (#1992)
Fix

```
Cannot find name 'queryType'.ts(2304)
any
```
2025-01-03 13:36:10 -05:00
BubbleCal
445a312667 fix: selecting columns failed on FTS and hybrid search (#1991)
it reports error `AttributeError: 'builtins.FTSQuery' object has no
attribute 'select_columns'`
because we missed `select_columns` method in rust

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2025-01-03 13:08:12 +08:00
Lance Release
92d845fa72 Bump version: 0.17.2-beta.0 → 0.17.2-beta.1 2024-12-31 23:36:18 +00:00
Lei Xu
397813f6a4 chore: bump pylance to 0.21.1b1 (#1989) 2024-12-31 15:34:27 -08:00
Lei Xu
50c30c5d34 chore(python): fix typo of the synchronized checkout API (#1988) 2024-12-30 18:54:31 -08:00
Bert
c9f248b058 feat: add hybrid search to node and rust SDKs (#1940)
Support hybrid search in both rust and node SDKs.

- Adds a new rerankers package to rust LanceDB, with the implementation
of the default RRF reranker
- Adds a new hybrid package to lancedb, with some helper methods related
to hybrid search such as normalizing scores and converting score column
to rank columns
- Adds capability to LanceDB VectorQuery to perform hybrid search if it
has both a nearest vector and full text search parameters.
- Adds wrappers for reranker implementations to nodejs SDK.

Additional rerankers will be added in followup PRs

https://github.com/lancedb/lancedb/issues/1921

---
Notes about how the rust rerankers are wrapped for calling from JS:

I wanted to keep the core reranker logic, and the invocation of the
reranker by the query code, in Rust. This aligns with the philosophy of
the new node SDK where it's just a thin wrapper around Rust. However, I
also wanted to have support for users who want to add custom rerankers
written in Javascript.

When we add a reranker to the query from Javascript, it adds a special
Rust reranker that has a callback to the Javascript code (which could
then turn around and call an underlying Rust reranker implementation if
desired). This adds a bit of complexity, but overall I think it moves us
in the right direction of having the majority of the query logic in the
underlying Rust SDK while keeping the option open to support custom
Javascript Rerankers.
2024-12-30 09:03:41 -05:00
Renato Marroquin
0cb6da6b7e docs: add new indexes to python docs (#1945)
closes issue #1855

Co-authored-by: Renato Marroquin <renato.marroquin@oracle.com>
2024-12-28 15:35:10 -08:00
BubbleCal
aec8332eb5 chore: add dynamic = ["version"] to pass build check (#1977)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2024-12-28 10:45:23 -08:00
Lance Release
46061070e6 Updating package-lock.json 2024-12-26 07:40:12 +00:00
Lance Release
dae8334d0b Bump version: 0.17.1 → 0.17.2-beta.0 2024-12-25 08:28:59 +00:00
BubbleCal
8c81968b59 feat: support IVF_FLAT on remote table in rust (#1979)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2024-12-25 15:54:17 +08:00
BubbleCal
16cf2990f3 feat: create IVF_FLAT on remote table (#1978)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2024-12-25 14:57:07 +08:00
Will Jones
0a0f667bbd chore: fix typos (#1976) 2024-12-24 12:50:54 -08:00
Will Jones
03753fd84b ci(node): remove hardcoded toolchain from typescript release build (#1974)
We upgraded the toolchain in #1960, but didn't realize we hardcoded it
in `npm-publish.yml`. I found if I just removed the hard-coded
toolchain, it selects the correct one.

This didn't fully fix Windows Arm, so I created a follow-up issue here:
https://github.com/lancedb/lancedb/issues/1975
2024-12-24 12:48:41 -08:00
Lance Release
55cceaa309 Updating package-lock.json 2024-12-24 18:39:00 +00:00
Lance Release
c3797eb834 Updating package-lock.json 2024-12-24 18:38:44 +00:00
Lance Release
c0d0f38494 Bump version: 0.14.1-beta.7 → 0.14.1 2024-12-24 18:38:11 +00:00
Lance Release
6a8ab78d0a Bump version: 0.14.1-beta.6 → 0.14.1-beta.7 2024-12-24 18:38:06 +00:00
55 changed files with 1879 additions and 263 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion] [tool.bumpversion]
current_version = "0.14.1-beta.6" current_version = "0.14.1"
parse = """(?x) parse = """(?x)
(?P<major>0|[1-9]\\d*)\\. (?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\. (?P<minor>0|[1-9]\\d*)\\.

View File

@@ -159,7 +159,7 @@ jobs:
- name: Install common dependencies - name: Install common dependencies
run: | run: |
apk add protobuf-dev curl clang mold grep npm bash apk add protobuf-dev curl clang mold grep npm bash
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y --default-toolchain 1.80.0 curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y
echo "source $HOME/.cargo/env" >> saved_env echo "source $HOME/.cargo/env" >> saved_env
echo "export CC=clang" >> saved_env echo "export CC=clang" >> saved_env
echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=-crt-static,+avx2,+fma,+f16c -Clinker=clang -Clink-arg=-fuse-ld=mold'" >> saved_env echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=-crt-static,+avx2,+fma,+f16c -Clinker=clang -Clink-arg=-fuse-ld=mold'" >> saved_env
@@ -167,7 +167,7 @@ jobs:
if: ${{ matrix.config.arch == 'aarch64' }} if: ${{ matrix.config.arch == 'aarch64' }}
run: | run: |
source "$HOME/.cargo/env" source "$HOME/.cargo/env"
rustup target add aarch64-unknown-linux-musl --toolchain 1.80.0 rustup target add aarch64-unknown-linux-musl
crt=$(realpath $(dirname $(rustup which rustc))/../lib/rustlib/aarch64-unknown-linux-musl/lib/self-contained) crt=$(realpath $(dirname $(rustup which rustc))/../lib/rustlib/aarch64-unknown-linux-musl/lib/self-contained)
sysroot_lib=/usr/aarch64-unknown-linux-musl/usr/lib sysroot_lib=/usr/aarch64-unknown-linux-musl/usr/lib
apk_url=https://dl-cdn.alpinelinux.org/alpine/latest-stable/main/aarch64/ apk_url=https://dl-cdn.alpinelinux.org/alpine/latest-stable/main/aarch64/
@@ -262,7 +262,7 @@ jobs:
- name: Install common dependencies - name: Install common dependencies
run: | run: |
apk add protobuf-dev curl clang mold grep npm bash openssl-dev openssl-libs-static apk add protobuf-dev curl clang mold grep npm bash openssl-dev openssl-libs-static
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y --default-toolchain 1.80.0 curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y
echo "source $HOME/.cargo/env" >> saved_env echo "source $HOME/.cargo/env" >> saved_env
echo "export CC=clang" >> saved_env echo "export CC=clang" >> saved_env
echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=-crt-static,+avx2,+fma,+f16c -Clinker=clang -Clink-arg=-fuse-ld=mold'" >> saved_env echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=-crt-static,+avx2,+fma,+f16c -Clinker=clang -Clink-arg=-fuse-ld=mold'" >> saved_env
@@ -272,7 +272,7 @@ jobs:
if: ${{ matrix.config.arch == 'aarch64' }} if: ${{ matrix.config.arch == 'aarch64' }}
run: | run: |
source "$HOME/.cargo/env" source "$HOME/.cargo/env"
rustup target add aarch64-unknown-linux-musl --toolchain 1.80.0 rustup target add aarch64-unknown-linux-musl
crt=$(realpath $(dirname $(rustup which rustc))/../lib/rustlib/aarch64-unknown-linux-musl/lib/self-contained) crt=$(realpath $(dirname $(rustup which rustc))/../lib/rustlib/aarch64-unknown-linux-musl/lib/self-contained)
sysroot_lib=/usr/aarch64-unknown-linux-musl/usr/lib sysroot_lib=/usr/aarch64-unknown-linux-musl/usr/lib
apk_url=https://dl-cdn.alpinelinux.org/alpine/latest-stable/main/aarch64/ apk_url=https://dl-cdn.alpinelinux.org/alpine/latest-stable/main/aarch64/
@@ -334,50 +334,51 @@ jobs:
path: | path: |
node/dist/lancedb-vectordb-win32*.tgz node/dist/lancedb-vectordb-win32*.tgz
node-windows-arm64: # TODO: https://github.com/lancedb/lancedb/issues/1975
name: vectordb ${{ matrix.config.arch }}-pc-windows-msvc # node-windows-arm64:
if: startsWith(github.ref, 'refs/tags/v') # name: vectordb ${{ matrix.config.arch }}-pc-windows-msvc
runs-on: ubuntu-latest # # if: startsWith(github.ref, 'refs/tags/v')
container: alpine:edge # runs-on: ubuntu-latest
strategy: # container: alpine:edge
fail-fast: false # strategy:
matrix: # fail-fast: false
config: # matrix:
# - arch: x86_64 # config:
- arch: aarch64 # # - arch: x86_64
steps: # - arch: aarch64
- name: Checkout # steps:
uses: actions/checkout@v4 # - name: Checkout
- name: Install dependencies # uses: actions/checkout@v4
run: | # - name: Install dependencies
apk add protobuf-dev curl clang lld llvm19 grep npm bash msitools sed # run: |
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y --default-toolchain 1.80.0 # apk add protobuf-dev curl clang lld llvm19 grep npm bash msitools sed
echo "source $HOME/.cargo/env" >> saved_env # curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y
echo "export CC=clang" >> saved_env # echo "source $HOME/.cargo/env" >> saved_env
echo "export AR=llvm-ar" >> saved_env # echo "export CC=clang" >> saved_env
source "$HOME/.cargo/env" # echo "export AR=llvm-ar" >> saved_env
rustup target add ${{ matrix.config.arch }}-pc-windows-msvc --toolchain 1.80.0 # source "$HOME/.cargo/env"
(mkdir -p sysroot && cd sysroot && sh ../ci/sysroot-${{ matrix.config.arch }}-pc-windows-msvc.sh) # rustup target add ${{ matrix.config.arch }}-pc-windows-msvc
echo "export C_INCLUDE_PATH=/usr/${{ matrix.config.arch }}-pc-windows-msvc/usr/include" >> saved_env # (mkdir -p sysroot && cd sysroot && sh ../ci/sysroot-${{ matrix.config.arch }}-pc-windows-msvc.sh)
echo "export CARGO_BUILD_TARGET=${{ matrix.config.arch }}-pc-windows-msvc" >> saved_env # echo "export C_INCLUDE_PATH=/usr/${{ matrix.config.arch }}-pc-windows-msvc/usr/include" >> saved_env
- name: Configure x86_64 build # echo "export CARGO_BUILD_TARGET=${{ matrix.config.arch }}-pc-windows-msvc" >> saved_env
if: ${{ matrix.config.arch == 'x86_64' }} # - name: Configure x86_64 build
run: | # if: ${{ matrix.config.arch == 'x86_64' }}
echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=+crt-static,+avx2,+fma,+f16c -Clinker=lld -Clink-arg=/LIBPATH:/usr/x86_64-pc-windows-msvc/usr/lib'" >> saved_env # run: |
- name: Configure aarch64 build # echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=+crt-static,+avx2,+fma,+f16c -Clinker=lld -Clink-arg=/LIBPATH:/usr/x86_64-pc-windows-msvc/usr/lib'" >> saved_env
if: ${{ matrix.config.arch == 'aarch64' }} # - name: Configure aarch64 build
run: | # if: ${{ matrix.config.arch == 'aarch64' }}
echo "export RUSTFLAGS='-Ctarget-feature=+crt-static,+neon,+fp16,+fhm,+dotprod -Clinker=lld -Clink-arg=/LIBPATH:/usr/aarch64-pc-windows-msvc/usr/lib -Clink-arg=arm64rt.lib'" >> saved_env # run: |
- name: Build Windows Artifacts # echo "export RUSTFLAGS='-Ctarget-feature=+crt-static,+neon,+fp16,+fhm,+dotprod -Clinker=lld -Clink-arg=/LIBPATH:/usr/aarch64-pc-windows-msvc/usr/lib -Clink-arg=arm64rt.lib'" >> saved_env
run: | # - name: Build Windows Artifacts
source ./saved_env # run: |
bash ci/manylinux_node/build_vectordb.sh ${{ matrix.config.arch }} ${{ matrix.config.arch }}-pc-windows-msvc # source ./saved_env
- name: Upload Windows Artifacts # bash ci/manylinux_node/build_vectordb.sh ${{ matrix.config.arch }} ${{ matrix.config.arch }}-pc-windows-msvc
uses: actions/upload-artifact@v4 # - name: Upload Windows Artifacts
with: # uses: actions/upload-artifact@v4
name: node-native-windows-${{ matrix.config.arch }} # with:
path: | # name: node-native-windows-${{ matrix.config.arch }}
node/dist/lancedb-vectordb-win32*.tgz # path: |
# node/dist/lancedb-vectordb-win32*.tgz
nodejs-windows: nodejs-windows:
name: lancedb ${{ matrix.target }} name: lancedb ${{ matrix.target }}
@@ -413,57 +414,58 @@ jobs:
path: | path: |
nodejs/dist/*.node nodejs/dist/*.node
nodejs-windows-arm64: # TODO: https://github.com/lancedb/lancedb/issues/1975
name: lancedb ${{ matrix.config.arch }}-pc-windows-msvc # nodejs-windows-arm64:
# Only runs on tags that matches the make-release action # name: lancedb ${{ matrix.config.arch }}-pc-windows-msvc
if: startsWith(github.ref, 'refs/tags/v') # # Only runs on tags that matches the make-release action
runs-on: ubuntu-latest # # if: startsWith(github.ref, 'refs/tags/v')
container: alpine:edge # runs-on: ubuntu-latest
strategy: # container: alpine:edge
fail-fast: false # strategy:
matrix: # fail-fast: false
config: # matrix:
# - arch: x86_64 # config:
- arch: aarch64 # # - arch: x86_64
steps: # - arch: aarch64
- name: Checkout # steps:
uses: actions/checkout@v4 # - name: Checkout
- name: Install dependencies # uses: actions/checkout@v4
run: | # - name: Install dependencies
apk add protobuf-dev curl clang lld llvm19 grep npm bash msitools sed # run: |
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y --default-toolchain 1.80.0 # apk add protobuf-dev curl clang lld llvm19 grep npm bash msitools sed
echo "source $HOME/.cargo/env" >> saved_env # curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y
echo "export CC=clang" >> saved_env # echo "source $HOME/.cargo/env" >> saved_env
echo "export AR=llvm-ar" >> saved_env # echo "export CC=clang" >> saved_env
source "$HOME/.cargo/env" # echo "export AR=llvm-ar" >> saved_env
rustup target add ${{ matrix.config.arch }}-pc-windows-msvc --toolchain 1.80.0 # source "$HOME/.cargo/env"
(mkdir -p sysroot && cd sysroot && sh ../ci/sysroot-${{ matrix.config.arch }}-pc-windows-msvc.sh) # rustup target add ${{ matrix.config.arch }}-pc-windows-msvc
echo "export C_INCLUDE_PATH=/usr/${{ matrix.config.arch }}-pc-windows-msvc/usr/include" >> saved_env # (mkdir -p sysroot && cd sysroot && sh ../ci/sysroot-${{ matrix.config.arch }}-pc-windows-msvc.sh)
echo "export CARGO_BUILD_TARGET=${{ matrix.config.arch }}-pc-windows-msvc" >> saved_env # echo "export C_INCLUDE_PATH=/usr/${{ matrix.config.arch }}-pc-windows-msvc/usr/include" >> saved_env
printf '#!/bin/sh\ncargo "$@"' > $HOME/.cargo/bin/cargo-xwin # echo "export CARGO_BUILD_TARGET=${{ matrix.config.arch }}-pc-windows-msvc" >> saved_env
chmod u+x $HOME/.cargo/bin/cargo-xwin # printf '#!/bin/sh\ncargo "$@"' > $HOME/.cargo/bin/cargo-xwin
- name: Configure x86_64 build # chmod u+x $HOME/.cargo/bin/cargo-xwin
if: ${{ matrix.config.arch == 'x86_64' }} # - name: Configure x86_64 build
run: | # if: ${{ matrix.config.arch == 'x86_64' }}
echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=+crt-static,+avx2,+fma,+f16c -Clinker=lld -Clink-arg=/LIBPATH:/usr/x86_64-pc-windows-msvc/usr/lib'" >> saved_env # run: |
- name: Configure aarch64 build # echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=+crt-static,+avx2,+fma,+f16c -Clinker=lld -Clink-arg=/LIBPATH:/usr/x86_64-pc-windows-msvc/usr/lib'" >> saved_env
if: ${{ matrix.config.arch == 'aarch64' }} # - name: Configure aarch64 build
run: | # if: ${{ matrix.config.arch == 'aarch64' }}
echo "export RUSTFLAGS='-Ctarget-feature=+crt-static,+neon,+fp16,+fhm,+dotprod -Clinker=lld -Clink-arg=/LIBPATH:/usr/aarch64-pc-windows-msvc/usr/lib -Clink-arg=arm64rt.lib'" >> saved_env # run: |
- name: Build Windows Artifacts # echo "export RUSTFLAGS='-Ctarget-feature=+crt-static,+neon,+fp16,+fhm,+dotprod -Clinker=lld -Clink-arg=/LIBPATH:/usr/aarch64-pc-windows-msvc/usr/lib -Clink-arg=arm64rt.lib'" >> saved_env
run: | # - name: Build Windows Artifacts
source ./saved_env # run: |
bash ci/manylinux_node/build_lancedb.sh ${{ matrix.config.arch }} # source ./saved_env
- name: Upload Windows Artifacts # bash ci/manylinux_node/build_lancedb.sh ${{ matrix.config.arch }}
uses: actions/upload-artifact@v4 # - name: Upload Windows Artifacts
with: # uses: actions/upload-artifact@v4
name: nodejs-native-windows-${{ matrix.config.arch }} # with:
path: | # name: nodejs-native-windows-${{ matrix.config.arch }}
nodejs/dist/*.node # path: |
# nodejs/dist/*.node
release: release:
name: vectordb NPM Publish name: vectordb NPM Publish
needs: [node, node-macos, node-linux-gnu, node-linux-musl, node-windows, node-windows-arm64] needs: [node, node-macos, node-linux-gnu, node-linux-musl, node-windows]
runs-on: ubuntu-latest runs-on: ubuntu-latest
# Only runs on tags that matches the make-release action # Only runs on tags that matches the make-release action
if: startsWith(github.ref, 'refs/tags/v') if: startsWith(github.ref, 'refs/tags/v')
@@ -503,7 +505,7 @@ jobs:
release-nodejs: release-nodejs:
name: lancedb NPM Publish name: lancedb NPM Publish
needs: [nodejs-macos, nodejs-linux-gnu, nodejs-linux-musl, nodejs-windows, nodejs-windows-arm64] needs: [nodejs-macos, nodejs-linux-gnu, nodejs-linux-musl, nodejs-windows]
runs-on: ubuntu-latest runs-on: ubuntu-latest
# Only runs on tags that matches the make-release action # Only runs on tags that matches the make-release action
if: startsWith(github.ref, 'refs/tags/v') if: startsWith(github.ref, 'refs/tags/v')

View File

@@ -30,10 +30,10 @@ jobs:
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
python-version: "3.11" python-version: "3.12"
- name: Install ruff - name: Install ruff
run: | run: |
pip install ruff==0.5.4 pip install ruff==0.8.4
- name: Format check - name: Format check
run: ruff format --check . run: ruff format --check .
- name: Lint - name: Lint

View File

@@ -21,16 +21,16 @@ categories = ["database-implementations"]
rust-version = "1.78.0" rust-version = "1.78.0"
[workspace.dependencies] [workspace.dependencies]
lance = { "version" = "=0.21.0", "features" = [ lance = { "version" = "=0.21.1", "features" = [
"dynamodb", "dynamodb",
], git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.5" } ], git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
lance-io = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.5" } lance-io = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
lance-index = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.5" } lance-index = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
lance-linalg = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.5" } lance-linalg = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
lance-table = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.5" } lance-table = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
lance-testing = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.5" } lance-testing = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
lance-datafusion = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.5" } lance-datafusion = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
lance-encoding = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.5" } lance-encoding = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
# Note that this one does not include pyarrow # Note that this one does not include pyarrow
arrow = { version = "53.2", optional = false } arrow = { version = "53.2", optional = false }
arrow-array = "53.2" arrow-array = "53.2"

View File

@@ -50,7 +50,7 @@ Consider that we have a LanceDB table named `my_table`, whose string column `tex
}); });
await tbl await tbl
.search("puppy", queryType="fts") .search("puppy", "fts")
.select(["text"]) .select(["text"])
.limit(10) .limit(10)
.toArray(); .toArray();

View File

@@ -133,6 +133,10 @@ lists the indices that LanceDb supports.
::: lancedb.index.IvfPq ::: lancedb.index.IvfPq
::: lancedb.index.HnswPq
::: lancedb.index.HnswSq
::: lancedb.index.IvfFlat ::: lancedb.index.IvfFlat
## Querying (Asynchronous) ## Querying (Asynchronous)

View File

@@ -8,7 +8,7 @@
<parent> <parent>
<groupId>com.lancedb</groupId> <groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId> <artifactId>lancedb-parent</artifactId>
<version>0.14.1-beta.6</version> <version>0.14.1-final.0</version>
<relativePath>../pom.xml</relativePath> <relativePath>../pom.xml</relativePath>
</parent> </parent>

View File

@@ -6,7 +6,7 @@
<groupId>com.lancedb</groupId> <groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId> <artifactId>lancedb-parent</artifactId>
<version>0.14.1-beta.6</version> <version>0.14.1-final.0</version>
<packaging>pom</packaging> <packaging>pom</packaging>
<name>LanceDB Parent</name> <name>LanceDB Parent</name>

111
node/package-lock.json generated
View File

@@ -1,12 +1,12 @@
{ {
"name": "vectordb", "name": "vectordb",
"version": "0.14.1-beta.6", "version": "0.14.1",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "vectordb", "name": "vectordb",
"version": "0.14.1-beta.6", "version": "0.14.1",
"cpu": [ "cpu": [
"x64", "x64",
"arm64" "arm64"
@@ -52,14 +52,14 @@
"uuid": "^9.0.0" "uuid": "^9.0.0"
}, },
"optionalDependencies": { "optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.14.1-beta.6", "@lancedb/vectordb-darwin-arm64": "0.14.1",
"@lancedb/vectordb-darwin-x64": "0.14.1-beta.6", "@lancedb/vectordb-darwin-x64": "0.14.1",
"@lancedb/vectordb-linux-arm64-gnu": "0.14.1-beta.6", "@lancedb/vectordb-linux-arm64-gnu": "0.14.1",
"@lancedb/vectordb-linux-arm64-musl": "0.14.1-beta.6", "@lancedb/vectordb-linux-arm64-musl": "0.14.1",
"@lancedb/vectordb-linux-x64-gnu": "0.14.1-beta.6", "@lancedb/vectordb-linux-x64-gnu": "0.14.1",
"@lancedb/vectordb-linux-x64-musl": "0.14.1-beta.6", "@lancedb/vectordb-linux-x64-musl": "0.14.1",
"@lancedb/vectordb-win32-arm64-msvc": "0.14.1-beta.6", "@lancedb/vectordb-win32-arm64-msvc": "0.14.1",
"@lancedb/vectordb-win32-x64-msvc": "0.14.1-beta.6" "@lancedb/vectordb-win32-x64-msvc": "0.14.1"
}, },
"peerDependencies": { "peerDependencies": {
"@apache-arrow/ts": "^14.0.2", "@apache-arrow/ts": "^14.0.2",
@@ -329,6 +329,97 @@
"@jridgewell/sourcemap-codec": "^1.4.10" "@jridgewell/sourcemap-codec": "^1.4.10"
} }
}, },
"node_modules/@lancedb/vectordb-darwin-arm64": {
"version": "0.14.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.14.1.tgz",
"integrity": "sha512-6t7XHR7dBjDmAS/kz5wbe7LPhKW+WkFA16ZPyh0lmuxfnss4VvN3LE6qQBHjzYzB9U6Nu/4ktQ50xZGEPTnc5A==",
"cpu": [
"arm64"
],
"license": "Apache-2.0",
"optional": true,
"os": [
"darwin"
]
},
"node_modules/@lancedb/vectordb-darwin-x64": {
"version": "0.14.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.14.1.tgz",
"integrity": "sha512-8q6Kd6XnNPKN8wqj75pHVQ4KFl6z9BaI6lWDiEaCNcO3bjPZkcLFNosJq4raxZ9iUi50Yl0qFJ6qR0XFVTwnnw==",
"cpu": [
"x64"
],
"license": "Apache-2.0",
"optional": true,
"os": [
"darwin"
]
},
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
"version": "0.14.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.14.1.tgz",
"integrity": "sha512-4djEMmeNb+p6nW/C4xb8wdMwnIbWfO8fYAwiplOxzxeOpPaUC9rhwUUDCbrJDCpMa8RP5ED4/jC6yT8epaDMDw==",
"cpu": [
"arm64"
],
"license": "Apache-2.0",
"optional": true,
"os": [
"linux"
]
},
"node_modules/@lancedb/vectordb-linux-arm64-musl": {
"version": "0.14.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-musl/-/vectordb-linux-arm64-musl-0.14.1.tgz",
"integrity": "sha512-c33hSsp16pnC58plzx1OXuifp9Rachx/MshE/L/OReoutt74fFdrRJwUjE4UCAysyY5QdvTrNm9OhDjopQK2Bw==",
"cpu": [
"arm64"
],
"license": "Apache-2.0",
"optional": true,
"os": [
"linux"
]
},
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
"version": "0.14.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.14.1.tgz",
"integrity": "sha512-psu6cH9iLiSbUEZD1EWbOA4THGYSwJvS2XICO9yN7A6D41AP/ynYMRZNKWo1fpdi2Fjb0xNQwiNhQyqwbi5gzA==",
"cpu": [
"x64"
],
"license": "Apache-2.0",
"optional": true,
"os": [
"linux"
]
},
"node_modules/@lancedb/vectordb-linux-x64-musl": {
"version": "0.14.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-musl/-/vectordb-linux-x64-musl-0.14.1.tgz",
"integrity": "sha512-Rg4VWW80HaTFmR7EvNSu+nfRQQM8beO/otBn/Nus5mj5zFw/7cacGRmiEYhDnk5iAn8nauV+Jsi9j2U+C2hp5w==",
"cpu": [
"x64"
],
"license": "Apache-2.0",
"optional": true,
"os": [
"linux"
]
},
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
"version": "0.14.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.14.1.tgz",
"integrity": "sha512-XbifasmMbQIt3V9P0AtQND6M3XFiIAc1ZIgmjzBjOmxwqw4sQUwHMyJGIGOzKFZTK3fPJIGRHId7jAzXuBgfQg==",
"cpu": [
"x64"
],
"license": "Apache-2.0",
"optional": true,
"os": [
"win32"
]
},
"node_modules/@neon-rs/cli": { "node_modules/@neon-rs/cli": {
"version": "0.0.160", "version": "0.0.160",
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz", "resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",

View File

@@ -1,6 +1,6 @@
{ {
"name": "vectordb", "name": "vectordb",
"version": "0.14.1-beta.6", "version": "0.14.1",
"description": " Serverless, low-latency vector database for AI applications", "description": " Serverless, low-latency vector database for AI applications",
"private": false, "private": false,
"main": "dist/index.js", "main": "dist/index.js",
@@ -92,13 +92,13 @@
} }
}, },
"optionalDependencies": { "optionalDependencies": {
"@lancedb/vectordb-darwin-x64": "0.14.1-beta.6", "@lancedb/vectordb-darwin-x64": "0.14.1",
"@lancedb/vectordb-darwin-arm64": "0.14.1-beta.6", "@lancedb/vectordb-darwin-arm64": "0.14.1",
"@lancedb/vectordb-linux-x64-gnu": "0.14.1-beta.6", "@lancedb/vectordb-linux-x64-gnu": "0.14.1",
"@lancedb/vectordb-linux-arm64-gnu": "0.14.1-beta.6", "@lancedb/vectordb-linux-arm64-gnu": "0.14.1",
"@lancedb/vectordb-linux-x64-musl": "0.14.1-beta.6", "@lancedb/vectordb-linux-x64-musl": "0.14.1",
"@lancedb/vectordb-linux-arm64-musl": "0.14.1-beta.6", "@lancedb/vectordb-linux-arm64-musl": "0.14.1",
"@lancedb/vectordb-win32-x64-msvc": "0.14.1-beta.6", "@lancedb/vectordb-win32-x64-msvc": "0.14.1",
"@lancedb/vectordb-win32-arm64-msvc": "0.14.1-beta.6" "@lancedb/vectordb-win32-arm64-msvc": "0.14.1"
} }
} }

View File

@@ -1,7 +1,7 @@
[package] [package]
name = "lancedb-nodejs" name = "lancedb-nodejs"
edition.workspace = true edition.workspace = true
version = "0.14.1-beta.6" version = "0.14.1"
license.workspace = true license.workspace = true
description.workspace = true description.workspace = true
repository.workspace = true repository.workspace = true
@@ -12,7 +12,10 @@ categories.workspace = true
crate-type = ["cdylib"] crate-type = ["cdylib"]
[dependencies] [dependencies]
async-trait.workspace = true
arrow-ipc.workspace = true arrow-ipc.workspace = true
arrow-array.workspace = true
arrow-schema.workspace = true
env_logger.workspace = true env_logger.workspace = true
futures.workspace = true futures.workspace = true
lancedb = { path = "../rust/lancedb", features = ["remote"] } lancedb = { path = "../rust/lancedb", features = ["remote"] }

View File

@@ -20,6 +20,8 @@ import * as arrow18 from "apache-arrow-18";
import { import {
convertToTable, convertToTable,
fromBufferToRecordBatch,
fromRecordBatchToBuffer,
fromTableToBuffer, fromTableToBuffer,
makeArrowTable, makeArrowTable,
makeEmptyTable, makeEmptyTable,
@@ -553,5 +555,28 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
}); });
}); });
}); });
describe("converting record batches to buffers", function () {
it("can convert to buffered record batch and back again", async function () {
const records = [
{ text: "dog", vector: [0.1, 0.2] },
{ text: "cat", vector: [0.3, 0.4] },
];
const table = await convertToTable(records);
const batch = table.batches[0];
const buffer = await fromRecordBatchToBuffer(batch);
const result = await fromBufferToRecordBatch(buffer);
expect(JSON.stringify(batch.toArray())).toEqual(
JSON.stringify(result?.toArray()),
);
});
it("converting from buffer returns null if buffer has no record batches", async function () {
const result = await fromBufferToRecordBatch(Buffer.from([0x01, 0x02])); // bad data
expect(result).toEqual(null);
});
});
}, },
); );

View File

@@ -0,0 +1,79 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
import { RecordBatch } from "apache-arrow";
import * as tmp from "tmp";
import { Connection, Index, Table, connect, makeArrowTable } from "../lancedb";
import { RRFReranker } from "../lancedb/rerankers";
describe("rerankers", function () {
let tmpDir: tmp.DirResult;
let conn: Connection;
let table: Table;
beforeEach(async () => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
conn = await connect(tmpDir.name);
table = await conn.createTable("mytable", [
{ vector: [0.1, 0.1], text: "dog" },
{ vector: [0.2, 0.2], text: "cat" },
]);
await table.createIndex("text", {
config: Index.fts(),
replace: true,
});
});
it("will query with the custom reranker", async function () {
const expectedResult = [
{
text: "albert",
// biome-ignore lint/style/useNamingConvention: this is the lance field name
_relevance_score: 0.99,
},
];
class MyCustomReranker {
async rerankHybrid(
_query: string,
_vecResults: RecordBatch,
_ftsResults: RecordBatch,
): Promise<RecordBatch> {
// no reranker logic, just return some static data
const table = makeArrowTable(expectedResult);
return table.batches[0];
}
}
let result = await table
.query()
.nearestTo([0.1, 0.1])
.fullTextSearch("dog")
.rerank(new MyCustomReranker())
.select(["text"])
.limit(5)
.toArray();
result = JSON.parse(JSON.stringify(result)); // convert StructRow to Object
expect(result).toEqual([
{
text: "albert",
// biome-ignore lint/style/useNamingConvention: this is the lance field name
_relevance_score: 0.99,
},
]);
});
it("will query with RRFReranker", async function () {
// smoke test to see if the Rust wrapping Typescript is wired up correctly
const result = await table
.query()
.nearestTo([0.1, 0.1])
.fullTextSearch("dog")
.rerank(await RRFReranker.create())
.select(["text"])
.limit(5)
.toArray();
expect(result).toHaveLength(2);
});
});

View File

@@ -27,7 +27,9 @@ import {
List, List,
Null, Null,
RecordBatch, RecordBatch,
RecordBatchFileReader,
RecordBatchFileWriter, RecordBatchFileWriter,
RecordBatchReader,
RecordBatchStreamWriter, RecordBatchStreamWriter,
Schema, Schema,
Struct, Struct,
@@ -810,6 +812,30 @@ export async function fromDataToBuffer(
} }
} }
/**
* Read a single record batch from a buffer.
*
* Returns null if the buffer does not contain a record batch
*/
export async function fromBufferToRecordBatch(
data: Buffer,
): Promise<RecordBatch | null> {
const iter = await RecordBatchFileReader.readAll(Buffer.from(data)).next()
.value;
const recordBatch = iter?.next().value;
return recordBatch || null;
}
/**
* Create a buffer containing a single record batch
*/
export async function fromRecordBatchToBuffer(
batch: RecordBatch,
): Promise<Buffer> {
const writer = new RecordBatchFileWriter().writeAll([batch]);
return Buffer.from(await writer.toUint8Array());
}
/** /**
* Serialize an Arrow Table into a buffer using the Arrow IPC Stream serialization * Serialize an Arrow Table into a buffer using the Arrow IPC Stream serialization
* *

View File

@@ -62,6 +62,7 @@ export { Index, IndexOptions, IvfPqOptions } from "./indices";
export { Table, AddDataOptions, UpdateOptions, OptimizeOptions } from "./table"; export { Table, AddDataOptions, UpdateOptions, OptimizeOptions } from "./table";
export * as embedding from "./embedding"; export * as embedding from "./embedding";
export * as rerankers from "./rerankers";
/** /**
* Connect to a LanceDB instance at the given URI. * Connect to a LanceDB instance at the given URI.

View File

@@ -16,6 +16,8 @@ import {
Table as ArrowTable, Table as ArrowTable,
type IntoVector, type IntoVector,
RecordBatch, RecordBatch,
fromBufferToRecordBatch,
fromRecordBatchToBuffer,
tableFromIPC, tableFromIPC,
} from "./arrow"; } from "./arrow";
import { type IvfPqOptions } from "./indices"; import { type IvfPqOptions } from "./indices";
@@ -25,6 +27,7 @@ import {
Table as NativeTable, Table as NativeTable,
VectorQuery as NativeVectorQuery, VectorQuery as NativeVectorQuery,
} from "./native"; } from "./native";
import { Reranker } from "./rerankers";
export class RecordBatchIterator implements AsyncIterator<RecordBatch> { export class RecordBatchIterator implements AsyncIterator<RecordBatch> {
private promisedInner?: Promise<NativeBatchIterator>; private promisedInner?: Promise<NativeBatchIterator>;
private inner?: NativeBatchIterator; private inner?: NativeBatchIterator;
@@ -542,6 +545,27 @@ export class VectorQuery extends QueryBase<NativeVectorQuery> {
return this; return this;
} }
} }
rerank(reranker: Reranker): VectorQuery {
super.doCall((inner) =>
inner.rerank({
rerankHybrid: async (_, args) => {
const vecResults = await fromBufferToRecordBatch(args.vecResults);
const ftsResults = await fromBufferToRecordBatch(args.ftsResults);
const result = await reranker.rerankHybrid(
args.query,
vecResults as RecordBatch,
ftsResults as RecordBatch,
);
const buffer = fromRecordBatchToBuffer(result);
return buffer;
},
}),
);
return this;
}
} }
/** A builder for LanceDB queries. */ /** A builder for LanceDB queries. */

View File

@@ -0,0 +1,17 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
import { RecordBatch } from "apache-arrow";
export * from "./rrf";
// Interface for a reranker. A reranker is used to rerank the results from a
// vector and FTS search. This is useful for combining the results from both
// search methods.
export interface Reranker {
rerankHybrid(
query: string,
vecResults: RecordBatch,
ftsResults: RecordBatch,
): Promise<RecordBatch>;
}

View File

@@ -0,0 +1,40 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
import { RecordBatch } from "apache-arrow";
import { fromBufferToRecordBatch, fromRecordBatchToBuffer } from "../arrow";
import { RrfReranker as NativeRRFReranker } from "../native";
/**
* Reranks the results using the Reciprocal Rank Fusion (RRF) algorithm.
*
* Internally this uses the Rust implementation
*/
export class RRFReranker {
private inner: NativeRRFReranker;
constructor(inner: NativeRRFReranker) {
this.inner = inner;
}
public static async create(k: number = 60) {
return new RRFReranker(
await NativeRRFReranker.tryNew(new Float32Array([k])),
);
}
async rerankHybrid(
query: string,
vecResults: RecordBatch,
ftsResults: RecordBatch,
): Promise<RecordBatch> {
const buffer = await this.inner.rerankHybrid(
query,
await fromRecordBatchToBuffer(vecResults),
await fromRecordBatchToBuffer(ftsResults),
);
const recordBatch = await fromBufferToRecordBatch(buffer);
return recordBatch as RecordBatch;
}
}

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-darwin-arm64", "name": "@lancedb/lancedb-darwin-arm64",
"version": "0.14.1-beta.6", "version": "0.14.1",
"os": ["darwin"], "os": ["darwin"],
"cpu": ["arm64"], "cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node", "main": "lancedb.darwin-arm64.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-darwin-x64", "name": "@lancedb/lancedb-darwin-x64",
"version": "0.14.1-beta.6", "version": "0.14.1",
"os": ["darwin"], "os": ["darwin"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.darwin-x64.node", "main": "lancedb.darwin-x64.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-arm64-gnu", "name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.14.1-beta.6", "version": "0.14.1",
"os": ["linux"], "os": ["linux"],
"cpu": ["arm64"], "cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node", "main": "lancedb.linux-arm64-gnu.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-arm64-musl", "name": "@lancedb/lancedb-linux-arm64-musl",
"version": "0.14.1-beta.6", "version": "0.14.1",
"os": ["linux"], "os": ["linux"],
"cpu": ["arm64"], "cpu": ["arm64"],
"main": "lancedb.linux-arm64-musl.node", "main": "lancedb.linux-arm64-musl.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-x64-gnu", "name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.14.1-beta.6", "version": "0.14.1",
"os": ["linux"], "os": ["linux"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node", "main": "lancedb.linux-x64-gnu.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-x64-musl", "name": "@lancedb/lancedb-linux-x64-musl",
"version": "0.14.1-beta.6", "version": "0.14.1",
"os": ["linux"], "os": ["linux"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.linux-x64-musl.node", "main": "lancedb.linux-x64-musl.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-win32-arm64-msvc", "name": "@lancedb/lancedb-win32-arm64-msvc",
"version": "0.14.1-beta.6", "version": "0.14.1",
"os": [ "os": [
"win32" "win32"
], ],

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-win32-x64-msvc", "name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.14.1-beta.6", "version": "0.14.1",
"os": ["win32"], "os": ["win32"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node", "main": "lancedb.win32-x64-msvc.node",

View File

@@ -1,12 +1,12 @@
{ {
"name": "@lancedb/lancedb", "name": "@lancedb/lancedb",
"version": "0.14.1-beta.6", "version": "0.14.1",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "@lancedb/lancedb", "name": "@lancedb/lancedb",
"version": "0.14.1-beta.6", "version": "0.14.1",
"cpu": [ "cpu": [
"x64", "x64",
"arm64" "arm64"

View File

@@ -11,7 +11,7 @@
"ann" "ann"
], ],
"private": false, "private": false,
"version": "0.14.1-beta.6", "version": "0.14.1",
"main": "dist/index.js", "main": "dist/index.js",
"exports": { "exports": {
".": "./dist/index.js", ".": "./dist/index.js",

View File

@@ -24,6 +24,7 @@ mod iterator;
pub mod merge; pub mod merge;
mod query; mod query;
pub mod remote; pub mod remote;
mod rerankers;
mod table; mod table;
mod util; mod util;

View File

@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::sync::Arc;
use lancedb::index::scalar::FullTextSearchQuery; use lancedb::index::scalar::FullTextSearchQuery;
use lancedb::query::ExecutableQuery; use lancedb::query::ExecutableQuery;
use lancedb::query::Query as LanceDbQuery; use lancedb::query::Query as LanceDbQuery;
@@ -25,6 +27,8 @@ use napi_derive::napi;
use crate::error::convert_error; use crate::error::convert_error;
use crate::error::NapiErrorExt; use crate::error::NapiErrorExt;
use crate::iterator::RecordBatchIterator; use crate::iterator::RecordBatchIterator;
use crate::rerankers::Reranker;
use crate::rerankers::RerankerCallbacks;
use crate::util::parse_distance_type; use crate::util::parse_distance_type;
#[napi] #[napi]
@@ -218,6 +222,14 @@ impl VectorQuery {
self.inner = self.inner.clone().with_row_id(); self.inner = self.inner.clone().with_row_id();
} }
#[napi]
pub fn rerank(&mut self, callbacks: RerankerCallbacks) {
self.inner = self
.inner
.clone()
.rerank(Arc::new(Reranker::new(callbacks)));
}
#[napi(catch_unwind)] #[napi(catch_unwind)]
pub async fn execute( pub async fn execute(
&self, &self,

147
nodejs/src/rerankers.rs Normal file
View File

@@ -0,0 +1,147 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use arrow_array::RecordBatch;
use async_trait::async_trait;
use napi::{
bindgen_prelude::*,
threadsafe_function::{ErrorStrategy, ThreadsafeFunction},
};
use napi_derive::napi;
use lancedb::ipc::batches_to_ipc_file;
use lancedb::rerankers::Reranker as LanceDBReranker;
use lancedb::{error::Error, ipc::ipc_file_to_batches};
use crate::error::NapiErrorExt;
/// Reranker implementation that "wraps" a NodeJS Reranker implementation.
/// This contains references to the callbacks that can be used to invoke the
/// reranking methods on the NodeJS implementation and handles serializing the
/// record batches to Arrow IPC buffers.
#[napi]
pub struct Reranker {
/// callback to the Javascript which will call the rerankHybrid method of
/// some Reranker implementation
rerank_hybrid: ThreadsafeFunction<RerankHybridCallbackArgs, ErrorStrategy::CalleeHandled>,
}
#[napi]
impl Reranker {
#[napi]
pub fn new(callbacks: RerankerCallbacks) -> Self {
let rerank_hybrid = callbacks
.rerank_hybrid
.create_threadsafe_function(0, move |ctx| Ok(vec![ctx.value]))
.unwrap();
Self { rerank_hybrid }
}
}
#[async_trait]
impl lancedb::rerankers::Reranker for Reranker {
async fn rerank_hybrid(
&self,
query: &str,
vector_results: RecordBatch,
fts_results: RecordBatch,
) -> lancedb::error::Result<RecordBatch> {
let callback_args = RerankHybridCallbackArgs {
query: query.to_string(),
vec_results: batches_to_ipc_file(&[vector_results])?,
fts_results: batches_to_ipc_file(&[fts_results])?,
};
let promised_buffer: Promise<Buffer> = self
.rerank_hybrid
.call_async(Ok(callback_args))
.await
.map_err(|e| Error::Runtime {
message: format!("napi error status={}, reason={}", e.status, e.reason),
})?;
let buffer = promised_buffer.await.map_err(|e| Error::Runtime {
message: format!("napi error status={}, reason={}", e.status, e.reason),
})?;
let mut reader = ipc_file_to_batches(buffer.to_vec())?;
let result = reader.next().ok_or(Error::Runtime {
message: "reranker result deserialization failed".to_string(),
})??;
return Ok(result);
}
}
impl std::fmt::Debug for Reranker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("NodeJSRerankerWrapper")
}
}
#[napi(object)]
pub struct RerankerCallbacks {
pub rerank_hybrid: JsFunction,
}
#[napi(object)]
pub struct RerankHybridCallbackArgs {
pub query: String,
pub vec_results: Vec<u8>,
pub fts_results: Vec<u8>,
}
fn buffer_to_record_batch(buffer: Buffer) -> Result<RecordBatch> {
let mut reader = ipc_file_to_batches(buffer.to_vec()).default_error()?;
reader
.next()
.ok_or(Error::InvalidInput {
message: "expected buffer containing record batch".to_string(),
})
.default_error()?
.map_err(Error::from)
.default_error()
}
/// Wrapper around rust RRFReranker
#[napi]
pub struct RRFReranker {
inner: lancedb::rerankers::rrf::RRFReranker,
}
#[napi]
impl RRFReranker {
#[napi]
pub async fn try_new(k: &[f32]) -> Result<Self> {
let k = k
.first()
.copied()
.ok_or(Error::InvalidInput {
message: "must supply RRF Reranker constructor arg 'k'".to_string(),
})
.default_error()?;
Ok(Self {
inner: lancedb::rerankers::rrf::RRFReranker::new(k),
})
}
#[napi]
pub async fn rerank_hybrid(
&self,
query: String,
vec_results: Buffer,
fts_results: Buffer,
) -> Result<Buffer> {
let vec_results = buffer_to_record_batch(vec_results)?;
let fts_results = buffer_to_record_batch(fts_results)?;
let result = self
.inner
.rerank_hybrid(&query, vec_results, fts_results)
.await
.unwrap();
let result_buff = batches_to_ipc_file(&[result]).default_error()?;
Ok(Buffer::from(result_buff.as_ref()))
}
}

View File

@@ -1,5 +1,5 @@
[tool.bumpversion] [tool.bumpversion]
current_version = "0.17.1" current_version = "0.17.2-beta.2"
parse = """(?x) parse = """(?x)
(?P<major>0|[1-9]\\d*)\\. (?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\. (?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "lancedb-python" name = "lancedb-python"
version = "0.17.1" version = "0.17.2-beta.2"
edition.workspace = true edition.workspace = true
description = "Python bindings for LanceDB" description = "Python bindings for LanceDB"
license.workspace = true license.workspace = true

View File

@@ -1,9 +1,10 @@
[project] [project]
name = "lancedb" name = "lancedb"
# version in Cargo.toml # version in Cargo.toml
dynamic = ["version"]
dependencies = [ dependencies = [
"deprecation", "deprecation",
"pylance==0.21.0b5", "pylance==0.21.1b1",
"tqdm>=4.27.0", "tqdm>=4.27.0",
"pydantic>=1.10", "pydantic>=1.10",
"packaging", "packaging",
@@ -52,8 +53,9 @@ tests = [
"pytz", "pytz",
"polars>=0.19, <=1.3.0", "polars>=0.19, <=1.3.0",
"tantivy", "tantivy",
"pyarrow-stubs"
] ]
dev = ["ruff", "pre-commit"] dev = ["ruff", "pre-commit", "pyright"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
clip = ["torch", "pillow", "open-clip"] clip = ["torch", "pillow", "open-clip"]
embeddings = [ embeddings = [
@@ -93,3 +95,7 @@ markers = [
"asyncio", "asyncio",
"s3_test", "s3_test",
] ]
[tool.pyright]
include = ["python/lancedb/table.py"]
pythonVersion = "3.12"

View File

@@ -1,7 +1,9 @@
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple, Any, Union, Literal
import pyarrow as pa import pyarrow as pa
from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
class Connection(object): class Connection(object):
uri: str uri: str
async def table_names( async def table_names(
@@ -31,16 +33,35 @@ class Connection(object):
class Table: class Table:
def name(self) -> str: ... def name(self) -> str: ...
def __repr__(self) -> str: ... def __repr__(self) -> str: ...
def is_open(self) -> bool: ...
def close(self) -> None: ...
async def schema(self) -> pa.Schema: ... async def schema(self) -> pa.Schema: ...
async def add(self, data: pa.RecordBatchReader, mode: str) -> None: ... async def add(
self, data: pa.RecordBatchReader, mode: Literal["append", "overwrite"]
) -> None: ...
async def update(self, updates: Dict[str, str], where: Optional[str]) -> None: ... async def update(self, updates: Dict[str, str], where: Optional[str]) -> None: ...
async def count_rows(self, filter: Optional[str]) -> int: ... async def count_rows(self, filter: Optional[str]) -> int: ...
async def create_index(self, column: str, config, replace: Optional[bool]): ... async def create_index(
self,
column: str,
index: Union[IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS],
replace: Optional[bool],
): ...
async def list_versions(self) -> List[Dict[str, Any]]: ...
async def version(self) -> int: ... async def version(self) -> int: ...
async def checkout(self, version): ... async def checkout(self, version: int): ...
async def checkout_latest(self): ... async def checkout_latest(self): ...
async def restore(self): ... async def restore(self): ...
async def list_indices(self) -> List[IndexConfig]: ... async def list_indices(self) -> list[IndexConfig]: ...
async def delete(self, filter: str): ...
async def add_columns(self, columns: list[tuple[str, str]]) -> None: ...
async def alter_columns(self, columns: list[dict[str, Any]]) -> None: ...
async def optimize(
self,
*,
cleanup_since_ms: Optional[int] = None,
delete_unverified: Optional[bool] = None,
) -> OptimizeStats: ...
def query(self) -> Query: ... def query(self) -> Query: ...
def vector_search(self) -> VectorQuery: ... def vector_search(self) -> VectorQuery: ...

View File

@@ -603,7 +603,7 @@ class AsyncConnection(object):
fill_value: Optional[float] = None, fill_value: Optional[float] = None,
storage_options: Optional[Dict[str, str]] = None, storage_options: Optional[Dict[str, str]] = None,
*, *,
embedding_functions: List[EmbeddingFunctionConfig] = None, embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
data_storage_version: Optional[str] = None, data_storage_version: Optional[str] = None,
use_legacy_format: Optional[bool] = None, use_legacy_format: Optional[bool] = None,
enable_v2_manifest_paths: Optional[bool] = None, enable_v2_manifest_paths: Optional[bool] = None,

View File

@@ -1,20 +1,10 @@
# Copyright 2023 LanceDB Developers # SPDX-License-Identifier: Apache-2.0
# # SPDX-FileCopyrightText: Copyright The LanceDB Authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Full text search index using tantivy-py""" """Full text search index using tantivy-py"""
import os import os
from typing import List, Tuple from typing import List, Tuple, Optional
import pyarrow as pa import pyarrow as pa
@@ -31,7 +21,7 @@ from .table import LanceTable
def create_index( def create_index(
index_path: str, index_path: str,
text_fields: List[str], text_fields: List[str],
ordering_fields: List[str] = None, ordering_fields: Optional[List[str]] = None,
tokenizer_name: str = "default", tokenizer_name: str = "default",
) -> tantivy.Index: ) -> tantivy.Index:
""" """
@@ -75,8 +65,8 @@ def populate_index(
index: tantivy.Index, index: tantivy.Index,
table: LanceTable, table: LanceTable,
fields: List[str], fields: List[str],
writer_heap_size: int = 1024 * 1024 * 1024, writer_heap_size: Optional[int] = None,
ordering_fields: List[str] = None, ordering_fields: Optional[List[str]] = None,
) -> int: ) -> int:
""" """
Populate an index with data from a LanceTable Populate an index with data from a LanceTable
@@ -99,6 +89,7 @@ def populate_index(
""" """
if ordering_fields is None: if ordering_fields is None:
ordering_fields = [] ordering_fields = []
writer_heap_size = writer_heap_size or 1024 * 1024 * 1024
# first check the fields exist and are string or large string type # first check the fields exist and are string or large string type
nested = [] nested = []

View File

@@ -568,4 +568,14 @@ class IvfPq:
sample_rate: int = 256 sample_rate: int = 256
__all__ = ["BTree", "IvfFlat", "IvfPq", "HnswPq", "HnswSq", "IndexConfig"] __all__ = [
"BTree",
"IvfPq",
"IvfFlat",
"HnswPq",
"HnswSq",
"IndexConfig",
"FTS",
"Bitmap",
"LabelList",
]

View File

@@ -115,6 +115,9 @@ class Query(pydantic.BaseModel):
# e.g. `{"nprobes": "10", "refine_factor": "10"}` # e.g. `{"nprobes": "10", "refine_factor": "10"}`
nprobes: int = 10 nprobes: int = 10
lower_bound: Optional[float] = None
upper_bound: Optional[float] = None
# Refine factor. # Refine factor.
refine_factor: Optional[int] = None refine_factor: Optional[int] = None
@@ -604,6 +607,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._query = query self._query = query
self._metric = "L2" self._metric = "L2"
self._nprobes = 20 self._nprobes = 20
self._lower_bound = None
self._upper_bound = None
self._refine_factor = None self._refine_factor = None
self._vector_column = vector_column self._vector_column = vector_column
self._prefilter = False self._prefilter = False
@@ -649,6 +654,30 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._nprobes = nprobes self._nprobes = nprobes
return self return self
def distance_range(
self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None
) -> LanceVectorQueryBuilder:
"""Set the distance range to use.
Only rows with distances within range [lower_bound, upper_bound)
will be returned.
Parameters
----------
lower: Optional[float]
The lower bound of the distance range.
upper_bound: Optional[float]
The upper bound of the distance range.
Returns
-------
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._lower_bound = lower_bound
self._upper_bound = upper_bound
return self
def ef(self, ef: int) -> LanceVectorQueryBuilder: def ef(self, ef: int) -> LanceVectorQueryBuilder:
"""Set the number of candidates to consider during search. """Set the number of candidates to consider during search.
@@ -728,6 +757,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
metric=self._metric, metric=self._metric,
columns=self._columns, columns=self._columns,
nprobes=self._nprobes, nprobes=self._nprobes,
lower_bound=self._lower_bound,
upper_bound=self._upper_bound,
refine_factor=self._refine_factor, refine_factor=self._refine_factor,
vector_column=self._vector_column, vector_column=self._vector_column,
with_row_id=self._with_row_id, with_row_id=self._with_row_id,
@@ -1284,6 +1315,31 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._nprobes = nprobes self._nprobes = nprobes
return self return self
def distance_range(
self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None
) -> LanceHybridQueryBuilder:
"""
Set the distance range to use.
Only rows with distances within range [lower_bound, upper_bound)
will be returned.
Parameters
----------
lower: Optional[float]
The lower bound of the distance range.
upper_bound: Optional[float]
The upper bound of the distance range.
Returns
-------
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._lower_bound = lower_bound
self._upper_bound = upper_bound
return self
def ef(self, ef: int) -> LanceHybridQueryBuilder: def ef(self, ef: int) -> LanceHybridQueryBuilder:
""" """
Set the number of candidates to consider during search. Set the number of candidates to consider during search.
@@ -1855,6 +1911,29 @@ class AsyncVectorQuery(AsyncQueryBase):
self._inner.nprobes(nprobes) self._inner.nprobes(nprobes)
return self return self
def distance_range(
self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None
) -> AsyncVectorQuery:
"""Set the distance range to use.
Only rows with distances within range [lower_bound, upper_bound)
will be returned.
Parameters
----------
lower: Optional[float]
The lower bound of the distance range.
upper_bound: Optional[float]
The upper bound of the distance range.
Returns
-------
AsyncVectorQuery
The AsyncVectorQuery object.
"""
self._inner.distance_range(lower_bound, upper_bound)
return self
def ef(self, ef: int) -> AsyncVectorQuery: def ef(self, ef: int) -> AsyncVectorQuery:
""" """
Set the number of candidates to consider during search Set the number of candidates to consider during search

View File

@@ -1,15 +1,5 @@
# Copyright 2023 LanceDB Developers # SPDX-License-Identifier: Apache-2.0
# # SPDX-FileCopyrightText: Copyright The LanceDB Authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datetime import timedelta from datetime import timedelta
import logging import logging
@@ -19,7 +9,7 @@ import warnings
from lancedb._lancedb import IndexConfig from lancedb._lancedb import IndexConfig
from lancedb.embeddings.base import EmbeddingFunctionConfig from lancedb.embeddings.base import EmbeddingFunctionConfig
from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfPq, LabelList from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfFlat, IvfPq, LabelList
from lancedb.remote.db import LOOP from lancedb.remote.db import LOOP
import pyarrow as pa import pyarrow as pa
@@ -91,7 +81,7 @@ class RemoteTable(Table):
"""to_pandas() is not yet supported on LanceDB cloud.""" """to_pandas() is not yet supported on LanceDB cloud."""
return NotImplementedError("to_pandas() is not yet supported on LanceDB cloud.") return NotImplementedError("to_pandas() is not yet supported on LanceDB cloud.")
def checkout(self, version): def checkout(self, version: int):
return LOOP.run(self._table.checkout(version)) return LOOP.run(self._table.checkout(version))
def checkout_latest(self): def checkout_latest(self):
@@ -235,10 +225,12 @@ class RemoteTable(Table):
config = HnswPq(distance_type=metric) config = HnswPq(distance_type=metric)
elif index_type == "IVF_HNSW_SQ": elif index_type == "IVF_HNSW_SQ":
config = HnswSq(distance_type=metric) config = HnswSq(distance_type=metric)
elif index_type == "IVF_FLAT":
config = IvfFlat(distance_type=metric)
else: else:
raise ValueError( raise ValueError(
f"Unknown vector index type: {index_type}. Valid options are" f"Unknown vector index type: {index_type}. Valid options are"
" 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'" " 'IVF_FLAT', 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
) )
LOOP.run(self._table.create_index(vector_column_name, config=config)) LOOP.run(self._table.create_index(vector_column_name, config=config))

View File

@@ -61,11 +61,12 @@ from .index import lang_mapping
if TYPE_CHECKING: if TYPE_CHECKING:
import PIL from ._lancedb import Table as LanceDBTable, OptimizeStats, CompactionStats
from lance.dataset import CleanupStats, ReaderLike
from ._lancedb import Table as LanceDBTable, OptimizeStats
from .db import LanceDBConnection from .db import LanceDBConnection
from .index import IndexConfig from .index import IndexConfig
from lance.dataset import CleanupStats, ReaderLike
import pandas
import PIL
pd = safe_import_pandas() pd = safe_import_pandas()
pl = safe_import_polars() pl = safe_import_polars()
@@ -84,7 +85,6 @@ def _pd_schema_without_embedding_funcs(
) )
if not embedding_functions: if not embedding_functions:
return schema return schema
columns = set(columns)
return pa.schema([field for field in schema if field.name in columns]) return pa.schema([field for field in schema if field.name in columns])
@@ -119,7 +119,7 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
return pa.Table.from_batches(data, schema=schema) return pa.Table.from_batches(data, schema=schema)
else: else:
return pa.Table.from_pylist(data, schema=schema) return pa.Table.from_pylist(data, schema=schema)
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame): elif _check_for_pandas(data) and isinstance(data, pd.DataFrame): # type: ignore
raw_schema = _pd_schema_without_embedding_funcs(schema, data.columns.to_list()) raw_schema = _pd_schema_without_embedding_funcs(schema, data.columns.to_list())
table = pa.Table.from_pandas(data, preserve_index=False, schema=raw_schema) table = pa.Table.from_pandas(data, preserve_index=False, schema=raw_schema)
# Do not serialize Pandas metadata # Do not serialize Pandas metadata
@@ -160,7 +160,7 @@ def _sanitize_data(
metadata: Optional[dict] = None, # embedding metadata metadata: Optional[dict] = None, # embedding metadata
on_bad_vectors: str = "error", on_bad_vectors: str = "error",
fill_value: float = 0.0, fill_value: float = 0.0,
): ) -> Tuple[pa.Table, pa.Schema]:
data = _coerce_to_table(data, schema) data = _coerce_to_table(data, schema)
if metadata: if metadata:
@@ -178,13 +178,17 @@ def _sanitize_data(
def sanitize_create_table( def sanitize_create_table(
data, schema, metadata=None, on_bad_vectors="error", fill_value=0.0 data,
schema: Union[pa.Schema, LanceModel],
metadata=None,
on_bad_vectors: str = "error",
fill_value: float = 0.0,
): ):
if inspect.isclass(schema) and issubclass(schema, LanceModel): if inspect.isclass(schema) and issubclass(schema, LanceModel):
# convert LanceModel to pyarrow schema # convert LanceModel to pyarrow schema
# note that it's possible this contains # note that it's possible this contains
# embedding function metadata already # embedding function metadata already
schema = schema.to_arrow_schema() schema: pa.Schema = schema.to_arrow_schema()
if data is not None: if data is not None:
if metadata is None and schema is not None: if metadata is None and schema is not None:
@@ -272,41 +276,6 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem
return data return data
def _generator_to_data_and_schema(
data: Iterable,
) -> Tuple[Iterable[pa.RecordBatch], pa.Schema]:
def _with_first_generator(first, data):
yield first
yield from data
first = next(data, None)
schema = None
if isinstance(first, pa.RecordBatch):
schema = first.schema
data = _with_first_generator(first, data)
elif isinstance(first, pa.Table):
schema = first.schema
data = _with_first_generator(first.to_batches(), data)
return data, schema
def _to_record_batch_generator(
data: Iterable,
schema,
metadata,
on_bad_vectors,
fill_value,
):
for batch in data:
# always convert to table because we need to sanitize the data
# and do things like add the vector column etc
if isinstance(batch, pa.RecordBatch):
batch = pa.Table.from_batches([batch])
batch, _ = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
for b in batch.to_batches():
yield b
def _table_path(base: str, table_name: str) -> str: def _table_path(base: str, table_name: str) -> str:
""" """
Get a table path that can be used in PyArrow FS. Get a table path that can be used in PyArrow FS.
@@ -404,7 +373,7 @@ class Table(ABC):
""" """
raise NotImplementedError raise NotImplementedError
def to_pandas(self) -> "pd.DataFrame": def to_pandas(self) -> "pandas.DataFrame":
"""Return the table as a pandas DataFrame. """Return the table as a pandas DataFrame.
Returns Returns
@@ -537,8 +506,8 @@ class Table(ABC):
def create_fts_index( def create_fts_index(
self, self,
field_names: Union[str, List[str]], field_names: Union[str, List[str]],
ordering_field_names: Union[str, List[str]] = None,
*, *,
ordering_field_names: Optional[Union[str, List[str]]] = None,
replace: bool = False, replace: bool = False,
writer_heap_size: Optional[int] = 1024 * 1024 * 1024, writer_heap_size: Optional[int] = 1024 * 1024 * 1024,
use_tantivy: bool = True, use_tantivy: bool = True,
@@ -790,8 +759,7 @@ class Table(ABC):
@abstractmethod @abstractmethod
def _execute_query( def _execute_query(
self, query: Query, batch_size: Optional[int] = None self, query: Query, batch_size: Optional[int] = None
) -> pa.RecordBatchReader: ) -> pa.RecordBatchReader: ...
pass
@abstractmethod @abstractmethod
def _do_merge( def _do_merge(
@@ -800,8 +768,7 @@ class Table(ABC):
new_data: DATA, new_data: DATA,
on_bad_vectors: str, on_bad_vectors: str,
fill_value: float, fill_value: float,
): ): ...
pass
@abstractmethod @abstractmethod
def delete(self, where: str): def delete(self, where: str):
@@ -1092,7 +1059,7 @@ class Table(ABC):
""" """
@abstractmethod @abstractmethod
def checkout(self): def checkout(self, version: int):
""" """
Checks out a specific version of the Table Checks out a specific version of the Table
@@ -1121,7 +1088,7 @@ class Table(ABC):
""" """
@abstractmethod @abstractmethod
def list_versions(self): def list_versions(self) -> List[Dict[str, Any]]:
"""List all versions of the table""" """List all versions of the table"""
@cached_property @cached_property
@@ -1244,7 +1211,7 @@ class LanceTable(Table):
A PyArrow schema object.""" A PyArrow schema object."""
return LOOP.run(self._table.schema()) return LOOP.run(self._table.schema())
def list_versions(self): def list_versions(self) -> List[Dict[str, Any]]:
"""List all versions of the table""" """List all versions of the table"""
return LOOP.run(self._table.list_versions()) return LOOP.run(self._table.list_versions())
@@ -1297,7 +1264,7 @@ class LanceTable(Table):
""" """
LOOP.run(self._table.checkout_latest()) LOOP.run(self._table.checkout_latest())
def restore(self, version: int = None): def restore(self, version: Optional[int] = None):
"""Restore a version of the table. This is an in-place operation. """Restore a version of the table. This is an in-place operation.
This creates a new version where the data is equivalent to the This creates a new version where the data is equivalent to the
@@ -1338,7 +1305,7 @@ class LanceTable(Table):
def count_rows(self, filter: Optional[str] = None) -> int: def count_rows(self, filter: Optional[str] = None) -> int:
return LOOP.run(self._table.count_rows(filter)) return LOOP.run(self._table.count_rows(filter))
def __len__(self): def __len__(self) -> int:
return self.count_rows() return self.count_rows()
def __repr__(self) -> str: def __repr__(self) -> str:
@@ -1506,8 +1473,8 @@ class LanceTable(Table):
def create_fts_index( def create_fts_index(
self, self,
field_names: Union[str, List[str]], field_names: Union[str, List[str]],
ordering_field_names: Union[str, List[str]] = None,
*, *,
ordering_field_names: Optional[Union[str, List[str]]] = None,
replace: bool = False, replace: bool = False,
writer_heap_size: Optional[int] = 1024 * 1024 * 1024, writer_heap_size: Optional[int] = 1024 * 1024 * 1024,
use_tantivy: bool = True, use_tantivy: bool = True,
@@ -1594,6 +1561,7 @@ class LanceTable(Table):
writer_heap_size=writer_heap_size, writer_heap_size=writer_heap_size,
) )
@staticmethod
def infer_tokenizer_configs(tokenizer_name: str) -> dict: def infer_tokenizer_configs(tokenizer_name: str) -> dict:
if tokenizer_name == "default": if tokenizer_name == "default":
return { return {
@@ -1759,7 +1727,7 @@ class LanceTable(Table):
) )
@overload @overload
def search( def search( # type: ignore
self, self,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None, query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None, vector_column_name: Optional[str] = None,
@@ -1895,11 +1863,11 @@ class LanceTable(Table):
name: str, name: str,
data: Optional[DATA] = None, data: Optional[DATA] = None,
schema: Optional[pa.Schema] = None, schema: Optional[pa.Schema] = None,
mode: Literal["create", "overwrite", "append"] = "create", mode: Literal["create", "overwrite"] = "create",
exist_ok: bool = False, exist_ok: bool = False,
on_bad_vectors: str = "error", on_bad_vectors: str = "error",
fill_value: float = 0.0, fill_value: float = 0.0,
embedding_functions: List[EmbeddingFunctionConfig] = None, embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
*, *,
storage_options: Optional[Dict[str, str]] = None, storage_options: Optional[Dict[str, str]] = None,
data_storage_version: Optional[str] = None, data_storage_version: Optional[str] = None,
@@ -2065,7 +2033,7 @@ class LanceTable(Table):
older_than, delete_unverified=delete_unverified older_than, delete_unverified=delete_unverified
) )
def compact_files(self, *args, **kwargs): def compact_files(self, *args, **kwargs) -> CompactionStats:
""" """
Run the compaction process on the table. Run the compaction process on the table.
@@ -2450,7 +2418,7 @@ def _process_iterator(data: Iterable, schema: Optional[pa.Schema] = None) -> pa.
if batch_table.schema != schema: if batch_table.schema != schema:
try: try:
batch_table = batch_table.cast(schema) batch_table = batch_table.cast(schema)
except pa.lib.ArrowInvalid: except pa.lib.ArrowInvalid: # type: ignore
raise ValueError( raise ValueError(
f"Input iterator yielded a batch with schema that " f"Input iterator yielded a batch with schema that "
f"does not match the expected schema.\nExpected:\n{schema}\n" f"does not match the expected schema.\nExpected:\n{schema}\n"
@@ -2710,16 +2678,17 @@ class AsyncTable:
on_bad_vectors = "error" on_bad_vectors = "error"
if fill_value is None: if fill_value is None:
fill_value = 0.0 fill_value = 0.0
data, _ = _sanitize_data( table_and_schema: Tuple[pa.Table, pa.Schema] = _sanitize_data(
data, data,
schema, schema,
metadata=schema.metadata, metadata=schema.metadata,
on_bad_vectors=on_bad_vectors, on_bad_vectors=on_bad_vectors,
fill_value=fill_value, fill_value=fill_value,
) )
if isinstance(data, pa.Table): tbl, schema = table_and_schema
data = pa.RecordBatchReader.from_batches(data.schema, data.to_batches()) if isinstance(tbl, pa.Table):
await self._inner.add(data, mode) data = pa.RecordBatchReader.from_batches(schema, tbl.to_batches())
await self._inner.add(data, mode or "append")
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder: def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
""" """
@@ -2817,6 +2786,7 @@ class AsyncTable:
async_query.nearest_to(query.vector) async_query.nearest_to(query.vector)
.distance_type(query.metric) .distance_type(query.metric)
.nprobes(query.nprobes) .nprobes(query.nprobes)
.distance_range(query.lower_bound, query.upper_bound)
) )
if query.refine_factor: if query.refine_factor:
async_query = async_query.refine_factor(query.refine_factor) async_query = async_query.refine_factor(query.refine_factor)
@@ -2977,7 +2947,7 @@ class AsyncTable:
return await self._inner.update(updates_sql, where) return await self._inner.update(updates_sql, where)
async def add_columns(self, transforms: Dict[str, str]): async def add_columns(self, transforms: dict[str, str]):
""" """
Add new columns with defined values. Add new columns with defined values.
@@ -2990,7 +2960,7 @@ class AsyncTable:
""" """
await self._inner.add_columns(list(transforms.items())) await self._inner.add_columns(list(transforms.items()))
async def alter_columns(self, *alterations: Iterable[Dict[str, str]]): async def alter_columns(self, *alterations: Iterable[dict[str, Any]]):
""" """
Alter column names and nullability. Alter column names and nullability.
@@ -3049,7 +3019,7 @@ class AsyncTable:
return versions return versions
async def checkout(self, version): async def checkout(self, version: int):
""" """
Checks out a specific version of the Table Checks out a specific version of the Table
@@ -3148,9 +3118,12 @@ class AsyncTable:
you have added or modified 100,000 or more records or run more than 20 data you have added or modified 100,000 or more records or run more than 20 data
modification operations. modification operations.
""" """
cleanup_since_ms: Optional[int] = None
if cleanup_older_than is not None: if cleanup_older_than is not None:
cleanup_older_than = round(cleanup_older_than.total_seconds() * 1000) cleanup_since_ms = round(cleanup_older_than.total_seconds() * 1000)
return await self._inner.optimize(cleanup_older_than, delete_unverified) return await self._inner.optimize(
cleanup_since_ms=cleanup_since_ms, delete_unverified=delete_unverified
)
async def list_indices(self) -> Iterable[IndexConfig]: async def list_indices(self) -> Iterable[IndexConfig]:
""" """

View File

@@ -167,8 +167,24 @@ def test_search_index(tmp_path, table):
@pytest.mark.parametrize("use_tantivy", [True, False]) @pytest.mark.parametrize("use_tantivy", [True, False])
def test_search_fts(table, use_tantivy): def test_search_fts(table, use_tantivy):
table.create_fts_index("text", use_tantivy=use_tantivy) table.create_fts_index("text", use_tantivy=use_tantivy)
results = table.search("puppy").limit(5).to_list() results = table.search("puppy").select(["id", "text"]).limit(5).to_list()
assert len(results) == 5 assert len(results) == 5
assert len(results[0]) == 3 # id, text, _score
@pytest.mark.asyncio
async def test_fts_select_async(async_table):
tbl = await async_table
await tbl.create_index("text", config=FTS())
results = (
await tbl.query()
.nearest_to_text("puppy")
.select(["id", "text"])
.limit(5)
.to_list()
)
assert len(results) == 5
assert len(results[0]) == 3 # id, text, _score
def test_search_fts_phrase_query(table): def test_search_fts_phrase_query(table):

View File

@@ -94,6 +94,73 @@ def test_with_row_id(table: lancedb.table.Table):
assert rs["_rowid"].to_pylist() == [0, 1] assert rs["_rowid"].to_pylist() == [0, 1]
def test_distance_range(table: lancedb.table.Table):
q = [0, 0]
rs = table.search(q).to_arrow()
dists = rs["_distance"].to_pylist()
min_dist = dists[0]
max_dist = dists[-1]
res = table.search(q).distance_range(upper_bound=min_dist).to_arrow()
assert len(res) == 0
res = table.search(q).distance_range(lower_bound=max_dist).to_arrow()
assert len(res) == 1
assert res["_distance"].to_pylist() == [max_dist]
res = table.search(q).distance_range(upper_bound=max_dist).to_arrow()
assert len(res) == 1
assert res["_distance"].to_pylist() == [min_dist]
res = table.search(q).distance_range(lower_bound=min_dist).to_arrow()
assert len(res) == 2
assert res["_distance"].to_pylist() == [min_dist, max_dist]
@pytest.mark.asyncio
async def test_distance_range_async(table_async: AsyncTable):
q = [0, 0]
rs = await table_async.query().nearest_to(q).to_arrow()
dists = rs["_distance"].to_pylist()
min_dist = dists[0]
max_dist = dists[-1]
res = (
await table_async.query()
.nearest_to(q)
.distance_range(upper_bound=min_dist)
.to_arrow()
)
assert len(res) == 0
res = (
await table_async.query()
.nearest_to(q)
.distance_range(lower_bound=max_dist)
.to_arrow()
)
assert len(res) == 1
assert res["_distance"].to_pylist() == [max_dist]
res = (
await table_async.query()
.nearest_to(q)
.distance_range(upper_bound=max_dist)
.to_arrow()
)
assert len(res) == 1
assert res["_distance"].to_pylist() == [min_dist]
res = (
await table_async.query()
.nearest_to(q)
.distance_range(lower_bound=min_dist)
.to_arrow()
)
assert len(res) == 2
assert res["_distance"].to_pylist() == [min_dist, max_dist]
def test_vector_query_with_no_limit(table): def test_vector_query_with_no_limit(table):
with pytest.raises(ValueError): with pytest.raises(ValueError):
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(0).select( LanceVectorQueryBuilder(table, [0, 0], "vector").limit(0).select(

View File

@@ -306,6 +306,8 @@ def test_query_sync_minimal():
"k": 10, "k": 10,
"prefilter": False, "prefilter": False,
"refine_factor": None, "refine_factor": None,
"lower_bound": None,
"upper_bound": None,
"ef": None, "ef": None,
"vector": [1.0, 2.0, 3.0], "vector": [1.0, 2.0, 3.0],
"nprobes": 20, "nprobes": 20,
@@ -348,6 +350,8 @@ def test_query_sync_maximal():
"refine_factor": 10, "refine_factor": 10,
"vector": [1.0, 2.0, 3.0], "vector": [1.0, 2.0, 3.0],
"nprobes": 5, "nprobes": 5,
"lower_bound": None,
"upper_bound": None,
"ef": None, "ef": None,
"filter": "id > 0", "filter": "id > 0",
"columns": ["id", "name"], "columns": ["id", "name"],
@@ -449,6 +453,8 @@ def test_query_sync_hybrid():
"refine_factor": None, "refine_factor": None,
"vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], "vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
"nprobes": 20, "nprobes": 20,
"lower_bound": None,
"upper_bound": None,
"ef": None, "ef": None,
"with_row_id": True, "with_row_id": True,
"version": None, "version": None,

View File

@@ -152,6 +152,10 @@ impl FTSQuery {
self.inner = self.inner.clone().select(Select::dynamic(&columns)); self.inner = self.inner.clone().select(Select::dynamic(&columns));
} }
pub fn select_columns(&mut self, columns: Vec<String>) {
self.inner = self.inner.clone().select(Select::columns(&columns));
}
pub fn limit(&mut self, limit: u32) { pub fn limit(&mut self, limit: u32) {
self.inner = self.inner.clone().limit(limit as usize); self.inner = self.inner.clone().limit(limit as usize);
} }
@@ -280,6 +284,11 @@ impl VectorQuery {
self.inner = self.inner.clone().nprobes(nprobe as usize); self.inner = self.inner.clone().nprobes(nprobe as usize);
} }
#[pyo3(signature = (lower_bound=None, upper_bound=None))]
pub fn distance_range(&mut self, lower_bound: Option<f32>, upper_bound: Option<f32>) {
self.inner = self.inner.clone().distance_range(lower_bound, upper_bound);
}
pub fn ef(&mut self, ef: u32) { pub fn ef(&mut self, ef: u32) {
self.inner = self.inner.clone().ef(ef as usize); self.inner = self.inner.clone().ef(ef as usize);
} }
@@ -341,6 +350,11 @@ impl HybridQuery {
self.inner_fts.select(columns); self.inner_fts.select(columns);
} }
pub fn select_columns(&mut self, columns: Vec<String>) {
self.inner_vec.select_columns(columns.clone());
self.inner_fts.select_columns(columns);
}
pub fn limit(&mut self, limit: u32) { pub fn limit(&mut self, limit: u32) {
self.inner_vec.limit(limit); self.inner_vec.limit(limit);
self.inner_fts.limit(limit); self.inner_fts.limit(limit);

View File

@@ -97,10 +97,12 @@ impl Table {
self.name.clone() self.name.clone()
} }
/// Returns True if the table is open, False if it is closed.
pub fn is_open(&self) -> bool { pub fn is_open(&self) -> bool {
self.inner.is_some() self.inner.is_some()
} }
/// Closes the table, releasing any resources associated with it.
pub fn close(&mut self) { pub fn close(&mut self) {
self.inner.take(); self.inner.take();
} }
@@ -301,6 +303,7 @@ impl Table {
Query::new(self.inner_ref().unwrap().query()) Query::new(self.inner_ref().unwrap().query())
} }
/// Optimize the on-disk data by compacting and pruning old data, for better performance.
#[pyo3(signature = (cleanup_since_ms=None, delete_unverified=None))] #[pyo3(signature = (cleanup_since_ms=None, delete_unverified=None))]
pub fn optimize( pub fn optimize(
self_: PyRef<'_, Self>, self_: PyRef<'_, Self>,

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "lancedb-node" name = "lancedb-node"
version = "0.14.1-beta.6" version = "0.14.1"
description = "Serverless, low-latency vector database for AI applications" description = "Serverless, low-latency vector database for AI applications"
license.workspace = true license.workspace = true
edition.workspace = true edition.workspace = true

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "lancedb" name = "lancedb"
version = "0.14.1-beta.6" version = "0.14.1"
edition.workspace = true edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications" description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true license.workspace = true

View File

@@ -214,6 +214,7 @@ mod polars_arrow_convertors;
pub mod query; pub mod query;
#[cfg(feature = "remote")] #[cfg(feature = "remote")]
pub mod remote; pub mod remote;
pub mod rerankers;
pub mod table; pub mod table;
pub mod utils; pub mod utils;

View File

@@ -15,19 +15,31 @@
use std::future::Future; use std::future::Future;
use std::sync::Arc; use std::sync::Arc;
use arrow::compute::concat_batches;
use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array}; use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array};
use arrow_schema::DataType; use arrow_schema::DataType;
use datafusion_physical_plan::ExecutionPlan; use datafusion_physical_plan::ExecutionPlan;
use futures::{stream, try_join, FutureExt, TryStreamExt};
use half::f16; use half::f16;
use lance::dataset::scanner::DatasetRecordBatchStream; use lance::{
arrow::RecordBatchExt,
dataset::{scanner::DatasetRecordBatchStream, ROW_ID},
};
use lance_datafusion::exec::execute_plan; use lance_datafusion::exec::execute_plan;
use lance_index::scalar::inverted::SCORE_COL;
use lance_index::scalar::FullTextSearchQuery; use lance_index::scalar::FullTextSearchQuery;
use lance_index::vector::DIST_COL;
use lance_io::stream::RecordBatchStreamAdapter;
use crate::arrow::SendableRecordBatchStream; use crate::arrow::SendableRecordBatchStream;
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::rerankers::rrf::RRFReranker;
use crate::rerankers::{check_reranker_result, NormalizeMethod, Reranker};
use crate::table::TableInternal; use crate::table::TableInternal;
use crate::DistanceType; use crate::DistanceType;
mod hybrid;
pub(crate) const DEFAULT_TOP_K: usize = 10; pub(crate) const DEFAULT_TOP_K: usize = 10;
/// Which columns should be retrieved from the database /// Which columns should be retrieved from the database
@@ -435,6 +447,16 @@ pub trait QueryBase {
/// Return the `_rowid` meta column from the Table. /// Return the `_rowid` meta column from the Table.
fn with_row_id(self) -> Self; fn with_row_id(self) -> Self;
/// Rerank the results using the specified reranker.
///
/// This is currently only supported for Hybrid Search.
fn rerank(self, reranker: Arc<dyn Reranker>) -> Self;
/// The method to normalize the scores. Can be "rank" or "Score". If "Rank",
/// the scores are converted to ranks and then normalized. If "Score", the
/// scores are normalized directly.
fn norm(self, norm: NormalizeMethod) -> Self;
} }
pub trait HasQuery { pub trait HasQuery {
@@ -481,6 +503,16 @@ impl<T: HasQuery> QueryBase for T {
self.mut_query().with_row_id = true; self.mut_query().with_row_id = true;
self self
} }
fn rerank(mut self, reranker: Arc<dyn Reranker>) -> Self {
self.mut_query().reranker = Some(reranker);
self
}
fn norm(mut self, norm: NormalizeMethod) -> Self {
self.mut_query().norm = Some(norm);
self
}
} }
/// Options for controlling the execution of a query /// Options for controlling the execution of a query
@@ -600,6 +632,13 @@ pub struct Query {
/// If set to false, the filter will be applied after the vector search. /// If set to false, the filter will be applied after the vector search.
pub(crate) prefilter: bool, pub(crate) prefilter: bool,
/// Implementation of reranker that can be used to reorder or combine query
/// results, especially if using hybrid search
pub(crate) reranker: Option<Arc<dyn Reranker>>,
/// Configure how query results are normalized when doing hybrid search
pub(crate) norm: Option<NormalizeMethod>,
} }
impl Query { impl Query {
@@ -614,6 +653,8 @@ impl Query {
fast_search: false, fast_search: false,
with_row_id: false, with_row_id: false,
prefilter: true, prefilter: true,
reranker: None,
norm: None,
} }
} }
@@ -714,6 +755,10 @@ pub struct VectorQuery {
// IVF PQ - ANN search. // IVF PQ - ANN search.
pub(crate) query_vector: Vec<Arc<dyn Array>>, pub(crate) query_vector: Vec<Arc<dyn Array>>,
pub(crate) nprobes: usize, pub(crate) nprobes: usize,
// The lower bound (inclusive) of the distance to search for.
pub(crate) lower_bound: Option<f32>,
// The upper bound (exclusive) of the distance to search for.
pub(crate) upper_bound: Option<f32>,
// The number of candidates to return during the refine step for HNSW, // The number of candidates to return during the refine step for HNSW,
// defaults to 1.5 * limit. // defaults to 1.5 * limit.
pub(crate) ef: Option<usize>, pub(crate) ef: Option<usize>,
@@ -730,6 +775,8 @@ impl VectorQuery {
column: None, column: None,
query_vector: Vec::new(), query_vector: Vec::new(),
nprobes: 20, nprobes: 20,
lower_bound: None,
upper_bound: None,
ef: None, ef: None,
refine_factor: None, refine_factor: None,
distance_type: None, distance_type: None,
@@ -790,6 +837,14 @@ impl VectorQuery {
self self
} }
/// Set the distance range for vector search,
/// only rows with distances in the range [lower_bound, upper_bound) will be returned
pub fn distance_range(mut self, lower_bound: Option<f32>, upper_bound: Option<f32>) -> Self {
self.lower_bound = lower_bound;
self.upper_bound = upper_bound;
self
}
/// Set the number of candidates to return during the refine step for HNSW /// Set the number of candidates to return during the refine step for HNSW
/// ///
/// This argument is only used when the vector column has an HNSW index. /// This argument is only used when the vector column has an HNSW index.
@@ -862,6 +917,65 @@ impl VectorQuery {
self.use_index = false; self.use_index = false;
self self
} }
pub async fn execute_hybrid(&self) -> Result<SendableRecordBatchStream> {
// clone query and specify we want to include row IDs, which can be needed for reranking
let fts_query = self.base.clone().with_row_id();
let mut vector_query = self.clone().with_row_id();
vector_query.base.full_text_search = None;
let (fts_results, vec_results) = try_join!(fts_query.execute(), vector_query.execute())?;
let (fts_results, vec_results) = try_join!(
fts_results.try_collect::<Vec<_>>(),
vec_results.try_collect::<Vec<_>>()
)?;
// try to get the schema to use when combining batches.
// if either
let (fts_schema, vec_schema) = hybrid::query_schemas(&fts_results, &vec_results);
// concatenate all the batches together
let mut fts_results = concat_batches(&fts_schema, fts_results.iter())?;
let mut vec_results = concat_batches(&vec_schema, vec_results.iter())?;
if matches!(self.base.norm, Some(NormalizeMethod::Rank)) {
vec_results = hybrid::rank(vec_results, DIST_COL, None)?;
fts_results = hybrid::rank(fts_results, SCORE_COL, None)?;
}
vec_results = hybrid::normalize_scores(vec_results, DIST_COL, None)?;
fts_results = hybrid::normalize_scores(fts_results, SCORE_COL, None)?;
let reranker = self
.base
.reranker
.clone()
.unwrap_or(Arc::new(RRFReranker::default()));
let fts_query = self.base.full_text_search.as_ref().ok_or(Error::Runtime {
message: "there should be an FTS search".to_string(),
})?;
let mut results = reranker
.rerank_hybrid(&fts_query.query, vec_results, fts_results)
.await?;
check_reranker_result(&results)?;
let limit = self.base.limit.unwrap_or(DEFAULT_TOP_K);
if results.num_rows() > limit {
results = results.slice(0, limit);
}
if !self.base.with_row_id {
results = results.drop_column(ROW_ID)?;
}
Ok(SendableRecordBatchStream::from(
RecordBatchStreamAdapter::new(results.schema(), stream::iter([Ok(results)])),
))
}
} }
impl ExecutableQuery for VectorQuery { impl ExecutableQuery for VectorQuery {
@@ -873,6 +987,11 @@ impl ExecutableQuery for VectorQuery {
&self, &self,
options: QueryExecutionOptions, options: QueryExecutionOptions,
) -> Result<SendableRecordBatchStream> { ) -> Result<SendableRecordBatchStream> {
if self.base.full_text_search.is_some() {
let hybrid_result = async move { self.execute_hybrid().await }.boxed().await?;
return Ok(hybrid_result);
}
Ok(SendableRecordBatchStream::from( Ok(SendableRecordBatchStream::from(
DatasetRecordBatchStream::new(execute_plan( DatasetRecordBatchStream::new(execute_plan(
self.create_plan(options).await?, self.create_plan(options).await?,
@@ -894,20 +1013,20 @@ impl HasQuery for VectorQuery {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::sync::Arc; use std::{collections::HashSet, sync::Arc};
use super::*; use super::*;
use arrow::{compute::concat_batches, datatypes::Int32Type}; use arrow::{array::downcast_array, compute::concat_batches, datatypes::Int32Type};
use arrow_array::{ use arrow_array::{
cast::AsArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator, cast::AsArray, types::Float32Type, FixedSizeListArray, Float32Array, Int32Array,
RecordBatchReader, RecordBatch, RecordBatchIterator, RecordBatchReader, StringArray,
}; };
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema}; use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
use futures::{StreamExt, TryStreamExt}; use futures::{StreamExt, TryStreamExt};
use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector}; use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector};
use tempfile::tempdir; use tempfile::tempdir;
use crate::{connect, Table}; use crate::{connect, connection::CreateTableMode, Table};
#[tokio::test] #[tokio::test]
async fn test_setters_getters() { async fn test_setters_getters() {
@@ -1245,6 +1364,30 @@ mod tests {
} }
} }
#[tokio::test]
async fn test_distance_range() {
let tmp_dir = tempdir().unwrap();
let table = make_test_table(&tmp_dir).await;
let results = table
.vector_search(&[0.1, 0.2, 0.3, 0.4])
.unwrap()
.distance_range(Some(0.0), Some(1.0))
.limit(10)
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
for batch in results {
let distances = batch["_distance"].as_primitive::<Float32Type>();
assert!(distances.iter().all(|d| {
let d = d.unwrap();
(0.0..1.0).contains(&d)
}));
}
}
#[tokio::test] #[tokio::test]
async fn test_multiple_query_vectors() { async fn test_multiple_query_vectors() {
let tmp_dir = tempdir().unwrap(); let tmp_dir = tempdir().unwrap();
@@ -1274,4 +1417,156 @@ mod tests {
assert!(query_index.values().contains(&0)); assert!(query_index.values().contains(&0));
assert!(query_index.values().contains(&1)); assert!(query_index.values().contains(&1));
} }
#[tokio::test]
async fn test_hybrid_search() {
let tmp_dir = tempdir().unwrap();
let dataset_path = tmp_dir.path();
let conn = connect(dataset_path.to_str().unwrap())
.execute()
.await
.unwrap();
let dims = 2;
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("text", DataType::Utf8, false),
ArrowField::new(
"vector",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float32, true)),
dims,
),
false,
),
]));
let text = StringArray::from(vec!["dog", "cat", "a", "b"]);
let vectors = vec![
Some(vec![Some(0.0), Some(0.0)]),
Some(vec![Some(-2.0), Some(-2.0)]),
Some(vec![Some(50.0), Some(50.0)]),
Some(vec![Some(-30.0), Some(-30.0)]),
];
let vector = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(vectors, dims);
let record_batch =
RecordBatch::try_new(schema.clone(), vec![Arc::new(text), Arc::new(vector)]).unwrap();
let record_batch_iter =
RecordBatchIterator::new(vec![record_batch].into_iter().map(Ok), schema.clone());
let table = conn
.create_table("my_table", record_batch_iter)
.execute()
.await
.unwrap();
table
.create_index(&["text"], crate::index::Index::FTS(Default::default()))
.replace(true)
.execute()
.await
.unwrap();
let fts_query = FullTextSearchQuery::new("b".to_string());
let results = table
.query()
.full_text_search(fts_query)
.limit(2)
.nearest_to(&[-10.0, -10.0])
.unwrap()
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch = &results[0];
let texts: StringArray = downcast_array(batch.column_by_name("text").unwrap());
let texts = texts.iter().map(|e| e.unwrap()).collect::<HashSet<_>>();
assert!(texts.contains("cat")); // should be close by vector search
assert!(texts.contains("b")); // should be close by fts search
// ensure that this works correctly if there are no matching FTS results
let fts_query = FullTextSearchQuery::new("z".to_string());
table
.query()
.full_text_search(fts_query)
.limit(2)
.nearest_to(&[-10.0, -10.0])
.unwrap()
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
}
#[tokio::test]
async fn test_hybrid_search_empty_table() {
let tmp_dir = tempdir().unwrap();
let dataset_path = tmp_dir.path();
let conn = connect(dataset_path.to_str().unwrap())
.execute()
.await
.unwrap();
let dims = 2;
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("text", DataType::Utf8, false),
ArrowField::new(
"vector",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float32, true)),
dims,
),
false,
),
]));
// ensure hybrid search is also supported on a fully empty table
let vectors: Vec<Option<Vec<Option<f32>>>> = Vec::new();
let record_batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(Vec::<&str>::new())),
Arc::new(
FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(vectors, dims),
),
],
)
.unwrap();
let record_batch_iter =
RecordBatchIterator::new(vec![record_batch].into_iter().map(Ok), schema.clone());
let table = conn
.create_table("my_table", record_batch_iter)
.mode(CreateTableMode::Overwrite)
.execute()
.await
.unwrap();
table
.create_index(&["text"], crate::index::Index::FTS(Default::default()))
.replace(true)
.execute()
.await
.unwrap();
let fts_query = FullTextSearchQuery::new("b".to_string());
let results = table
.query()
.full_text_search(fts_query)
.limit(2)
.nearest_to(&[-10.0, -10.0])
.unwrap()
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch = &results[0];
assert_eq!(0, batch.num_rows());
assert_eq!(2, batch.num_columns());
}
} }

View File

@@ -0,0 +1,346 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use arrow::compute::{
kernels::numeric::{div, sub},
max, min,
};
use arrow_array::{cast::downcast_array, Float32Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema, SortOptions};
use lance::dataset::ROW_ID;
use lance_index::{scalar::inverted::SCORE_COL, vector::DIST_COL};
use std::sync::Arc;
use crate::error::{Error, Result};
/// Converts results's score column to a rank.
///
/// Expects the `column` argument to be type Float32 and will panic if it's not
pub fn rank(results: RecordBatch, column: &str, ascending: Option<bool>) -> Result<RecordBatch> {
let scores = results.column_by_name(column).ok_or(Error::InvalidInput {
message: format!(
"expected column {} not found in rank. found columns {:?}",
column,
results
.schema()
.fields()
.iter()
.map(|f| f.name())
.collect::<Vec<_>>(),
),
})?;
if results.num_rows() == 0 {
return Ok(results);
}
let scores: Float32Array = downcast_array(scores);
let ranks = Float32Array::from_iter_values(
arrow::compute::kernels::rank::rank(
&scores,
Some(SortOptions {
descending: !ascending.unwrap_or(true),
..Default::default()
}),
)?
.iter()
.map(|i| *i as f32),
);
let schema = results.schema();
let (column_idx, _) = schema.column_with_name(column).unwrap();
let mut columns = results.columns().to_vec();
columns[column_idx] = Arc::new(ranks);
let results = RecordBatch::try_new(results.schema(), columns)?;
Ok(results)
}
/// Get the query schemas needed when combining the search results.
///
/// If either of the record batches are empty, then we create a schema from the
/// other record batch, and replace the score/distance column. If both record
/// batches are empty, create empty schemas.
pub fn query_schemas(
fts_results: &[RecordBatch],
vec_results: &[RecordBatch],
) -> (Arc<Schema>, Arc<Schema>) {
let (fts_schema, vec_schema) = match (
fts_results.first().map(|r| r.schema()),
vec_results.first().map(|r| r.schema()),
) {
(Some(fts_schema), Some(vec_schema)) => (fts_schema, vec_schema),
(None, Some(vec_schema)) => {
let fts_schema = with_field_name_replaced(&vec_schema, DIST_COL, SCORE_COL);
(Arc::new(fts_schema), vec_schema)
}
(Some(fts_schema), None) => {
let vec_schema = with_field_name_replaced(&fts_schema, DIST_COL, SCORE_COL);
(fts_schema, Arc::new(vec_schema))
}
(None, None) => (Arc::new(empty_fts_schema()), Arc::new(empty_vec_schema())),
};
(fts_schema, vec_schema)
}
pub fn empty_fts_schema() -> Schema {
Schema::new(vec![
Arc::new(Field::new(SCORE_COL, DataType::Float32, false)),
Arc::new(Field::new(ROW_ID, DataType::UInt64, false)),
])
}
pub fn empty_vec_schema() -> Schema {
Schema::new(vec![
Arc::new(Field::new(DIST_COL, DataType::Float32, false)),
Arc::new(Field::new(ROW_ID, DataType::UInt64, false)),
])
}
pub fn with_field_name_replaced(schema: &Schema, target: &str, replacement: &str) -> Schema {
let field_idx = schema.fields().iter().enumerate().find_map(|(i, field)| {
if field.name() == target {
Some(i)
} else {
None
}
});
let mut fields = schema.fields().to_vec();
if let Some(idx) = field_idx {
let new_field = (*fields[idx]).clone().with_name(replacement);
fields[idx] = Arc::new(new_field);
}
Schema::new(fields)
}
/// Normalize the scores column to have values between 0 and 1.
///
/// Expects the `column` argument to be type Float32 and will panic if it's not
pub fn normalize_scores(
results: RecordBatch,
column: &str,
invert: Option<bool>,
) -> Result<RecordBatch> {
let scores = results.column_by_name(column).ok_or(Error::InvalidInput {
message: format!(
"expected column {} not found in rank. found columns {:?}",
column,
results
.schema()
.fields()
.iter()
.map(|f| f.name())
.collect::<Vec<_>>(),
),
})?;
if results.num_rows() == 0 {
return Ok(results);
}
let mut scores: Float32Array = downcast_array(scores);
let max = max(&scores).unwrap_or(0.0);
let min = min(&scores).unwrap_or(0.0);
// this is equivalent to np.isclose which is used in python
let rng = if max - min < 10e-5 { max } else { max - min };
// if rng is 0, then min and max are both 0 so we just leave the scores as is
if rng != 0.0 {
let tmp = div(
&sub(&scores, &Float32Array::new_scalar(min))?,
&Float32Array::new_scalar(rng),
)?;
scores = downcast_array(&tmp);
}
if invert.unwrap_or(false) {
let tmp = sub(&Float32Array::new_scalar(1.0), &scores)?;
scores = downcast_array(&tmp);
}
let schema = results.schema();
let (column_idx, _) = schema.column_with_name(column).unwrap();
let mut columns = results.columns().to_vec();
columns[column_idx] = Arc::new(scores);
let results = RecordBatch::try_new(results.schema(), columns).unwrap();
Ok(results)
}
#[cfg(test)]
mod test {
use super::*;
use arrow_array::StringArray;
use arrow_schema::{DataType, Field, Schema};
#[test]
fn test_rank() {
let schema = Arc::new(Schema::new(vec![
Arc::new(Field::new("name", DataType::Utf8, false)),
Arc::new(Field::new("score", DataType::Float32, false)),
]));
let names = StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"]);
let scores = Float32Array::from(vec![0.2, 0.4, 0.1, 0.6, 0.45]);
let batch =
RecordBatch::try_new(schema.clone(), vec![Arc::new(names), Arc::new(scores)]).unwrap();
let result = rank(batch.clone(), "score", Some(false)).unwrap();
assert_eq!(2, result.schema().fields().len());
assert_eq!("name", result.schema().field(0).name());
assert_eq!("score", result.schema().field(1).name());
let names: StringArray = downcast_array(result.column(0));
assert_eq!(
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec!["foo", "bar", "baz", "bean", "dog"]
);
let scores: Float32Array = downcast_array(result.column(1));
assert_eq!(
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec![4.0, 3.0, 5.0, 1.0, 2.0]
);
// check sort ascending
let result = rank(batch.clone(), "score", Some(true)).unwrap();
let names: StringArray = downcast_array(result.column(0));
assert_eq!(
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec!["foo", "bar", "baz", "bean", "dog"]
);
let scores: Float32Array = downcast_array(result.column(1));
assert_eq!(
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec![2.0, 3.0, 1.0, 5.0, 4.0]
);
// ensure default sort is ascending
let result = rank(batch.clone(), "score", None).unwrap();
let names: StringArray = downcast_array(result.column(0));
assert_eq!(
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec!["foo", "bar", "baz", "bean", "dog"]
);
let scores: Float32Array = downcast_array(result.column(1));
assert_eq!(
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec![2.0, 3.0, 1.0, 5.0, 4.0]
);
// check it can handle an empty batch
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(Vec::<&str>::new())),
Arc::new(Float32Array::from(Vec::<f32>::new())),
],
)
.unwrap();
let result = rank(batch.clone(), "score", None).unwrap();
assert_eq!(0, result.num_rows());
assert_eq!(2, result.schema().fields().len());
assert_eq!("name", result.schema().field(0).name());
assert_eq!("score", result.schema().field(1).name());
// check it returns the expected error when there's no column
let result = rank(batch.clone(), "bad_col", None);
match result {
Err(Error::InvalidInput { message }) => {
assert_eq!("expected column bad_col not found in rank. found columns [\"name\", \"score\"]", message);
}
_ => {
panic!("expected invalid input error, received {:?}", result)
}
}
}
#[test]
fn test_normalize_scores() {
let schema = Arc::new(Schema::new(vec![
Arc::new(Field::new("name", DataType::Utf8, false)),
Arc::new(Field::new("score", DataType::Float32, false)),
]));
let names = Arc::new(StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"]));
let scores = Arc::new(Float32Array::from(vec![-4.0, 2.0, 0.0, 3.0, 6.0]));
let batch =
RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap();
let result = normalize_scores(batch.clone(), "score", Some(false)).unwrap();
let names: StringArray = downcast_array(result.column(0));
assert_eq!(
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec!["foo", "bar", "baz", "bean", "dog"]
);
let scores: Float32Array = downcast_array(result.column(1));
assert_eq!(
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec![0.0, 0.6, 0.4, 0.7, 1.0]
);
// check it can invert the normalization
let result = normalize_scores(batch.clone(), "score", Some(true)).unwrap();
let scores: Float32Array = downcast_array(result.column(1));
assert_eq!(
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec![1.0, 1.0 - 0.6, 0.6, 0.3, 0.0]
);
// check that the default is not inverted
let result = normalize_scores(batch.clone(), "score", None).unwrap();
let scores: Float32Array = downcast_array(result.column(1));
assert_eq!(
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec![0.0, 0.6, 0.4, 0.7, 1.0]
);
// check that it will function correctly if all the values are the same
let names = Arc::new(StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"]));
let scores = Arc::new(Float32Array::from(vec![2.1, 2.1, 2.1, 2.1, 2.1]));
let batch =
RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap();
let result = normalize_scores(batch.clone(), "score", None).unwrap();
let scores: Float32Array = downcast_array(result.column(1));
assert_eq!(
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec![0.0, 0.0, 0.0, 0.0, 0.0]
);
// check it keeps floating point rounding errors for same score normalized the same
// e.g., the behaviour is consistent with python
let scores = Arc::new(Float32Array::from(vec![1.0, 1.0, 1.0, 1.0, 0.9999999]));
let batch =
RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap();
let result = normalize_scores(batch.clone(), "score", None).unwrap();
let scores: Float32Array = downcast_array(result.column(1));
assert_eq!(
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec![
1.0 - 0.9999999,
1.0 - 0.9999999,
1.0 - 0.9999999,
1.0 - 0.9999999,
0.0
]
);
// check that it can handle if all the scores are 0
let scores = Arc::new(Float32Array::from(vec![0.0, 0.0, 0.0, 0.0, 0.0]));
let batch =
RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap();
let result = normalize_scores(batch.clone(), "score", None).unwrap();
let scores: Float32Array = downcast_array(result.column(1));
assert_eq!(
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec![0.0, 0.0, 0.0, 0.0, 0.0]
);
}
}

View File

@@ -210,6 +210,8 @@ impl<S: HttpSend> RemoteTable<S> {
body["prefilter"] = query.base.prefilter.into(); body["prefilter"] = query.base.prefilter.into();
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default()); body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
body["nprobes"] = query.nprobes.into(); body["nprobes"] = query.nprobes.into();
body["lower_bound"] = query.lower_bound.into();
body["upper_bound"] = query.upper_bound.into();
body["ef"] = query.ef.into(); body["ef"] = query.ef.into();
body["refine_factor"] = query.refine_factor.into(); body["refine_factor"] = query.refine_factor.into();
if let Some(vector_column) = query.column.as_ref() { if let Some(vector_column) = query.column.as_ref() {
@@ -563,6 +565,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
let (index_type, distance_type) = match index.index { let (index_type, distance_type) = match index.index {
// TODO: Should we pass the actual index parameters? SaaS does not // TODO: Should we pass the actual index parameters? SaaS does not
// yet support them. // yet support them.
Index::IvfFlat(index) => ("IVF_FLAT", Some(index.distance_type)),
Index::IvfPq(index) => ("IVF_PQ", Some(index.distance_type)), Index::IvfPq(index) => ("IVF_PQ", Some(index.distance_type)),
Index::IvfHnswSq(index) => ("IVF_HNSW_SQ", Some(index.distance_type)), Index::IvfHnswSq(index) => ("IVF_HNSW_SQ", Some(index.distance_type)),
Index::BTree(_) => ("BTREE", None), Index::BTree(_) => ("BTREE", None),
@@ -873,6 +876,7 @@ mod tests {
use lance_index::scalar::FullTextSearchQuery; use lance_index::scalar::FullTextSearchQuery;
use reqwest::Body; use reqwest::Body;
use crate::index::vector::IvfFlatIndexBuilder;
use crate::{ use crate::{
index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType}, index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType},
query::{ExecutableQuery, QueryBase}, query::{ExecutableQuery, QueryBase},
@@ -1302,6 +1306,8 @@ mod tests {
"prefilter": true, "prefilter": true,
"distance_type": "l2", "distance_type": "l2",
"nprobes": 20, "nprobes": 20,
"lower_bound": Option::<f32>::None,
"upper_bound": Option::<f32>::None,
"k": 10, "k": 10,
"ef": Option::<usize>::None, "ef": Option::<usize>::None,
"refine_factor": null, "refine_factor": null,
@@ -1351,6 +1357,8 @@ mod tests {
"bypass_vector_index": true, "bypass_vector_index": true,
"columns": ["a", "b"], "columns": ["a", "b"],
"nprobes": 12, "nprobes": 12,
"lower_bound": Option::<f32>::None,
"upper_bound": Option::<f32>::None,
"ef": Option::<usize>::None, "ef": Option::<usize>::None,
"refine_factor": 2, "refine_factor": 2,
"version": null, "version": null,
@@ -1489,6 +1497,11 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_create_index() { async fn test_create_index() {
let cases = [ let cases = [
(
"IVF_FLAT",
Some("hamming"),
Index::IvfFlat(IvfFlatIndexBuilder::default().distance_type(DistanceType::Hamming)),
),
("IVF_PQ", Some("l2"), Index::IvfPq(Default::default())), ("IVF_PQ", Some("l2"), Index::IvfPq(Default::default())),
( (
"IVF_PQ", "IVF_PQ",

View File

@@ -0,0 +1,87 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::collections::BTreeSet;
use arrow::{
array::downcast_array,
compute::{concat_batches, filter_record_batch},
};
use arrow_array::{BooleanArray, RecordBatch, UInt64Array};
use async_trait::async_trait;
use lance::dataset::ROW_ID;
use crate::error::{Error, Result};
pub mod rrf;
/// column name for reranker relevance score
const RELEVANCE_SCORE: &str = "_relevance_score";
#[derive(Debug, Clone, PartialEq)]
pub enum NormalizeMethod {
Score,
Rank,
}
/// Interface for a reranker. A reranker is used to rerank the results from a
/// vector and FTS search. This is useful for combining the results from both
/// search methods.
#[async_trait]
pub trait Reranker: std::fmt::Debug + Sync + Send {
// TODO support vector reranking and FTS reranking. Currently only hybrid reranking is supported.
/// Rerank function receives the individual results from the vector and FTS search
/// results. You can choose to use any of the results to generate the final results,
/// allowing maximum flexibility.
async fn rerank_hybrid(
&self,
query: &str,
vector_results: RecordBatch,
fts_results: RecordBatch,
) -> Result<RecordBatch>;
fn merge_results(
&self,
vector_results: RecordBatch,
fts_results: RecordBatch,
) -> Result<RecordBatch> {
let combined = concat_batches(&fts_results.schema(), [vector_results, fts_results].iter())?;
let mut mask = BooleanArray::builder(combined.num_rows());
let mut unique_ids = BTreeSet::new();
let row_ids = combined.column_by_name(ROW_ID).ok_or(Error::InvalidInput {
message: format!(
"could not find expected column {} while merging results. found columns {:?}",
ROW_ID,
combined
.schema()
.fields()
.iter()
.map(|f| f.name())
.collect::<Vec<_>>()
),
})?;
let row_ids: UInt64Array = downcast_array(row_ids);
row_ids.values().iter().for_each(|id| {
mask.append_value(unique_ids.insert(id));
});
let combined = filter_record_batch(&combined, &mask.finish())?;
Ok(combined)
}
}
pub fn check_reranker_result(result: &RecordBatch) -> Result<()> {
if result.schema().column_with_name(RELEVANCE_SCORE).is_none() {
return Err(Error::Schema {
message: format!(
"rerank_hybrid must return a RecordBatch with a column named {}",
RELEVANCE_SCORE
),
});
}
Ok(())
}

View File

@@ -0,0 +1,223 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::collections::BTreeMap;
use std::sync::Arc;
use arrow::{
array::downcast_array,
compute::{sort_to_indices, take},
};
use arrow_array::{Float32Array, RecordBatch, UInt64Array};
use arrow_schema::{DataType, Field, Schema, SortOptions};
use async_trait::async_trait;
use lance::dataset::ROW_ID;
use crate::error::{Error, Result};
use crate::rerankers::{Reranker, RELEVANCE_SCORE};
/// Reranks the results using Reciprocal Rank Fusion(RRF) algorithm based
/// on the scores of vector and FTS search.
///
#[derive(Debug)]
pub struct RRFReranker {
k: f32,
}
impl RRFReranker {
/// Create a new RRFReranker
///
/// The parameter k is a constant used in the RRF formula (default is 60).
/// Experiments indicate that k = 60 was near-optimal, but that the choice
/// is not critical. See paper:
/// https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf
pub fn new(k: f32) -> Self {
Self { k }
}
}
impl Default for RRFReranker {
fn default() -> Self {
Self { k: 60.0 }
}
}
#[async_trait]
impl Reranker for RRFReranker {
async fn rerank_hybrid(
&self,
_query: &str,
vector_results: RecordBatch,
fts_results: RecordBatch,
) -> Result<RecordBatch> {
let vector_ids = vector_results
.column_by_name(ROW_ID)
.ok_or(Error::InvalidInput {
message: format!(
"expected column {} not found in vector_results. found columns {:?}",
ROW_ID,
vector_results
.schema()
.fields()
.iter()
.map(|f| f.name())
.collect::<Vec<_>>()
),
})?;
let fts_ids = fts_results
.column_by_name(ROW_ID)
.ok_or(Error::InvalidInput {
message: format!(
"expected column {} not found in fts_results. found columns {:?}",
ROW_ID,
fts_results
.schema()
.fields()
.iter()
.map(|f| f.name())
.collect::<Vec<_>>()
),
})?;
let vector_ids: UInt64Array = downcast_array(&vector_ids);
let fts_ids: UInt64Array = downcast_array(&fts_ids);
let mut rrf_score_map = BTreeMap::new();
let mut update_score_map = |(i, result_id)| {
let score = 1.0 / (i as f32 + self.k);
rrf_score_map
.entry(result_id)
.and_modify(|e| *e += score)
.or_insert(score);
};
vector_ids
.values()
.iter()
.enumerate()
.for_each(&mut update_score_map);
fts_ids
.values()
.iter()
.enumerate()
.for_each(&mut update_score_map);
let combined_results = self.merge_results(vector_results, fts_results)?;
let combined_row_ids: UInt64Array =
downcast_array(combined_results.column_by_name(ROW_ID).unwrap());
let relevance_scores = Float32Array::from_iter_values(
combined_row_ids
.values()
.iter()
.map(|row_id| rrf_score_map.get(row_id).unwrap())
.copied(),
);
// keep track of indices sorted by the relevance column
let sort_indices = sort_to_indices(
&relevance_scores,
Some(SortOptions {
descending: true,
..Default::default()
}),
None,
)
.unwrap();
// add relevance scores to columns
let mut columns = combined_results.columns().to_vec();
columns.push(Arc::new(relevance_scores));
// sort by the relevance scores
let columns = columns
.iter()
.map(|c| take(c, &sort_indices, None).unwrap())
.collect();
// add relevance score to schema
let mut fields = combined_results.schema().fields().to_vec();
fields.push(Arc::new(Field::new(
RELEVANCE_SCORE,
DataType::Float32,
false,
)));
let schema = Schema::new(fields);
let combined_results = RecordBatch::try_new(Arc::new(schema), columns)?;
Ok(combined_results)
}
}
#[cfg(test)]
pub mod test {
use super::*;
use arrow_array::StringArray;
#[tokio::test]
async fn test_rrf_reranker() {
let schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, false),
Field::new(ROW_ID, DataType::UInt64, false),
]));
let vec_results = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"])),
Arc::new(UInt64Array::from(vec![1, 4, 2, 5, 3])),
],
)
.unwrap();
let fts_results = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(vec!["bar", "bean", "dog"])),
Arc::new(UInt64Array::from(vec![4, 5, 3])),
],
)
.unwrap();
// scores should be calculated as:
// - foo = 1/1 = 1.0
// - bar = 1/2 + 1/1 = 1.5
// - baz = 1/3 = 0.333
// - bean = 1/4 + 1/2 = 0.75
// - dog = 1/5 + 1/3 = 0.533
// then we should get the result ranked in descending order
let reranker = RRFReranker::new(1.0);
let result = reranker
.rerank_hybrid("", vec_results, fts_results)
.await
.unwrap();
assert_eq!(3, result.schema().fields().len());
assert_eq!("name", result.schema().fields().first().unwrap().name());
assert_eq!(ROW_ID, result.schema().fields().get(1).unwrap().name());
assert_eq!(
RELEVANCE_SCORE,
result.schema().fields().get(2).unwrap().name()
);
let names: StringArray = downcast_array(result.column(0));
assert_eq!(
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec!["bar", "foo", "bean", "dog", "baz"]
);
let ids: UInt64Array = downcast_array(result.column(1));
assert_eq!(
ids.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec![4, 1, 5, 3, 2]
);
let scores: Float32Array = downcast_array(result.column(2));
assert_eq!(
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec![1.5, 1.0, 0.75, 1.0 / 5.0 + 1.0 / 3.0, 1.0 / 3.0]
);
}
}

View File

@@ -1944,6 +1944,7 @@ impl TableInternal for NativeTable {
if let Some(ef) = query.ef { if let Some(ef) = query.ef {
scanner.ef(ef); scanner.ef(ef);
} }
scanner.distance_range(query.lower_bound, query.upper_bound);
scanner.use_index(query.use_index); scanner.use_index(query.use_index);
scanner.prefilter(query.base.prefilter); scanner.prefilter(query.base.prefilter);
match query.base.select { match query.base.select {