Compare commits

..

31 Commits

Author SHA1 Message Date
Lance Release
92d845fa72 Bump version: 0.17.2-beta.0 → 0.17.2-beta.1 2024-12-31 23:36:18 +00:00
Lei Xu
397813f6a4 chore: bump pylance to 0.21.1b1 (#1989) 2024-12-31 15:34:27 -08:00
Lei Xu
50c30c5d34 chore(python): fix typo of the synchronized checkout API (#1988) 2024-12-30 18:54:31 -08:00
Bert
c9f248b058 feat: add hybrid search to node and rust SDKs (#1940)
Support hybrid search in both rust and node SDKs.

- Adds a new rerankers package to rust LanceDB, with the implementation
of the default RRF reranker
- Adds a new hybrid package to lancedb, with some helper methods related
to hybrid search such as normalizing scores and converting score column
to rank columns
- Adds capability to LanceDB VectorQuery to perform hybrid search if it
has both a nearest vector and full text search parameters.
- Adds wrappers for reranker implementations to nodejs SDK.

Additional rerankers will be added in followup PRs

https://github.com/lancedb/lancedb/issues/1921

---
Notes about how the rust rerankers are wrapped for calling from JS:

I wanted to keep the core reranker logic, and the invocation of the
reranker by the query code, in Rust. This aligns with the philosophy of
the new node SDK where it's just a thin wrapper around Rust. However, I
also wanted to have support for users who want to add custom rerankers
written in Javascript.

When we add a reranker to the query from Javascript, it adds a special
Rust reranker that has a callback to the Javascript code (which could
then turn around and call an underlying Rust reranker implementation if
desired). This adds a bit of complexity, but overall I think it moves us
in the right direction of having the majority of the query logic in the
underlying Rust SDK while keeping the option open to support custom
Javascript Rerankers.
2024-12-30 09:03:41 -05:00
Renato Marroquin
0cb6da6b7e docs: add new indexes to python docs (#1945)
closes issue #1855

Co-authored-by: Renato Marroquin <renato.marroquin@oracle.com>
2024-12-28 15:35:10 -08:00
BubbleCal
aec8332eb5 chore: add dynamic = ["version"] to pass build check (#1977)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2024-12-28 10:45:23 -08:00
Lance Release
46061070e6 Updating package-lock.json 2024-12-26 07:40:12 +00:00
Lance Release
dae8334d0b Bump version: 0.17.1 → 0.17.2-beta.0 2024-12-25 08:28:59 +00:00
BubbleCal
8c81968b59 feat: support IVF_FLAT on remote table in rust (#1979)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2024-12-25 15:54:17 +08:00
BubbleCal
16cf2990f3 feat: create IVF_FLAT on remote table (#1978)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2024-12-25 14:57:07 +08:00
Will Jones
0a0f667bbd chore: fix typos (#1976) 2024-12-24 12:50:54 -08:00
Will Jones
03753fd84b ci(node): remove hardcoded toolchain from typescript release build (#1974)
We upgraded the toolchain in #1960, but didn't realize we hardcoded it
in `npm-publish.yml`. I found if I just removed the hard-coded
toolchain, it selects the correct one.

This didn't fully fix Windows Arm, so I created a follow-up issue here:
https://github.com/lancedb/lancedb/issues/1975
2024-12-24 12:48:41 -08:00
Lance Release
55cceaa309 Updating package-lock.json 2024-12-24 18:39:00 +00:00
Lance Release
c3797eb834 Updating package-lock.json 2024-12-24 18:38:44 +00:00
Lance Release
c0d0f38494 Bump version: 0.14.1-beta.7 → 0.14.1 2024-12-24 18:38:11 +00:00
Lance Release
6a8ab78d0a Bump version: 0.14.1-beta.6 → 0.14.1-beta.7 2024-12-24 18:38:06 +00:00
Lance Release
27404c8623 Bump version: 0.17.1-beta.7 → 0.17.1 2024-12-24 18:37:28 +00:00
Lance Release
f181c7e77f Bump version: 0.17.1-beta.6 → 0.17.1-beta.7 2024-12-24 18:37:27 +00:00
BubbleCal
e70fd4fecc feat: support IVF_FLAT, binary vectors and hamming distance (#1955)
binary vectors and hamming distance can work on only IVF_FLAT, so
introduce them all in this PR.

---------

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2024-12-24 10:36:20 -08:00
verma nakul
ac0068b80e feat(python): add ignore_missing to the async drop_table() method (#1953)
- feat(db): add `ignore_missing` to async `drop_table` method

Fixes #1951

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
2024-12-24 10:33:47 -08:00
Hezi Zisman
ebac960571 feat(python): add bypass_vector_index to sync api (#1947)
Hi lancedb team,

This PR adds the `bypass_vector_index` logic to the sync API, as
described in [Issue
#535](https://github.com/lancedb/lancedb/issues/535). (Closes #535).

Iv'e implemented it only for the regular vector search. If you think it
should also be supported for FTS, Hybrid, or Empty queries and for the
cloud solution, please let me know, and I’ll be happy to extend it.

Since there’s no `CONTRIBUTING.md` or contribution guidelines, I opted
for the simplest implementation to get this started.

Looking forward to your feedback!

Thanks!

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
2024-12-24 10:33:26 -08:00
Lance Release
59b57055e7 Updating package-lock.json 2024-12-19 19:40:28 +00:00
Lance Release
591c8de8fc Updating package-lock.json 2024-12-19 19:40:13 +00:00
Lance Release
f835ff310f Bump version: 0.14.1-beta.5 → 0.14.1-beta.6 2024-12-19 19:39:41 +00:00
Lance Release
cf8c2edaf4 Bump version: 0.17.1-beta.5 → 0.17.1-beta.6 2024-12-19 19:39:08 +00:00
Will Jones
61a714a459 docs: improve optimization docs (#1957)
* Add `See Also` section to `cleanup_old_files` and `compact_files` so
they know it's linked to `optimize`.
* Fixes link to `compact_files` arguments
* Improves formatting of note.
2024-12-19 10:55:11 -08:00
Will Jones
5ddd84cec0 feat: upgrade lance to 0.21.0-beta.5 (#1961) 2024-12-19 10:54:59 -08:00
Will Jones
27ef0bb0a2 ci(rust): check MSRV and upgrade toolchain (#1960)
* Upgrades our toolchain file to v1.83.0, since many dependencies now
have MSRV of 1.81.0
* Reverts Rust changes from #1946 that were working around this in a
dumb way
* Adding an MSRV check
* Reduce MSRV back to 1.78.0
2024-12-19 08:43:25 -08:00
Will Jones
25402ba6ec chore: update lockfiles (#1946) 2024-12-18 08:43:33 -08:00
Lance Release
37c359ed40 Updating package-lock.json 2024-12-13 22:38:04 +00:00
Lance Release
06cdf00987 Bump version: 0.14.1-beta.4 → 0.14.1-beta.5 2024-12-13 22:37:41 +00:00
63 changed files with 2166 additions and 241 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.14.1-beta.4"
current_version = "0.14.1"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -97,3 +97,7 @@ jobs:
if: ${{ !inputs.dry_run && inputs.other }}
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
- uses: ./.github/workflows/update_package_lock_nodejs
if: ${{ !inputs.dry_run && inputs.other }}
with:
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -159,7 +159,7 @@ jobs:
- name: Install common dependencies
run: |
apk add protobuf-dev curl clang mold grep npm bash
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y --default-toolchain 1.80.0
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y
echo "source $HOME/.cargo/env" >> saved_env
echo "export CC=clang" >> saved_env
echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=-crt-static,+avx2,+fma,+f16c -Clinker=clang -Clink-arg=-fuse-ld=mold'" >> saved_env
@@ -167,7 +167,7 @@ jobs:
if: ${{ matrix.config.arch == 'aarch64' }}
run: |
source "$HOME/.cargo/env"
rustup target add aarch64-unknown-linux-musl --toolchain 1.80.0
rustup target add aarch64-unknown-linux-musl
crt=$(realpath $(dirname $(rustup which rustc))/../lib/rustlib/aarch64-unknown-linux-musl/lib/self-contained)
sysroot_lib=/usr/aarch64-unknown-linux-musl/usr/lib
apk_url=https://dl-cdn.alpinelinux.org/alpine/latest-stable/main/aarch64/
@@ -262,7 +262,7 @@ jobs:
- name: Install common dependencies
run: |
apk add protobuf-dev curl clang mold grep npm bash openssl-dev openssl-libs-static
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y --default-toolchain 1.80.0
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y
echo "source $HOME/.cargo/env" >> saved_env
echo "export CC=clang" >> saved_env
echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=-crt-static,+avx2,+fma,+f16c -Clinker=clang -Clink-arg=-fuse-ld=mold'" >> saved_env
@@ -272,7 +272,7 @@ jobs:
if: ${{ matrix.config.arch == 'aarch64' }}
run: |
source "$HOME/.cargo/env"
rustup target add aarch64-unknown-linux-musl --toolchain 1.80.0
rustup target add aarch64-unknown-linux-musl
crt=$(realpath $(dirname $(rustup which rustc))/../lib/rustlib/aarch64-unknown-linux-musl/lib/self-contained)
sysroot_lib=/usr/aarch64-unknown-linux-musl/usr/lib
apk_url=https://dl-cdn.alpinelinux.org/alpine/latest-stable/main/aarch64/
@@ -334,50 +334,51 @@ jobs:
path: |
node/dist/lancedb-vectordb-win32*.tgz
node-windows-arm64:
name: vectordb ${{ matrix.config.arch }}-pc-windows-msvc
if: startsWith(github.ref, 'refs/tags/v')
runs-on: ubuntu-latest
container: alpine:edge
strategy:
fail-fast: false
matrix:
config:
# - arch: x86_64
- arch: aarch64
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install dependencies
run: |
apk add protobuf-dev curl clang lld llvm19 grep npm bash msitools sed
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y --default-toolchain 1.80.0
echo "source $HOME/.cargo/env" >> saved_env
echo "export CC=clang" >> saved_env
echo "export AR=llvm-ar" >> saved_env
source "$HOME/.cargo/env"
rustup target add ${{ matrix.config.arch }}-pc-windows-msvc --toolchain 1.80.0
(mkdir -p sysroot && cd sysroot && sh ../ci/sysroot-${{ matrix.config.arch }}-pc-windows-msvc.sh)
echo "export C_INCLUDE_PATH=/usr/${{ matrix.config.arch }}-pc-windows-msvc/usr/include" >> saved_env
echo "export CARGO_BUILD_TARGET=${{ matrix.config.arch }}-pc-windows-msvc" >> saved_env
- name: Configure x86_64 build
if: ${{ matrix.config.arch == 'x86_64' }}
run: |
echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=+crt-static,+avx2,+fma,+f16c -Clinker=lld -Clink-arg=/LIBPATH:/usr/x86_64-pc-windows-msvc/usr/lib'" >> saved_env
- name: Configure aarch64 build
if: ${{ matrix.config.arch == 'aarch64' }}
run: |
echo "export RUSTFLAGS='-Ctarget-feature=+crt-static,+neon,+fp16,+fhm,+dotprod -Clinker=lld -Clink-arg=/LIBPATH:/usr/aarch64-pc-windows-msvc/usr/lib -Clink-arg=arm64rt.lib'" >> saved_env
- name: Build Windows Artifacts
run: |
source ./saved_env
bash ci/manylinux_node/build_vectordb.sh ${{ matrix.config.arch }} ${{ matrix.config.arch }}-pc-windows-msvc
- name: Upload Windows Artifacts
uses: actions/upload-artifact@v4
with:
name: node-native-windows-${{ matrix.config.arch }}
path: |
node/dist/lancedb-vectordb-win32*.tgz
# TODO: https://github.com/lancedb/lancedb/issues/1975
# node-windows-arm64:
# name: vectordb ${{ matrix.config.arch }}-pc-windows-msvc
# # if: startsWith(github.ref, 'refs/tags/v')
# runs-on: ubuntu-latest
# container: alpine:edge
# strategy:
# fail-fast: false
# matrix:
# config:
# # - arch: x86_64
# - arch: aarch64
# steps:
# - name: Checkout
# uses: actions/checkout@v4
# - name: Install dependencies
# run: |
# apk add protobuf-dev curl clang lld llvm19 grep npm bash msitools sed
# curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y
# echo "source $HOME/.cargo/env" >> saved_env
# echo "export CC=clang" >> saved_env
# echo "export AR=llvm-ar" >> saved_env
# source "$HOME/.cargo/env"
# rustup target add ${{ matrix.config.arch }}-pc-windows-msvc
# (mkdir -p sysroot && cd sysroot && sh ../ci/sysroot-${{ matrix.config.arch }}-pc-windows-msvc.sh)
# echo "export C_INCLUDE_PATH=/usr/${{ matrix.config.arch }}-pc-windows-msvc/usr/include" >> saved_env
# echo "export CARGO_BUILD_TARGET=${{ matrix.config.arch }}-pc-windows-msvc" >> saved_env
# - name: Configure x86_64 build
# if: ${{ matrix.config.arch == 'x86_64' }}
# run: |
# echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=+crt-static,+avx2,+fma,+f16c -Clinker=lld -Clink-arg=/LIBPATH:/usr/x86_64-pc-windows-msvc/usr/lib'" >> saved_env
# - name: Configure aarch64 build
# if: ${{ matrix.config.arch == 'aarch64' }}
# run: |
# echo "export RUSTFLAGS='-Ctarget-feature=+crt-static,+neon,+fp16,+fhm,+dotprod -Clinker=lld -Clink-arg=/LIBPATH:/usr/aarch64-pc-windows-msvc/usr/lib -Clink-arg=arm64rt.lib'" >> saved_env
# - name: Build Windows Artifacts
# run: |
# source ./saved_env
# bash ci/manylinux_node/build_vectordb.sh ${{ matrix.config.arch }} ${{ matrix.config.arch }}-pc-windows-msvc
# - name: Upload Windows Artifacts
# uses: actions/upload-artifact@v4
# with:
# name: node-native-windows-${{ matrix.config.arch }}
# path: |
# node/dist/lancedb-vectordb-win32*.tgz
nodejs-windows:
name: lancedb ${{ matrix.target }}
@@ -413,57 +414,58 @@ jobs:
path: |
nodejs/dist/*.node
nodejs-windows-arm64:
name: lancedb ${{ matrix.config.arch }}-pc-windows-msvc
# Only runs on tags that matches the make-release action
if: startsWith(github.ref, 'refs/tags/v')
runs-on: ubuntu-latest
container: alpine:edge
strategy:
fail-fast: false
matrix:
config:
# - arch: x86_64
- arch: aarch64
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install dependencies
run: |
apk add protobuf-dev curl clang lld llvm19 grep npm bash msitools sed
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y --default-toolchain 1.80.0
echo "source $HOME/.cargo/env" >> saved_env
echo "export CC=clang" >> saved_env
echo "export AR=llvm-ar" >> saved_env
source "$HOME/.cargo/env"
rustup target add ${{ matrix.config.arch }}-pc-windows-msvc --toolchain 1.80.0
(mkdir -p sysroot && cd sysroot && sh ../ci/sysroot-${{ matrix.config.arch }}-pc-windows-msvc.sh)
echo "export C_INCLUDE_PATH=/usr/${{ matrix.config.arch }}-pc-windows-msvc/usr/include" >> saved_env
echo "export CARGO_BUILD_TARGET=${{ matrix.config.arch }}-pc-windows-msvc" >> saved_env
printf '#!/bin/sh\ncargo "$@"' > $HOME/.cargo/bin/cargo-xwin
chmod u+x $HOME/.cargo/bin/cargo-xwin
- name: Configure x86_64 build
if: ${{ matrix.config.arch == 'x86_64' }}
run: |
echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=+crt-static,+avx2,+fma,+f16c -Clinker=lld -Clink-arg=/LIBPATH:/usr/x86_64-pc-windows-msvc/usr/lib'" >> saved_env
- name: Configure aarch64 build
if: ${{ matrix.config.arch == 'aarch64' }}
run: |
echo "export RUSTFLAGS='-Ctarget-feature=+crt-static,+neon,+fp16,+fhm,+dotprod -Clinker=lld -Clink-arg=/LIBPATH:/usr/aarch64-pc-windows-msvc/usr/lib -Clink-arg=arm64rt.lib'" >> saved_env
- name: Build Windows Artifacts
run: |
source ./saved_env
bash ci/manylinux_node/build_lancedb.sh ${{ matrix.config.arch }}
- name: Upload Windows Artifacts
uses: actions/upload-artifact@v4
with:
name: nodejs-native-windows-${{ matrix.config.arch }}
path: |
nodejs/dist/*.node
# TODO: https://github.com/lancedb/lancedb/issues/1975
# nodejs-windows-arm64:
# name: lancedb ${{ matrix.config.arch }}-pc-windows-msvc
# # Only runs on tags that matches the make-release action
# # if: startsWith(github.ref, 'refs/tags/v')
# runs-on: ubuntu-latest
# container: alpine:edge
# strategy:
# fail-fast: false
# matrix:
# config:
# # - arch: x86_64
# - arch: aarch64
# steps:
# - name: Checkout
# uses: actions/checkout@v4
# - name: Install dependencies
# run: |
# apk add protobuf-dev curl clang lld llvm19 grep npm bash msitools sed
# curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y
# echo "source $HOME/.cargo/env" >> saved_env
# echo "export CC=clang" >> saved_env
# echo "export AR=llvm-ar" >> saved_env
# source "$HOME/.cargo/env"
# rustup target add ${{ matrix.config.arch }}-pc-windows-msvc
# (mkdir -p sysroot && cd sysroot && sh ../ci/sysroot-${{ matrix.config.arch }}-pc-windows-msvc.sh)
# echo "export C_INCLUDE_PATH=/usr/${{ matrix.config.arch }}-pc-windows-msvc/usr/include" >> saved_env
# echo "export CARGO_BUILD_TARGET=${{ matrix.config.arch }}-pc-windows-msvc" >> saved_env
# printf '#!/bin/sh\ncargo "$@"' > $HOME/.cargo/bin/cargo-xwin
# chmod u+x $HOME/.cargo/bin/cargo-xwin
# - name: Configure x86_64 build
# if: ${{ matrix.config.arch == 'x86_64' }}
# run: |
# echo "export RUSTFLAGS='-Ctarget-cpu=haswell -Ctarget-feature=+crt-static,+avx2,+fma,+f16c -Clinker=lld -Clink-arg=/LIBPATH:/usr/x86_64-pc-windows-msvc/usr/lib'" >> saved_env
# - name: Configure aarch64 build
# if: ${{ matrix.config.arch == 'aarch64' }}
# run: |
# echo "export RUSTFLAGS='-Ctarget-feature=+crt-static,+neon,+fp16,+fhm,+dotprod -Clinker=lld -Clink-arg=/LIBPATH:/usr/aarch64-pc-windows-msvc/usr/lib -Clink-arg=arm64rt.lib'" >> saved_env
# - name: Build Windows Artifacts
# run: |
# source ./saved_env
# bash ci/manylinux_node/build_lancedb.sh ${{ matrix.config.arch }}
# - name: Upload Windows Artifacts
# uses: actions/upload-artifact@v4
# with:
# name: nodejs-native-windows-${{ matrix.config.arch }}
# path: |
# nodejs/dist/*.node
release:
name: vectordb NPM Publish
needs: [node, node-macos, node-linux-gnu, node-linux-musl, node-windows, node-windows-arm64]
needs: [node, node-macos, node-linux-gnu, node-linux-musl, node-windows]
runs-on: ubuntu-latest
# Only runs on tags that matches the make-release action
if: startsWith(github.ref, 'refs/tags/v')
@@ -503,7 +505,7 @@ jobs:
release-nodejs:
name: lancedb NPM Publish
needs: [nodejs-macos, nodejs-linux-gnu, nodejs-linux-musl, nodejs-windows, nodejs-windows-arm64]
needs: [nodejs-macos, nodejs-linux-gnu, nodejs-linux-musl, nodejs-windows]
runs-on: ubuntu-latest
# Only runs on tags that matches the make-release action
if: startsWith(github.ref, 'refs/tags/v')
@@ -571,7 +573,7 @@ jobs:
uses: actions/checkout@v4
with:
ref: main
persist-credentials: false
token: ${{ secrets.LANCEDB_RELEASE_TOKEN }}
fetch-depth: 0
lfs: true
- uses: ./.github/workflows/update_package_lock
@@ -589,7 +591,7 @@ jobs:
uses: actions/checkout@v4
with:
ref: main
persist-credentials: false
token: ${{ secrets.LANCEDB_RELEASE_TOKEN }}
fetch-depth: 0
lfs: true
- uses: ./.github/workflows/update_package_lock_nodejs

View File

@@ -185,7 +185,7 @@ jobs:
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\Llvm\x64\bin"
# Add MSVC runtime libraries to LIB
$env:LIB = "C:\BuildTools\VC\Tools\MSVC\$latestVersion\lib\arm64;" +
$env:LIB = "C:\BuildTools\VC\Tools\MSVC\$latestVersion\lib\arm64;" +
"C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\um\arm64;" +
"C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\ucrt\arm64"
Add-Content $env:GITHUB_ENV "LIB=$env:LIB"
@@ -238,3 +238,41 @@ jobs:
$env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT
cargo build --target aarch64-pc-windows-msvc
cargo test --target aarch64-pc-windows-msvc
msrv:
# Check the minimum supported Rust version
name: MSRV Check - Rust v${{ matrix.msrv }}
runs-on: ubuntu-24.04
strategy:
matrix:
msrv: ["1.78.0"] # This should match up with rust-version in Cargo.toml
env:
# Need up-to-date compilers for kernels
CC: clang-18
CXX: clang++-18
steps:
- uses: actions/checkout@v4
with:
submodules: true
- name: Install dependencies
run: |
sudo apt update
sudo apt install -y protobuf-compiler libssl-dev
- name: Install ${{ matrix.msrv }}
uses: dtolnay/rust-toolchain@master
with:
toolchain: ${{ matrix.msrv }}
- name: Downgrade dependencies
# These packages have newer requirements for MSRV
run: |
cargo update -p aws-sdk-bedrockruntime --precise 1.64.0
cargo update -p aws-sdk-dynamodb --precise 1.55.0
cargo update -p aws-config --precise 1.5.10
cargo update -p aws-sdk-kms --precise 1.51.0
cargo update -p aws-sdk-s3 --precise 1.65.0
cargo update -p aws-sdk-sso --precise 1.50.0
cargo update -p aws-sdk-ssooidc --precise 1.51.0
cargo update -p aws-sdk-sts --precise 1.51.0
cargo update -p home --precise 0.5.9
- name: cargo +${{ matrix.msrv }} check
run: cargo check --workspace --tests --benches --all-features

View File

@@ -18,19 +18,19 @@ repository = "https://github.com/lancedb/lancedb"
description = "Serverless, low-latency vector database for AI applications"
keywords = ["lancedb", "lance", "database", "vector", "search"]
categories = ["database-implementations"]
rust-version = "1.80.0" # TODO: lower this once we upgrade Lance again.
rust-version = "1.78.0"
[workspace.dependencies]
lance = { "version" = "=0.21.0", "features" = [
lance = { "version" = "=0.21.1", "features" = [
"dynamodb",
], git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.4" }
lance-io = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.4" }
lance-index = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.4" }
lance-linalg = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.4" }
lance-table = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.4" }
lance-testing = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.4" }
lance-datafusion = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.4" }
lance-encoding = { version = "=0.21.0", git = "https://github.com/lancedb/lance.git", tag = "v0.21.0-beta.4" }
], git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
lance-io = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
lance-index = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
lance-linalg = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
lance-table = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
lance-testing = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
lance-datafusion = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
lance-encoding = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
# Note that this one does not include pyarrow
arrow = { version = "53.2", optional = false }
arrow-array = "53.2"

View File

@@ -62,6 +62,7 @@ plugins:
# for cross references
- https://arrow.apache.org/docs/objects.inv
- https://pandas.pydata.org/docs/objects.inv
- https://lancedb.github.io/lance/objects.inv
- mkdocs-jupyter
- render_swagger:
allow_arbitrary_locations: true

View File

@@ -129,8 +129,16 @@ lists the indices that LanceDb supports.
::: lancedb.index.LabelList
::: lancedb.index.FTS
::: lancedb.index.IvfPq
::: lancedb.index.HnswPq
::: lancedb.index.HnswSq
::: lancedb.index.IvfFlat
## Querying (Asynchronous)
Queries allow you to return data from your database. Basic queries can be

View File

@@ -17,4 +17,8 @@ pip install lancedb
## Table
::: lancedb.remote.table.RemoteTable
options:
filters:
- "!cleanup_old_versions"
- "!compact_files"
- "!optimize"

View File

@@ -13,11 +13,15 @@ A vector search finds the approximate or exact nearest neighbors to a given quer
Distance metrics are a measure of the similarity between a pair of vectors.
Currently, LanceDB supports the following metrics:
| Metric | Description |
| -------- | --------------------------------------------------------------------------- |
| `l2` | [Euclidean / L2 distance](https://en.wikipedia.org/wiki/Euclidean_distance) |
| `cosine` | [Cosine Similarity](https://en.wikipedia.org/wiki/Cosine_similarity) |
| `dot` | [Dot Production](https://en.wikipedia.org/wiki/Dot_product) |
| Metric | Description |
| --------- | --------------------------------------------------------------------------- |
| `l2` | [Euclidean / L2 distance](https://en.wikipedia.org/wiki/Euclidean_distance) |
| `cosine` | [Cosine Similarity](https://en.wikipedia.org/wiki/Cosine_similarity) |
| `dot` | [Dot Production](https://en.wikipedia.org/wiki/Dot_product) |
| `hamming` | [Hamming Distance](https://en.wikipedia.org/wiki/Hamming_distance) |
!!! note
The `hamming` metric is only available for binary vectors.
## Exhaustive search (kNN)
@@ -107,6 +111,31 @@ an ANN search means that using an index often involves a trade-off between recal
See the [IVF_PQ index](./concepts/index_ivfpq.md) for a deeper description of how `IVF_PQ`
indexes work in LanceDB.
## Binary vector
LanceDB supports binary vectors as a data type, and has the ability to search binary vectors with hamming distance. The binary vectors are stored as uint8 arrays (every 8 bits are stored as a byte):
!!! note
The dim of the binary vector must be a multiple of 8. A vector of dim 128 will be stored as a uint8 array of size 16.
=== "Python"
=== "sync API"
```python
--8<-- "python/python/tests/docs/test_binary_vector.py:imports"
--8<-- "python/python/tests/docs/test_binary_vector.py:sync_binary_vector"
```
=== "async API"
```python
--8<-- "python/python/tests/docs/test_binary_vector.py:imports"
--8<-- "python/python/tests/docs/test_binary_vector.py:async_binary_vector"
```
## Output search results
LanceDB returns vector search results via different formats commonly used in python.

View File

@@ -16,6 +16,7 @@ excluded_globs = [
"../src/concepts/*.md",
"../src/ann_indexes.md",
"../src/basic.md",
"../src/search.md",
"../src/hybrid_search/hybrid_search.md",
"../src/reranking/*.md",
"../src/guides/tuning_retrievers/*.md",

View File

@@ -8,7 +8,7 @@
<parent>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.14.1-beta.4</version>
<version>0.14.1-final.0</version>
<relativePath>../pom.xml</relativePath>
</parent>

View File

@@ -6,7 +6,7 @@
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.14.1-beta.4</version>
<version>0.14.1-final.0</version>
<packaging>pom</packaging>
<name>LanceDB Parent</name>

111
node/package-lock.json generated
View File

@@ -1,12 +1,12 @@
{
"name": "vectordb",
"version": "0.14.1-beta.4",
"version": "0.14.1",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "vectordb",
"version": "0.14.1-beta.4",
"version": "0.14.1",
"cpu": [
"x64",
"arm64"
@@ -52,14 +52,14 @@
"uuid": "^9.0.0"
},
"optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.14.1-beta.4",
"@lancedb/vectordb-darwin-x64": "0.14.1-beta.4",
"@lancedb/vectordb-linux-arm64-gnu": "0.14.1-beta.4",
"@lancedb/vectordb-linux-arm64-musl": "0.14.1-beta.4",
"@lancedb/vectordb-linux-x64-gnu": "0.14.1-beta.4",
"@lancedb/vectordb-linux-x64-musl": "0.14.1-beta.4",
"@lancedb/vectordb-win32-arm64-msvc": "0.14.1-beta.4",
"@lancedb/vectordb-win32-x64-msvc": "0.14.1-beta.4"
"@lancedb/vectordb-darwin-arm64": "0.14.1",
"@lancedb/vectordb-darwin-x64": "0.14.1",
"@lancedb/vectordb-linux-arm64-gnu": "0.14.1",
"@lancedb/vectordb-linux-arm64-musl": "0.14.1",
"@lancedb/vectordb-linux-x64-gnu": "0.14.1",
"@lancedb/vectordb-linux-x64-musl": "0.14.1",
"@lancedb/vectordb-win32-arm64-msvc": "0.14.1",
"@lancedb/vectordb-win32-x64-msvc": "0.14.1"
},
"peerDependencies": {
"@apache-arrow/ts": "^14.0.2",
@@ -329,6 +329,97 @@
"@jridgewell/sourcemap-codec": "^1.4.10"
}
},
"node_modules/@lancedb/vectordb-darwin-arm64": {
"version": "0.14.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.14.1.tgz",
"integrity": "sha512-6t7XHR7dBjDmAS/kz5wbe7LPhKW+WkFA16ZPyh0lmuxfnss4VvN3LE6qQBHjzYzB9U6Nu/4ktQ50xZGEPTnc5A==",
"cpu": [
"arm64"
],
"license": "Apache-2.0",
"optional": true,
"os": [
"darwin"
]
},
"node_modules/@lancedb/vectordb-darwin-x64": {
"version": "0.14.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.14.1.tgz",
"integrity": "sha512-8q6Kd6XnNPKN8wqj75pHVQ4KFl6z9BaI6lWDiEaCNcO3bjPZkcLFNosJq4raxZ9iUi50Yl0qFJ6qR0XFVTwnnw==",
"cpu": [
"x64"
],
"license": "Apache-2.0",
"optional": true,
"os": [
"darwin"
]
},
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
"version": "0.14.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.14.1.tgz",
"integrity": "sha512-4djEMmeNb+p6nW/C4xb8wdMwnIbWfO8fYAwiplOxzxeOpPaUC9rhwUUDCbrJDCpMa8RP5ED4/jC6yT8epaDMDw==",
"cpu": [
"arm64"
],
"license": "Apache-2.0",
"optional": true,
"os": [
"linux"
]
},
"node_modules/@lancedb/vectordb-linux-arm64-musl": {
"version": "0.14.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-musl/-/vectordb-linux-arm64-musl-0.14.1.tgz",
"integrity": "sha512-c33hSsp16pnC58plzx1OXuifp9Rachx/MshE/L/OReoutt74fFdrRJwUjE4UCAysyY5QdvTrNm9OhDjopQK2Bw==",
"cpu": [
"arm64"
],
"license": "Apache-2.0",
"optional": true,
"os": [
"linux"
]
},
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
"version": "0.14.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.14.1.tgz",
"integrity": "sha512-psu6cH9iLiSbUEZD1EWbOA4THGYSwJvS2XICO9yN7A6D41AP/ynYMRZNKWo1fpdi2Fjb0xNQwiNhQyqwbi5gzA==",
"cpu": [
"x64"
],
"license": "Apache-2.0",
"optional": true,
"os": [
"linux"
]
},
"node_modules/@lancedb/vectordb-linux-x64-musl": {
"version": "0.14.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-musl/-/vectordb-linux-x64-musl-0.14.1.tgz",
"integrity": "sha512-Rg4VWW80HaTFmR7EvNSu+nfRQQM8beO/otBn/Nus5mj5zFw/7cacGRmiEYhDnk5iAn8nauV+Jsi9j2U+C2hp5w==",
"cpu": [
"x64"
],
"license": "Apache-2.0",
"optional": true,
"os": [
"linux"
]
},
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
"version": "0.14.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.14.1.tgz",
"integrity": "sha512-XbifasmMbQIt3V9P0AtQND6M3XFiIAc1ZIgmjzBjOmxwqw4sQUwHMyJGIGOzKFZTK3fPJIGRHId7jAzXuBgfQg==",
"cpu": [
"x64"
],
"license": "Apache-2.0",
"optional": true,
"os": [
"win32"
]
},
"node_modules/@neon-rs/cli": {
"version": "0.0.160",
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",

View File

@@ -1,6 +1,6 @@
{
"name": "vectordb",
"version": "0.14.1-beta.4",
"version": "0.14.1",
"description": " Serverless, low-latency vector database for AI applications",
"private": false,
"main": "dist/index.js",
@@ -92,13 +92,13 @@
}
},
"optionalDependencies": {
"@lancedb/vectordb-darwin-x64": "0.14.1-beta.4",
"@lancedb/vectordb-darwin-arm64": "0.14.1-beta.4",
"@lancedb/vectordb-linux-x64-gnu": "0.14.1-beta.4",
"@lancedb/vectordb-linux-arm64-gnu": "0.14.1-beta.4",
"@lancedb/vectordb-linux-x64-musl": "0.14.1-beta.4",
"@lancedb/vectordb-linux-arm64-musl": "0.14.1-beta.4",
"@lancedb/vectordb-win32-x64-msvc": "0.14.1-beta.4",
"@lancedb/vectordb-win32-arm64-msvc": "0.14.1-beta.4"
"@lancedb/vectordb-darwin-x64": "0.14.1",
"@lancedb/vectordb-darwin-arm64": "0.14.1",
"@lancedb/vectordb-linux-x64-gnu": "0.14.1",
"@lancedb/vectordb-linux-arm64-gnu": "0.14.1",
"@lancedb/vectordb-linux-x64-musl": "0.14.1",
"@lancedb/vectordb-linux-arm64-musl": "0.14.1",
"@lancedb/vectordb-win32-x64-msvc": "0.14.1",
"@lancedb/vectordb-win32-arm64-msvc": "0.14.1"
}
}

View File

@@ -1,7 +1,7 @@
[package]
name = "lancedb-nodejs"
edition.workspace = true
version = "0.14.1-beta.4"
version = "0.14.1"
license.workspace = true
description.workspace = true
repository.workspace = true
@@ -12,7 +12,10 @@ categories.workspace = true
crate-type = ["cdylib"]
[dependencies]
async-trait.workspace = true
arrow-ipc.workspace = true
arrow-array.workspace = true
arrow-schema.workspace = true
env_logger.workspace = true
futures.workspace = true
lancedb = { path = "../rust/lancedb", features = ["remote"] }

View File

@@ -20,6 +20,8 @@ import * as arrow18 from "apache-arrow-18";
import {
convertToTable,
fromBufferToRecordBatch,
fromRecordBatchToBuffer,
fromTableToBuffer,
makeArrowTable,
makeEmptyTable,
@@ -553,5 +555,28 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
});
});
});
describe("converting record batches to buffers", function () {
it("can convert to buffered record batch and back again", async function () {
const records = [
{ text: "dog", vector: [0.1, 0.2] },
{ text: "cat", vector: [0.3, 0.4] },
];
const table = await convertToTable(records);
const batch = table.batches[0];
const buffer = await fromRecordBatchToBuffer(batch);
const result = await fromBufferToRecordBatch(buffer);
expect(JSON.stringify(batch.toArray())).toEqual(
JSON.stringify(result?.toArray()),
);
});
it("converting from buffer returns null if buffer has no record batches", async function () {
const result = await fromBufferToRecordBatch(Buffer.from([0x01, 0x02])); // bad data
expect(result).toEqual(null);
});
});
},
);

View File

@@ -0,0 +1,79 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
import { RecordBatch } from "apache-arrow";
import * as tmp from "tmp";
import { Connection, Index, Table, connect, makeArrowTable } from "../lancedb";
import { RRFReranker } from "../lancedb/rerankers";
describe("rerankers", function () {
let tmpDir: tmp.DirResult;
let conn: Connection;
let table: Table;
beforeEach(async () => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
conn = await connect(tmpDir.name);
table = await conn.createTable("mytable", [
{ vector: [0.1, 0.1], text: "dog" },
{ vector: [0.2, 0.2], text: "cat" },
]);
await table.createIndex("text", {
config: Index.fts(),
replace: true,
});
});
it("will query with the custom reranker", async function () {
const expectedResult = [
{
text: "albert",
// biome-ignore lint/style/useNamingConvention: this is the lance field name
_relevance_score: 0.99,
},
];
class MyCustomReranker {
async rerankHybrid(
_query: string,
_vecResults: RecordBatch,
_ftsResults: RecordBatch,
): Promise<RecordBatch> {
// no reranker logic, just return some static data
const table = makeArrowTable(expectedResult);
return table.batches[0];
}
}
let result = await table
.query()
.nearestTo([0.1, 0.1])
.fullTextSearch("dog")
.rerank(new MyCustomReranker())
.select(["text"])
.limit(5)
.toArray();
result = JSON.parse(JSON.stringify(result)); // convert StructRow to Object
expect(result).toEqual([
{
text: "albert",
// biome-ignore lint/style/useNamingConvention: this is the lance field name
_relevance_score: 0.99,
},
]);
});
it("will query with RRFReranker", async function () {
// smoke test to see if the Rust wrapping Typescript is wired up correctly
const result = await table
.query()
.nearestTo([0.1, 0.1])
.fullTextSearch("dog")
.rerank(await RRFReranker.create())
.select(["text"])
.limit(5)
.toArray();
expect(result).toHaveLength(2);
});
});

View File

@@ -27,7 +27,9 @@ import {
List,
Null,
RecordBatch,
RecordBatchFileReader,
RecordBatchFileWriter,
RecordBatchReader,
RecordBatchStreamWriter,
Schema,
Struct,
@@ -810,6 +812,30 @@ export async function fromDataToBuffer(
}
}
/**
* Read a single record batch from a buffer.
*
* Returns null if the buffer does not contain a record batch
*/
export async function fromBufferToRecordBatch(
data: Buffer,
): Promise<RecordBatch | null> {
const iter = await RecordBatchFileReader.readAll(Buffer.from(data)).next()
.value;
const recordBatch = iter?.next().value;
return recordBatch || null;
}
/**
* Create a buffer containing a single record batch
*/
export async function fromRecordBatchToBuffer(
batch: RecordBatch,
): Promise<Buffer> {
const writer = new RecordBatchFileWriter().writeAll([batch]);
return Buffer.from(await writer.toUint8Array());
}
/**
* Serialize an Arrow Table into a buffer using the Arrow IPC Stream serialization
*

View File

@@ -62,6 +62,7 @@ export { Index, IndexOptions, IvfPqOptions } from "./indices";
export { Table, AddDataOptions, UpdateOptions, OptimizeOptions } from "./table";
export * as embedding from "./embedding";
export * as rerankers from "./rerankers";
/**
* Connect to a LanceDB instance at the given URI.

View File

@@ -16,6 +16,8 @@ import {
Table as ArrowTable,
type IntoVector,
RecordBatch,
fromBufferToRecordBatch,
fromRecordBatchToBuffer,
tableFromIPC,
} from "./arrow";
import { type IvfPqOptions } from "./indices";
@@ -25,6 +27,7 @@ import {
Table as NativeTable,
VectorQuery as NativeVectorQuery,
} from "./native";
import { Reranker } from "./rerankers";
export class RecordBatchIterator implements AsyncIterator<RecordBatch> {
private promisedInner?: Promise<NativeBatchIterator>;
private inner?: NativeBatchIterator;
@@ -542,6 +545,27 @@ export class VectorQuery extends QueryBase<NativeVectorQuery> {
return this;
}
}
rerank(reranker: Reranker): VectorQuery {
super.doCall((inner) =>
inner.rerank({
rerankHybrid: async (_, args) => {
const vecResults = await fromBufferToRecordBatch(args.vecResults);
const ftsResults = await fromBufferToRecordBatch(args.ftsResults);
const result = await reranker.rerankHybrid(
args.query,
vecResults as RecordBatch,
ftsResults as RecordBatch,
);
const buffer = fromRecordBatchToBuffer(result);
return buffer;
},
}),
);
return this;
}
}
/** A builder for LanceDB queries. */

View File

@@ -0,0 +1,17 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
import { RecordBatch } from "apache-arrow";
export * from "./rrf";
// Interface for a reranker. A reranker is used to rerank the results from a
// vector and FTS search. This is useful for combining the results from both
// search methods.
export interface Reranker {
rerankHybrid(
query: string,
vecResults: RecordBatch,
ftsResults: RecordBatch,
): Promise<RecordBatch>;
}

View File

@@ -0,0 +1,40 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
import { RecordBatch } from "apache-arrow";
import { fromBufferToRecordBatch, fromRecordBatchToBuffer } from "../arrow";
import { RrfReranker as NativeRRFReranker } from "../native";
/**
* Reranks the results using the Reciprocal Rank Fusion (RRF) algorithm.
*
* Internally this uses the Rust implementation
*/
export class RRFReranker {
private inner: NativeRRFReranker;
constructor(inner: NativeRRFReranker) {
this.inner = inner;
}
public static async create(k: number = 60) {
return new RRFReranker(
await NativeRRFReranker.tryNew(new Float32Array([k])),
);
}
async rerankHybrid(
query: string,
vecResults: RecordBatch,
ftsResults: RecordBatch,
): Promise<RecordBatch> {
const buffer = await this.inner.rerankHybrid(
query,
await fromRecordBatchToBuffer(vecResults),
await fromRecordBatchToBuffer(ftsResults),
);
const recordBatch = await fromBufferToRecordBatch(buffer);
return recordBatch as RecordBatch;
}
}

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-arm64",
"version": "0.14.1-beta.4",
"version": "0.14.1",
"os": ["darwin"],
"cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-x64",
"version": "0.14.1-beta.4",
"version": "0.14.1",
"os": ["darwin"],
"cpu": ["x64"],
"main": "lancedb.darwin-x64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.14.1-beta.4",
"version": "0.14.1",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-musl",
"version": "0.14.1-beta.4",
"version": "0.14.1",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.14.1-beta.4",
"version": "0.14.1",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-musl",
"version": "0.14.1-beta.4",
"version": "0.14.1",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-arm64-msvc",
"version": "0.14.1-beta.4",
"version": "0.14.1",
"os": [
"win32"
],

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.14.1-beta.4",
"version": "0.14.1",
"os": ["win32"],
"cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node",

View File

@@ -1,12 +1,12 @@
{
"name": "@lancedb/lancedb",
"version": "0.14.0",
"version": "0.14.1",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "@lancedb/lancedb",
"version": "0.14.0",
"version": "0.14.1",
"cpu": [
"x64",
"arm64"

View File

@@ -11,7 +11,7 @@
"ann"
],
"private": false,
"version": "0.14.1-beta.4",
"version": "0.14.1",
"main": "dist/index.js",
"exports": {
".": "./dist/index.js",

View File

@@ -24,6 +24,7 @@ mod iterator;
pub mod merge;
mod query;
pub mod remote;
mod rerankers;
mod table;
mod util;

View File

@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use lancedb::index::scalar::FullTextSearchQuery;
use lancedb::query::ExecutableQuery;
use lancedb::query::Query as LanceDbQuery;
@@ -25,6 +27,8 @@ use napi_derive::napi;
use crate::error::convert_error;
use crate::error::NapiErrorExt;
use crate::iterator::RecordBatchIterator;
use crate::rerankers::Reranker;
use crate::rerankers::RerankerCallbacks;
use crate::util::parse_distance_type;
#[napi]
@@ -218,6 +222,14 @@ impl VectorQuery {
self.inner = self.inner.clone().with_row_id();
}
#[napi]
pub fn rerank(&mut self, callbacks: RerankerCallbacks) {
self.inner = self
.inner
.clone()
.rerank(Arc::new(Reranker::new(callbacks)));
}
#[napi(catch_unwind)]
pub async fn execute(
&self,

147
nodejs/src/rerankers.rs Normal file
View File

@@ -0,0 +1,147 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use arrow_array::RecordBatch;
use async_trait::async_trait;
use napi::{
bindgen_prelude::*,
threadsafe_function::{ErrorStrategy, ThreadsafeFunction},
};
use napi_derive::napi;
use lancedb::ipc::batches_to_ipc_file;
use lancedb::rerankers::Reranker as LanceDBReranker;
use lancedb::{error::Error, ipc::ipc_file_to_batches};
use crate::error::NapiErrorExt;
/// Reranker implementation that "wraps" a NodeJS Reranker implementation.
/// This contains references to the callbacks that can be used to invoke the
/// reranking methods on the NodeJS implementation and handles serializing the
/// record batches to Arrow IPC buffers.
#[napi]
pub struct Reranker {
/// callback to the Javascript which will call the rerankHybrid method of
/// some Reranker implementation
rerank_hybrid: ThreadsafeFunction<RerankHybridCallbackArgs, ErrorStrategy::CalleeHandled>,
}
#[napi]
impl Reranker {
#[napi]
pub fn new(callbacks: RerankerCallbacks) -> Self {
let rerank_hybrid = callbacks
.rerank_hybrid
.create_threadsafe_function(0, move |ctx| Ok(vec![ctx.value]))
.unwrap();
Self { rerank_hybrid }
}
}
#[async_trait]
impl lancedb::rerankers::Reranker for Reranker {
async fn rerank_hybrid(
&self,
query: &str,
vector_results: RecordBatch,
fts_results: RecordBatch,
) -> lancedb::error::Result<RecordBatch> {
let callback_args = RerankHybridCallbackArgs {
query: query.to_string(),
vec_results: batches_to_ipc_file(&[vector_results])?,
fts_results: batches_to_ipc_file(&[fts_results])?,
};
let promised_buffer: Promise<Buffer> = self
.rerank_hybrid
.call_async(Ok(callback_args))
.await
.map_err(|e| Error::Runtime {
message: format!("napi error status={}, reason={}", e.status, e.reason),
})?;
let buffer = promised_buffer.await.map_err(|e| Error::Runtime {
message: format!("napi error status={}, reason={}", e.status, e.reason),
})?;
let mut reader = ipc_file_to_batches(buffer.to_vec())?;
let result = reader.next().ok_or(Error::Runtime {
message: "reranker result deserialization failed".to_string(),
})??;
return Ok(result);
}
}
impl std::fmt::Debug for Reranker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("NodeJSRerankerWrapper")
}
}
#[napi(object)]
pub struct RerankerCallbacks {
pub rerank_hybrid: JsFunction,
}
#[napi(object)]
pub struct RerankHybridCallbackArgs {
pub query: String,
pub vec_results: Vec<u8>,
pub fts_results: Vec<u8>,
}
fn buffer_to_record_batch(buffer: Buffer) -> Result<RecordBatch> {
let mut reader = ipc_file_to_batches(buffer.to_vec()).default_error()?;
reader
.next()
.ok_or(Error::InvalidInput {
message: "expected buffer containing record batch".to_string(),
})
.default_error()?
.map_err(Error::from)
.default_error()
}
/// Wrapper around rust RRFReranker
#[napi]
pub struct RRFReranker {
inner: lancedb::rerankers::rrf::RRFReranker,
}
#[napi]
impl RRFReranker {
#[napi]
pub async fn try_new(k: &[f32]) -> Result<Self> {
let k = k
.first()
.copied()
.ok_or(Error::InvalidInput {
message: "must supply RRF Reranker constructor arg 'k'".to_string(),
})
.default_error()?;
Ok(Self {
inner: lancedb::rerankers::rrf::RRFReranker::new(k),
})
}
#[napi]
pub async fn rerank_hybrid(
&self,
query: String,
vec_results: Buffer,
fts_results: Buffer,
) -> Result<Buffer> {
let vec_results = buffer_to_record_batch(vec_results)?;
let fts_results = buffer_to_record_batch(fts_results)?;
let result = self
.inner
.rerank_hybrid(&query, vec_results, fts_results)
.await
.unwrap();
let result_buff = batches_to_ipc_file(&[result]).default_error()?;
Ok(Buffer::from(result_buff.as_ref()))
}
}

View File

@@ -5,8 +5,9 @@ pub fn parse_distance_type(distance_type: impl AsRef<str>) -> napi::Result<Dista
"l2" => Ok(DistanceType::L2),
"cosine" => Ok(DistanceType::Cosine),
"dot" => Ok(DistanceType::Dot),
"hamming" => Ok(DistanceType::Hamming),
_ => Err(napi::Error::from_reason(format!(
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
"Invalid distance type '{}'. Must be one of l2, cosine, dot, or hamming",
distance_type.as_ref()
))),
}

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.17.1-beta.5"
current_version = "0.17.2-beta.1"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-python"
version = "0.17.1-beta.5"
version = "0.17.2-beta.1"
edition.workspace = true
description = "Python bindings for LanceDB"
license.workspace = true

View File

@@ -1,9 +1,10 @@
[project]
name = "lancedb"
# version in Cargo.toml
dynamic = ["version"]
dependencies = [
"deprecation",
"pylance==0.21.0b4",
"pylance==0.21.1b1",
"tqdm>=4.27.0",
"pydantic>=1.10",
"packaging",

View File

@@ -37,7 +37,7 @@ class Table:
async def count_rows(self, filter: Optional[str]) -> int: ...
async def create_index(self, column: str, config, replace: Optional[bool]): ...
async def version(self) -> int: ...
async def checkout(self, version): ...
async def checkout(self, version: int): ...
async def checkout_latest(self): ...
async def restore(self): ...
async def list_indices(self) -> List[IndexConfig]: ...

View File

@@ -18,12 +18,12 @@ from pathlib import Path
from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Union
from lancedb.embeddings.registry import EmbeddingFunctionRegistry
from overrides import EnforceOverrides, override
from overrides import EnforceOverrides, override # type: ignore
from lancedb.common import data_to_reader, sanitize_uri, validate_schema
from lancedb.background_loop import LOOP
from ._lancedb import connect as lancedb_connect
from ._lancedb import connect as lancedb_connect # type: ignore
from .table import (
AsyncTable,
LanceTable,
@@ -503,13 +503,7 @@ class LanceDBConnection(DBConnection):
ignore_missing: bool, default False
If True, ignore if the table does not exist.
"""
try:
LOOP.run(self._conn.drop_table(name))
except ValueError as e:
if not ignore_missing:
raise e
if f"Table '{name}' was not found" not in str(e):
raise e
LOOP.run(self._conn.drop_table(name, ignore_missing=ignore_missing))
@override
def drop_database(self):
@@ -886,15 +880,23 @@ class AsyncConnection(object):
"""
await self._inner.rename_table(old_name, new_name)
async def drop_table(self, name: str):
async def drop_table(self, name: str, *, ignore_missing: bool = False):
"""Drop a table from the database.
Parameters
----------
name: str
The name of the table.
ignore_missing: bool, default False
If True, ignore if the table does not exist.
"""
await self._inner.drop_table(name)
try:
await self._inner.drop_table(name)
except ValueError as e:
if not ignore_missing:
raise e
if f"Table '{name}' was not found" not in str(e):
raise e
async def drop_database(self):
"""

View File

@@ -355,6 +355,97 @@ class HnswSq:
ef_construction: int = 300
@dataclass
class IvfFlat:
"""Describes an IVF Flat Index
This index stores raw vectors.
These vectors are grouped into partitions of similar vectors.
Each partition keeps track of a centroid which is
the average value of all vectors in the group.
Attributes
----------
distance_type: str, default "L2"
The distance metric used to train the index
This is used when training the index to calculate the IVF partitions
(vectors are grouped in partitions with similar vectors according to this
distance type) and to calculate a subvector's code during quantization.
The distance type used to train an index MUST match the distance type used
to search the index. Failure to do so will yield inaccurate results.
The following distance types are available:
"l2" - Euclidean distance. This is a very common distance metric that
accounts for both magnitude and direction when determining the distance
between vectors. L2 distance has a range of [0, ∞).
"cosine" - Cosine distance. Cosine distance is a distance metric
calculated from the cosine similarity between two vectors. Cosine
similarity is a measure of similarity between two non-zero vectors of an
inner product space. It is defined to equal the cosine of the angle
between them. Unlike L2, the cosine distance is not affected by the
magnitude of the vectors. Cosine distance has a range of [0, 2].
Note: the cosine distance is undefined when one (or both) of the vectors
are all zeros (there is no direction). These vectors are invalid and may
never be returned from a vector search.
"dot" - Dot product. Dot distance is the dot product of two vectors. Dot
distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their
L2 norm is 1), then dot distance is equivalent to the cosine distance.
"hamming" - Hamming distance. Hamming distance is a distance metric
calculated as the number of positions at which the corresponding bits are
different. Hamming distance has a range of [0, vector dimension].
num_partitions: int, default sqrt(num_rows)
The number of IVF partitions to create.
This value should generally scale with the number of rows in the dataset.
By default the number of partitions is the square root of the number of
rows.
If this value is too large then the first part of the search (picking the
right partition) will be slow. If this value is too small then the second
part of the search (searching within a partition) will be slow.
max_iterations: int, default 50
Max iteration to train kmeans.
When training an IVF PQ index we use kmeans to calculate the partitions.
This parameter controls how many iterations of kmeans to run.
Increasing this might improve the quality of the index but in most cases
these extra iterations have diminishing returns.
The default value is 50.
sample_rate: int, default 256
The rate used to calculate the number of training vectors for kmeans.
When an IVF PQ index is trained, we need to calculate partitions. These
are groups of vectors that are similar to each other. To do this we use an
algorithm called kmeans.
Running kmeans on a large dataset can be slow. To speed this up we run
kmeans on a random sample of the data. This parameter controls the size of
the sample. The total number of vectors used to train the index is
`sample_rate * num_partitions`.
Increasing this value might improve the quality of the index but in most
cases the default should be sufficient.
The default value is 256.
"""
distance_type: Literal["l2", "cosine", "dot", "hamming"] = "l2"
num_partitions: Optional[int] = None
max_iterations: int = 50
sample_rate: int = 256
@dataclass
class IvfPq:
"""Describes an IVF PQ Index
@@ -477,4 +568,14 @@ class IvfPq:
sample_rate: int = 256
__all__ = ["BTree", "IvfPq", "HnswPq", "HnswSq", "IndexConfig"]
__all__ = [
"BTree",
"IvfPq",
"IvfFlat",
"HnswPq",
"HnswSq",
"IndexConfig",
"FTS",
"Bitmap",
"LabelList",
]

View File

@@ -126,6 +126,9 @@ class Query(pydantic.BaseModel):
ef: Optional[int] = None
# Default is true. Set to false to enforce a brute force search.
use_index: bool = True
class LanceQueryBuilder(ABC):
"""An abstract query builder. Subclasses are defined for vector search,
@@ -253,6 +256,7 @@ class LanceQueryBuilder(ABC):
self._vector = None
self._text = None
self._ef = None
self._use_index = True
@deprecation.deprecated(
deprecated_in="0.3.1",
@@ -511,6 +515,7 @@ class LanceQueryBuilder(ABC):
"metric": self._metric,
"nprobes": self._nprobes,
"refine_factor": self._refine_factor,
"use_index": self._use_index,
},
prefilter=self._prefilter,
filter=self._str_query,
@@ -729,6 +734,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
offset=self._offset,
fast_search=self._fast_search,
ef=self._ef,
use_index=self._use_index,
)
result_set = self._table._execute_query(query, batch_size)
if self._reranker is not None:
@@ -802,6 +808,24 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._str_query = query_string if query_string is not None else self._str_query
return self
def bypass_vector_index(self) -> LanceVectorQueryBuilder:
"""
If this is called then any vector index is skipped
An exhaustive (flat) search will be performed. The query vector will
be compared to every vector in the table. At high scales this can be
expensive. However, this is often still useful. For example, skipping
the vector index can give you ground truth results which you can use to
calculate your recall to select an appropriate value for nprobes.
Returns
-------
LanceVectorQueryBuilder
The LanceVectorQueryBuilder object.
"""
self._use_index = False
return self
class LanceFtsQueryBuilder(LanceQueryBuilder):
"""A builder for full text search for LanceDB."""
@@ -1108,6 +1132,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._vector_query.refine_factor(self._refine_factor)
if self._ef:
self._vector_query.ef(self._ef)
if not self._use_index:
self._vector_query.bypass_vector_index()
with ThreadPoolExecutor() as executor:
fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow)
@@ -1323,6 +1349,24 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._text = text
return self
def bypass_vector_index(self) -> LanceHybridQueryBuilder:
"""
If this is called then any vector index is skipped
An exhaustive (flat) search will be performed. The query vector will
be compared to every vector in the table. At high scales this can be
expensive. However, this is often still useful. For example, skipping
the vector index can give you ground truth results which you can use to
calculate your recall to select an appropriate value for nprobes.
Returns
-------
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._use_index = False
return self
class AsyncQueryBase(object):
def __init__(self, inner: Union[LanceQuery | LanceVectorQuery]):

View File

@@ -1,24 +1,15 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from datetime import timedelta
import logging
from functools import cached_property
from typing import Dict, Iterable, List, Optional, Union, Literal
import warnings
from lancedb._lancedb import IndexConfig
from lancedb.embeddings.base import EmbeddingFunctionConfig
from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfPq, LabelList
from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfFlat, IvfPq, LabelList
from lancedb.remote.db import LOOP
import pyarrow as pa
@@ -90,7 +81,7 @@ class RemoteTable(Table):
"""to_pandas() is not yet supported on LanceDB cloud."""
return NotImplementedError("to_pandas() is not yet supported on LanceDB cloud.")
def checkout(self, version):
def checkout(self, version: int):
return LOOP.run(self._table.checkout(version))
def checkout_latest(self):
@@ -234,10 +225,12 @@ class RemoteTable(Table):
config = HnswPq(distance_type=metric)
elif index_type == "IVF_HNSW_SQ":
config = HnswSq(distance_type=metric)
elif index_type == "IVF_FLAT":
config = IvfFlat(distance_type=metric)
else:
raise ValueError(
f"Unknown vector index type: {index_type}. Valid options are"
" 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
" 'IVF_FLAT', 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
)
LOOP.run(self._table.create_index(vector_column_name, config=config))
@@ -481,16 +474,28 @@ class RemoteTable(Table):
)
def cleanup_old_versions(self, *_):
"""cleanup_old_versions() is not supported on the LanceDB cloud"""
raise NotImplementedError(
"cleanup_old_versions() is not supported on the LanceDB cloud"
"""
cleanup_old_versions() is a no-op on LanceDB Cloud.
Tables are automatically cleaned up and optimized.
"""
warnings.warn(
"cleanup_old_versions() is a no-op on LanceDB Cloud. "
"Tables are automatically cleaned up and optimized."
)
pass
def compact_files(self, *_):
"""compact_files() is not supported on the LanceDB cloud"""
raise NotImplementedError(
"compact_files() is not supported on the LanceDB cloud"
"""
compact_files() is a no-op on LanceDB Cloud.
Tables are automatically compacted and optimized.
"""
warnings.warn(
"compact_files() is a no-op on LanceDB Cloud. "
"Tables are automatically compacted and optimized."
)
pass
def optimize(
self,
@@ -498,12 +503,16 @@ class RemoteTable(Table):
cleanup_older_than: Optional[timedelta] = None,
delete_unverified: bool = False,
):
"""optimize() is not supported on the LanceDB cloud.
Indices are optimized automatically."""
raise NotImplementedError(
"optimize() is not supported on the LanceDB cloud. "
"""
optimize() is a no-op on LanceDB Cloud.
Indices are optimized automatically.
"""
warnings.warn(
"optimize() is a no-op on LanceDB Cloud. "
"Indices are optimized automatically."
)
pass
def count_rows(self, filter: Optional[str] = None) -> int:
return LOOP.run(self._table.count_rows(filter))

View File

@@ -34,7 +34,7 @@ from lance.dependencies import _check_for_hugging_face
from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from .index import BTree, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
from .merge import LanceMergeInsertBuilder
from .pydantic import LanceModel, model_to_dict
from .query import (
@@ -433,7 +433,9 @@ class Table(ABC):
accelerator: Optional[str] = None,
index_cache_size: Optional[int] = None,
*,
index_type: Literal["IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"] = "IVF_PQ",
index_type: Literal[
"IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"
] = "IVF_PQ",
num_bits: int = 8,
max_iterations: int = 50,
sample_rate: int = 256,
@@ -446,8 +448,9 @@ class Table(ABC):
----------
metric: str, default "L2"
The distance metric to use when creating the index.
Valid values are "L2", "cosine", or "dot".
Valid values are "L2", "cosine", "dot", or "hamming".
L2 is euclidean distance.
Hamming is available only for binary vectors.
num_partitions: int, default 256
The number of IVF partitions to use when creating the index.
Default is 256.
@@ -917,9 +920,6 @@ class Table(ABC):
"""
Clean up old versions of the table, freeing disk space.
Note: This function is not available in LanceDb Cloud (since LanceDb
Cloud manages cleanup for you automatically)
Parameters
----------
older_than: timedelta, default None
@@ -936,21 +936,38 @@ class Table(ABC):
CleanupStats
The stats of the cleanup operation, including how many bytes were
freed.
See Also
--------
[Table.optimize][lancedb.table.Table.optimize]: A more comprehensive
optimization operation that includes cleanup as well as other operations.
Notes
-----
This function is not available in LanceDb Cloud (since LanceDB
Cloud manages cleanup for you automatically)
"""
@abstractmethod
def compact_files(self, *args, **kwargs):
"""
Run the compaction process on the table.
Note: This function is not available in LanceDb Cloud (since LanceDb
Cloud manages compaction for you automatically)
This can be run after making several small appends to optimize the table
for faster reads.
Arguments are passed onto :meth:`lance.dataset.DatasetOptimizer.compact_files`.
Arguments are passed onto Lance's
[compact_files][lance.dataset.DatasetOptimizer.compact_files].
For most cases, the default should be fine.
See Also
--------
[Table.optimize][lancedb.table.Table.optimize]: A more comprehensive
optimization operation that includes cleanup as well as other operations.
Notes
-----
This function is not available in LanceDB Cloud (since LanceDB
Cloud manages compaction for you automatically)
"""
@abstractmethod
@@ -1075,7 +1092,7 @@ class Table(ABC):
"""
@abstractmethod
def checkout(self):
def checkout(self, version: int):
"""
Checks out a specific version of the Table
@@ -1394,7 +1411,9 @@ class LanceTable(Table):
accelerator: Optional[str] = None,
index_cache_size: Optional[int] = None,
num_bits: int = 8,
index_type: Literal["IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"] = "IVF_PQ",
index_type: Literal[
"IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"
] = "IVF_PQ",
max_iterations: int = 50,
sample_rate: int = 256,
m: int = 20,
@@ -1418,6 +1437,13 @@ class LanceTable(Table):
)
self.checkout_latest()
return
elif index_type == "IVF_FLAT":
config = IvfFlat(
distance_type=metric,
num_partitions=num_partitions,
max_iterations=max_iterations,
sample_rate=sample_rate,
)
elif index_type == "IVF_PQ":
config = IvfPq(
distance_type=metric,
@@ -2605,7 +2631,7 @@ class AsyncTable:
*,
replace: Optional[bool] = None,
config: Optional[
Union[IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS]
Union[IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS]
] = None,
):
"""Create an index to speed up queries
@@ -2634,7 +2660,7 @@ class AsyncTable:
"""
if config is not None:
if not isinstance(
config, (IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS)
config, (IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS)
):
raise TypeError(
"config must be an instance of IvfPq, HnswPq, HnswSq, BTree,"
@@ -2798,6 +2824,8 @@ class AsyncTable:
async_query = async_query.column(query.vector_column)
if query.ef:
async_query = async_query.ef(query.ef)
if not query.use_index:
async_query = async_query.bypass_vector_index()
if not query.prefilter:
async_query = async_query.postfilter()
@@ -3021,7 +3049,7 @@ class AsyncTable:
return versions
async def checkout(self, version):
async def checkout(self, version: int):
"""
Checks out a specific version of the Table

View File

@@ -0,0 +1,44 @@
import shutil
# --8<-- [start:imports]
import lancedb
import numpy as np
import pytest
# --8<-- [end:imports]
shutil.rmtree("data/binary_lancedb", ignore_errors=True)
def test_binary_vector():
# --8<-- [start:sync_binary_vector]
db = lancedb.connect("data/binary_lancedb")
data = [
{
"id": i,
"vector": np.random.randint(0, 256, size=16),
}
for i in range(1024)
]
tbl = db.create_table("my_binary_vectors", data=data)
query = np.random.randint(0, 256, size=16)
tbl.search(query).to_arrow()
# --8<-- [end:sync_binary_vector]
db.drop_table("my_binary_vectors")
@pytest.mark.asyncio
async def test_binary_vector_async():
# --8<-- [start:async_binary_vector]
db = await lancedb.connect_async("data/binary_lancedb")
data = [
{
"id": i,
"vector": np.random.randint(0, 256, size=16),
}
for i in range(1024)
]
tbl = await db.create_table("my_binary_vectors", data=data)
query = np.random.randint(0, 256, size=16)
await tbl.query().nearest_to(query).to_arrow()
# --8<-- [end:async_binary_vector]
await db.drop_table("my_binary_vectors")

View File

@@ -508,6 +508,32 @@ def test_delete_table(tmp_db: lancedb.DBConnection):
tmp_db.drop_table("does_not_exist", ignore_missing=True)
@pytest.mark.asyncio
async def test_delete_table_async(tmp_db: lancedb.DBConnection):
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
"item": ["foo", "bar"],
"price": [10.0, 20.0],
}
)
tmp_db.create_table("test", data=data)
with pytest.raises(Exception):
tmp_db.create_table("test", data=data)
assert tmp_db.table_names() == ["test"]
tmp_db.drop_table("test")
assert tmp_db.table_names() == []
tmp_db.create_table("test", data=data)
assert tmp_db.table_names() == ["test"]
tmp_db.drop_table("does_not_exist", ignore_missing=True)
def test_drop_database(tmp_db: lancedb.DBConnection):
data = pd.DataFrame(
{
@@ -681,3 +707,25 @@ def test_create_table_with_invalid_names(tmp_db: lancedb.DBConnection):
with pytest.raises(ValueError):
tmp_db.create_table("foo$$bar", data)
tmp_db.create_table("foo.bar", data)
def test_bypass_vector_index_sync(tmp_db: lancedb.DBConnection):
data = [{"vector": np.random.rand(32)} for _ in range(512)]
sample_key = data[100]["vector"]
table = tmp_db.create_table(
"test",
data,
)
table.create_index(
num_partitions=2,
num_sub_vectors=2,
)
plan_with_index = table.search(sample_key).explain_plan(verbose=True)
assert "ANN" in plan_with_index
plan_without_index = (
table.search(sample_key).bypass_vector_index().explain_plan(verbose=True)
)
assert "KNN" in plan_without_index

View File

@@ -8,7 +8,7 @@ import pyarrow as pa
import pytest
import pytest_asyncio
from lancedb import AsyncConnection, AsyncTable, connect_async
from lancedb.index import BTree, IvfPq, Bitmap, LabelList, HnswPq, HnswSq
from lancedb.index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq
@pytest_asyncio.fixture
@@ -42,6 +42,27 @@ async def some_table(db_async):
)
@pytest_asyncio.fixture
async def binary_table(db_async):
data = [
{
"id": i,
"vector": [i] * 128,
}
for i in range(NROWS)
]
return await db_async.create_table(
"binary_table",
data,
schema=pa.schema(
[
pa.field("id", pa.int64()),
pa.field("vector", pa.list_(pa.uint8(), 128)),
]
),
)
@pytest.mark.asyncio
async def test_create_scalar_index(some_table: AsyncTable):
# Can create
@@ -143,3 +164,27 @@ async def test_create_hnswsq_index(some_table: AsyncTable):
await some_table.create_index("vector", config=HnswSq(num_partitions=10))
indices = await some_table.list_indices()
assert len(indices) == 1
@pytest.mark.asyncio
async def test_create_index_with_binary_vectors(binary_table: AsyncTable):
await binary_table.create_index(
"vector", config=IvfFlat(distance_type="hamming", num_partitions=10)
)
indices = await binary_table.list_indices()
assert len(indices) == 1
assert indices[0].index_type == "IvfFlat"
assert indices[0].columns == ["vector"]
assert indices[0].name == "vector_idx"
stats = await binary_table.index_stats("vector_idx")
assert stats.index_type == "IVF_FLAT"
assert stats.distance_type == "hamming"
assert stats.num_indexed_rows == await binary_table.count_rows()
assert stats.num_unindexed_rows == 0
assert stats.num_indices == 1
# the dataset contains vectors with all values from 0 to 255
for v in range(256):
res = await binary_table.query().nearest_to([v] * 128).to_arrow()
assert res["id"][0].as_py() == v

View File

@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use lancedb::index::vector::IvfFlatIndexBuilder;
use lancedb::index::{
scalar::{BTreeIndexBuilder, FtsIndexBuilder, TokenizerConfig},
vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
@@ -59,6 +60,18 @@ pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<Lance
opts.tokenizer_configs = inner_opts;
Ok(LanceDbIndex::FTS(opts))
},
"IvfFlat" => {
let params = source.extract::<IvfFlatParams>()?;
let distance_type = parse_distance_type(params.distance_type)?;
let mut ivf_flat_builder = IvfFlatIndexBuilder::default()
.distance_type(distance_type)
.max_iterations(params.max_iterations)
.sample_rate(params.sample_rate);
if let Some(num_partitions) = params.num_partitions {
ivf_flat_builder = ivf_flat_builder.num_partitions(num_partitions);
}
Ok(LanceDbIndex::IvfFlat(ivf_flat_builder))
},
"IvfPq" => {
let params = source.extract::<IvfPqParams>()?;
let distance_type = parse_distance_type(params.distance_type)?;
@@ -129,6 +142,14 @@ struct FtsParams {
ascii_folding: bool,
}
#[derive(FromPyObject)]
struct IvfFlatParams {
distance_type: String,
num_partitions: Option<u32>,
max_iterations: u32,
sample_rate: u32,
}
#[derive(FromPyObject)]
struct IvfPqParams {
distance_type: String,

View File

@@ -43,8 +43,9 @@ pub fn parse_distance_type(distance_type: impl AsRef<str>) -> PyResult<DistanceT
"l2" => Ok(DistanceType::L2),
"cosine" => Ok(DistanceType::Cosine),
"dot" => Ok(DistanceType::Dot),
"hamming" => Ok(DistanceType::Hamming),
_ => Err(PyValueError::new_err(format!(
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
"Invalid distance type '{}'. Must be one of l2, cosine, dot, or hamming",
distance_type.as_ref()
))),
}

View File

@@ -1,2 +1,2 @@
[toolchain]
channel = "1.80.0"
channel = "1.83.0"

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-node"
version = "0.14.1-beta.4"
version = "0.14.1"
description = "Serverless, low-latency vector database for AI applications"
license.workspace = true
edition.workspace = true

View File

@@ -1,13 +1,13 @@
[package]
name = "lancedb"
version = "0.14.1-beta.4"
version = "0.14.1"
edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
rust-version = "1.75"
rust-version.workspace = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]

View File

@@ -17,6 +17,7 @@ use std::sync::Arc;
use scalar::FtsIndexBuilder;
use serde::Deserialize;
use serde_with::skip_serializing_none;
use vector::IvfFlatIndexBuilder;
use crate::{table::TableInternal, DistanceType, Error, Result};
@@ -56,6 +57,9 @@ pub enum Index {
/// Full text search index using bm25.
FTS(FtsIndexBuilder),
/// IVF index
IvfFlat(IvfFlatIndexBuilder),
/// IVF index with Product Quantization
IvfPq(IvfPqIndexBuilder),
@@ -106,6 +110,8 @@ impl IndexBuilder {
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub enum IndexType {
// Vector
#[serde(alias = "IVF_FLAT")]
IvfFlat,
#[serde(alias = "IVF_PQ")]
IvfPq,
#[serde(alias = "IVF_HNSW_PQ")]
@@ -127,6 +133,7 @@ pub enum IndexType {
impl std::fmt::Display for IndexType {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::IvfFlat => write!(f, "IVF_FLAT"),
Self::IvfPq => write!(f, "IVF_PQ"),
Self::IvfHnswPq => write!(f, "IVF_HNSW_PQ"),
Self::IvfHnswSq => write!(f, "IVF_HNSW_SQ"),
@@ -147,6 +154,7 @@ impl std::str::FromStr for IndexType {
"BITMAP" => Ok(Self::Bitmap),
"LABEL_LIST" | "LABELLIST" => Ok(Self::LabelList),
"FTS" | "INVERTED" => Ok(Self::FTS),
"IVF_FLAT" => Ok(Self::IvfFlat),
"IVF_PQ" => Ok(Self::IvfPq),
"IVF_HNSW_PQ" => Ok(Self::IvfHnswPq),
"IVF_HNSW_SQ" => Ok(Self::IvfHnswSq),

View File

@@ -162,6 +162,43 @@ macro_rules! impl_hnsw_params_setter {
};
}
/// Builder for an IVF Flat index.
///
/// This index stores raw vectors. These vectors are grouped into partitions of similar vectors.
/// Each partition keeps track of a centroid which is the average value of all vectors in the group.
///
/// During a query the centroids are compared with the query vector to find the closest partitions.
/// The raw vectors in these partitions are then searched to find the closest vectors.
///
/// The partitioning process is called IVF and the `num_partitions` parameter controls how many groups to create.
///
/// Note that training an IVF Flat index on a large dataset is a slow operation and currently is also a memory intensive operation.
#[derive(Debug, Clone)]
pub struct IvfFlatIndexBuilder {
pub(crate) distance_type: DistanceType,
// IVF
pub(crate) num_partitions: Option<u32>,
pub(crate) sample_rate: u32,
pub(crate) max_iterations: u32,
}
impl Default for IvfFlatIndexBuilder {
fn default() -> Self {
Self {
distance_type: DistanceType::L2,
num_partitions: None,
sample_rate: 256,
max_iterations: 50,
}
}
}
impl IvfFlatIndexBuilder {
impl_distance_type_setter!();
impl_ivf_params_setter!();
}
/// Builder for an IVF PQ index.
///
/// This index stores a compressed (quantized) copy of every vector. These vectors

View File

@@ -214,6 +214,7 @@ mod polars_arrow_convertors;
pub mod query;
#[cfg(feature = "remote")]
pub mod remote;
pub mod rerankers;
pub mod table;
pub mod utils;

View File

@@ -15,19 +15,31 @@
use std::future::Future;
use std::sync::Arc;
use arrow::compute::concat_batches;
use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array};
use arrow_schema::DataType;
use datafusion_physical_plan::ExecutionPlan;
use futures::{stream, try_join, FutureExt, TryStreamExt};
use half::f16;
use lance::dataset::scanner::DatasetRecordBatchStream;
use lance::{
arrow::RecordBatchExt,
dataset::{scanner::DatasetRecordBatchStream, ROW_ID},
};
use lance_datafusion::exec::execute_plan;
use lance_index::scalar::inverted::SCORE_COL;
use lance_index::scalar::FullTextSearchQuery;
use lance_index::vector::DIST_COL;
use lance_io::stream::RecordBatchStreamAdapter;
use crate::arrow::SendableRecordBatchStream;
use crate::error::{Error, Result};
use crate::rerankers::rrf::RRFReranker;
use crate::rerankers::{check_reranker_result, NormalizeMethod, Reranker};
use crate::table::TableInternal;
use crate::DistanceType;
mod hybrid;
pub(crate) const DEFAULT_TOP_K: usize = 10;
/// Which columns should be retrieved from the database
@@ -339,7 +351,7 @@ pub trait QueryBase {
fn limit(self, limit: usize) -> Self;
/// Set the offset of the query.
///
/// By default, it fetches starting with the first row.
/// This method can be used to skip the first `offset` rows.
fn offset(self, offset: usize) -> Self;
@@ -435,6 +447,16 @@ pub trait QueryBase {
/// Return the `_rowid` meta column from the Table.
fn with_row_id(self) -> Self;
/// Rerank the results using the specified reranker.
///
/// This is currently only supported for Hybrid Search.
fn rerank(self, reranker: Arc<dyn Reranker>) -> Self;
/// The method to normalize the scores. Can be "rank" or "Score". If "Rank",
/// the scores are converted to ranks and then normalized. If "Score", the
/// scores are normalized directly.
fn norm(self, norm: NormalizeMethod) -> Self;
}
pub trait HasQuery {
@@ -481,6 +503,16 @@ impl<T: HasQuery> QueryBase for T {
self.mut_query().with_row_id = true;
self
}
fn rerank(mut self, reranker: Arc<dyn Reranker>) -> Self {
self.mut_query().reranker = Some(reranker);
self
}
fn norm(mut self, norm: NormalizeMethod) -> Self {
self.mut_query().norm = Some(norm);
self
}
}
/// Options for controlling the execution of a query
@@ -600,6 +632,13 @@ pub struct Query {
/// If set to false, the filter will be applied after the vector search.
pub(crate) prefilter: bool,
/// Implementation of reranker that can be used to reorder or combine query
/// results, especially if using hybrid search
pub(crate) reranker: Option<Arc<dyn Reranker>>,
/// Configure how query results are normalized when doing hybrid search
pub(crate) norm: Option<NormalizeMethod>,
}
impl Query {
@@ -614,6 +653,8 @@ impl Query {
fast_search: false,
with_row_id: false,
prefilter: true,
reranker: None,
norm: None,
}
}
@@ -862,6 +903,65 @@ impl VectorQuery {
self.use_index = false;
self
}
pub async fn execute_hybrid(&self) -> Result<SendableRecordBatchStream> {
// clone query and specify we want to include row IDs, which can be needed for reranking
let fts_query = self.base.clone().with_row_id();
let mut vector_query = self.clone().with_row_id();
vector_query.base.full_text_search = None;
let (fts_results, vec_results) = try_join!(fts_query.execute(), vector_query.execute())?;
let (fts_results, vec_results) = try_join!(
fts_results.try_collect::<Vec<_>>(),
vec_results.try_collect::<Vec<_>>()
)?;
// try to get the schema to use when combining batches.
// if either
let (fts_schema, vec_schema) = hybrid::query_schemas(&fts_results, &vec_results);
// concatenate all the batches together
let mut fts_results = concat_batches(&fts_schema, fts_results.iter())?;
let mut vec_results = concat_batches(&vec_schema, vec_results.iter())?;
if matches!(self.base.norm, Some(NormalizeMethod::Rank)) {
vec_results = hybrid::rank(vec_results, DIST_COL, None)?;
fts_results = hybrid::rank(fts_results, SCORE_COL, None)?;
}
vec_results = hybrid::normalize_scores(vec_results, DIST_COL, None)?;
fts_results = hybrid::normalize_scores(fts_results, SCORE_COL, None)?;
let reranker = self
.base
.reranker
.clone()
.unwrap_or(Arc::new(RRFReranker::default()));
let fts_query = self.base.full_text_search.as_ref().ok_or(Error::Runtime {
message: "there should be an FTS search".to_string(),
})?;
let mut results = reranker
.rerank_hybrid(&fts_query.query, vec_results, fts_results)
.await?;
check_reranker_result(&results)?;
let limit = self.base.limit.unwrap_or(DEFAULT_TOP_K);
if results.num_rows() > limit {
results = results.slice(0, limit);
}
if !self.base.with_row_id {
results = results.drop_column(ROW_ID)?;
}
Ok(SendableRecordBatchStream::from(
RecordBatchStreamAdapter::new(results.schema(), stream::iter([Ok(results)])),
))
}
}
impl ExecutableQuery for VectorQuery {
@@ -873,6 +973,11 @@ impl ExecutableQuery for VectorQuery {
&self,
options: QueryExecutionOptions,
) -> Result<SendableRecordBatchStream> {
if self.base.full_text_search.is_some() {
let hybrid_result = async move { self.execute_hybrid().await }.boxed().await?;
return Ok(hybrid_result);
}
Ok(SendableRecordBatchStream::from(
DatasetRecordBatchStream::new(execute_plan(
self.create_plan(options).await?,
@@ -894,20 +999,20 @@ impl HasQuery for VectorQuery {
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::{collections::HashSet, sync::Arc};
use super::*;
use arrow::{compute::concat_batches, datatypes::Int32Type};
use arrow::{array::downcast_array, compute::concat_batches, datatypes::Int32Type};
use arrow_array::{
cast::AsArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator,
RecordBatchReader,
cast::AsArray, types::Float32Type, FixedSizeListArray, Float32Array, Int32Array,
RecordBatch, RecordBatchIterator, RecordBatchReader, StringArray,
};
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
use futures::{StreamExt, TryStreamExt};
use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector};
use tempfile::tempdir;
use crate::{connect, Table};
use crate::{connect, connection::CreateTableMode, Table};
#[tokio::test]
async fn test_setters_getters() {
@@ -1274,4 +1379,156 @@ mod tests {
assert!(query_index.values().contains(&0));
assert!(query_index.values().contains(&1));
}
#[tokio::test]
async fn test_hybrid_search() {
let tmp_dir = tempdir().unwrap();
let dataset_path = tmp_dir.path();
let conn = connect(dataset_path.to_str().unwrap())
.execute()
.await
.unwrap();
let dims = 2;
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("text", DataType::Utf8, false),
ArrowField::new(
"vector",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float32, true)),
dims,
),
false,
),
]));
let text = StringArray::from(vec!["dog", "cat", "a", "b"]);
let vectors = vec![
Some(vec![Some(0.0), Some(0.0)]),
Some(vec![Some(-2.0), Some(-2.0)]),
Some(vec![Some(50.0), Some(50.0)]),
Some(vec![Some(-30.0), Some(-30.0)]),
];
let vector = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(vectors, dims);
let record_batch =
RecordBatch::try_new(schema.clone(), vec![Arc::new(text), Arc::new(vector)]).unwrap();
let record_batch_iter =
RecordBatchIterator::new(vec![record_batch].into_iter().map(Ok), schema.clone());
let table = conn
.create_table("my_table", record_batch_iter)
.execute()
.await
.unwrap();
table
.create_index(&["text"], crate::index::Index::FTS(Default::default()))
.replace(true)
.execute()
.await
.unwrap();
let fts_query = FullTextSearchQuery::new("b".to_string());
let results = table
.query()
.full_text_search(fts_query)
.limit(2)
.nearest_to(&[-10.0, -10.0])
.unwrap()
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch = &results[0];
let texts: StringArray = downcast_array(batch.column_by_name("text").unwrap());
let texts = texts.iter().map(|e| e.unwrap()).collect::<HashSet<_>>();
assert!(texts.contains("cat")); // should be close by vector search
assert!(texts.contains("b")); // should be close by fts search
// ensure that this works correctly if there are no matching FTS results
let fts_query = FullTextSearchQuery::new("z".to_string());
table
.query()
.full_text_search(fts_query)
.limit(2)
.nearest_to(&[-10.0, -10.0])
.unwrap()
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
}
#[tokio::test]
async fn test_hybrid_search_empty_table() {
let tmp_dir = tempdir().unwrap();
let dataset_path = tmp_dir.path();
let conn = connect(dataset_path.to_str().unwrap())
.execute()
.await
.unwrap();
let dims = 2;
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("text", DataType::Utf8, false),
ArrowField::new(
"vector",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float32, true)),
dims,
),
false,
),
]));
// ensure hybrid search is also supported on a fully empty table
let vectors: Vec<Option<Vec<Option<f32>>>> = Vec::new();
let record_batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(Vec::<&str>::new())),
Arc::new(
FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(vectors, dims),
),
],
)
.unwrap();
let record_batch_iter =
RecordBatchIterator::new(vec![record_batch].into_iter().map(Ok), schema.clone());
let table = conn
.create_table("my_table", record_batch_iter)
.mode(CreateTableMode::Overwrite)
.execute()
.await
.unwrap();
table
.create_index(&["text"], crate::index::Index::FTS(Default::default()))
.replace(true)
.execute()
.await
.unwrap();
let fts_query = FullTextSearchQuery::new("b".to_string());
let results = table
.query()
.full_text_search(fts_query)
.limit(2)
.nearest_to(&[-10.0, -10.0])
.unwrap()
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch = &results[0];
assert_eq!(0, batch.num_rows());
assert_eq!(2, batch.num_columns());
}
}

View File

@@ -0,0 +1,346 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use arrow::compute::{
kernels::numeric::{div, sub},
max, min,
};
use arrow_array::{cast::downcast_array, Float32Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema, SortOptions};
use lance::dataset::ROW_ID;
use lance_index::{scalar::inverted::SCORE_COL, vector::DIST_COL};
use std::sync::Arc;
use crate::error::{Error, Result};
/// Converts results's score column to a rank.
///
/// Expects the `column` argument to be type Float32 and will panic if it's not
pub fn rank(results: RecordBatch, column: &str, ascending: Option<bool>) -> Result<RecordBatch> {
let scores = results.column_by_name(column).ok_or(Error::InvalidInput {
message: format!(
"expected column {} not found in rank. found columns {:?}",
column,
results
.schema()
.fields()
.iter()
.map(|f| f.name())
.collect::<Vec<_>>(),
),
})?;
if results.num_rows() == 0 {
return Ok(results);
}
let scores: Float32Array = downcast_array(scores);
let ranks = Float32Array::from_iter_values(
arrow::compute::kernels::rank::rank(
&scores,
Some(SortOptions {
descending: !ascending.unwrap_or(true),
..Default::default()
}),
)?
.iter()
.map(|i| *i as f32),
);
let schema = results.schema();
let (column_idx, _) = schema.column_with_name(column).unwrap();
let mut columns = results.columns().to_vec();
columns[column_idx] = Arc::new(ranks);
let results = RecordBatch::try_new(results.schema(), columns)?;
Ok(results)
}
/// Get the query schemas needed when combining the search results.
///
/// If either of the record batches are empty, then we create a schema from the
/// other record batch, and replace the score/distance column. If both record
/// batches are empty, create empty schemas.
pub fn query_schemas(
fts_results: &[RecordBatch],
vec_results: &[RecordBatch],
) -> (Arc<Schema>, Arc<Schema>) {
let (fts_schema, vec_schema) = match (
fts_results.first().map(|r| r.schema()),
vec_results.first().map(|r| r.schema()),
) {
(Some(fts_schema), Some(vec_schema)) => (fts_schema, vec_schema),
(None, Some(vec_schema)) => {
let fts_schema = with_field_name_replaced(&vec_schema, DIST_COL, SCORE_COL);
(Arc::new(fts_schema), vec_schema)
}
(Some(fts_schema), None) => {
let vec_schema = with_field_name_replaced(&fts_schema, DIST_COL, SCORE_COL);
(fts_schema, Arc::new(vec_schema))
}
(None, None) => (Arc::new(empty_fts_schema()), Arc::new(empty_vec_schema())),
};
(fts_schema, vec_schema)
}
pub fn empty_fts_schema() -> Schema {
Schema::new(vec![
Arc::new(Field::new(SCORE_COL, DataType::Float32, false)),
Arc::new(Field::new(ROW_ID, DataType::UInt64, false)),
])
}
pub fn empty_vec_schema() -> Schema {
Schema::new(vec![
Arc::new(Field::new(DIST_COL, DataType::Float32, false)),
Arc::new(Field::new(ROW_ID, DataType::UInt64, false)),
])
}
pub fn with_field_name_replaced(schema: &Schema, target: &str, replacement: &str) -> Schema {
let field_idx = schema.fields().iter().enumerate().find_map(|(i, field)| {
if field.name() == target {
Some(i)
} else {
None
}
});
let mut fields = schema.fields().to_vec();
if let Some(idx) = field_idx {
let new_field = (*fields[idx]).clone().with_name(replacement);
fields[idx] = Arc::new(new_field);
}
Schema::new(fields)
}
/// Normalize the scores column to have values between 0 and 1.
///
/// Expects the `column` argument to be type Float32 and will panic if it's not
pub fn normalize_scores(
results: RecordBatch,
column: &str,
invert: Option<bool>,
) -> Result<RecordBatch> {
let scores = results.column_by_name(column).ok_or(Error::InvalidInput {
message: format!(
"expected column {} not found in rank. found columns {:?}",
column,
results
.schema()
.fields()
.iter()
.map(|f| f.name())
.collect::<Vec<_>>(),
),
})?;
if results.num_rows() == 0 {
return Ok(results);
}
let mut scores: Float32Array = downcast_array(scores);
let max = max(&scores).unwrap_or(0.0);
let min = min(&scores).unwrap_or(0.0);
// this is equivalent to np.isclose which is used in python
let rng = if max - min < 10e-5 { max } else { max - min };
// if rng is 0, then min and max are both 0 so we just leave the scores as is
if rng != 0.0 {
let tmp = div(
&sub(&scores, &Float32Array::new_scalar(min))?,
&Float32Array::new_scalar(rng),
)?;
scores = downcast_array(&tmp);
}
if invert.unwrap_or(false) {
let tmp = sub(&Float32Array::new_scalar(1.0), &scores)?;
scores = downcast_array(&tmp);
}
let schema = results.schema();
let (column_idx, _) = schema.column_with_name(column).unwrap();
let mut columns = results.columns().to_vec();
columns[column_idx] = Arc::new(scores);
let results = RecordBatch::try_new(results.schema(), columns).unwrap();
Ok(results)
}
#[cfg(test)]
mod test {
use super::*;
use arrow_array::StringArray;
use arrow_schema::{DataType, Field, Schema};
#[test]
fn test_rank() {
let schema = Arc::new(Schema::new(vec![
Arc::new(Field::new("name", DataType::Utf8, false)),
Arc::new(Field::new("score", DataType::Float32, false)),
]));
let names = StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"]);
let scores = Float32Array::from(vec![0.2, 0.4, 0.1, 0.6, 0.45]);
let batch =
RecordBatch::try_new(schema.clone(), vec![Arc::new(names), Arc::new(scores)]).unwrap();
let result = rank(batch.clone(), "score", Some(false)).unwrap();
assert_eq!(2, result.schema().fields().len());
assert_eq!("name", result.schema().field(0).name());
assert_eq!("score", result.schema().field(1).name());
let names: StringArray = downcast_array(result.column(0));
assert_eq!(
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec!["foo", "bar", "baz", "bean", "dog"]
);
let scores: Float32Array = downcast_array(result.column(1));
assert_eq!(
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec![4.0, 3.0, 5.0, 1.0, 2.0]
);
// check sort ascending
let result = rank(batch.clone(), "score", Some(true)).unwrap();
let names: StringArray = downcast_array(result.column(0));
assert_eq!(
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec!["foo", "bar", "baz", "bean", "dog"]
);
let scores: Float32Array = downcast_array(result.column(1));
assert_eq!(
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec![2.0, 3.0, 1.0, 5.0, 4.0]
);
// ensure default sort is ascending
let result = rank(batch.clone(), "score", None).unwrap();
let names: StringArray = downcast_array(result.column(0));
assert_eq!(
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec!["foo", "bar", "baz", "bean", "dog"]
);
let scores: Float32Array = downcast_array(result.column(1));
assert_eq!(
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec![2.0, 3.0, 1.0, 5.0, 4.0]
);
// check it can handle an empty batch
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(Vec::<&str>::new())),
Arc::new(Float32Array::from(Vec::<f32>::new())),
],
)
.unwrap();
let result = rank(batch.clone(), "score", None).unwrap();
assert_eq!(0, result.num_rows());
assert_eq!(2, result.schema().fields().len());
assert_eq!("name", result.schema().field(0).name());
assert_eq!("score", result.schema().field(1).name());
// check it returns the expected error when there's no column
let result = rank(batch.clone(), "bad_col", None);
match result {
Err(Error::InvalidInput { message }) => {
assert_eq!("expected column bad_col not found in rank. found columns [\"name\", \"score\"]", message);
}
_ => {
panic!("expected invalid input error, received {:?}", result)
}
}
}
#[test]
fn test_normalize_scores() {
let schema = Arc::new(Schema::new(vec![
Arc::new(Field::new("name", DataType::Utf8, false)),
Arc::new(Field::new("score", DataType::Float32, false)),
]));
let names = Arc::new(StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"]));
let scores = Arc::new(Float32Array::from(vec![-4.0, 2.0, 0.0, 3.0, 6.0]));
let batch =
RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap();
let result = normalize_scores(batch.clone(), "score", Some(false)).unwrap();
let names: StringArray = downcast_array(result.column(0));
assert_eq!(
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec!["foo", "bar", "baz", "bean", "dog"]
);
let scores: Float32Array = downcast_array(result.column(1));
assert_eq!(
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec![0.0, 0.6, 0.4, 0.7, 1.0]
);
// check it can invert the normalization
let result = normalize_scores(batch.clone(), "score", Some(true)).unwrap();
let scores: Float32Array = downcast_array(result.column(1));
assert_eq!(
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec![1.0, 1.0 - 0.6, 0.6, 0.3, 0.0]
);
// check that the default is not inverted
let result = normalize_scores(batch.clone(), "score", None).unwrap();
let scores: Float32Array = downcast_array(result.column(1));
assert_eq!(
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec![0.0, 0.6, 0.4, 0.7, 1.0]
);
// check that it will function correctly if all the values are the same
let names = Arc::new(StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"]));
let scores = Arc::new(Float32Array::from(vec![2.1, 2.1, 2.1, 2.1, 2.1]));
let batch =
RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap();
let result = normalize_scores(batch.clone(), "score", None).unwrap();
let scores: Float32Array = downcast_array(result.column(1));
assert_eq!(
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec![0.0, 0.0, 0.0, 0.0, 0.0]
);
// check it keeps floating point rounding errors for same score normalized the same
// e.g., the behaviour is consistent with python
let scores = Arc::new(Float32Array::from(vec![1.0, 1.0, 1.0, 1.0, 0.9999999]));
let batch =
RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap();
let result = normalize_scores(batch.clone(), "score", None).unwrap();
let scores: Float32Array = downcast_array(result.column(1));
assert_eq!(
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec![
1.0 - 0.9999999,
1.0 - 0.9999999,
1.0 - 0.9999999,
1.0 - 0.9999999,
0.0
]
);
// check that it can handle if all the scores are 0
let scores = Arc::new(Float32Array::from(vec![0.0, 0.0, 0.0, 0.0, 0.0]));
let batch =
RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap();
let result = normalize_scores(batch.clone(), "score", None).unwrap();
let scores: Float32Array = downcast_array(result.column(1));
assert_eq!(
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec![0.0, 0.0, 0.0, 0.0, 0.0]
);
}
}

View File

@@ -563,6 +563,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
let (index_type, distance_type) = match index.index {
// TODO: Should we pass the actual index parameters? SaaS does not
// yet support them.
Index::IvfFlat(index) => ("IVF_FLAT", Some(index.distance_type)),
Index::IvfPq(index) => ("IVF_PQ", Some(index.distance_type)),
Index::IvfHnswSq(index) => ("IVF_HNSW_SQ", Some(index.distance_type)),
Index::BTree(_) => ("BTREE", None),
@@ -873,6 +874,7 @@ mod tests {
use lance_index::scalar::FullTextSearchQuery;
use reqwest::Body;
use crate::index::vector::IvfFlatIndexBuilder;
use crate::{
index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType},
query::{ExecutableQuery, QueryBase},
@@ -1489,6 +1491,11 @@ mod tests {
#[tokio::test]
async fn test_create_index() {
let cases = [
(
"IVF_FLAT",
Some("hamming"),
Index::IvfFlat(IvfFlatIndexBuilder::default().distance_type(DistanceType::Hamming)),
),
("IVF_PQ", Some("l2"), Index::IvfPq(Default::default())),
(
"IVF_PQ",

View File

@@ -0,0 +1,87 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::collections::BTreeSet;
use arrow::{
array::downcast_array,
compute::{concat_batches, filter_record_batch},
};
use arrow_array::{BooleanArray, RecordBatch, UInt64Array};
use async_trait::async_trait;
use lance::dataset::ROW_ID;
use crate::error::{Error, Result};
pub mod rrf;
/// column name for reranker relevance score
const RELEVANCE_SCORE: &str = "_relevance_score";
#[derive(Debug, Clone, PartialEq)]
pub enum NormalizeMethod {
Score,
Rank,
}
/// Interface for a reranker. A reranker is used to rerank the results from a
/// vector and FTS search. This is useful for combining the results from both
/// search methods.
#[async_trait]
pub trait Reranker: std::fmt::Debug + Sync + Send {
// TODO support vector reranking and FTS reranking. Currently only hybrid reranking is supported.
/// Rerank function receives the individual results from the vector and FTS search
/// results. You can choose to use any of the results to generate the final results,
/// allowing maximum flexibility.
async fn rerank_hybrid(
&self,
query: &str,
vector_results: RecordBatch,
fts_results: RecordBatch,
) -> Result<RecordBatch>;
fn merge_results(
&self,
vector_results: RecordBatch,
fts_results: RecordBatch,
) -> Result<RecordBatch> {
let combined = concat_batches(&fts_results.schema(), [vector_results, fts_results].iter())?;
let mut mask = BooleanArray::builder(combined.num_rows());
let mut unique_ids = BTreeSet::new();
let row_ids = combined.column_by_name(ROW_ID).ok_or(Error::InvalidInput {
message: format!(
"could not find expected column {} while merging results. found columns {:?}",
ROW_ID,
combined
.schema()
.fields()
.iter()
.map(|f| f.name())
.collect::<Vec<_>>()
),
})?;
let row_ids: UInt64Array = downcast_array(row_ids);
row_ids.values().iter().for_each(|id| {
mask.append_value(unique_ids.insert(id));
});
let combined = filter_record_batch(&combined, &mask.finish())?;
Ok(combined)
}
}
pub fn check_reranker_result(result: &RecordBatch) -> Result<()> {
if result.schema().column_with_name(RELEVANCE_SCORE).is_none() {
return Err(Error::Schema {
message: format!(
"rerank_hybrid must return a RecordBatch with a column named {}",
RELEVANCE_SCORE
),
});
}
Ok(())
}

View File

@@ -0,0 +1,223 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::collections::BTreeMap;
use std::sync::Arc;
use arrow::{
array::downcast_array,
compute::{sort_to_indices, take},
};
use arrow_array::{Float32Array, RecordBatch, UInt64Array};
use arrow_schema::{DataType, Field, Schema, SortOptions};
use async_trait::async_trait;
use lance::dataset::ROW_ID;
use crate::error::{Error, Result};
use crate::rerankers::{Reranker, RELEVANCE_SCORE};
/// Reranks the results using Reciprocal Rank Fusion(RRF) algorithm based
/// on the scores of vector and FTS search.
///
#[derive(Debug)]
pub struct RRFReranker {
k: f32,
}
impl RRFReranker {
/// Create a new RRFReranker
///
/// The parameter k is a constant used in the RRF formula (default is 60).
/// Experiments indicate that k = 60 was near-optimal, but that the choice
/// is not critical. See paper:
/// https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf
pub fn new(k: f32) -> Self {
Self { k }
}
}
impl Default for RRFReranker {
fn default() -> Self {
Self { k: 60.0 }
}
}
#[async_trait]
impl Reranker for RRFReranker {
async fn rerank_hybrid(
&self,
_query: &str,
vector_results: RecordBatch,
fts_results: RecordBatch,
) -> Result<RecordBatch> {
let vector_ids = vector_results
.column_by_name(ROW_ID)
.ok_or(Error::InvalidInput {
message: format!(
"expected column {} not found in vector_results. found columns {:?}",
ROW_ID,
vector_results
.schema()
.fields()
.iter()
.map(|f| f.name())
.collect::<Vec<_>>()
),
})?;
let fts_ids = fts_results
.column_by_name(ROW_ID)
.ok_or(Error::InvalidInput {
message: format!(
"expected column {} not found in fts_results. found columns {:?}",
ROW_ID,
fts_results
.schema()
.fields()
.iter()
.map(|f| f.name())
.collect::<Vec<_>>()
),
})?;
let vector_ids: UInt64Array = downcast_array(&vector_ids);
let fts_ids: UInt64Array = downcast_array(&fts_ids);
let mut rrf_score_map = BTreeMap::new();
let mut update_score_map = |(i, result_id)| {
let score = 1.0 / (i as f32 + self.k);
rrf_score_map
.entry(result_id)
.and_modify(|e| *e += score)
.or_insert(score);
};
vector_ids
.values()
.iter()
.enumerate()
.for_each(&mut update_score_map);
fts_ids
.values()
.iter()
.enumerate()
.for_each(&mut update_score_map);
let combined_results = self.merge_results(vector_results, fts_results)?;
let combined_row_ids: UInt64Array =
downcast_array(combined_results.column_by_name(ROW_ID).unwrap());
let relevance_scores = Float32Array::from_iter_values(
combined_row_ids
.values()
.iter()
.map(|row_id| rrf_score_map.get(row_id).unwrap())
.copied(),
);
// keep track of indices sorted by the relevance column
let sort_indices = sort_to_indices(
&relevance_scores,
Some(SortOptions {
descending: true,
..Default::default()
}),
None,
)
.unwrap();
// add relevance scores to columns
let mut columns = combined_results.columns().to_vec();
columns.push(Arc::new(relevance_scores));
// sort by the relevance scores
let columns = columns
.iter()
.map(|c| take(c, &sort_indices, None).unwrap())
.collect();
// add relevance score to schema
let mut fields = combined_results.schema().fields().to_vec();
fields.push(Arc::new(Field::new(
RELEVANCE_SCORE,
DataType::Float32,
false,
)));
let schema = Schema::new(fields);
let combined_results = RecordBatch::try_new(Arc::new(schema), columns)?;
Ok(combined_results)
}
}
#[cfg(test)]
pub mod test {
use super::*;
use arrow_array::StringArray;
#[tokio::test]
async fn test_rrf_reranker() {
let schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, false),
Field::new(ROW_ID, DataType::UInt64, false),
]));
let vec_results = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"])),
Arc::new(UInt64Array::from(vec![1, 4, 2, 5, 3])),
],
)
.unwrap();
let fts_results = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(vec!["bar", "bean", "dog"])),
Arc::new(UInt64Array::from(vec![4, 5, 3])),
],
)
.unwrap();
// scores should be calculated as:
// - foo = 1/1 = 1.0
// - bar = 1/2 + 1/1 = 1.5
// - baz = 1/3 = 0.333
// - bean = 1/4 + 1/2 = 0.75
// - dog = 1/5 + 1/3 = 0.533
// then we should get the result ranked in descending order
let reranker = RRFReranker::new(1.0);
let result = reranker
.rerank_hybrid("", vec_results, fts_results)
.await
.unwrap();
assert_eq!(3, result.schema().fields().len());
assert_eq!("name", result.schema().fields().first().unwrap().name());
assert_eq!(ROW_ID, result.schema().fields().get(1).unwrap().name());
assert_eq!(
RELEVANCE_SCORE,
result.schema().fields().get(2).unwrap().name()
);
let names: StringArray = downcast_array(result.column(0));
assert_eq!(
names.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec!["bar", "foo", "bean", "dog", "baz"]
);
let ids: UInt64Array = downcast_array(result.column(1));
assert_eq!(
ids.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec![4, 1, 5, 3, 2]
);
let scores: Float32Array = downcast_array(result.column(2));
assert_eq!(
scores.iter().map(|e| e.unwrap()).collect::<Vec<_>>(),
vec![1.5, 1.0, 0.75, 1.0 / 5.0 + 1.0 / 3.0, 1.0 / 3.0]
);
}
}

View File

@@ -18,9 +18,9 @@ use std::path::Path;
use std::sync::Arc;
use arrow::array::AsArray;
use arrow::datatypes::Float32Type;
use arrow::datatypes::{Float32Type, UInt8Type};
use arrow_array::{RecordBatchIterator, RecordBatchReader};
use arrow_schema::{Field, Schema, SchemaRef};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use async_trait::async_trait;
use datafusion_physical_plan::display::DisplayableExecutionPlan;
use datafusion_physical_plan::projection::ProjectionExec;
@@ -58,8 +58,8 @@ use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, M
use crate::error::{Error, Result};
use crate::index::scalar::FtsIndexBuilder;
use crate::index::vector::{
suggested_num_partitions_for_hnsw, IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder,
IvfPqIndexBuilder, VectorIndex,
suggested_num_partitions_for_hnsw, IvfFlatIndexBuilder, IvfHnswPqIndexBuilder,
IvfHnswSqIndexBuilder, IvfPqIndexBuilder, VectorIndex,
};
use crate::index::IndexStatistics;
use crate::index::{
@@ -1306,6 +1306,44 @@ impl NativeTable {
.collect())
}
async fn create_ivf_flat_index(
&self,
index: IvfFlatIndexBuilder,
field: &Field,
replace: bool,
) -> Result<()> {
if !supported_vector_data_type(field.data_type()) {
return Err(Error::InvalidInput {
message: format!(
"An IVF Flat index cannot be created on the column `{}` which has data type {}",
field.name(),
field.data_type()
),
});
}
let num_partitions = if let Some(n) = index.num_partitions {
n
} else {
suggested_num_partitions(self.count_rows(None).await?)
};
let mut dataset = self.dataset.get_mut().await?;
let lance_idx_params = lance::index::vector::VectorIndexParams::ivf_flat(
num_partitions as usize,
index.distance_type.into(),
);
dataset
.create_index(
&[field.name()],
IndexType::Vector,
None,
&lance_idx_params,
replace,
)
.await?;
Ok(())
}
async fn create_ivf_pq_index(
&self,
index: IvfPqIndexBuilder,
@@ -1778,6 +1816,10 @@ impl TableInternal for NativeTable {
Index::Bitmap(_) => self.create_bitmap_index(field, opts).await,
Index::LabelList(_) => self.create_label_list_index(field, opts).await,
Index::FTS(fts_opts) => self.create_fts_index(field, fts_opts, opts.replace).await,
Index::IvfFlat(ivf_flat) => {
self.create_ivf_flat_index(ivf_flat, field, opts.replace)
.await
}
Index::IvfPq(ivf_pq) => self.create_ivf_pq_index(ivf_pq, field, opts.replace).await,
Index::IvfHnswPq(ivf_hnsw_pq) => {
self.create_ivf_hnsw_pq_index(ivf_hnsw_pq, field, opts.replace)
@@ -1848,14 +1890,21 @@ impl TableInternal for NativeTable {
message: format!("Column {} not found in dataset schema", column),
})?;
if let arrow_schema::DataType::FixedSizeList(f, dim) = field.data_type() {
if !f.data_type().is_floating() {
return Err(Error::InvalidInput {
message: format!(
"The data type of the vector column '{}' is not a floating point type",
column
),
});
let mut is_binary = false;
if let arrow_schema::DataType::FixedSizeList(element, dim) = field.data_type() {
match element.data_type() {
e_type if e_type.is_floating() => {}
e_type if *e_type == DataType::UInt8 => {
is_binary = true;
}
_ => {
return Err(Error::InvalidInput {
message: format!(
"The data type of the vector column '{}' is not a floating point type",
column
),
});
}
}
if dim != query_vector.len() as i32 {
return Err(Error::InvalidInput {
@@ -1870,12 +1919,22 @@ impl TableInternal for NativeTable {
}
}
let query_vector = query_vector.as_primitive::<Float32Type>();
scanner.nearest(
&column,
query_vector,
query.base.limit.unwrap_or(DEFAULT_TOP_K),
)?;
if is_binary {
let query_vector = arrow::compute::cast(&query_vector, &DataType::UInt8)?;
let query_vector = query_vector.as_primitive::<UInt8Type>();
scanner.nearest(
&column,
query_vector,
query.base.limit.unwrap_or(DEFAULT_TOP_K),
)?;
} else {
let query_vector = query_vector.as_primitive::<Float32Type>();
scanner.nearest(
&column,
query_vector,
query.base.limit.unwrap_or(DEFAULT_TOP_K),
)?;
}
}
scanner.limit(
query.base.limit.map(|limit| limit as i64),

View File

@@ -110,7 +110,7 @@ pub(crate) fn default_vector_column(schema: &Schema, dim: Option<i32>) -> Result
.iter()
.filter_map(|field| match field.data_type() {
arrow_schema::DataType::FixedSizeList(f, d)
if f.data_type().is_floating()
if (f.data_type().is_floating() || f.data_type() == &DataType::UInt8)
&& dim.map(|expect| *d == expect).unwrap_or(true) =>
{
Some(field.name())
@@ -171,7 +171,9 @@ pub fn supported_fts_data_type(dtype: &DataType) -> bool {
pub fn supported_vector_data_type(dtype: &DataType) -> bool {
match dtype {
DataType::FixedSizeList(inner, _) => DataType::is_floating(inner.data_type()),
DataType::FixedSizeList(inner, _) => {
DataType::is_floating(inner.data_type()) || *inner.data_type() == DataType::UInt8
}
_ => false,
}
}