Compare commits

..

38 Commits

Author SHA1 Message Date
BubbleCal
f69b673c1e Merge branch 'main' of https://github.com/lancedb/lancedb into yang/relative-lance-dep 2024-11-11 17:36:06 +08:00
Umut Hope YILDIRIM
729718cb09 fix: arm64 runner proto already installed bug (#1810)
https://github.com/lancedb/lancedb/actions/runs/11748512661/job/32732745458
2024-11-08 14:49:37 -08:00
Umut Hope YILDIRIM
b1c84e0bda feat: added lancedb and vectordb release ci for win32-arm64-msvc npmjs only (#1805) 2024-11-08 11:40:57 -08:00
fzowl
cbbc07d0f5 feat: voyageai support (#1799)
Adding VoyageAI embedding and rerank support
2024-11-09 00:51:20 +05:30
Kursat Aktas
21021f94ca docs: introducing LanceDB Guru on Gurubase.io (#1797)
Hello team,

I'm the maintainer of [Anteon](https://github.com/getanteon/anteon). We
have created Gurubase.io with the mission of building a centralized,
open-source tool-focused knowledge base. Essentially, each "guru" is
equipped with custom knowledge to answer user questions based on
collected data related to that tool.

I wanted to update you that I've manually added the [LanceDB
Guru](https://gurubase.io/g/lancedb) to Gurubase. LanceDB Guru uses the
data from this repo and data from the
[docs](https://lancedb.github.io/lancedb/) to answer questions by
leveraging the LLM.

In this PR, I showcased the "LanceDB Guru", which highlights that
LanceDB now has an AI assistant available to help users with their
questions. Please let me know your thoughts on this contribution.

Additionally, if you want me to disable LanceDB Guru in Gurubase, just
let me know that's totally fine.

Signed-off-by: Kursat Aktas <kursat.ce@gmail.com>
2024-11-08 10:55:22 -08:00
BubbleCal
0ed77fa990 chore: impl Debug & Clone for Index params (#1808)
we don't really need these trait in lancedb, but all fields in `Index`
implement the 2 traits, so do it for possibility to use `Index`
somewhere

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2024-11-09 01:07:43 +08:00
BubbleCal
4372c231cd feat: support optimize indices in sync API (#1769)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2024-11-08 08:48:07 -08:00
BubbleCal
4c6b728a31 feat: support FTS options on RemoteTable
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2024-11-08 18:49:13 +08:00
BubbleCal
138a12a427 Merge branch 'main' of https://github.com/lancedb/lancedb into yang/relative-lance-dep 2024-11-08 18:49:09 +08:00
Umut Hope YILDIRIM
fa9ca8f7a6 ci: arm64 windows build support (#1770)
Adds support for 'aarch64-pc-windows-msvc'.
2024-11-06 15:34:23 -08:00
Lance Release
2a35d24ee6 Updating package-lock.json 2024-11-06 17:26:36 +00:00
Lance Release
dd9ce337e2 Bump version: 0.13.0-beta.0 → 0.13.0-beta.1 2024-11-06 17:26:17 +00:00
Will Jones
b9921d56cc fix(node): update default log level to warn (#1801)
🤦
2024-11-06 09:13:53 -08:00
Lance Release
0cfd9ed18e Updating package-lock.json 2024-11-05 23:21:50 +00:00
Lance Release
975398c3a8 Bump version: 0.12.0 → 0.13.0-beta.0 2024-11-05 23:21:32 +00:00
Lance Release
08d5f93f34 Bump version: 0.15.0 → 0.16.0-beta.0 2024-11-05 23:21:13 +00:00
Will Jones
91cab3b556 feat(python): transition Python remote sdk to use Rust implementation (#1701)
* Replaces Python implementation of Remote SDK with Rust one.
* Drops dependency on `attrs` and `cachetools`. Makes `requests` an
optional dependency used only for embeddings feature.
* Adds dependency on `nest-asyncio`. This was required to get hybrid
search working.
* Deprecate `request_thread_pool` parameter. We now use the tokio
threadpool.
* Stop caching the `schema` on a remote table. Schema is mutable and
there's no mechanism in place to invalidate the cache.
* Removed the client-side resolution of the vector column. We should
already be resolving this server-side.
2024-11-05 13:44:39 -08:00
Will Jones
c61bfc3af8 chore: update package locks (#1798) 2024-11-05 13:28:59 -08:00
Bert
4e8c7b0adf fix: serialize vectordb client errors as json (#1795) 2024-11-05 14:16:25 -05:00
Weston Pace
26f4a80e10 feat: upgrade to lance 0.19.2-beta.3 (#1794) 2024-11-05 06:43:41 -08:00
Will Jones
3604d20ad3 feat(python,node): support with_row_id in Python and remote (#1784)
Needed to support hybrid search in Remote SDK.
2024-11-04 11:25:45 -08:00
Gagan Bhullar
9708d829a9 fix: explain plan options (#1776)
PR fixes #1768
2024-11-04 10:25:34 -08:00
Will Jones
059c9794b5 fix(rust): fix update, open_table, fts search in remote client (#1785)
* `open_table` uses `POST` not `GET`
* `update` uses `predicate` key not `only_if`
* For FTS search, vector cannot be omitted. It must be passed as empty.
* Added logging of JSON request bodies to debug level logging.
2024-11-04 08:27:55 -08:00
Will Jones
15ed7f75a0 feat(python): support post filter on FTS (#1783) 2024-11-01 10:05:05 -07:00
Will Jones
96181ab421 feat: fast_search in Python and Node (#1623)
Sometimes it is acceptable to users to only search indexed data and skip
and new un-indexed data. For example, if un-indexed data will be shortly
indexed and they don't mind the delay. In these cases, we can save a lot
of CPU time in search, and provide better latency. Users can activate
this on queries using `fast_search()`.
2024-11-01 09:29:09 -07:00
BubbleCal
0c108407ab bump version
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2024-11-01 15:17:25 +08:00
BubbleCal
a7fead3801 Merge branch 'main' of https://github.com/lancedb/lancedb into yang/relative-lance-dep 2024-11-01 15:15:23 +08:00
Will Jones
f3fc339ef6 fix(rust): fix delete, update, query in remote SDK (#1782)
Fixes several minor issues with Rust remote SDK:

* Delete uses `predicate` not `filter` as parameter
* Update does not return the row value in remote SDK
* Update takes tuples
* Content type returned by query node is wrong, so we shouldn't validate
it. https://github.com/lancedb/sophon/issues/2742
* Data returned by query endpoint is actually an Arrow IPC file, not IPC
stream.
2024-10-31 15:22:09 -07:00
Will Jones
113cd6995b fix: index_stats works for FTS indices (#1780)
When running `index_stats()` for an FTS index, users would get the
deserialization error:

```
InvalidInput { message: "error deserializing index statistics: unknown variant `Inverted`, expected one of `IvfPq`, `IvfHnswPq`, `IvfHnswSq`, `BTree`, `Bitmap`, `LabelList`, `FTS` at line 1 column 24" }
```
2024-10-30 11:33:49 -07:00
Lance Release
02535bdc88 Updating package-lock.json 2024-10-29 22:16:51 +00:00
Lance Release
facc7d61c0 Bump version: 0.12.0-beta.0 → 0.12.0 2024-10-29 22:16:32 +00:00
Lance Release
f947259f16 Bump version: 0.11.1-beta.1 → 0.12.0-beta.0 2024-10-29 22:16:27 +00:00
BubbleCal
50c68feae9 Merge branch 'main' of https://github.com/lancedb/lancedb into yang/relative-lance-dep 2024-09-29 15:05:56 +08:00
BubbleCal
f30c5b24fa fix
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2024-09-27 17:58:35 +08:00
BubbleCal
2a477ad387 Merge branch 'main' of https://github.com/lancedb/lancedb into yang/relative-lance-dep 2024-09-27 17:00:31 +08:00
BubbleCal
0b29aca23b Merge branch 'main' of https://github.com/lancedb/lancedb into yang/relative-lance-dep 2024-09-09 08:09:13 +08:00
BubbleCal
df62c3d9ac Merge branch 'main' of https://github.com/lancedb/lancedb into yang/relative-lance-dep 2024-09-04 16:45:46 +08:00
BubbleCal
aef4656053 feat: use relative lance
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2024-08-13 16:24:34 +08:00
73 changed files with 1969 additions and 1126 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.11.1-beta.1"
current_version = "0.13.0-beta.1"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.
@@ -92,6 +92,11 @@ glob = "node/package.json"
replace = "\"@lancedb/vectordb-win32-x64-msvc\": \"{new_version}\""
search = "\"@lancedb/vectordb-win32-x64-msvc\": \"{current_version}\""
[[tool.bumpversion.files]]
glob = "node/package.json"
replace = "\"@lancedb/vectordb-win32-arm64-msvc\": \"{new_version}\""
search = "\"@lancedb/vectordb-win32-arm64-msvc\": \"{current_version}\""
# Cargo files
# ------------
[[tool.bumpversion.files]]

View File

@@ -38,3 +38,7 @@ rustflags = ["-C", "target-cpu=apple-m1", "-C", "target-feature=+neon,+fp16,+fhm
# not found errors on systems that are missing it.
[target.x86_64-pc-windows-msvc]
rustflags = ["-Ctarget-feature=+crt-static"]
# Experimental target for Arm64 Windows
[target.aarch64-pc-windows-msvc]
rustflags = ["-Ctarget-feature=+crt-static"]

View File

@@ -226,6 +226,126 @@ jobs:
path: |
node/dist/lancedb-vectordb-win32*.tgz
node-windows-arm64:
name: vectordb win32-arm64-msvc
runs-on: windows-4x-arm
if: startsWith(github.ref, 'refs/tags/v')
steps:
- uses: actions/checkout@v4
- name: Cache installations
id: cache-installs
uses: actions/cache@v4
with:
path: |
C:\Program Files\Git
C:\BuildTools
C:\Program Files (x86)\Windows Kits
C:\Program Files\7-Zip
C:\protoc
key: ${{ runner.os }}-arm64-installs-v1
restore-keys: |
${{ runner.os }}-arm64-installs-
- name: Install Git
if: steps.cache-installs.outputs.cache-hit != 'true'
run: |
Invoke-WebRequest -Uri "https://github.com/git-for-windows/git/releases/download/v2.44.0.windows.1/Git-2.44.0-64-bit.exe" -OutFile "git-installer.exe"
Start-Process -FilePath "git-installer.exe" -ArgumentList "/VERYSILENT", "/NORESTART" -Wait
shell: powershell
- name: Add Git to PATH
run: |
Add-Content $env:GITHUB_PATH "C:\Program Files\Git\bin"
$env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User")
shell: powershell
- name: Configure Git symlinks
run: git config --global core.symlinks true
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.13"
- name: Install Visual Studio Build Tools
if: steps.cache-installs.outputs.cache-hit != 'true'
run: |
Invoke-WebRequest -Uri "https://aka.ms/vs/17/release/vs_buildtools.exe" -OutFile "vs_buildtools.exe"
Start-Process -FilePath "vs_buildtools.exe" -ArgumentList "--quiet", "--wait", "--norestart", "--nocache", `
"--installPath", "C:\BuildTools", `
"--add", "Microsoft.VisualStudio.Component.VC.Tools.ARM64", `
"--add", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64", `
"--add", "Microsoft.VisualStudio.Component.Windows11SDK.22621", `
"--add", "Microsoft.VisualStudio.Component.VC.ATL", `
"--add", "Microsoft.VisualStudio.Component.VC.ATLMFC", `
"--add", "Microsoft.VisualStudio.Component.VC.Llvm.Clang" -Wait
shell: powershell
- name: Add Visual Studio Build Tools to PATH
run: |
$vsPath = "C:\BuildTools\VC\Tools\MSVC"
$latestVersion = (Get-ChildItem $vsPath | Sort-Object {[version]$_.Name} -Descending)[0].Name
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\MSVC\$latestVersion\bin\Hostx64\arm64"
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\MSVC\$latestVersion\bin\Hostx64\x64"
Add-Content $env:GITHUB_PATH "C:\Program Files (x86)\Windows Kits\10\bin\10.0.22621.0\arm64"
Add-Content $env:GITHUB_PATH "C:\Program Files (x86)\Windows Kits\10\bin\10.0.22621.0\x64"
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\Llvm\x64\bin"
# Add MSVC runtime libraries to LIB
$env:LIB = "C:\BuildTools\VC\Tools\MSVC\$latestVersion\lib\arm64;" +
"C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\um\arm64;" +
"C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\ucrt\arm64"
Add-Content $env:GITHUB_ENV "LIB=$env:LIB"
# Add INCLUDE paths
$env:INCLUDE = "C:\BuildTools\VC\Tools\MSVC\$latestVersion\include;" +
"C:\Program Files (x86)\Windows Kits\10\Include\10.0.22621.0\ucrt;" +
"C:\Program Files (x86)\Windows Kits\10\Include\10.0.22621.0\um;" +
"C:\Program Files (x86)\Windows Kits\10\Include\10.0.22621.0\shared"
Add-Content $env:GITHUB_ENV "INCLUDE=$env:INCLUDE"
shell: powershell
- name: Install Rust
run: |
Invoke-WebRequest https://win.rustup.rs/x86_64 -OutFile rustup-init.exe
.\rustup-init.exe -y --default-host aarch64-pc-windows-msvc
shell: powershell
- name: Add Rust to PATH
run: |
Add-Content $env:GITHUB_PATH "$env:USERPROFILE\.cargo\bin"
shell: powershell
- uses: Swatinem/rust-cache@v2
with:
workspaces: rust
- name: Install 7-Zip ARM
if: steps.cache-installs.outputs.cache-hit != 'true'
run: |
New-Item -Path 'C:\7zip' -ItemType Directory
Invoke-WebRequest https://7-zip.org/a/7z2408-arm64.exe -OutFile C:\7zip\7z-installer.exe
Start-Process -FilePath C:\7zip\7z-installer.exe -ArgumentList '/S' -Wait
shell: powershell
- name: Add 7-Zip to PATH
run: Add-Content $env:GITHUB_PATH "C:\Program Files\7-Zip"
shell: powershell
- name: Install Protoc v21.12
if: steps.cache-installs.outputs.cache-hit != 'true'
working-directory: C:\
run: |
if (Test-Path 'C:\protoc') {
Write-Host "Protoc directory exists, skipping installation"
return
}
New-Item -Path 'C:\protoc' -ItemType Directory
Set-Location C:\protoc
Invoke-WebRequest https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win64.zip -OutFile C:\protoc\protoc.zip
& 'C:\Program Files\7-Zip\7z.exe' x protoc.zip
shell: powershell
- name: Add Protoc to PATH
run: Add-Content $env:GITHUB_PATH "C:\protoc\bin"
shell: powershell
- name: Build Windows native node modules
run: .\ci\build_windows_artifacts.ps1 aarch64-pc-windows-msvc
- name: Upload Windows ARM64 Artifacts
uses: actions/upload-artifact@v4
with:
name: node-native-windows-arm64
path: |
node/dist/*.node
nodejs-windows:
name: lancedb ${{ matrix.target }}
runs-on: windows-2022
@@ -260,9 +380,119 @@ jobs:
path: |
nodejs/dist/*.node
nodejs-windows-arm64:
name: lancedb win32-arm64-msvc
runs-on: windows-4x-arm
if: startsWith(github.ref, 'refs/tags/v')
steps:
- uses: actions/checkout@v4
- name: Cache installations
id: cache-installs
uses: actions/cache@v4
with:
path: |
C:\Program Files\Git
C:\BuildTools
C:\Program Files (x86)\Windows Kits
C:\Program Files\7-Zip
C:\protoc
key: ${{ runner.os }}-arm64-installs-v1
restore-keys: |
${{ runner.os }}-arm64-installs-
- name: Install Git
if: steps.cache-installs.outputs.cache-hit != 'true'
run: |
Invoke-WebRequest -Uri "https://github.com/git-for-windows/git/releases/download/v2.44.0.windows.1/Git-2.44.0-64-bit.exe" -OutFile "git-installer.exe"
Start-Process -FilePath "git-installer.exe" -ArgumentList "/VERYSILENT", "/NORESTART" -Wait
shell: powershell
- name: Add Git to PATH
run: |
Add-Content $env:GITHUB_PATH "C:\Program Files\Git\bin"
$env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User")
shell: powershell
- name: Configure Git symlinks
run: git config --global core.symlinks true
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.13"
- name: Install Visual Studio Build Tools
if: steps.cache-installs.outputs.cache-hit != 'true'
run: |
Invoke-WebRequest -Uri "https://aka.ms/vs/17/release/vs_buildtools.exe" -OutFile "vs_buildtools.exe"
Start-Process -FilePath "vs_buildtools.exe" -ArgumentList "--quiet", "--wait", "--norestart", "--nocache", `
"--installPath", "C:\BuildTools", `
"--add", "Microsoft.VisualStudio.Component.VC.Tools.ARM64", `
"--add", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64", `
"--add", "Microsoft.VisualStudio.Component.Windows11SDK.22621", `
"--add", "Microsoft.VisualStudio.Component.VC.ATL", `
"--add", "Microsoft.VisualStudio.Component.VC.ATLMFC", `
"--add", "Microsoft.VisualStudio.Component.VC.Llvm.Clang" -Wait
shell: powershell
- name: Add Visual Studio Build Tools to PATH
run: |
$vsPath = "C:\BuildTools\VC\Tools\MSVC"
$latestVersion = (Get-ChildItem $vsPath | Sort-Object {[version]$_.Name} -Descending)[0].Name
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\MSVC\$latestVersion\bin\Hostx64\arm64"
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\MSVC\$latestVersion\bin\Hostx64\x64"
Add-Content $env:GITHUB_PATH "C:\Program Files (x86)\Windows Kits\10\bin\10.0.22621.0\arm64"
Add-Content $env:GITHUB_PATH "C:\Program Files (x86)\Windows Kits\10\bin\10.0.22621.0\x64"
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\Llvm\x64\bin"
$env:LIB = ""
Add-Content $env:GITHUB_ENV "LIB=C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\um\arm64;C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\ucrt\arm64"
shell: powershell
- name: Install Rust
run: |
Invoke-WebRequest https://win.rustup.rs/x86_64 -OutFile rustup-init.exe
.\rustup-init.exe -y --default-host aarch64-pc-windows-msvc
shell: powershell
- name: Add Rust to PATH
run: |
Add-Content $env:GITHUB_PATH "$env:USERPROFILE\.cargo\bin"
shell: powershell
- uses: Swatinem/rust-cache@v2
with:
workspaces: rust
- name: Install 7-Zip ARM
if: steps.cache-installs.outputs.cache-hit != 'true'
run: |
New-Item -Path 'C:\7zip' -ItemType Directory
Invoke-WebRequest https://7-zip.org/a/7z2408-arm64.exe -OutFile C:\7zip\7z-installer.exe
Start-Process -FilePath C:\7zip\7z-installer.exe -ArgumentList '/S' -Wait
shell: powershell
- name: Add 7-Zip to PATH
run: Add-Content $env:GITHUB_PATH "C:\Program Files\7-Zip"
shell: powershell
- name: Install Protoc v21.12
if: steps.cache-installs.outputs.cache-hit != 'true'
working-directory: C:\
run: |
if (Test-Path 'C:\protoc') {
Write-Host "Protoc directory exists, skipping installation"
return
}
New-Item -Path 'C:\protoc' -ItemType Directory
Set-Location C:\protoc
Invoke-WebRequest https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win64.zip -OutFile C:\protoc\protoc.zip
& 'C:\Program Files\7-Zip\7z.exe' x protoc.zip
shell: powershell
- name: Add Protoc to PATH
run: Add-Content $env:GITHUB_PATH "C:\protoc\bin"
shell: powershell
- name: Build Windows native node modules
run: .\ci\build_windows_artifacts_nodejs.ps1 aarch64-pc-windows-msvc
- name: Upload Windows ARM64 Artifacts
uses: actions/upload-artifact@v4
with:
name: nodejs-native-windows-arm64
path: |
nodejs/dist/*.node
release:
name: vectordb NPM Publish
needs: [node, node-macos, node-linux, node-windows]
needs: [node, node-macos, node-linux, node-windows, node-windows-arm64]
runs-on: ubuntu-latest
# Only runs on tags that matches the make-release action
if: startsWith(github.ref, 'refs/tags/v')
@@ -302,7 +532,7 @@ jobs:
release-nodejs:
name: lancedb NPM Publish
needs: [nodejs-macos, nodejs-linux, nodejs-windows]
needs: [nodejs-macos, nodejs-linux, nodejs-windows, nodejs-windows-arm64]
runs-on: ubuntu-latest
# Only runs on tags that matches the make-release action
if: startsWith(github.ref, 'refs/tags/v')

View File

@@ -35,21 +35,21 @@ jobs:
CC: clang-18
CXX: clang++-18
steps:
- uses: actions/checkout@v4
with:
- uses: actions/checkout@v4
with:
fetch-depth: 0
lfs: true
- uses: Swatinem/rust-cache@v2
with:
workspaces: rust
- name: Install dependencies
run: |
- uses: Swatinem/rust-cache@v2
with:
workspaces: rust
- name: Install dependencies
run: |
sudo apt update
sudo apt install -y protobuf-compiler libssl-dev
- name: Run format
run: cargo fmt --all -- --check
- name: Run clippy
run: cargo clippy --workspace --tests --all-features -- -D warnings
- name: Run format
run: cargo fmt --all -- --check
- name: Run clippy
run: cargo clippy --workspace --tests --all-features -- -D warnings
linux:
timeout-minutes: 30
# To build all features, we need more disk space than is available
@@ -65,37 +65,37 @@ jobs:
CC: clang-18
CXX: clang++-18
steps:
- uses: actions/checkout@v4
with:
- uses: actions/checkout@v4
with:
fetch-depth: 0
lfs: true
- uses: Swatinem/rust-cache@v2
with:
- uses: Swatinem/rust-cache@v2
with:
workspaces: rust
- name: Install dependencies
run: |
- name: Install dependencies
run: |
sudo apt update
sudo apt install -y protobuf-compiler libssl-dev
- name: Make Swap
run: |
sudo fallocate -l 16G /swapfile
sudo chmod 600 /swapfile
sudo mkswap /swapfile
sudo swapon /swapfile
- name: Start S3 integration test environment
working-directory: .
run: docker compose up --detach --wait
- name: Build
run: cargo build --all-features
- name: Run tests
run: cargo test --all-features
- name: Run examples
run: cargo run --example simple
- name: Make Swap
run: |
sudo fallocate -l 16G /swapfile
sudo chmod 600 /swapfile
sudo mkswap /swapfile
sudo swapon /swapfile
- name: Start S3 integration test environment
working-directory: .
run: docker compose up --detach --wait
- name: Build
run: cargo build --all-features
- name: Run tests
run: cargo test --all-features
- name: Run examples
run: cargo run --example simple
macos:
timeout-minutes: 30
strategy:
matrix:
mac-runner: [ "macos-13", "macos-14" ]
mac-runner: ["macos-13", "macos-14"]
runs-on: "${{ matrix.mac-runner }}"
defaults:
run:
@@ -104,8 +104,8 @@ jobs:
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
lfs: true
fetch-depth: 0
lfs: true
- name: CPU features
run: sysctl -a | grep cpu
- uses: Swatinem/rust-cache@v2
@@ -139,3 +139,116 @@ jobs:
$env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT
cargo build
cargo test
windows-arm64:
runs-on: windows-4x-arm
steps:
- name: Cache installations
id: cache-installs
uses: actions/cache@v4
with:
path: |
C:\Program Files\Git
C:\BuildTools
C:\Program Files (x86)\Windows Kits
C:\Program Files\7-Zip
C:\protoc
key: ${{ runner.os }}-arm64-installs-v1
restore-keys: |
${{ runner.os }}-arm64-installs-
- name: Install Git
if: steps.cache-installs.outputs.cache-hit != 'true'
run: |
Invoke-WebRequest -Uri "https://github.com/git-for-windows/git/releases/download/v2.44.0.windows.1/Git-2.44.0-64-bit.exe" -OutFile "git-installer.exe"
Start-Process -FilePath "git-installer.exe" -ArgumentList "/VERYSILENT", "/NORESTART" -Wait
shell: powershell
- name: Add Git to PATH
run: |
Add-Content $env:GITHUB_PATH "C:\Program Files\Git\bin"
$env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User")
shell: powershell
- name: Configure Git symlinks
run: git config --global core.symlinks true
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.13"
- name: Install Visual Studio Build Tools
if: steps.cache-installs.outputs.cache-hit != 'true'
run: |
Invoke-WebRequest -Uri "https://aka.ms/vs/17/release/vs_buildtools.exe" -OutFile "vs_buildtools.exe"
Start-Process -FilePath "vs_buildtools.exe" -ArgumentList "--quiet", "--wait", "--norestart", "--nocache", `
"--installPath", "C:\BuildTools", `
"--add", "Microsoft.VisualStudio.Component.VC.Tools.ARM64", `
"--add", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64", `
"--add", "Microsoft.VisualStudio.Component.Windows11SDK.22621", `
"--add", "Microsoft.VisualStudio.Component.VC.ATL", `
"--add", "Microsoft.VisualStudio.Component.VC.ATLMFC", `
"--add", "Microsoft.VisualStudio.Component.VC.Llvm.Clang" -Wait
shell: powershell
- name: Add Visual Studio Build Tools to PATH
run: |
$vsPath = "C:\BuildTools\VC\Tools\MSVC"
$latestVersion = (Get-ChildItem $vsPath | Sort-Object {[version]$_.Name} -Descending)[0].Name
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\MSVC\$latestVersion\bin\Hostx64\arm64"
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\MSVC\$latestVersion\bin\Hostx64\x64"
Add-Content $env:GITHUB_PATH "C:\Program Files (x86)\Windows Kits\10\bin\10.0.22621.0\arm64"
Add-Content $env:GITHUB_PATH "C:\Program Files (x86)\Windows Kits\10\bin\10.0.22621.0\x64"
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\Llvm\x64\bin"
# Add MSVC runtime libraries to LIB
$env:LIB = "C:\BuildTools\VC\Tools\MSVC\$latestVersion\lib\arm64;" +
"C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\um\arm64;" +
"C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\ucrt\arm64"
Add-Content $env:GITHUB_ENV "LIB=$env:LIB"
# Add INCLUDE paths
$env:INCLUDE = "C:\BuildTools\VC\Tools\MSVC\$latestVersion\include;" +
"C:\Program Files (x86)\Windows Kits\10\Include\10.0.22621.0\ucrt;" +
"C:\Program Files (x86)\Windows Kits\10\Include\10.0.22621.0\um;" +
"C:\Program Files (x86)\Windows Kits\10\Include\10.0.22621.0\shared"
Add-Content $env:GITHUB_ENV "INCLUDE=$env:INCLUDE"
shell: powershell
- name: Install Rust
run: |
Invoke-WebRequest https://win.rustup.rs/x86_64 -OutFile rustup-init.exe
.\rustup-init.exe -y --default-host aarch64-pc-windows-msvc
shell: powershell
- name: Add Rust to PATH
run: |
Add-Content $env:GITHUB_PATH "$env:USERPROFILE\.cargo\bin"
shell: powershell
- uses: Swatinem/rust-cache@v2
with:
workspaces: rust
- name: Install 7-Zip ARM
if: steps.cache-installs.outputs.cache-hit != 'true'
run: |
New-Item -Path 'C:\7zip' -ItemType Directory
Invoke-WebRequest https://7-zip.org/a/7z2408-arm64.exe -OutFile C:\7zip\7z-installer.exe
Start-Process -FilePath C:\7zip\7z-installer.exe -ArgumentList '/S' -Wait
shell: powershell
- name: Add 7-Zip to PATH
run: Add-Content $env:GITHUB_PATH "C:\Program Files\7-Zip"
shell: powershell
- name: Install Protoc v21.12
if: steps.cache-installs.outputs.cache-hit != 'true'
working-directory: C:\
run: |
if (Test-Path 'C:\protoc') {
Write-Host "Protoc directory exists, skipping installation"
return
}
New-Item -Path 'C:\protoc' -ItemType Directory
Set-Location C:\protoc
Invoke-WebRequest https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win64.zip -OutFile C:\protoc\protoc.zip
& 'C:\Program Files\7-Zip\7z.exe' x protoc.zip
shell: powershell
- name: Add Protoc to PATH
run: Add-Content $env:GITHUB_PATH "C:\protoc\bin"
shell: powershell
- name: Run tests
run: |
$env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT
cargo build --target aarch64-pc-windows-msvc
cargo test --target aarch64-pc-windows-msvc

View File

@@ -21,13 +21,13 @@ categories = ["database-implementations"]
rust-version = "1.80.0" # TODO: lower this once we upgrade Lance again.
[workspace.dependencies]
lance = { "version" = "=0.19.1", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.19.1" }
lance-linalg = { "version" = "=0.19.1" }
lance-table = { "version" = "=0.19.1" }
lance-testing = { "version" = "=0.19.1" }
lance-datafusion = { "version" = "=0.19.1" }
lance-encoding = { "version" = "=0.19.1" }
lance = { "version" = "=0.19.2", "features" = ["dynamodb"], path = "../lance/rust/lance"}
lance-index = { "version" = "=0.19.2", path = "../lance/rust/lance-index"}
lance-linalg = { "version" = "=0.19.2", path = "../lance/rust/lance-linalg"}
lance-testing = { "version" = "=0.19.2", path = "../lance/rust/lance-testing"}
lance-datafusion = { "version" = "=0.19.2", path = "../lance/rust/lance-datafusion"}
lance-encoding = { "version" = "=0.19.2", path = "../lance/rust/lance-encoding"}
lance-table = { "version" = "=0.19.2", path = "../lance/rust/lance-table"}
# Note that this one does not include pyarrow
arrow = { version = "52.2", optional = false }
arrow-array = "52.2"

View File

@@ -10,6 +10,7 @@
[![Blog](https://img.shields.io/badge/Blog-12100E?style=for-the-badge&logoColor=white)](https://blog.lancedb.com/)
[![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/zMM32dvNtd)
[![Twitter](https://img.shields.io/badge/Twitter-%231DA1F2.svg?style=for-the-badge&logo=Twitter&logoColor=white)](https://twitter.com/lancedb)
[![Gurubase](https://img.shields.io/badge/Gurubase-Ask%20LanceDB%20Guru-006BFF?style=for-the-badge)](https://gurubase.io/g/lancedb)
</p>

View File

@@ -3,6 +3,7 @@
# Targets supported:
# - x86_64-pc-windows-msvc
# - i686-pc-windows-msvc
# - aarch64-pc-windows-msvc
function Prebuild-Rust {
param (
@@ -31,7 +32,7 @@ function Build-NodeBinaries {
$targets = $args[0]
if (-not $targets) {
$targets = "x86_64-pc-windows-msvc"
$targets = "x86_64-pc-windows-msvc", "aarch64-pc-windows-msvc"
}
Write-Host "Building artifacts for targets: $targets"

View File

@@ -3,6 +3,7 @@
# Targets supported:
# - x86_64-pc-windows-msvc
# - i686-pc-windows-msvc
# - aarch64-pc-windows-msvc
function Prebuild-Rust {
param (
@@ -31,7 +32,7 @@ function Build-NodeBinaries {
$targets = $args[0]
if (-not $targets) {
$targets = "x86_64-pc-windows-msvc"
$targets = "x86_64-pc-windows-msvc", "aarch64-pc-windows-msvc"
}
Write-Host "Building artifacts for targets: $targets"

View File

@@ -0,0 +1,51 @@
# VoyageAI Embeddings
Voyage AI provides cutting-edge embedding and rerankers.
Using voyageai API requires voyageai package, which can be installed using `pip install voyageai`. Voyage AI embeddings are used to generate embeddings for text data. The embeddings can be used for various tasks like semantic search, clustering, and classification.
You also need to set the `VOYAGE_API_KEY` environment variable to use the VoyageAI API.
Supported models are:
- voyage-3
- voyage-3-lite
- voyage-finance-2
- voyage-multilingual-2
- voyage-law-2
- voyage-code-2
Supported parameters (to be passed in `create` method) are:
| Parameter | Type | Default Value | Description |
|---|---|--------|---------|
| `name` | `str` | `"voyage-3"` | The model ID of the model to use. Supported base models for Text Embeddings: voyage-3, voyage-3-lite, voyage-finance-2, voyage-multilingual-2, voyage-law-2, voyage-code-2 |
| `input_type` | `str` | `None` | Type of the input text. Default to None. Other options: query, document. |
| `truncation` | `bool` | `True` | Whether to truncate the input texts to fit within the context length. |
Usage Example:
```python
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import EmbeddingFunctionRegistry
voyageai = EmbeddingFunctionRegistry
.get_instance()
.get("voyageai")
.create(name="voyage-3")
class TextModel(LanceModel):
text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
data = [ { "text": "hello world" },
{ "text": "goodbye world" }]
db = lancedb.connect("~/.lancedb")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(data)
```

View File

@@ -0,0 +1,77 @@
# Voyage AI Reranker
Voyage AI provides cutting-edge embedding and rerankers.
This re-ranker uses the [VoyageAI](https://docs.voyageai.com/docs/) API to rerank the search results. You can use this re-ranker by passing `VoyageAIReranker()` to the `rerank()` method. Note that you'll either need to set the `VOYAGE_API_KEY` environment variable or pass the `api_key` argument to use this re-ranker.
!!! note
Supported Query Types: Hybrid, Vector, FTS
```python
import numpy
import lancedb
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector
from lancedb.rerankers import VoyageAIReranker
embedder = get_registry().get("sentence-transformers").create()
db = lancedb.connect("~/.lancedb")
class Schema(LanceModel):
text: str = embedder.SourceField()
vector: Vector(embedder.ndims()) = embedder.VectorField()
data = [
{"text": "hello world"},
{"text": "goodbye world"}
]
tbl = db.create_table("test", schema=Schema, mode="overwrite")
tbl.add(data)
reranker = VoyageAIReranker(model_name="rerank-2")
# Run vector search with a reranker
result = tbl.search("hello").rerank(reranker=reranker).to_list()
# Run FTS search with a reranker
result = tbl.search("hello", query_type="fts").rerank(reranker=reranker).to_list()
# Run hybrid search with a reranker
tbl.create_fts_index("text", replace=True)
result = tbl.search("hello", query_type="hybrid").rerank(reranker=reranker).to_list()
```
Accepted Arguments
----------------
| Argument | Type | Default | Description |
| --- | --- | --- | --- |
| `model_name` | `str` | `None` | The name of the reranker model to use. Available models are: rerank-2, rerank-2-lite |
| `column` | `str` | `"text"` | The name of the column to use as input to the cross encoder model. |
| `top_n` | `str` | `None` | The number of results to return. If None, will return all results. |
| `api_key` | `str` | `None` | The API key for the Voyage AI API. If not provided, the `VOYAGE_API_KEY` environment variable is used. |
| `return_score` | str | `"relevance"` | Options are "relevance" or "all". The type of score to return. If "relevance", will return only the `_relevance_score. If "all" is supported, will return relevance score along with the vector and/or fts scores depending on query type |
| `truncation` | `bool` | `None` | Whether to truncate the input to satisfy the "context length limit" on the query and the documents. |
## Supported Scores for each query type
You can specify the type of scores you want the reranker to return. The following are the supported scores for each query type:
### Hybrid Search
|`return_score`| Status | Description |
| --- | --- | --- |
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
| `all` | ❌ Not Supported | Returns have vector(`_distance`) and FTS(`score`) along with Hybrid Search score(`_relevance_score`) |
### Vector Search
|`return_score`| Status | Description |
| --- | --- | --- |
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
| `all` | ✅ Supported | Returns have vector(`_distance`) along with Hybrid Search score(`_relevance_score`) |
### FTS Search
|`return_score`| Status | Description |
| --- | --- | --- |
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
| `all` | ✅ Supported | Returns have FTS(`score`) along with Hybrid Search score(`_relevance_score`) |

View File

@@ -8,7 +8,7 @@
<parent>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.11.1-beta.1</version>
<version>0.13.0-beta.1</version>
<relativePath>../pom.xml</relativePath>
</parent>

View File

@@ -6,7 +6,7 @@
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.11.1-beta.1</version>
<version>0.13.0-beta.1</version>
<packaging>pom</packaging>
<name>LanceDB Parent</name>

50
node/package-lock.json generated
View File

@@ -1,12 +1,12 @@
{
"name": "vectordb",
"version": "0.11.1-beta.1",
"version": "0.13.0-beta.1",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "vectordb",
"version": "0.11.1-beta.1",
"version": "0.13.0-beta.1",
"cpu": [
"x64",
"arm64"
@@ -52,11 +52,12 @@
"uuid": "^9.0.0"
},
"optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.11.1-beta.1",
"@lancedb/vectordb-darwin-x64": "0.11.1-beta.1",
"@lancedb/vectordb-linux-arm64-gnu": "0.11.1-beta.1",
"@lancedb/vectordb-linux-x64-gnu": "0.11.1-beta.1",
"@lancedb/vectordb-win32-x64-msvc": "0.11.1-beta.1"
"@lancedb/vectordb-darwin-arm64": "0.13.0-beta.1",
"@lancedb/vectordb-darwin-x64": "0.13.0-beta.1",
"@lancedb/vectordb-linux-arm64-gnu": "0.13.0-beta.1",
"@lancedb/vectordb-linux-x64-gnu": "0.13.0-beta.1",
"@lancedb/vectordb-win32-arm64-msvc": "0.13.0-beta.1",
"@lancedb/vectordb-win32-x64-msvc": "0.13.0-beta.1"
},
"peerDependencies": {
"@apache-arrow/ts": "^14.0.2",
@@ -327,65 +328,60 @@
}
},
"node_modules/@lancedb/vectordb-darwin-arm64": {
"version": "0.11.1-beta.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.11.1-beta.1.tgz",
"integrity": "sha512-q9jcCbmcz45UHmjgecL6zK82WaqUJsARfniwXXPcnd8ooISVhPkgN+RVKv6edwI9T0PV+xVRYq+LQLlZu5fyxw==",
"version": "0.13.0-beta.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.13.0-beta.1.tgz",
"integrity": "sha512-beOrf6selCzzhLgDG8Nibma4nO/CSnA1wUKRmlJHEPtGcg7PW18z6MP/nfwQMpMR/FLRfTo8pPTbpzss47MiQQ==",
"cpu": [
"arm64"
],
"license": "Apache-2.0",
"optional": true,
"os": [
"darwin"
]
},
"node_modules/@lancedb/vectordb-darwin-x64": {
"version": "0.11.1-beta.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.11.1-beta.1.tgz",
"integrity": "sha512-E5tCTS5TaTkssTPa+gdnFxZJ1f60jnSIJXhqufNFZk4s+IMViwR1BPqaqE++WY5c1uBI55ef1862CROKDKX4gg==",
"version": "0.13.0-beta.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.13.0-beta.1.tgz",
"integrity": "sha512-YdraGRF/RbJRkKh0v3xT03LUhq47T2GtCvJ5gZp8wKlh4pHa8LuhLU0DIdvmG/DT5vuQA+td8HDkBm/e3EOdNg==",
"cpu": [
"x64"
],
"license": "Apache-2.0",
"optional": true,
"os": [
"darwin"
]
},
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
"version": "0.11.1-beta.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.11.1-beta.1.tgz",
"integrity": "sha512-Obohy6TH31Uq+fp6ZisHR7iAsvgVPqBExrycVcIJqrLZnIe88N9OWUwBXkmfMAw/2hNJFwD4tU7+4U2FcBWX4w==",
"version": "0.13.0-beta.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.13.0-beta.1.tgz",
"integrity": "sha512-Pp0O/uhEqof1oLaWrNbv+Ym+q8kBkiCqaA5+2eAZ6a3e9U+Ozkvb0FQrHuyi9adJ5wKQ4NabyQE9BMf2bYpOnQ==",
"cpu": [
"arm64"
],
"license": "Apache-2.0",
"optional": true,
"os": [
"linux"
]
},
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
"version": "0.11.1-beta.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.11.1-beta.1.tgz",
"integrity": "sha512-3Meu0dgrzNrnBVVQhxkUSAOhQNmgtKHvOvmrRLUicV+X19hd33udihgxVpZZb9mpXenJ8lZsS+Jq6R0hWqntag==",
"version": "0.13.0-beta.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.13.0-beta.1.tgz",
"integrity": "sha512-y8nxOye4egfWF5FGED9EfkmZ1O5HnRLU4a61B8m5JSpkivO9v2epTcbYN0yt/7ZFCgtqMfJ8VW4Mi7qQcz3KDA==",
"cpu": [
"x64"
],
"license": "Apache-2.0",
"optional": true,
"os": [
"linux"
]
},
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
"version": "0.11.1-beta.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.11.1-beta.1.tgz",
"integrity": "sha512-BafZ9OJPQXsS7JW0weAl12wC+827AiRjfUrE5tvrYWZah2OwCF2U2g6uJ3x4pxfwEGsv5xcHFqgxlS7ttFkh+Q==",
"version": "0.13.0-beta.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.13.0-beta.1.tgz",
"integrity": "sha512-STMDP9dp0TBLkB3ro+16pKcGy6bmbhRuEZZZ1Tp5P75yTPeVh4zIgWkidMdU1qBbEYM7xacnsp9QAwgLnMU/Ow==",
"cpu": [
"x64"
],
"license": "Apache-2.0",
"optional": true,
"os": [
"win32"

View File

@@ -1,6 +1,6 @@
{
"name": "vectordb",
"version": "0.11.1-beta.1",
"version": "0.13.0-beta.1",
"description": " Serverless, low-latency vector database for AI applications",
"main": "dist/index.js",
"types": "dist/index.d.ts",
@@ -84,14 +84,16 @@
"aarch64-apple-darwin": "@lancedb/vectordb-darwin-arm64",
"x86_64-unknown-linux-gnu": "@lancedb/vectordb-linux-x64-gnu",
"aarch64-unknown-linux-gnu": "@lancedb/vectordb-linux-arm64-gnu",
"x86_64-pc-windows-msvc": "@lancedb/vectordb-win32-x64-msvc"
"x86_64-pc-windows-msvc": "@lancedb/vectordb-win32-x64-msvc",
"aarch64-pc-windows-msvc": "@lancedb/vectordb-win32-arm64-msvc"
}
},
"optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.11.1-beta.1",
"@lancedb/vectordb-darwin-x64": "0.11.1-beta.1",
"@lancedb/vectordb-linux-arm64-gnu": "0.11.1-beta.1",
"@lancedb/vectordb-linux-x64-gnu": "0.11.1-beta.1",
"@lancedb/vectordb-win32-x64-msvc": "0.11.1-beta.1"
"@lancedb/vectordb-darwin-arm64": "0.13.0-beta.1",
"@lancedb/vectordb-darwin-x64": "0.13.0-beta.1",
"@lancedb/vectordb-linux-arm64-gnu": "0.13.0-beta.1",
"@lancedb/vectordb-linux-x64-gnu": "0.13.0-beta.1",
"@lancedb/vectordb-win32-x64-msvc": "0.13.0-beta.1",
"@lancedb/vectordb-win32-arm64-msvc": "0.13.0-beta.1"
}
}

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
import axios, { type AxiosResponse, type ResponseType } from 'axios'
import axios, { type AxiosError, type AxiosResponse, type ResponseType } from 'axios'
import { tableFromIPC, type Table as ArrowTable } from 'apache-arrow'
@@ -197,7 +197,7 @@ export class HttpLancedbClient {
response = await callWithMiddlewares(req, this._middlewares)
return response
} catch (err: any) {
console.error('error: ', err)
console.error(serializeErrorAsJson(err))
if (err.response === undefined) {
throw new Error(`Network Error: ${err.message as string}`)
}
@@ -247,7 +247,8 @@ export class HttpLancedbClient {
// return response
} catch (err: any) {
console.error('error: ', err)
console.error(serializeErrorAsJson(err))
if (err.response === undefined) {
throw new Error(`Network Error: ${err.message as string}`)
}
@@ -287,3 +288,15 @@ export class HttpLancedbClient {
return clone
}
}
function serializeErrorAsJson(err: AxiosError) {
const error = JSON.parse(JSON.stringify(err, Object.getOwnPropertyNames(err)))
error.response = err.response != null
? JSON.parse(JSON.stringify(
err.response,
// config contains the request data, too noisy
Object.getOwnPropertyNames(err.response).filter(prop => prop !== 'config')
))
: null
return JSON.stringify({ error })
}

View File

@@ -1,7 +1,7 @@
[package]
name = "lancedb-nodejs"
edition.workspace = true
version = "0.11.1-beta.1"
version = "0.13.0-beta.1"
license.workspace = true
description.workspace = true
repository.workspace = true
@@ -18,7 +18,7 @@ futures.workspace = true
lancedb = { path = "../rust/lancedb", features = ["remote"] }
napi = { version = "2.16.8", default-features = false, features = [
"napi9",
"async",
"async"
] }
napi-derive = "2.16.4"
# Prevent dynamic linking of lzma, which comes from datafusion

View File

@@ -402,6 +402,40 @@ describe("When creating an index", () => {
expect(rst.numRows).toBe(1);
});
it("should be able to query unindexed data", async () => {
await tbl.createIndex("vec");
await tbl.add([
{
id: 300,
vec: Array(32)
.fill(1)
.map(() => Math.random()),
tags: [],
},
]);
const plan1 = await tbl.query().nearestTo(queryVec).explainPlan(true);
expect(plan1).toMatch("LanceScan");
const plan2 = await tbl
.query()
.nearestTo(queryVec)
.fastSearch()
.explainPlan(true);
expect(plan2).not.toMatch("LanceScan");
});
it("should be able to query with row id", async () => {
const results = await tbl
.query()
.nearestTo(queryVec)
.withRowId()
.limit(1)
.toArray();
expect(results.length).toBe(1);
expect(results[0]).toHaveProperty("_rowid");
});
it("should allow parameters to be specified", async () => {
await tbl.createIndex("vec", {
config: Index.ivfPq({

View File

@@ -239,6 +239,29 @@ export class QueryBase<NativeQueryType extends NativeQuery | NativeVectorQuery>
return this;
}
/**
* Skip searching un-indexed data. This can make search faster, but will miss
* any data that is not yet indexed.
*
* Use {@link lancedb.Table#optimize} to index all un-indexed data.
*/
fastSearch(): this {
this.doCall((inner: NativeQueryType) => inner.fastSearch());
return this;
}
/**
* Whether to return the row id in the results.
*
* This column can be used to match results between different queries. For
* example, to match results from a full text search and a vector search in
* order to perform hybrid search.
*/
withRowId(): this {
this.doCall((inner: NativeQueryType) => inner.withRowId());
return this;
}
protected nativeExecute(
options?: Partial<QueryExecutionOptions>,
): Promise<NativeBatchIterator> {

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-arm64",
"version": "0.11.1-beta.1",
"version": "0.13.0-beta.1",
"os": ["darwin"],
"cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-x64",
"version": "0.11.1-beta.1",
"version": "0.13.0-beta.1",
"os": ["darwin"],
"cpu": ["x64"],
"main": "lancedb.darwin-x64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.11.1-beta.1",
"version": "0.13.0-beta.1",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.11.1-beta.1",
"version": "0.13.0-beta.1",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node",

View File

@@ -0,0 +1,3 @@
# `@lancedb/lancedb-win32-arm64-msvc`
This is the **aarch64-pc-windows-msvc** binary for `@lancedb/lancedb`

View File

@@ -0,0 +1,18 @@
{
"name": "@lancedb/lancedb-win32-arm64-msvc",
"version": "0.13.0-beta.1",
"os": [
"win32"
],
"cpu": [
"arm64"
],
"main": "lancedb.win32-arm64-msvc.node",
"files": [
"lancedb.win32-arm64-msvc.node"
],
"license": "Apache 2.0",
"engines": {
"node": ">= 18"
}
}

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.11.1-beta.1",
"version": "0.13.0-beta.1",
"os": ["win32"],
"cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node",

View File

@@ -1,12 +1,12 @@
{
"name": "@lancedb/lancedb",
"version": "0.11.1-beta.1",
"version": "0.12.0",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "@lancedb/lancedb",
"version": "0.11.1-beta.1",
"version": "0.12.0",
"cpu": [
"x64",
"arm64"

View File

@@ -10,7 +10,7 @@
"vector database",
"ann"
],
"version": "0.11.1-beta.1",
"version": "0.13.0-beta.1",
"main": "dist/index.js",
"exports": {
".": "./dist/index.js",

View File

@@ -82,7 +82,7 @@ pub struct OpenTableOptions {
#[napi::module_init]
fn init() {
let env = Env::new()
.filter_or("LANCEDB_LOG", "trace")
.filter_or("LANCEDB_LOG", "warn")
.write_style("LANCEDB_LOG_STYLE");
env_logger::init_from_env(env);
}

View File

@@ -80,6 +80,16 @@ impl Query {
Ok(VectorQuery { inner })
}
#[napi]
pub fn fast_search(&mut self) {
self.inner = self.inner.clone().fast_search();
}
#[napi]
pub fn with_row_id(&mut self) {
self.inner = self.inner.clone().with_row_id();
}
#[napi(catch_unwind)]
pub async fn execute(
&self,
@@ -183,6 +193,16 @@ impl VectorQuery {
self.inner = self.inner.clone().offset(offset as usize);
}
#[napi]
pub fn fast_search(&mut self) {
self.inner = self.inner.clone().fast_search();
}
#[napi]
pub fn with_row_id(&mut self) {
self.inner = self.inner.clone().with_row_id();
}
#[napi(catch_unwind)]
pub async fn execute(
&self,

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.15.0"
current_version = "0.16.0-beta.0"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-python"
version = "0.15.0"
version = "0.16.0-beta.0"
edition.workspace = true
description = "Python bindings for LanceDB"
license.workspace = true

View File

@@ -3,13 +3,11 @@ name = "lancedb"
# version in Cargo.toml
dependencies = [
"deprecation",
"pylance==0.19.1",
"requests>=2.31.0",
"nest-asyncio~=1.0",
"pylance==0.19.2-beta.3",
"tqdm>=4.27.0",
"pydantic>=1.10",
"attrs>=21.3.0",
"packaging",
"cachetools",
"overrides>=0.7",
]
description = "lancedb"
@@ -61,6 +59,7 @@ dev = ["ruff", "pre-commit"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
clip = ["torch", "pillow", "open-clip"]
embeddings = [
"requests>=2.31.0",
"openai>=1.6.1",
"sentence-transformers",
"torch",

View File

@@ -19,12 +19,10 @@ from typing import Dict, Optional, Union, Any
__version__ = importlib.metadata.version("lancedb")
from lancedb.remote import ClientConfig
from ._lancedb import connect as lancedb_connect
from .common import URI, sanitize_uri
from .db import AsyncConnection, DBConnection, LanceDBConnection
from .remote.db import RemoteDBConnection
from .remote import ClientConfig
from .schema import vector
from .table import AsyncTable
@@ -37,6 +35,7 @@ def connect(
host_override: Optional[str] = None,
read_consistency_interval: Optional[timedelta] = None,
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
client_config: Union[ClientConfig, Dict[str, Any], None] = None,
**kwargs: Any,
) -> DBConnection:
"""Connect to a LanceDB database.
@@ -64,14 +63,10 @@ def connect(
the last check, then the table will be checked for updates. Note: this
consistency only applies to read operations. Write operations are
always consistent.
request_thread_pool: int or ThreadPoolExecutor, optional
The thread pool to use for making batch requests to the LanceDB Cloud API.
If an integer, then a ThreadPoolExecutor will be created with that
number of threads. If None, then a ThreadPoolExecutor will be created
with the default number of threads. If a ThreadPoolExecutor, then that
executor will be used for making requests. This is for LanceDB Cloud
only and is only used when making batch requests (i.e., passing in
multiple queries to the search method at once).
client_config: ClientConfig or dict, optional
Configuration options for the LanceDB Cloud HTTP client. If a dict, then
the keys are the attributes of the ClientConfig class. If None, then the
default configuration is used.
Examples
--------
@@ -94,6 +89,8 @@ def connect(
conn : DBConnection
A connection to a LanceDB database.
"""
from .remote.db import RemoteDBConnection
if isinstance(uri, str) and uri.startswith("db://"):
if api_key is None:
api_key = os.environ.get("LANCEDB_API_KEY")
@@ -106,7 +103,9 @@ def connect(
api_key,
region,
host_override,
# TODO: remove this (deprecation warning downstream)
request_thread_pool=request_thread_pool,
client_config=client_config,
**kwargs,
)

View File

@@ -36,6 +36,8 @@ class Connection(object):
data_storage_version: Optional[str] = None,
enable_v2_manifest_paths: Optional[bool] = None,
) -> Table: ...
async def rename_table(self, old_name: str, new_name: str) -> None: ...
async def drop_table(self, name: str) -> None: ...
class Table:
def name(self) -> str: ...

View File

@@ -817,6 +817,18 @@ class AsyncConnection(object):
table = await self._inner.open_table(name, storage_options, index_cache_size)
return AsyncTable(table)
async def rename_table(self, old_name: str, new_name: str):
"""Rename a table in the database.
Parameters
----------
old_name: str
The current name of the table.
new_name: str
The new name of the table.
"""
await self._inner.rename_table(old_name, new_name)
async def drop_table(self, name: str):
"""Drop a table from the database.

View File

@@ -27,3 +27,4 @@ from .imagebind import ImageBindEmbeddings
from .utils import with_embeddings
from .jinaai import JinaEmbeddings
from .watsonx import WatsonxEmbeddings
from .voyageai import VoyageAIEmbeddingFunction

View File

@@ -13,7 +13,6 @@
import os
import io
import requests
import base64
from urllib.parse import urlparse
from pathlib import Path
@@ -226,6 +225,8 @@ class JinaEmbeddings(EmbeddingFunction):
return [result["embedding"] for result in sorted_embeddings]
def _init_client(self):
import requests
if JinaEmbeddings._session is None:
if self.api_key is None and os.environ.get("JINA_API_KEY") is None:
api_key_not_found_help("jina")

View File

@@ -0,0 +1,127 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import ClassVar, List, Union
import numpy as np
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction
from .registry import register
from .utils import api_key_not_found_help, TEXT
@register("voyageai")
class VoyageAIEmbeddingFunction(TextEmbeddingFunction):
"""
An embedding function that uses the VoyageAI API
https://docs.voyageai.com/docs/embeddings
Parameters
----------
name: str
The name of the model to use. List of acceptable models:
* voyage-3
* voyage-3-lite
* voyage-finance-2
* voyage-multilingual-2
* voyage-law-2
* voyage-code-2
Examples
--------
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import EmbeddingFunctionRegistry
voyageai = EmbeddingFunctionRegistry
.get_instance()
.get("voyageai")
.create(name="voyage-3")
class TextModel(LanceModel):
text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
data = [ { "text": "hello world" },
{ "text": "goodbye world" }]
db = lancedb.connect("~/.lancedb")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(data)
"""
name: str
client: ClassVar = None
def ndims(self):
if self.name == "voyage-3-lite":
return 512
elif self.name == "voyage-code-2":
return 1536
elif self.name in [
"voyage-3",
"voyage-finance-2",
"voyage-multilingual-2",
"voyage-law-2",
]:
return 1024
else:
raise ValueError(f"Model {self.name} not supported")
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
return self.compute_source_embeddings(query, input_type="query")
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
texts = self.sanitize_input(texts)
input_type = (
kwargs.get("input_type") or "document"
) # assume source input type if not passed by `compute_query_embeddings`
return self.generate_embeddings(texts, input_type=input_type)
def generate_embeddings(
self, texts: Union[List[str], np.ndarray], *args, **kwargs
) -> List[np.array]:
"""
Get the embeddings for the given texts
Parameters
----------
texts: list[str] or np.ndarray (of str)
The texts to embed
input_type: Optional[str]
truncation: Optional[bool]
"""
VoyageAIEmbeddingFunction._init_client()
rs = VoyageAIEmbeddingFunction.client.embed(
texts=texts, model=self.name, **kwargs
)
return [emb for emb in rs.embeddings]
@staticmethod
def _init_client():
if VoyageAIEmbeddingFunction.client is None:
voyageai = attempt_import_or_raise("voyageai")
if os.environ.get("VOYAGE_API_KEY") is None:
api_key_not_found_help("voyageai")
VoyageAIEmbeddingFunction.client = voyageai.Client(
os.environ["VOYAGE_API_KEY"]
)

View File

@@ -110,7 +110,16 @@ class FTS:
remove_stop_words: bool = False,
ascii_folding: bool = False,
):
self._inner = LanceDbIndex.fts(with_position=with_position)
self._inner = LanceDbIndex.fts(
with_position=with_position,
base_tokenizer=base_tokenizer,
language=language,
max_token_length=max_token_length,
lower_case=lower_case,
stem=stem,
remove_stop_words=remove_stop_words,
ascii_folding=ascii_folding,
)
class HnswPq:
@@ -467,6 +476,8 @@ class IvfPq:
The default value is 256.
"""
if distance_type is not None:
distance_type = distance_type.lower()
self._inner = LanceDbIndex.ivf_pq(
distance_type=distance_type,
num_partitions=num_partitions,

View File

@@ -481,6 +481,7 @@ class LanceQueryBuilder(ABC):
>>> plan = table.search(query).explain_plan(True)
>>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
ProjectionExec: expr=[vector@0 as vector, _distance@2 as _distance]
GlobalLimitExec: skip=0, fetch=10
FilterExec: _distance@2 IS NOT NULL
SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], preserve_partitioning=[false]
KNNVectorDistance: metric=l2
@@ -500,7 +501,16 @@ class LanceQueryBuilder(ABC):
nearest={
"column": self._vector_column,
"q": self._query,
"k": self._limit,
"metric": self._metric,
"nprobes": self._nprobes,
"refine_factor": self._refine_factor,
},
prefilter=self._prefilter,
filter=self._str_query,
limit=self._limit,
with_row_id=self._with_row_id,
offset=self._offset,
).explain_plan(verbose)
def vector(self, vector: Union[np.ndarray, list]) -> LanceQueryBuilder:
@@ -1315,6 +1325,48 @@ class AsyncQueryBase(object):
self._inner.offset(offset)
return self
def fast_search(self) -> AsyncQuery:
"""
Skip searching un-indexed data.
This can make queries faster, but will miss any data that has not been
indexed.
!!! tip
You can add new data into an existing index by calling
[AsyncTable.optimize][lancedb.table.AsyncTable.optimize].
"""
self._inner.fast_search()
return self
def with_row_id(self) -> AsyncQuery:
"""
Include the _rowid column in the results.
"""
self._inner.with_row_id()
return self
def postfilter(self) -> AsyncQuery:
"""
If this is called then filtering will happen after the search instead of
before.
By default filtering will be performed before the search. This is how
filtering is typically understood to work. This prefilter step does add some
additional latency. Creating a scalar index on the filter column(s) can
often improve this latency. However, sometimes a filter is too complex or
scalar indices cannot be applied to the column. In these cases postfiltering
can be used instead of prefiltering to improve latency.
Post filtering applies the filter to the results of the search. This
means we only run the filter on a much smaller set of data. However, it can
cause the query to return fewer than `limit` results (or even no results) if
none of the nearest results match the filter.
Post filtering happens during the "refine stage" (described in more detail in
@see {@link VectorQuery#refineFactor}). This means that setting a higher refine
factor can often help restore some of the results lost by post filtering.
"""
self._inner.postfilter()
return self
async def to_batches(
self, *, max_batch_length: Optional[int] = None
) -> AsyncRecordBatchReader:
@@ -1618,30 +1670,6 @@ class AsyncVectorQuery(AsyncQueryBase):
self._inner.distance_type(distance_type)
return self
def postfilter(self) -> AsyncVectorQuery:
"""
If this is called then filtering will happen after the vector search instead of
before.
By default filtering will be performed before the vector search. This is how
filtering is typically understood to work. This prefilter step does add some
additional latency. Creating a scalar index on the filter column(s) can
often improve this latency. However, sometimes a filter is too complex or
scalar indices cannot be applied to the column. In these cases postfiltering
can be used instead of prefiltering to improve latency.
Post filtering applies the filter to the results of the vector search. This
means we only run the filter on a much smaller set of data. However, it can
cause the query to return fewer than `limit` results (or even no results) if
none of the nearest results match the filter.
Post filtering happens during the "refine stage" (described in more detail in
@see {@link VectorQuery#refineFactor}). This means that setting a higher refine
factor can often help restore some of the results lost by post filtering.
"""
self._inner.postfilter()
return self
def bypass_vector_index(self) -> AsyncVectorQuery:
"""
If this is called then any vector index is skipped

View File

@@ -11,62 +11,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import timedelta
from typing import List, Optional
import attrs
from lancedb import __version__
import pyarrow as pa
from pydantic import BaseModel
from lancedb.common import VECTOR_COLUMN_NAME
__all__ = ["LanceDBClient", "VectorQuery", "VectorQueryResult"]
class VectorQuery(BaseModel):
# vector to search for
vector: List[float]
# sql filter to refine the query with
filter: Optional[str] = None
# top k results to return
k: int
# # metrics
_metric: str = "L2"
# which columns to return in the results
columns: Optional[List[str]] = None
# optional query parameters for tuning the results,
# e.g. `{"nprobes": "10", "refine_factor": "10"}`
nprobes: int = 10
refine_factor: Optional[int] = None
vector_column: str = VECTOR_COLUMN_NAME
fast_search: bool = False
@attrs.define
class VectorQueryResult:
# for now the response is directly seralized into a pandas dataframe
tbl: pa.Table
def to_arrow(self) -> pa.Table:
return self.tbl
class LanceDBClient(abc.ABC):
@abc.abstractmethod
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
"""Query the LanceDB server for the given table and query."""
pass
__all__ = ["TimeoutConfig", "RetryConfig", "ClientConfig"]
@dataclass
@@ -165,8 +116,8 @@ class RetryConfig:
@dataclass
class ClientConfig:
user_agent: str = f"LanceDB-Python-Client/{__version__}"
retry_config: Optional[RetryConfig] = None
timeout_config: Optional[TimeoutConfig] = None
retry_config: RetryConfig = field(default_factory=RetryConfig)
timeout_config: Optional[TimeoutConfig] = field(default_factory=TimeoutConfig)
def __post_init__(self):
if isinstance(self.retry_config, dict):

View File

@@ -1,25 +0,0 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable, Union
import pyarrow as pa
def to_ipc_binary(table: Union[pa.Table, Iterable[pa.RecordBatch]]) -> bytes:
"""Serialize a PyArrow Table to IPC binary."""
sink = pa.BufferOutputStream()
if isinstance(table, Iterable):
table = pa.Table.from_batches(table)
with pa.ipc.new_stream(sink, table.schema) as writer:
writer.write_table(table)
return sink.getvalue().to_pybytes()

View File

@@ -1,269 +0,0 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Union
from urllib.parse import urljoin
import attrs
import pyarrow as pa
import requests
from pydantic import BaseModel
from requests.adapters import HTTPAdapter
from urllib3 import Retry
from lancedb.common import Credential
from lancedb.remote import VectorQuery, VectorQueryResult
from lancedb.remote.connection_timeout import LanceDBClientHTTPAdapterFactory
from lancedb.remote.errors import LanceDBClientError
ARROW_STREAM_CONTENT_TYPE = "application/vnd.apache.arrow.stream"
def _check_not_closed(f):
@functools.wraps(f)
def wrapped(self, *args, **kwargs):
if self.closed:
raise ValueError("Connection is closed")
return f(self, *args, **kwargs)
return wrapped
def _read_ipc(resp: requests.Response) -> pa.Table:
resp_body = resp.content
with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader:
return reader.read_all()
@attrs.define(slots=False)
class RestfulLanceDBClient:
db_name: str
region: str
api_key: Credential
host_override: Optional[str] = attrs.field(default=None)
closed: bool = attrs.field(default=False, init=False)
connection_timeout: float = attrs.field(default=120.0, kw_only=True)
read_timeout: float = attrs.field(default=300.0, kw_only=True)
@functools.cached_property
def session(self) -> requests.Session:
sess = requests.Session()
retry_adapter_instance = retry_adapter(retry_adapter_options())
sess.mount(urljoin(self.url, "/v1/table/"), retry_adapter_instance)
adapter_class = LanceDBClientHTTPAdapterFactory()
sess.mount("https://", adapter_class())
return sess
@property
def url(self) -> str:
return (
self.host_override
or f"https://{self.db_name}.{self.region}.api.lancedb.com"
)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
return False # Do not suppress exceptions
def close(self):
self.session.close()
self.closed = True
@functools.cached_property
def headers(self) -> Dict[str, str]:
headers = {
"x-api-key": self.api_key,
}
if self.region == "local": # Local test mode
headers["Host"] = f"{self.db_name}.{self.region}.api.lancedb.com"
if self.host_override:
headers["x-lancedb-database"] = self.db_name
return headers
@staticmethod
def _check_status(resp: requests.Response):
# Leaving request id empty for now, as we'll be replacing this impl
# with the Rust one shortly.
if resp.status_code == 404:
raise LanceDBClientError(
f"Not found: {resp.text}", request_id="", status_code=404
)
elif 400 <= resp.status_code < 500:
raise LanceDBClientError(
f"Bad Request: {resp.status_code}, error: {resp.text}",
request_id="",
status_code=resp.status_code,
)
elif 500 <= resp.status_code < 600:
raise LanceDBClientError(
f"Internal Server Error: {resp.status_code}, error: {resp.text}",
request_id="",
status_code=resp.status_code,
)
elif resp.status_code != 200:
raise LanceDBClientError(
f"Unknown Error: {resp.status_code}, error: {resp.text}",
request_id="",
status_code=resp.status_code,
)
@_check_not_closed
def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None):
"""Send a GET request and returns the deserialized response payload."""
if isinstance(params, BaseModel):
params: Dict[str, Any] = params.dict(exclude_none=True)
with self.session.get(
urljoin(self.url, uri),
params=params,
headers=self.headers,
timeout=(self.connection_timeout, self.read_timeout),
) as resp:
self._check_status(resp)
return resp.json()
@_check_not_closed
def post(
self,
uri: str,
data: Optional[Union[Dict[str, Any], BaseModel, bytes]] = None,
params: Optional[Dict[str, Any]] = None,
content_type: Optional[str] = None,
deserialize: Callable = lambda resp: resp.json(),
request_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Send a POST request and returns the deserialized response payload.
Parameters
----------
uri : str
The uri to send the POST request to.
data: Union[Dict[str, Any], BaseModel]
request_id: Optional[str]
Optional client side request id to be sent in the request headers.
"""
if isinstance(data, BaseModel):
data: Dict[str, Any] = data.dict(exclude_none=True)
if isinstance(data, bytes):
req_kwargs = {"data": data}
else:
req_kwargs = {"json": data}
headers = self.headers.copy()
if content_type is not None:
headers["content-type"] = content_type
if request_id is not None:
headers["x-request-id"] = request_id
with self.session.post(
urljoin(self.url, uri),
headers=headers,
params=params,
timeout=(self.connection_timeout, self.read_timeout),
**req_kwargs,
) as resp:
self._check_status(resp)
return deserialize(resp)
@_check_not_closed
def list_tables(self, limit: int, page_token: Optional[str] = None) -> List[str]:
"""List all tables in the database."""
if page_token is None:
page_token = ""
json = self.get("/v1/table/", {"limit": limit, "page_token": page_token})
return json["tables"]
@_check_not_closed
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
"""Query a table."""
tbl = self.post(f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc)
return VectorQueryResult(tbl)
def mount_retry_adapter_for_table(self, table_name: str) -> None:
"""
Adds an http adapter to session that will retry retryable requests to the table.
"""
retry_options = retry_adapter_options(methods=["GET", "POST"])
retry_adapter_instance = retry_adapter(retry_options)
session = self.session
session.mount(
urljoin(self.url, f"/v1/table/{table_name}/query/"), retry_adapter_instance
)
session.mount(
urljoin(self.url, f"/v1/table/{table_name}/describe/"),
retry_adapter_instance,
)
session.mount(
urljoin(self.url, f"/v1/table/{table_name}/index/list/"),
retry_adapter_instance,
)
def retry_adapter_options(methods=["GET"]) -> Dict[str, Any]:
return {
"retries": int(os.environ.get("LANCE_CLIENT_MAX_RETRIES", "3")),
"connect_retries": int(os.environ.get("LANCE_CLIENT_CONNECT_RETRIES", "3")),
"read_retries": int(os.environ.get("LANCE_CLIENT_READ_RETRIES", "3")),
"backoff_factor": float(
os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_FACTOR", "0.25")
),
"backoff_jitter": float(
os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_JITTER", "0.25")
),
"statuses": [
int(i.strip())
for i in os.environ.get(
"LANCE_CLIENT_RETRY_STATUSES", "429, 500, 502, 503"
).split(",")
],
"methods": methods,
}
def retry_adapter(options: Dict[str, Any]) -> HTTPAdapter:
total_retries = options["retries"]
connect_retries = options["connect_retries"]
read_retries = options["read_retries"]
backoff_factor = options["backoff_factor"]
backoff_jitter = options["backoff_jitter"]
statuses = options["statuses"]
methods = frozenset(options["methods"])
logging.debug(
f"Setting up retry adapter with {total_retries} retries," # noqa G003
+ f"connect retries {connect_retries}, read retries {read_retries},"
+ f"backoff factor {backoff_factor}, statuses {statuses}, "
+ f"methods {methods}"
)
return HTTPAdapter(
max_retries=Retry(
total=total_retries,
connect=connect_retries,
read=read_retries,
backoff_factor=backoff_factor,
backoff_jitter=backoff_jitter,
status_forcelist=statuses,
allowed_methods=methods,
)
)

View File

@@ -1,115 +0,0 @@
# Copyright 2024 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This module contains an adapter that will close connections if they have not been
# used before a certain timeout. This is necessary because some load balancers will
# close connections after a certain amount of time, but the request module may not yet
# have received the FIN/ACK and will try to reuse the connection.
#
# TODO some of the code here can be simplified if/when this PR is merged:
# https://github.com/urllib3/urllib3/pull/3275
import datetime
import logging
import os
from requests.adapters import HTTPAdapter
from urllib3.connection import HTTPSConnection
from urllib3.connectionpool import HTTPSConnectionPool
from urllib3.poolmanager import PoolManager
def get_client_connection_timeout() -> int:
return int(os.environ.get("LANCE_CLIENT_CONNECTION_TIMEOUT", "300"))
class LanceDBHTTPSConnection(HTTPSConnection):
"""
HTTPSConnection that tracks the last time it was used.
"""
idle_timeout: datetime.timedelta
last_activity: datetime.datetime
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.last_activity = datetime.datetime.now()
def request(self, *args, **kwargs):
self.last_activity = datetime.datetime.now()
super().request(*args, **kwargs)
def is_expired(self):
return datetime.datetime.now() - self.last_activity > self.idle_timeout
def LanceDBHTTPSConnectionPoolFactory(client_idle_timeout: int):
"""
Creates a connection pool class that can be used to close idle connections.
"""
class LanceDBHTTPSConnectionPool(HTTPSConnectionPool):
# override the connection class
ConnectionCls = LanceDBHTTPSConnection
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _get_conn(self, timeout: float | None = None):
logging.debug("Getting https connection")
conn = super()._get_conn(timeout)
if conn.is_expired():
logging.debug("Closing expired connection")
conn.close()
return conn
def _new_conn(self):
conn = super()._new_conn()
conn.idle_timeout = datetime.timedelta(seconds=client_idle_timeout)
return conn
return LanceDBHTTPSConnectionPool
class LanceDBClientPoolManager(PoolManager):
def __init__(
self, client_idle_timeout: int, num_pools: int, maxsize: int, **kwargs
):
super().__init__(num_pools=num_pools, maxsize=maxsize, **kwargs)
# inject our connection pool impl
connection_pool_class = LanceDBHTTPSConnectionPoolFactory(
client_idle_timeout=client_idle_timeout
)
self.pool_classes_by_scheme["https"] = connection_pool_class
def LanceDBClientHTTPAdapterFactory():
"""
Creates an HTTPAdapter class that can be used to close idle connections
"""
# closure over the timeout
client_idle_timeout = get_client_connection_timeout()
class LanceDBClientRequestHTTPAdapter(HTTPAdapter):
def init_poolmanager(self, connections, maxsize, block=False):
# inject our pool manager impl
self.poolmanager = LanceDBClientPoolManager(
client_idle_timeout=client_idle_timeout,
num_pools=connections,
maxsize=maxsize,
block=block,
)
return LanceDBClientRequestHTTPAdapter

View File

@@ -11,13 +11,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from datetime import timedelta
import logging
import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import Iterable, List, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, Union
from urllib.parse import urlparse
import warnings
from cachetools import TTLCache
from lancedb import connect_async
from lancedb.remote import ClientConfig
import pyarrow as pa
from overrides import override
@@ -25,10 +28,8 @@ from ..common import DATA
from ..db import DBConnection
from ..embeddings import EmbeddingFunctionConfig
from ..pydantic import LanceModel
from ..table import Table, sanitize_create_table
from ..table import Table
from ..util import validate_table_name
from .arrow import to_ipc_binary
from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient
class RemoteDBConnection(DBConnection):
@@ -41,26 +42,70 @@ class RemoteDBConnection(DBConnection):
region: str,
host_override: Optional[str] = None,
request_thread_pool: Optional[ThreadPoolExecutor] = None,
connection_timeout: float = 120.0,
read_timeout: float = 300.0,
client_config: Union[ClientConfig, Dict[str, Any], None] = None,
connection_timeout: Optional[float] = None,
read_timeout: Optional[float] = None,
):
"""Connect to a remote LanceDB database."""
if isinstance(client_config, dict):
client_config = ClientConfig(**client_config)
elif client_config is None:
client_config = ClientConfig()
# These are legacy options from the old Python-based client. We keep them
# here for backwards compatibility, but will remove them in a future release.
if request_thread_pool is not None:
warnings.warn(
"request_thread_pool is no longer used and will be removed in "
"a future release.",
DeprecationWarning,
)
if connection_timeout is not None:
warnings.warn(
"connection_timeout is deprecated and will be removed in a future "
"release. Please use client_config.timeout_config.connect_timeout "
"instead.",
DeprecationWarning,
)
client_config.timeout_config.connect_timeout = timedelta(
seconds=connection_timeout
)
if read_timeout is not None:
warnings.warn(
"read_timeout is deprecated and will be removed in a future release. "
"Please use client_config.timeout_config.read_timeout instead.",
DeprecationWarning,
)
client_config.timeout_config.read_timeout = timedelta(seconds=read_timeout)
parsed = urlparse(db_url)
if parsed.scheme != "db":
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
self._uri = str(db_url)
self.db_name = parsed.netloc
self.api_key = api_key
self._client = RestfulLanceDBClient(
self.db_name,
region,
api_key,
host_override,
connection_timeout=connection_timeout,
read_timeout=read_timeout,
import nest_asyncio
nest_asyncio.apply()
try:
self._loop = asyncio.get_running_loop()
except RuntimeError:
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self.client_config = client_config
self._conn = self._loop.run_until_complete(
connect_async(
db_url,
api_key=api_key,
region=region,
host_override=host_override,
client_config=client_config,
)
)
self._request_thread_pool = request_thread_pool
self._table_cache = TTLCache(maxsize=10000, ttl=300)
def __repr__(self) -> str:
return f"RemoteConnect(name={self.db_name})"
@@ -82,16 +127,9 @@ class RemoteDBConnection(DBConnection):
-------
An iterator of table names.
"""
while True:
result = self._client.list_tables(limit, page_token)
if len(result) > 0:
page_token = result[len(result) - 1]
else:
break
for item in result:
self._table_cache[item] = True
yield item
return self._loop.run_until_complete(
self._conn.table_names(start_after=page_token, limit=limit)
)
@override
def open_table(self, name: str, *, index_cache_size: Optional[int] = None) -> Table:
@@ -108,20 +146,14 @@ class RemoteDBConnection(DBConnection):
"""
from .table import RemoteTable
self._client.mount_retry_adapter_for_table(name)
if index_cache_size is not None:
logging.info(
"index_cache_size is ignored in LanceDb Cloud"
" (there is no local cache to configure)"
)
# check if table exists
if self._table_cache.get(name) is None:
self._client.post(f"/v1/table/{name}/describe/")
self._table_cache[name] = True
return RemoteTable(self, name)
table = self._loop.run_until_complete(self._conn.open_table(name))
return RemoteTable(table, self.db_name, self._loop)
@override
def create_table(
@@ -233,27 +265,20 @@ class RemoteDBConnection(DBConnection):
"Please vote https://github.com/lancedb/lancedb/issues/626 "
"for this feature."
)
if mode is not None:
logging.warning("mode is not yet supported on LanceDB Cloud.")
data, schema = sanitize_create_table(
data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
)
from .table import RemoteTable
data = to_ipc_binary(data)
request_id = uuid.uuid4().hex
self._client.post(
f"/v1/table/{name}/create/",
data=data,
request_id=request_id,
content_type=ARROW_STREAM_CONTENT_TYPE,
table = self._loop.run_until_complete(
self._conn.create_table(
name,
data,
mode=mode,
schema=schema,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
)
self._table_cache[name] = True
return RemoteTable(self, name)
return RemoteTable(table, self.db_name, self._loop)
@override
def drop_table(self, name: str):
@@ -264,11 +289,7 @@ class RemoteDBConnection(DBConnection):
name: str
The name of the table.
"""
self._client.post(
f"/v1/table/{name}/drop/",
)
self._table_cache.pop(name, default=None)
self._loop.run_until_complete(self._conn.drop_table(name))
@override
def rename_table(self, cur_name: str, new_name: str):
@@ -281,12 +302,7 @@ class RemoteDBConnection(DBConnection):
new_name: str
The new name of the table.
"""
self._client.post(
f"/v1/table/{cur_name}/rename/",
data={"new_table_name": new_name},
)
self._table_cache.pop(cur_name, default=None)
self._table_cache[new_name] = True
self._loop.run_until_complete(self._conn.rename_table(cur_name, new_name))
async def close(self):
"""Close the connection to the database."""

View File

@@ -11,53 +11,57 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from datetime import timedelta
import asyncio
import logging
import uuid
from concurrent.futures import Future
from functools import cached_property
from typing import Dict, Iterable, List, Optional, Union, Literal
from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfPq, LabelList
import pyarrow as pa
from lance import json_to_schema
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
from lancedb.merge import LanceMergeInsertBuilder
from lancedb.embeddings import EmbeddingFunctionRegistry
from ..query import LanceVectorQueryBuilder, LanceQueryBuilder
from ..table import Query, Table, _sanitize_data
from ..util import value_to_sql, infer_vector_column_name
from .arrow import to_ipc_binary
from .client import ARROW_STREAM_CONTENT_TYPE
from .db import RemoteDBConnection
from ..table import AsyncTable, Query, Table
class RemoteTable(Table):
def __init__(self, conn: RemoteDBConnection, name: str):
self._conn = conn
self.name = name
def __init__(
self,
table: AsyncTable,
db_name: str,
loop: Optional[asyncio.AbstractEventLoop] = None,
):
self._loop = loop
self._table = table
self.db_name = db_name
@property
def name(self) -> str:
"""The name of the table"""
return self._table.name
def __repr__(self) -> str:
return f"RemoteTable({self._conn.db_name}.{self.name})"
return f"RemoteTable({self.db_name}.{self.name})"
def __len__(self) -> int:
self.count_rows(None)
@cached_property
@property
def schema(self) -> pa.Schema:
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#)
of this Table
"""
resp = self._conn._client.post(f"/v1/table/{self.name}/describe/")
schema = json_to_schema(resp["schema"])
return schema
return self._loop.run_until_complete(self._table.schema())
@property
def version(self) -> int:
"""Get the current version of the table"""
resp = self._conn._client.post(f"/v1/table/{self.name}/describe/")
return resp["version"]
return self._loop.run_until_complete(self._table.version())
@cached_property
def embedding_functions(self) -> dict:
@@ -84,20 +88,18 @@ class RemoteTable(Table):
def list_indices(self):
"""List all the indices on the table"""
resp = self._conn._client.post(f"/v1/table/{self.name}/index/list/")
return resp
return self._loop.run_until_complete(self._table.list_indices())
def index_stats(self, index_uuid: str):
"""List all the stats of a specified index"""
resp = self._conn._client.post(
f"/v1/table/{self.name}/index/{index_uuid}/stats/"
)
return resp
return self._loop.run_until_complete(self._table.index_stats(index_uuid))
def create_scalar_index(
self,
column: str,
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST", "scalar"] = "scalar",
*,
replace: bool = False,
):
"""Creates a scalar index
Parameters
@@ -107,36 +109,51 @@ class RemoteTable(Table):
or string column.
index_type : str
The index type of the scalar index. Must be "scalar" (BTREE),
"BTREE", "BITMAP", or "LABEL_LIST"
"BTREE", "BITMAP", or "LABEL_LIST",
replace : bool
If True, replace the existing index with the new one.
"""
if index_type == "scalar" or index_type == "BTREE":
config = BTree()
elif index_type == "BITMAP":
config = Bitmap()
elif index_type == "LABEL_LIST":
config = LabelList()
else:
raise ValueError(f"Unknown index type: {index_type}")
data = {
"column": column,
"index_type": index_type,
"replace": True,
}
resp = self._conn._client.post(
f"/v1/table/{self.name}/create_scalar_index/", data=data
self._loop.run_until_complete(
self._table.create_index(column, config=config, replace=replace)
)
return resp
def create_fts_index(
self,
column: str,
*,
replace: bool = False,
with_position: bool = True,
# tokenizer configs:
base_tokenizer: str = "simple",
language: str = "English",
max_token_length: Optional[int] = 40,
lower_case: bool = True,
stem: bool = False,
remove_stop_words: bool = False,
ascii_folding: bool = False,
):
data = {
"column": column,
"index_type": "FTS",
"replace": replace,
}
resp = self._conn._client.post(
f"/v1/table/{self.name}/create_index/", data=data
config = FTS(
with_position=with_position,
base_tokenizer=base_tokenizer,
language=language,
max_token_length=max_token_length,
lower_case=lower_case,
stem=stem,
remove_stop_words=remove_stop_words,
ascii_folding=ascii_folding,
)
self._loop.run_until_complete(
self._table.create_index(column, config=config, replace=replace)
)
return resp
def create_index(
self,
@@ -204,17 +221,22 @@ class RemoteTable(Table):
"Existing indexes will always be replaced."
)
data = {
"column": vector_column_name,
"index_type": index_type,
"metric_type": metric,
"index_cache_size": index_cache_size,
}
resp = self._conn._client.post(
f"/v1/table/{self.name}/create_index/", data=data
)
index_type = index_type.upper()
if index_type == "VECTOR" or index_type == "IVF_PQ":
config = IvfPq(distance_type=metric)
elif index_type == "IVF_HNSW_PQ":
config = HnswPq(distance_type=metric)
elif index_type == "IVF_HNSW_SQ":
config = HnswSq(distance_type=metric)
else:
raise ValueError(
f"Unknown vector index type: {index_type}. Valid options are"
" 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
)
return resp
self._loop.run_until_complete(
self._table.create_index(vector_column_name, config=config)
)
def add(
self,
@@ -246,22 +268,10 @@ class RemoteTable(Table):
The value to use when filling vectors. Only used if on_bad_vectors="fill".
"""
data, _ = _sanitize_data(
data,
self.schema,
metadata=self.schema.metadata,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
payload = to_ipc_binary(data)
request_id = uuid.uuid4().hex
self._conn._client.post(
f"/v1/table/{self.name}/insert/",
data=payload,
params={"request_id": request_id, "mode": mode},
content_type=ARROW_STREAM_CONTENT_TYPE,
self._loop.run_until_complete(
self._table.add(
data, mode=mode, on_bad_vectors=on_bad_vectors, fill_value=fill_value
)
)
def search(
@@ -337,12 +347,6 @@ class RemoteTable(Table):
# empty query builder is not supported in saas, raise error
if query is None and query_type != "hybrid":
raise ValueError("Empty query is not supported")
vector_column_name = infer_vector_column_name(
schema=self.schema,
query_type=query_type,
query=query,
vector_column_name=vector_column_name,
)
return LanceQueryBuilder.create(
self,
@@ -356,37 +360,9 @@ class RemoteTable(Table):
def _execute_query(
self, query: Query, batch_size: Optional[int] = None
) -> pa.RecordBatchReader:
if (
query.vector is not None
and len(query.vector) > 0
and not isinstance(query.vector[0], float)
):
if self._conn._request_thread_pool is None:
def submit(name, q):
f = Future()
f.set_result(self._conn._client.query(name, q))
return f
else:
def submit(name, q):
return self._conn._request_thread_pool.submit(
self._conn._client.query, name, q
)
results = []
for v in query.vector:
v = list(v)
q = query.copy()
q.vector = v
results.append(submit(self.name, q))
return pa.concat_tables(
[add_index(r.result().to_arrow(), i) for i, r in enumerate(results)]
).to_reader()
else:
result = self._conn._client.query(self.name, query)
return result.to_arrow().to_reader()
return self._loop.run_until_complete(
self._table._execute_query(query, batch_size=batch_size)
)
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
"""Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder]
@@ -403,42 +379,8 @@ class RemoteTable(Table):
on_bad_vectors: str,
fill_value: float,
):
data, _ = _sanitize_data(
new_data,
self.schema,
metadata=None,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
payload = to_ipc_binary(data)
params = {}
if len(merge._on) != 1:
raise ValueError(
"RemoteTable only supports a single on key in merge_insert"
)
params["on"] = merge._on[0]
params["when_matched_update_all"] = str(merge._when_matched_update_all).lower()
if merge._when_matched_update_all_condition is not None:
params["when_matched_update_all_filt"] = (
merge._when_matched_update_all_condition
)
params["when_not_matched_insert_all"] = str(
merge._when_not_matched_insert_all
).lower()
params["when_not_matched_by_source_delete"] = str(
merge._when_not_matched_by_source_delete
).lower()
if merge._when_not_matched_by_source_condition is not None:
params["when_not_matched_by_source_delete_filt"] = (
merge._when_not_matched_by_source_condition
)
self._conn._client.post(
f"/v1/table/{self.name}/merge_insert/",
data=payload,
params=params,
content_type=ARROW_STREAM_CONTENT_TYPE,
self._loop.run_until_complete(
self._table._do_merge(merge, new_data, on_bad_vectors, fill_value)
)
def delete(self, predicate: str):
@@ -488,8 +430,7 @@ class RemoteTable(Table):
x vector _distance # doctest: +SKIP
0 2 [3.0, 4.0] 85.0 # doctest: +SKIP
"""
payload = {"predicate": predicate}
self._conn._client.post(f"/v1/table/{self.name}/delete/", data=payload)
self._loop.run_until_complete(self._table.delete(predicate))
def update(
self,
@@ -539,18 +480,9 @@ class RemoteTable(Table):
2 2 [10.0, 10.0] # doctest: +SKIP
"""
if values is not None and values_sql is not None:
raise ValueError("Only one of values or values_sql can be provided")
if values is None and values_sql is None:
raise ValueError("Either values or values_sql must be provided")
if values is not None:
updates = [[k, value_to_sql(v)] for k, v in values.items()]
else:
updates = [[k, v] for k, v in values_sql.items()]
payload = {"predicate": where, "updates": updates}
self._conn._client.post(f"/v1/table/{self.name}/update/", data=payload)
self._loop.run_until_complete(
self._table.update(where=where, updates=values, updates_sql=values_sql)
)
def cleanup_old_versions(self, *_):
"""cleanup_old_versions() is not supported on the LanceDB cloud"""
@@ -564,12 +496,21 @@ class RemoteTable(Table):
"compact_files() is not supported on the LanceDB cloud"
)
def count_rows(self, filter: Optional[str] = None) -> int:
payload = {"predicate": filter}
resp = self._conn._client.post(
f"/v1/table/{self.name}/count_rows/", data=payload
def optimize(
self,
*,
cleanup_older_than: Optional[timedelta] = None,
delete_unverified: bool = False,
):
"""optimize() is not supported on the LanceDB cloud.
Indices are optimized automatically."""
raise NotImplementedError(
"optimize() is not supported on the LanceDB cloud. "
"Indices are optimized automatically."
)
return resp
def count_rows(self, filter: Optional[str] = None) -> int:
return self._loop.run_until_complete(self._table.count_rows(filter))
def add_columns(self, transforms: Dict[str, str]):
raise NotImplementedError(

View File

@@ -7,6 +7,7 @@ from .openai import OpenaiReranker
from .jinaai import JinaReranker
from .rrf import RRFReranker
from .answerdotai import AnswerdotaiRerankers
from .voyageai import VoyageAIReranker
__all__ = [
"Reranker",
@@ -18,4 +19,5 @@ __all__ = [
"JinaReranker",
"RRFReranker",
"AnswerdotaiRerankers",
"VoyageAIReranker",
]

View File

@@ -12,7 +12,6 @@
# limitations under the License.
import os
import requests
from functools import cached_property
from typing import Union
@@ -57,6 +56,8 @@ class JinaReranker(Reranker):
@cached_property
def _client(self):
import requests
if os.environ.get("JINA_API_KEY") is None and self.api_key is None:
raise ValueError(
"JINA_API_KEY not set. Either set it in your environment or \

View File

@@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from numpy import NaN
from numpy import nan
import pyarrow as pa
from .base import Reranker
@@ -71,7 +71,7 @@ class LinearCombinationReranker(Reranker):
elif self.score == "all":
results = results.append_column(
"_distance",
pa.array([NaN] * len(fts_results), type=pa.float32()),
pa.array([nan] * len(fts_results), type=pa.float32()),
)
return results
@@ -92,7 +92,7 @@ class LinearCombinationReranker(Reranker):
elif self.score == "all":
results = results.append_column(
"_score",
pa.array([NaN] * len(vector_results), type=pa.float32()),
pa.array([nan] * len(vector_results), type=pa.float32()),
)
return results

View File

@@ -0,0 +1,133 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from functools import cached_property
from typing import Union, Optional
import pyarrow as pa
from ..util import attempt_import_or_raise
from .base import Reranker
class VoyageAIReranker(Reranker):
"""
Reranks the results using the VoyageAI Rerank API.
https://docs.voyageai.com/docs/reranker
Parameters
----------
model_name : str, default "rerank-english-v2.0"
The name of the cross encoder model to use. Available voyageai models are:
- rerank-2
- rerank-2-lite
column : str, default "text"
The name of the column to use as input to the cross encoder model.
top_n : int, default None
The number of results to return. If None, will return all results.
return_score : str, default "relevance"
options are "relevance" or "all". Only "relevance" is supported for now.
api_key : str, default None
The API key to use. If None, will use the OPENAI_API_KEY environment variable.
truncation : Optional[bool], default None
"""
def __init__(
self,
model_name: str,
column: str = "text",
top_n: Optional[int] = None,
return_score="relevance",
api_key: Optional[str] = None,
truncation: Optional[bool] = True,
):
super().__init__(return_score)
self.model_name = model_name
self.column = column
self.top_n = top_n
self.api_key = api_key
self.truncation = truncation
@cached_property
def _client(self):
voyageai = attempt_import_or_raise("voyageai")
if os.environ.get("VOYAGE_API_KEY") is None and self.api_key is None:
raise ValueError(
"VOYAGE_API_KEY not set. Either set it in your environment or \
pass it as `api_key` argument to the VoyageAIReranker."
)
return voyageai.Client(
api_key=os.environ.get("VOYAGE_API_KEY") or self.api_key,
)
def _rerank(self, result_set: pa.Table, query: str):
docs = result_set[self.column].to_pylist()
response = self._client.rerank(
query=query,
documents=docs,
top_k=self.top_n,
model=self.model_name,
truncation=self.truncation,
)
results = (
response.results
) # returns list (text, idx, relevance) attributes sorted descending by score
indices, scores = list(
zip(*[(result.index, result.relevance_score) for result in results])
) # tuples
result_set = result_set.take(list(indices))
# add the scores
result_set = result_set.append_column(
"_relevance_score", pa.array(scores, type=pa.float32())
)
return result_set
def rerank_hybrid(
self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
combined_results = self._rerank(combined_results, query)
if self.score == "relevance":
combined_results = self._keep_relevance_score(combined_results)
elif self.score == "all":
raise NotImplementedError(
"return_score='all' not implemented for voyageai reranker"
)
return combined_results
def rerank_vector(
self,
query: str,
vector_results: pa.Table,
):
result_set = self._rerank(vector_results, query)
if self.score == "relevance":
result_set = result_set.drop_columns(["_distance"])
return result_set
def rerank_fts(
self,
query: str,
fts_results: pa.Table,
):
result_set = self._rerank(fts_results, query)
if self.score == "relevance":
result_set = result_set.drop_columns(["_score"])
return result_set

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
import inspect
import time
from abc import ABC, abstractmethod
@@ -32,7 +33,7 @@ import pyarrow.fs as pa_fs
from lance import LanceDataset
from lance.dependencies import _check_for_hugging_face
from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .common import DATA, VEC, VECTOR_COLUMN_NAME, sanitize_uri
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from .merge import LanceMergeInsertBuilder
from .pydantic import LanceModel, model_to_dict
@@ -57,12 +58,14 @@ from .util import (
)
from .index import lang_mapping
from ._lancedb import connect as lancedb_connect
if TYPE_CHECKING:
import PIL
from lance.dataset import CleanupStats, ReaderLike
from ._lancedb import Table as LanceDBTable, OptimizeStats
from .db import LanceDBConnection
from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList, FTS
from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList, FTS, HnswPq, HnswSq
pd = safe_import_pandas()
pl = safe_import_polars()
@@ -893,6 +896,55 @@ class Table(ABC):
For most cases, the default should be fine.
"""
@abstractmethod
def optimize(
self,
*,
cleanup_older_than: Optional[timedelta] = None,
delete_unverified: bool = False,
):
"""
Optimize the on-disk data and indices for better performance.
Modeled after ``VACUUM`` in PostgreSQL.
Optimization covers three operations:
* Compaction: Merges small files into larger ones
* Prune: Removes old versions of the dataset
* Index: Optimizes the indices, adding new data to existing indices
Parameters
----------
cleanup_older_than: timedelta, optional default 7 days
All files belonging to versions older than this will be removed. Set
to 0 days to remove all versions except the latest. The latest version
is never removed.
delete_unverified: bool, default False
Files leftover from a failed transaction may appear to be part of an
in-progress operation (e.g. appending new data) and these files will not
be deleted unless they are at least 7 days old. If delete_unverified is True
then these files will be deleted regardless of their age.
Experimental API
----------------
The optimization process is undergoing active development and may change.
Our goal with these changes is to improve the performance of optimization and
reduce the complexity.
That being said, it is essential today to run optimize if you want the best
performance. It should be stable and safe to use in production, but it our
hope that the API may be simplified (or not even need to be called) in the
future.
The frequency an application shoudl call optimize is based on the frequency of
data modifications. If data is frequently added, deleted, or updated then
optimize should be run frequently. A good rule of thumb is to run optimize if
you have added or modified 100,000 or more records or run more than 20 data
modification operations.
"""
@abstractmethod
def add_columns(self, transforms: Dict[str, str]):
"""
@@ -948,7 +1000,9 @@ class Table(ABC):
return _table_uri(self._conn.uri, self.name)
def _get_fts_index_path(self) -> Tuple[str, pa_fs.FileSystem, bool]:
if get_uri_scheme(self._dataset_uri) != "file":
from .remote.table import RemoteTable
if isinstance(self, RemoteTable) or get_uri_scheme(self._dataset_uri) != "file":
return ("", None, False)
path = join_uri(self._dataset_uri, "_indices", "fts")
fs, path = fs_from_uri(path)
@@ -1969,6 +2023,83 @@ class LanceTable(Table):
"""
return self.to_lance().optimize.compact_files(*args, **kwargs)
def optimize(
self,
*,
cleanup_older_than: Optional[timedelta] = None,
delete_unverified: bool = False,
):
"""
Optimize the on-disk data and indices for better performance.
Modeled after ``VACUUM`` in PostgreSQL.
Optimization covers three operations:
* Compaction: Merges small files into larger ones
* Prune: Removes old versions of the dataset
* Index: Optimizes the indices, adding new data to existing indices
Parameters
----------
cleanup_older_than: timedelta, optional default 7 days
All files belonging to versions older than this will be removed. Set
to 0 days to remove all versions except the latest. The latest version
is never removed.
delete_unverified: bool, default False
Files leftover from a failed transaction may appear to be part of an
in-progress operation (e.g. appending new data) and these files will not
be deleted unless they are at least 7 days old. If delete_unverified is True
then these files will be deleted regardless of their age.
Experimental API
----------------
The optimization process is undergoing active development and may change.
Our goal with these changes is to improve the performance of optimization and
reduce the complexity.
That being said, it is essential today to run optimize if you want the best
performance. It should be stable and safe to use in production, but it our
hope that the API may be simplified (or not even need to be called) in the
future.
The frequency an application shoudl call optimize is based on the frequency of
data modifications. If data is frequently added, deleted, or updated then
optimize should be run frequently. A good rule of thumb is to run optimize if
you have added or modified 100,000 or more records or run more than 20 data
modification operations.
"""
try:
asyncio.get_running_loop()
raise AssertionError(
"Synchronous method called in asynchronous context. "
"If you are writing an asynchronous application "
"then please use the asynchronous APIs"
)
except RuntimeError:
asyncio.run(
self._async_optimize(
cleanup_older_than=cleanup_older_than,
delete_unverified=delete_unverified,
)
)
self.checkout_latest()
async def _async_optimize(
self,
cleanup_older_than: Optional[timedelta] = None,
delete_unverified: bool = False,
):
conn = await lancedb_connect(
sanitize_uri(self._conn.uri),
)
table = AsyncTable(await conn.open_table(self.name))
return await table.optimize(
cleanup_older_than=cleanup_older_than, delete_unverified=delete_unverified
)
def add_columns(self, transforms: Dict[str, str]):
self._dataset_mut.add_columns(transforms)
@@ -2382,7 +2513,9 @@ class AsyncTable:
column: str,
*,
replace: Optional[bool] = None,
config: Optional[Union[IvfPq, BTree, Bitmap, LabelList, FTS]] = None,
config: Optional[
Union[IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS]
] = None,
):
"""Create an index to speed up queries
@@ -2535,7 +2668,44 @@ class AsyncTable:
async def _execute_query(
self, query: Query, batch_size: Optional[int] = None
) -> pa.RecordBatchReader:
pass
# The sync remote table calls into this method, so we need to map the
# query to the async version of the query and run that here. This is only
# used for that code path right now.
async_query = self.query().limit(query.k)
if query.offset > 0:
async_query = async_query.offset(query.offset)
if query.columns:
async_query = async_query.select(query.columns)
if query.filter:
async_query = async_query.where(query.filter)
if query.fast_search:
async_query = async_query.fast_search()
if query.with_row_id:
async_query = async_query.with_row_id()
if query.vector:
async_query = (
async_query.nearest_to(query.vector)
.distance_type(query.metric)
.nprobes(query.nprobes)
)
if query.refine_factor:
async_query = async_query.refine_factor(query.refine_factor)
if query.vector_column:
async_query = async_query.column(query.vector_column)
if not query.prefilter:
async_query = async_query.postfilter()
if isinstance(query.full_text_query, str):
async_query = async_query.nearest_to_text(query.full_text_query)
elif isinstance(query.full_text_query, dict):
fts_query = query.full_text_query["query"]
fts_columns = query.full_text_query.get("columns", []) or []
async_query = async_query.nearest_to_text(fts_query, columns=fts_columns)
table = await async_query.to_arrow()
return table.to_reader()
async def _do_merge(
self,
@@ -2781,7 +2951,7 @@ class AsyncTable:
cleanup_older_than = round(cleanup_older_than.total_seconds() * 1000)
return await self._inner.optimize(cleanup_older_than, delete_unverified)
async def list_indices(self) -> IndexConfig:
async def list_indices(self) -> Iterable[IndexConfig]:
"""
List all indices that have been created with Self::create_index
"""
@@ -2865,3 +3035,8 @@ class IndexStatistics:
]
distance_type: Optional[Literal["l2", "cosine", "dot"]] = None
num_indices: Optional[int] = None
# This exists for backwards compatibility with an older API, which returned
# a dictionary instead of a class.
def __getitem__(self, key):
return getattr(self, key)

View File

@@ -196,6 +196,7 @@ def test_add_optional_vector(tmp_path):
"ollama",
"cohere",
"instructor",
"voyageai",
],
)
def test_embedding_function_safe_model_dump(embedding_type):

View File

@@ -18,7 +18,6 @@ import lancedb
import numpy as np
import pandas as pd
import pytest
import requests
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector
@@ -108,6 +107,7 @@ def test_basic_text_embeddings(alias, tmp_path):
@pytest.mark.slow
def test_openclip(tmp_path):
import requests
from PIL import Image
db = lancedb.connect(tmp_path)
@@ -481,3 +481,22 @@ def test_ollama_embedding(tmp_path):
json.dumps(dumped_model)
except TypeError:
pytest.fail("Failed to JSON serialize the dumped model")
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
def test_voyageai_embedding_function():
voyageai = get_registry().get("voyageai").create(name="voyage-3", max_retries=0)
class TextModel(LanceModel):
text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect("~/lancedb")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()

View File

@@ -235,6 +235,29 @@ async def test_search_fts_async(async_table):
results = await async_table.query().nearest_to_text("puppy").limit(5).to_list()
assert len(results) == 5
expected_count = await async_table.count_rows(
"count > 5000 and contains(text, 'puppy')"
)
expected_count = min(expected_count, 10)
limited_results_pre_filter = await (
async_table.query()
.nearest_to_text("puppy")
.where("count > 5000")
.limit(10)
.to_list()
)
assert len(limited_results_pre_filter) == expected_count
limited_results_post_filter = await (
async_table.query()
.nearest_to_text("puppy")
.where("count > 5000")
.limit(10)
.postfilter()
.to_list()
)
assert len(limited_results_post_filter) <= expected_count
@pytest.mark.asyncio
async def test_search_fts_specify_column_async(async_table):

View File

@@ -49,7 +49,7 @@ async def test_create_scalar_index(some_table: AsyncTable):
# Can recreate if replace=True
await some_table.create_index("id", replace=True)
indices = await some_table.list_indices()
assert str(indices) == '[Index(BTree, columns=["id"])]'
assert str(indices) == '[Index(BTree, columns=["id"], name="id_idx")]'
assert len(indices) == 1
assert indices[0].index_type == "BTree"
assert indices[0].columns == ["id"]
@@ -64,7 +64,7 @@ async def test_create_scalar_index(some_table: AsyncTable):
async def test_create_bitmap_index(some_table: AsyncTable):
await some_table.create_index("id", config=Bitmap())
indices = await some_table.list_indices()
assert str(indices) == '[Index(Bitmap, columns=["id"])]'
assert str(indices) == '[Index(Bitmap, columns=["id"], name="id_idx")]'
indices = await some_table.list_indices()
assert len(indices) == 1
index_name = indices[0].name
@@ -80,7 +80,7 @@ async def test_create_bitmap_index(some_table: AsyncTable):
async def test_create_label_list_index(some_table: AsyncTable):
await some_table.create_index("tags", config=LabelList())
indices = await some_table.list_indices()
assert str(indices) == '[Index(LabelList, columns=["tags"])]'
assert str(indices) == '[Index(LabelList, columns=["tags"], name="tags_idx")]'
@pytest.mark.asyncio

View File

@@ -17,6 +17,7 @@ from typing import Optional
import lance
import lancedb
from lancedb.index import IvfPq
import numpy as np
import pandas.testing as tm
import pyarrow as pa
@@ -330,6 +331,12 @@ async def test_query_async(table_async: AsyncTable):
# Also check an empty query
await check_query(table_async.query().where("id < 0"), expected_num_rows=0)
# with row id
await check_query(
table_async.query().select(["id", "vector"]).with_row_id(),
expected_columns=["id", "vector", "_rowid"],
)
@pytest.mark.asyncio
async def test_query_to_arrow_async(table_async: AsyncTable):
@@ -358,6 +365,25 @@ async def test_query_to_pandas_async(table_async: AsyncTable):
assert df.shape == (0, 4)
@pytest.mark.asyncio
async def test_fast_search_async(tmp_path):
db = await lancedb.connect_async(tmp_path)
vectors = pa.FixedShapeTensorArray.from_numpy_ndarray(
np.random.rand(256, 32)
).storage
table = await db.create_table("test", pa.table({"vector": vectors}))
await table.create_index(
"vector", config=IvfPq(num_partitions=1, num_sub_vectors=1)
)
await table.add(pa.table({"vector": vectors}))
q = [1.0] * 32
plan = await table.query().nearest_to(q).explain_plan(True)
assert "LanceScan" in plan
plan = await table.query().nearest_to(q).fast_search().explain_plan(True)
assert "LanceScan" not in plan
def test_explain_plan(table):
q = LanceVectorQueryBuilder(table, [0, 0], "vector")
plan = q.explain_plan(verbose=True)

View File

@@ -1,96 +0,0 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import attrs
import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
from aiohttp import web
from lancedb.remote.client import RestfulLanceDBClient, VectorQuery
@attrs.define
class MockLanceDBServer:
runner: web.AppRunner = attrs.field(init=False)
site: web.TCPSite = attrs.field(init=False)
async def query_handler(self, request: web.Request) -> web.Response:
table_name = request.match_info["table_name"]
assert table_name == "test_table"
await request.json()
# TODO: do some matching
vecs = pd.Series([np.random.rand(128) for x in range(10)], name="vector")
ids = pd.Series(range(10), name="id")
df = pd.DataFrame([vecs, ids]).T
batch = pa.RecordBatch.from_pandas(
df,
schema=pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 128)),
pa.field("id", pa.int64()),
]
),
)
sink = pa.BufferOutputStream()
with pa.ipc.new_file(sink, batch.schema) as writer:
writer.write_batch(batch)
return web.Response(body=sink.getvalue().to_pybytes())
async def setup(self):
app = web.Application()
app.add_routes([web.post("/table/{table_name}", self.query_handler)])
self.runner = web.AppRunner(app)
await self.runner.setup()
self.site = web.TCPSite(self.runner, "localhost", 8111)
async def start(self):
await self.site.start()
async def stop(self):
await self.runner.cleanup()
@pytest.mark.skip(reason="flaky somehow, fix later")
@pytest.mark.asyncio
async def test_e2e_with_mock_server():
mock_server = MockLanceDBServer()
await mock_server.setup()
await mock_server.start()
try:
with RestfulLanceDBClient("lancedb+http://localhost:8111") as client:
df = (
await client.query(
"test_table",
VectorQuery(
vector=np.random.rand(128).tolist(),
k=10,
_metric="L2",
columns=["id", "vector"],
),
)
).to_pandas()
assert "vector" in df.columns
assert "id" in df.columns
assert client.closed
finally:
# make sure we don't leak resources
await mock_server.stop()

View File

@@ -2,91 +2,19 @@
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import contextlib
from datetime import timedelta
import http.server
import json
import threading
from unittest.mock import MagicMock
import uuid
import lancedb
from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.remote import ClientConfig
from lancedb.remote.errors import HttpError, RetryError
import pyarrow as pa
from lancedb.remote.client import VectorQuery, VectorQueryResult
import pytest
class FakeLanceDBClient:
def close(self):
pass
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
assert table_name == "test"
t = pa.schema([]).empty_table()
return VectorQueryResult(t)
def post(self, path: str):
pass
def mount_retry_adapter_for_table(self, table_name: str):
pass
def test_remote_db():
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
setattr(conn, "_client", FakeLanceDBClient())
table = conn["test"]
table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
table.search([1.0, 2.0]).to_pandas()
def test_create_empty_table():
client = MagicMock()
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
conn._client = client
schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
client.post.return_value = {"status": "ok"}
table = conn.create_table("test", schema=schema)
assert table.name == "test"
assert client.post.call_args[0][0] == "/v1/table/test/create/"
json_schema = {
"fields": [
{
"name": "vector",
"nullable": True,
"type": {
"type": "fixed_size_list",
"fields": [
{"name": "item", "nullable": True, "type": {"type": "float"}}
],
"length": 2,
},
},
]
}
client.post.return_value = {"schema": json_schema}
assert table.schema == schema
assert client.post.call_args[0][0] == "/v1/table/test/describe/"
client.post.return_value = 0
assert table.count_rows(None) == 0
def test_create_table_with_recordbatches():
client = MagicMock()
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
conn._client = client
batch = pa.RecordBatch.from_arrays([pa.array([[1.0, 2.0], [3.0, 4.0]])], ["vector"])
client.post.return_value = {"status": "ok"}
table = conn.create_table("test", [batch], schema=batch.schema)
assert table.name == "test"
assert client.post.call_args[0][0] == "/v1/table/test/create/"
import pyarrow as pa
def make_mock_http_handler(handler):
@@ -100,8 +28,35 @@ def make_mock_http_handler(handler):
return MockLanceDBHandler
@contextlib.contextmanager
def mock_lancedb_connection(handler):
with http.server.HTTPServer(
("localhost", 8080), make_mock_http_handler(handler)
) as server:
handle = threading.Thread(target=server.serve_forever)
handle.start()
db = lancedb.connect(
"db://dev",
api_key="fake",
host_override="http://localhost:8080",
client_config={
"retry_config": {"retries": 2},
"timeout_config": {
"connect_timeout": 1,
},
},
)
try:
yield db
finally:
server.shutdown()
handle.join()
@contextlib.asynccontextmanager
async def mock_lancedb_connection(handler):
async def mock_lancedb_connection_async(handler):
with http.server.HTTPServer(
("localhost", 8080), make_mock_http_handler(handler)
) as server:
@@ -143,7 +98,7 @@ async def test_async_remote_db():
request.end_headers()
request.wfile.write(b'{"tables": []}')
async with mock_lancedb_connection(handler) as db:
async with mock_lancedb_connection_async(handler) as db:
table_names = await db.table_names()
assert table_names == []
@@ -159,12 +114,12 @@ async def test_http_error():
request.end_headers()
request.wfile.write(b"Internal Server Error")
async with mock_lancedb_connection(handler) as db:
with pytest.raises(HttpError, match="Internal Server Error") as exc_info:
async with mock_lancedb_connection_async(handler) as db:
with pytest.raises(HttpError) as exc_info:
await db.table_names()
assert exc_info.value.request_id == request_id_holder["request_id"]
assert exc_info.value.status_code == 507
assert "Internal Server Error" in str(exc_info.value)
@pytest.mark.asyncio
@@ -178,15 +133,225 @@ async def test_retry_error():
request.end_headers()
request.wfile.write(b"Try again later")
async with mock_lancedb_connection(handler) as db:
with pytest.raises(RetryError, match="Hit retry limit") as exc_info:
async with mock_lancedb_connection_async(handler) as db:
with pytest.raises(RetryError) as exc_info:
await db.table_names()
assert exc_info.value.request_id == request_id_holder["request_id"]
assert exc_info.value.status_code == 429
cause = exc_info.value.__cause__
assert isinstance(cause, HttpError)
assert "Try again later" in str(cause)
assert cause.request_id == request_id_holder["request_id"]
assert cause.status_code == 429
@contextlib.contextmanager
def query_test_table(query_handler):
def handler(request):
if request.path == "/v1/table/test/describe/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(b"{}")
elif request.path == "/v1/table/test/query/":
content_len = int(request.headers.get("Content-Length"))
body = request.rfile.read(content_len)
body = json.loads(body)
data = query_handler(body)
request.send_response(200)
request.send_header("Content-Type", "application/vnd.apache.arrow.file")
request.end_headers()
with pa.ipc.new_file(request.wfile, schema=data.schema) as f:
f.write_table(data)
else:
request.send_response(404)
request.end_headers()
with mock_lancedb_connection(handler) as db:
assert repr(db) == "RemoteConnect(name=dev)"
table = db.open_table("test")
assert repr(table) == "RemoteTable(dev.test)"
yield table
def test_query_sync_minimal():
def handler(body):
assert body == {
"distance_type": "l2",
"k": 10,
"prefilter": False,
"refine_factor": None,
"vector": [1.0, 2.0, 3.0],
"nprobes": 20,
}
return pa.table({"id": [1, 2, 3]})
with query_test_table(handler) as table:
data = table.search([1, 2, 3]).to_list()
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
assert data == expected
def test_query_sync_maximal():
def handler(body):
assert body == {
"distance_type": "cosine",
"k": 42,
"prefilter": True,
"refine_factor": 10,
"vector": [1.0, 2.0, 3.0],
"nprobes": 5,
"filter": "id > 0",
"columns": ["id", "name"],
"vector_column": "vector2",
"fast_search": True,
"with_row_id": True,
}
return pa.table({"id": [1, 2, 3], "name": ["a", "b", "c"]})
with query_test_table(handler) as table:
(
table.search([1, 2, 3], vector_column_name="vector2", fast_search=True)
.metric("cosine")
.limit(42)
.refine_factor(10)
.nprobes(5)
.where("id > 0", prefilter=True)
.with_row_id(True)
.select(["id", "name"])
.to_list()
)
def test_query_sync_fts():
def handler(body):
assert body == {
"full_text_query": {
"query": "puppy",
"columns": [],
},
"k": 10,
"vector": [],
}
return pa.table({"id": [1, 2, 3]})
with query_test_table(handler) as table:
(table.search("puppy", query_type="fts").to_list())
def handler(body):
assert body == {
"full_text_query": {
"query": "puppy",
"columns": ["name", "description"],
},
"k": 42,
"vector": [],
"with_row_id": True,
}
return pa.table({"id": [1, 2, 3]})
with query_test_table(handler) as table:
(
table.search("puppy", query_type="fts", fts_columns=["name", "description"])
.with_row_id(True)
.limit(42)
.to_list()
)
def test_query_sync_hybrid():
def handler(body):
if "full_text_query" in body:
# FTS query
assert body == {
"full_text_query": {
"query": "puppy",
"columns": [],
},
"k": 42,
"vector": [],
"with_row_id": True,
}
return pa.table({"_rowid": [1, 2, 3], "_score": [0.1, 0.2, 0.3]})
else:
# Vector query
assert body == {
"distance_type": "l2",
"k": 42,
"prefilter": False,
"refine_factor": None,
"vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
"nprobes": 20,
"with_row_id": True,
}
return pa.table({"_rowid": [1, 2, 3], "_distance": [0.1, 0.2, 0.3]})
with query_test_table(handler) as table:
embedding_func = MockTextEmbeddingFunction()
embedding_config = MagicMock()
embedding_config.function = embedding_func
embedding_funcs = MagicMock()
embedding_funcs.get = MagicMock(return_value=embedding_config)
table.embedding_functions = embedding_funcs
(table.search("puppy", query_type="hybrid").limit(42).to_list())
def test_create_client():
mandatory_args = {
"uri": "db://dev",
"api_key": "fake-api-key",
"region": "us-east-1",
}
db = lancedb.connect(**mandatory_args)
assert isinstance(db.client_config, ClientConfig)
db = lancedb.connect(**mandatory_args, client_config={})
assert isinstance(db.client_config, ClientConfig)
db = lancedb.connect(
**mandatory_args,
client_config=ClientConfig(timeout_config={"connect_timeout": 42}),
)
assert isinstance(db.client_config, ClientConfig)
assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42)
db = lancedb.connect(
**mandatory_args,
client_config={"timeout_config": {"connect_timeout": timedelta(seconds=42)}},
)
assert isinstance(db.client_config, ClientConfig)
assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42)
db = lancedb.connect(
**mandatory_args, client_config=ClientConfig(retry_config={"retries": 42})
)
assert isinstance(db.client_config, ClientConfig)
assert db.client_config.retry_config.retries == 42
db = lancedb.connect(
**mandatory_args, client_config={"retry_config": {"retries": 42}}
)
assert isinstance(db.client_config, ClientConfig)
assert db.client_config.retry_config.retries == 42
with pytest.warns(DeprecationWarning):
db = lancedb.connect(**mandatory_args, connection_timeout=42)
assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42)
with pytest.warns(DeprecationWarning):
db = lancedb.connect(**mandatory_args, read_timeout=42)
assert db.client_config.timeout_config.read_timeout == timedelta(seconds=42)
with pytest.warns(DeprecationWarning):
lancedb.connect(**mandatory_args, request_thread_pool=10)

View File

@@ -16,6 +16,7 @@ from lancedb.rerankers import (
OpenaiReranker,
JinaReranker,
AnswerdotaiRerankers,
VoyageAIReranker,
)
from lancedb.table import LanceTable
@@ -344,3 +345,14 @@ def test_jina_reranker(tmp_path, use_tantivy):
table, schema = get_test_table(tmp_path, use_tantivy)
reranker = JinaReranker()
_run_test_reranker(reranker, table, "single player experience", None, schema)
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_voyageai_reranker(tmp_path, use_tantivy):
pytest.importorskip("voyageai")
reranker = VoyageAIReranker(model_name="rerank-2")
table, schema = get_test_table(tmp_path, use_tantivy)
_run_test_reranker(reranker, table, "single player experience", None, schema)

View File

@@ -1223,6 +1223,54 @@ async def test_time_travel(db_async: AsyncConnection):
await table.restore()
def test_sync_optimize(db):
table = LanceTable.create(
db,
"test",
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
)
table.create_scalar_index("price", index_type="BTREE")
stats = table.to_lance().stats.index_stats("price_idx")
assert stats["num_indexed_rows"] == 2
table.add([{"vector": [2.0, 2.0], "item": "baz", "price": 30.0}])
assert table.count_rows() == 3
table.optimize()
stats = table.to_lance().stats.index_stats("price_idx")
assert stats["num_indexed_rows"] == 3
@pytest.mark.asyncio
async def test_sync_optimize_in_async(db):
table = LanceTable.create(
db,
"test",
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
)
table.create_scalar_index("price", index_type="BTREE")
stats = table.to_lance().stats.index_stats("price_idx")
assert stats["num_indexed_rows"] == 2
table.add([{"vector": [2.0, 2.0], "item": "baz", "price": 30.0}])
assert table.count_rows() == 3
try:
table.optimize()
except Exception as e:
assert (
"Synchronous method called in asynchronous context. "
"If you are writing an asynchronous application "
"then please use the asynchronous APIs" in str(e)
)
@pytest.mark.asyncio
async def test_optimize(db_async: AsyncConnection):
table = await db_async.create_table(

View File

@@ -170,6 +170,17 @@ impl Connection {
})
}
pub fn rename_table(
self_: PyRef<'_, Self>,
old_name: String,
new_name: String,
) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.get_inner()?.clone();
future_into_py(self_.py(), async move {
inner.rename_table(old_name, new_name).await.infer_error()
})
}
pub fn drop_table(self_: PyRef<'_, Self>, name: String) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.get_inner()?.clone();
future_into_py(self_.py(), async move {

View File

@@ -24,8 +24,8 @@ use lancedb::{
DistanceType,
};
use pyo3::{
exceptions::{PyRuntimeError, PyValueError},
pyclass, pymethods, PyResult,
exceptions::{PyKeyError, PyRuntimeError, PyValueError},
pyclass, pymethods, IntoPy, PyObject, PyResult, Python,
};
use crate::util::parse_distance_type;
@@ -236,7 +236,21 @@ pub struct IndexConfig {
#[pymethods]
impl IndexConfig {
pub fn __repr__(&self) -> String {
format!("Index({}, columns={:?})", self.index_type, self.columns)
format!(
"Index({}, columns={:?}, name=\"{}\")",
self.index_type, self.columns, self.name
)
}
// For backwards-compatibility with the old sync SDK, we also support getting
// attributes via __getitem__.
pub fn __getitem__(&self, key: String, py: Python<'_>) -> PyResult<PyObject> {
match key.as_str() {
"index_type" => Ok(self.index_type.clone().into_py(py)),
"columns" => Ok(self.columns.clone().into_py(py)),
"name" | "index_name" => Ok(self.name.clone().into_py(py)),
_ => Err(PyKeyError::new_err(format!("Invalid key: {}", key))),
}
}
}

View File

@@ -68,6 +68,18 @@ impl Query {
self.inner = self.inner.clone().offset(offset as usize);
}
pub fn fast_search(&mut self) {
self.inner = self.inner.clone().fast_search();
}
pub fn with_row_id(&mut self) {
self.inner = self.inner.clone().with_row_id();
}
pub fn postfilter(&mut self) {
self.inner = self.inner.clone().postfilter();
}
pub fn nearest_to(&mut self, vector: Bound<'_, PyAny>) -> PyResult<VectorQuery> {
let data: ArrayData = ArrayData::from_pyarrow_bound(&vector)?;
let array = make_array(data);
@@ -146,6 +158,14 @@ impl VectorQuery {
self.inner = self.inner.clone().offset(offset as usize);
}
pub fn fast_search(&mut self) {
self.inner = self.inner.clone().fast_search();
}
pub fn with_row_id(&mut self) {
self.inner = self.inner.clone().with_row_id();
}
pub fn column(&mut self, column: String) {
self.inner = self.inner.clone().column(&column);
}

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-node"
version = "0.11.1-beta.1"
version = "0.13.0-beta.1"
description = "Serverless, low-latency vector database for AI applications"
license.workspace = true
edition.workspace = true

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb"
version = "0.11.1-beta.1"
version = "0.13.0-beta.1"
edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true

View File

@@ -39,9 +39,6 @@ use crate::utils::validate_table_name;
use crate::Table;
pub use lance_encoding::version::LanceFileVersion;
#[cfg(feature = "remote")]
use log::warn;
pub const LANCE_FILE_EXTENSION: &str = "lance";
pub type TableBuilderCallback = Box<dyn FnOnce(OpenTableBuilder) -> OpenTableBuilder + Send>;
@@ -719,8 +716,7 @@ impl ConnectBuilder {
let api_key = self.api_key.ok_or_else(|| Error::InvalidInput {
message: "An api_key is required when connecting to LanceDb Cloud".to_string(),
})?;
// TODO: remove this warning when the remote client is ready
warn!("The rust implementation of the remote client is not yet ready for use.");
let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new(
&self.uri,
&api_key,

View File

@@ -29,6 +29,7 @@ pub mod scalar;
pub mod vector;
/// Supported index types.
#[derive(Debug, Clone)]
pub enum Index {
Auto,
/// A `BTree` index is an sorted index on scalar columns.
@@ -119,6 +120,7 @@ pub enum IndexType {
#[serde(alias = "LABEL_LIST")]
LabelList,
// FTS
#[serde(alias = "INVERTED", alias = "Inverted")]
FTS,
}

View File

@@ -403,6 +403,26 @@ pub trait QueryBase {
/// By default, it is false.
fn fast_search(self) -> Self;
/// If this is called then filtering will happen after the vector search instead of
/// before.
///
/// By default filtering will be performed before the vector search. This is how
/// filtering is typically understood to work. This prefilter step does add some
/// additional latency. Creating a scalar index on the filter column(s) can
/// often improve this latency. However, sometimes a filter is too complex or scalar
/// indices cannot be applied to the column. In these cases postfiltering can be
/// used instead of prefiltering to improve latency.
///
/// Post filtering applies the filter to the results of the vector search. This means
/// we only run the filter on a much smaller set of data. However, it can cause the
/// query to return fewer than `limit` results (or even no results) if none of the nearest
/// results match the filter.
///
/// Post filtering happens during the "refine stage" (described in more detail in
/// [`Self::refine_factor`]). This means that setting a higher refine factor can often
/// help restore some of the results lost by post filtering.
fn postfilter(self) -> Self;
/// Return the `_rowid` meta column from the Table.
fn with_row_id(self) -> Self;
}
@@ -442,6 +462,11 @@ impl<T: HasQuery> QueryBase for T {
self
}
fn postfilter(mut self) -> Self {
self.mut_query().prefilter = false;
self
}
fn with_row_id(mut self) -> Self {
self.mut_query().with_row_id = true;
self
@@ -561,6 +586,9 @@ pub struct Query {
///
/// By default, this is false.
pub(crate) with_row_id: bool,
/// If set to false, the filter will be applied after the vector search.
pub(crate) prefilter: bool,
}
impl Query {
@@ -574,6 +602,7 @@ impl Query {
select: Select::All,
fast_search: false,
with_row_id: false,
prefilter: true,
}
}
@@ -678,8 +707,6 @@ pub struct VectorQuery {
pub(crate) distance_type: Option<DistanceType>,
/// Default is true. Set to false to enforce a brute force search.
pub(crate) use_index: bool,
/// Apply filter before ANN search/
pub(crate) prefilter: bool,
}
impl VectorQuery {
@@ -692,7 +719,6 @@ impl VectorQuery {
refine_factor: None,
distance_type: None,
use_index: true,
prefilter: true,
}
}
@@ -782,29 +808,6 @@ impl VectorQuery {
self
}
/// If this is called then filtering will happen after the vector search instead of
/// before.
///
/// By default filtering will be performed before the vector search. This is how
/// filtering is typically understood to work. This prefilter step does add some
/// additional latency. Creating a scalar index on the filter column(s) can
/// often improve this latency. However, sometimes a filter is too complex or scalar
/// indices cannot be applied to the column. In these cases postfiltering can be
/// used instead of prefiltering to improve latency.
///
/// Post filtering applies the filter to the results of the vector search. This means
/// we only run the filter on a much smaller set of data. However, it can cause the
/// query to return fewer than `limit` results (or even no results) if none of the nearest
/// results match the filter.
///
/// Post filtering happens during the "refine stage" (described in more detail in
/// [`Self::refine_factor`]). This means that setting a higher refine factor can often
/// help restore some of the results lost by post filtering.
pub fn postfilter(mut self) -> Self {
self.prefilter = false;
self
}
/// If this is called then any vector index is skipped
///
/// An exhaustive (flat) search will be performed. The query vector will

View File

@@ -23,6 +23,8 @@ pub(crate) mod table;
pub(crate) mod util;
const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream";
#[cfg(test)]
const ARROW_FILE_CONTENT_TYPE: &str = "application/vnd.apache.arrow.file";
const JSON_CONTENT_TYPE: &str = "application/json";
pub use client::{ClientConfig, RetryConfig, TimeoutConfig};

View File

@@ -341,7 +341,22 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
request_id
};
debug!("Sending request_id={}: {:?}", request_id, &request);
if log::log_enabled!(log::Level::Debug) {
let content_type = request
.headers()
.get("content-type")
.map(|v| v.to_str().unwrap());
if content_type == Some("application/json") {
let body = request.body().as_ref().unwrap().as_bytes().unwrap();
let body = String::from_utf8_lossy(body);
debug!(
"Sending request_id={}: {:?} with body {}",
request_id, request, body
);
} else {
debug!("Sending request_id={}: {:?}", request_id, request);
}
}
if with_retry {
self.send_with_retry_impl(client, request, request_id).await

View File

@@ -161,7 +161,7 @@ impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
if self.table_cache.get(&options.name).is_none() {
let req = self
.client
.get(&format!("/v1/table/{}/describe/", options.name));
.post(&format!("/v1/table/{}/describe/", options.name));
let (request_id, resp) = self.client.send(req, true).await?;
if resp.status() == StatusCode::NOT_FOUND {
return Err(crate::Error::TableNotFound { name: options.name });
@@ -301,7 +301,7 @@ mod tests {
#[tokio::test]
async fn test_open_table() {
let conn = Connection::new_with_handler(|request| {
assert_eq!(request.method(), &reqwest::Method::GET);
assert_eq!(request.method(), &reqwest::Method::POST);
assert_eq!(request.url().path(), "/v1/table/table1/describe/");
assert_eq!(request.url().query(), None);

View File

@@ -1,3 +1,4 @@
use std::io::Cursor;
use std::sync::{Arc, Mutex};
use crate::index::Index;
@@ -7,10 +8,9 @@ use crate::table::AddDataMode;
use crate::utils::{supported_btree_data_type, supported_vector_data_type};
use crate::Error;
use arrow_array::RecordBatchReader;
use arrow_ipc::reader::StreamReader;
use arrow_ipc::reader::FileReader;
use arrow_schema::{DataType, SchemaRef};
use async_trait::async_trait;
use bytes::Buf;
use datafusion_common::DataFusionError;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use datafusion_physical_plan::{ExecutionPlan, SendableRecordBatchStream};
@@ -115,39 +115,14 @@ impl<S: HttpSend> RemoteTable<S> {
async fn read_arrow_stream(
&self,
request_id: &str,
body: reqwest::Response,
response: reqwest::Response,
) -> Result<SendableRecordBatchStream> {
// Assert that the content type is correct
let content_type = body
.headers()
.get(CONTENT_TYPE)
.ok_or_else(|| Error::Http {
source: "Missing content type".into(),
request_id: request_id.to_string(),
status_code: None,
})?
.to_str()
.map_err(|e| Error::Http {
source: format!("Failed to parse content type: {}", e).into(),
request_id: request_id.to_string(),
status_code: None,
})?;
if content_type != ARROW_STREAM_CONTENT_TYPE {
return Err(Error::Http {
source: format!(
"Expected content type {}, got {}",
ARROW_STREAM_CONTENT_TYPE, content_type
)
.into(),
request_id: request_id.to_string(),
status_code: None,
});
}
let response = self.check_table_response(request_id, response).await?;
// There isn't a way to actually stream this data yet. I have an upstream issue:
// https://github.com/apache/arrow-rs/issues/6420
let body = body.bytes().await.err_to_http(request_id.into())?;
let reader = StreamReader::try_new(body.reader(), None)?;
let body = response.bytes().await.err_to_http(request_id.into())?;
let reader = FileReader::try_new(Cursor::new(body), None)?;
let schema = reader.schema();
let stream = futures::stream::iter(reader).map_err(DataFusionError::from);
Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
@@ -192,6 +167,10 @@ impl<S: HttpSend> RemoteTable<S> {
body["fast_search"] = serde_json::Value::Bool(true);
}
if params.with_row_id {
body["with_row_id"] = serde_json::Value::Bool(true);
}
if let Some(full_text_search) = &params.full_text_search {
if full_text_search.wand_factor.is_some() {
return Err(Error::NotSupported {
@@ -277,7 +256,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
.post(&format!("/v1/table/{}/count_rows/", self.name));
if let Some(filter) = filter {
request = request.json(&serde_json::json!({ "filter": filter }));
request = request.json(&serde_json::json!({ "predicate": filter }));
} else {
request = request.json(&serde_json::json!({}));
}
@@ -330,13 +309,13 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
let mut body = serde_json::Value::Object(Default::default());
Self::apply_query_params(&mut body, &query.base)?;
body["prefilter"] = query.prefilter.into();
body["prefilter"] = query.base.prefilter.into();
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
body["nprobes"] = query.nprobes.into();
body["refine_factor"] = query.refine_factor.into();
if let Some(vector) = query.query_vector.as_ref() {
let vector: Vec<f32> = match vector.data_type() {
let vector: Vec<f32> = if let Some(vector) = query.query_vector.as_ref() {
match vector.data_type() {
DataType::Float32 => vector
.as_any()
.downcast_ref::<arrow_array::Float32Array>()
@@ -350,9 +329,12 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
message: "VectorQuery vector must be of type Float32".into(),
})
}
};
body["vector"] = serde_json::json!(vector);
}
}
} else {
// Server takes empty vector, not null or undefined.
Vec::new()
};
body["vector"] = serde_json::json!(vector);
if let Some(vector_column) = query.column.as_ref() {
body["vector_column"] = serde_json::Value::String(vector_column.clone());
@@ -383,6 +365,8 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
let mut body = serde_json::Value::Object(Default::default());
Self::apply_query_params(&mut body, query)?;
// Empty vector can be passed if no vector search is performed.
body["vector"] = serde_json::Value::Array(Vec::new());
let request = request.json(&body);
@@ -399,30 +383,19 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
let mut updates = Vec::new();
for (column, expression) in update.columns {
updates.push(column);
updates.push(expression);
updates.push(vec![column, expression]);
}
let request = request.json(&serde_json::json!({
"updates": updates,
"only_if": update.filter,
"predicate": update.filter,
}));
let (request_id, response) = self.client.send(request, false).await?;
let response = self.check_table_response(&request_id, response).await?;
self.check_table_response(&request_id, response).await?;
let body = response.text().await.err_to_http(request_id.clone())?;
serde_json::from_str(&body).map_err(|e| Error::Http {
source: format!(
"Failed to parse updated rows result from response {}: {}",
body, e
)
.into(),
request_id,
status_code: None,
})
Ok(0) // TODO: support returning number of modified rows once supported in SaaS.
}
async fn delete(&self, predicate: &str) -> Result<()> {
let body = serde_json::json!({ "predicate": predicate });
@@ -691,6 +664,7 @@ mod tests {
use crate::{
index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType},
query::{ExecutableQuery, QueryBase},
remote::ARROW_FILE_CONTENT_TYPE,
DistanceType, Error, Table,
};
@@ -804,7 +778,7 @@ mod tests {
);
assert_eq!(
request.body().unwrap().as_bytes().unwrap(),
br#"{"filter":"a > 10"}"#
br#"{"predicate":"a > 10"}"#
);
http::Response::builder().status(200).body("42").unwrap()
@@ -839,6 +813,17 @@ mod tests {
body
}
fn write_ipc_file(data: &RecordBatch) -> Vec<u8> {
let mut body = Vec::new();
{
let mut writer = arrow_ipc::writer::FileWriter::try_new(&mut body, &data.schema())
.expect("Failed to create writer");
writer.write(data).expect("Failed to write data");
writer.finish().expect("Failed to finish");
}
body
}
#[tokio::test]
async fn test_add_append() {
let data = RecordBatch::try_new(
@@ -947,21 +932,27 @@ mod tests {
let updates = value.get("updates").unwrap().as_array().unwrap();
assert!(updates.len() == 2);
let col_name = updates[0].as_str().unwrap();
let expression = updates[1].as_str().unwrap();
let col_name = updates[0][0].as_str().unwrap();
let expression = updates[0][1].as_str().unwrap();
assert_eq!(col_name, "a");
assert_eq!(expression, "a + 1");
let only_if = value.get("only_if").unwrap().as_str().unwrap();
let col_name = updates[1][0].as_str().unwrap();
let expression = updates[1][1].as_str().unwrap();
assert_eq!(col_name, "b");
assert_eq!(expression, "b - 1");
let only_if = value.get("predicate").unwrap().as_str().unwrap();
assert_eq!(only_if, "b > 10");
}
http::Response::builder().status(200).body("1").unwrap()
http::Response::builder().status(200).body("{}").unwrap()
});
table
.update()
.column("a", "a + 1")
.column("b", "b - 1")
.only_if("b > 10")
.execute()
.await
@@ -1092,10 +1083,10 @@ mod tests {
expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into();
assert_eq!(body, expected_body);
let response_body = write_ipc_stream(&expected_data_ref);
let response_body = write_ipc_file(&expected_data_ref);
http::Response::builder()
.status(200)
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
.body(response_body)
.unwrap()
});
@@ -1142,10 +1133,10 @@ mod tests {
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let response_body = write_ipc_stream(&data);
let response_body = write_ipc_file(&data);
http::Response::builder()
.status(200)
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
.body(response_body)
.unwrap()
});
@@ -1185,6 +1176,8 @@ mod tests {
"query": "hello world",
},
"k": 10,
"vector": [],
"with_row_id": true,
});
assert_eq!(body, expected_body);
@@ -1193,10 +1186,10 @@ mod tests {
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let response_body = write_ipc_stream(&data);
let response_body = write_ipc_file(&data);
http::Response::builder()
.status(200)
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
.body(response_body)
.unwrap()
});
@@ -1207,6 +1200,7 @@ mod tests {
FullTextSearchQuery::new("hello world".into())
.columns(Some(vec!["a".into(), "b".into()])),
)
.with_row_id()
.limit(10)
.execute()
.await

View File

@@ -1842,7 +1842,7 @@ impl TableInternal for NativeTable {
scanner.nprobs(query.nprobes);
scanner.use_index(query.use_index);
scanner.prefilter(query.prefilter);
scanner.prefilter(query.base.prefilter);
match query.base.select {
Select::Columns(ref columns) => {
scanner.project(columns.as_slice())?;
@@ -3123,6 +3123,12 @@ mod tests {
assert_eq!(index.index_type, crate::index::IndexType::FTS);
assert_eq!(index.columns, vec!["text".to_string()]);
assert_eq!(index.name, "text_idx");
let stats = table.index_stats("text_idx").await.unwrap().unwrap();
assert_eq!(stats.num_indexed_rows, num_rows);
assert_eq!(stats.num_unindexed_rows, 0);
assert_eq!(stats.index_type, crate::index::IndexType::FTS);
assert_eq!(stats.distance_type, None);
}
#[tokio::test]