mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-25 14:29:56 +00:00
Compare commits
38 Commits
python-v0.
...
yang/relat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f69b673c1e | ||
|
|
729718cb09 | ||
|
|
b1c84e0bda | ||
|
|
cbbc07d0f5 | ||
|
|
21021f94ca | ||
|
|
0ed77fa990 | ||
|
|
4372c231cd | ||
|
|
4c6b728a31 | ||
|
|
138a12a427 | ||
|
|
fa9ca8f7a6 | ||
|
|
2a35d24ee6 | ||
|
|
dd9ce337e2 | ||
|
|
b9921d56cc | ||
|
|
0cfd9ed18e | ||
|
|
975398c3a8 | ||
|
|
08d5f93f34 | ||
|
|
91cab3b556 | ||
|
|
c61bfc3af8 | ||
|
|
4e8c7b0adf | ||
|
|
26f4a80e10 | ||
|
|
3604d20ad3 | ||
|
|
9708d829a9 | ||
|
|
059c9794b5 | ||
|
|
15ed7f75a0 | ||
|
|
96181ab421 | ||
|
|
0c108407ab | ||
|
|
a7fead3801 | ||
|
|
f3fc339ef6 | ||
|
|
113cd6995b | ||
|
|
02535bdc88 | ||
|
|
facc7d61c0 | ||
|
|
f947259f16 | ||
|
|
50c68feae9 | ||
|
|
f30c5b24fa | ||
|
|
2a477ad387 | ||
|
|
0b29aca23b | ||
|
|
df62c3d9ac | ||
|
|
aef4656053 |
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.11.1-beta.1"
|
||||
current_version = "0.13.0-beta.1"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
@@ -92,6 +92,11 @@ glob = "node/package.json"
|
||||
replace = "\"@lancedb/vectordb-win32-x64-msvc\": \"{new_version}\""
|
||||
search = "\"@lancedb/vectordb-win32-x64-msvc\": \"{current_version}\""
|
||||
|
||||
[[tool.bumpversion.files]]
|
||||
glob = "node/package.json"
|
||||
replace = "\"@lancedb/vectordb-win32-arm64-msvc\": \"{new_version}\""
|
||||
search = "\"@lancedb/vectordb-win32-arm64-msvc\": \"{current_version}\""
|
||||
|
||||
# Cargo files
|
||||
# ------------
|
||||
[[tool.bumpversion.files]]
|
||||
|
||||
@@ -38,3 +38,7 @@ rustflags = ["-C", "target-cpu=apple-m1", "-C", "target-feature=+neon,+fp16,+fhm
|
||||
# not found errors on systems that are missing it.
|
||||
[target.x86_64-pc-windows-msvc]
|
||||
rustflags = ["-Ctarget-feature=+crt-static"]
|
||||
|
||||
# Experimental target for Arm64 Windows
|
||||
[target.aarch64-pc-windows-msvc]
|
||||
rustflags = ["-Ctarget-feature=+crt-static"]
|
||||
234
.github/workflows/npm-publish.yml
vendored
234
.github/workflows/npm-publish.yml
vendored
@@ -226,6 +226,126 @@ jobs:
|
||||
path: |
|
||||
node/dist/lancedb-vectordb-win32*.tgz
|
||||
|
||||
node-windows-arm64:
|
||||
name: vectordb win32-arm64-msvc
|
||||
runs-on: windows-4x-arm
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Cache installations
|
||||
id: cache-installs
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
C:\Program Files\Git
|
||||
C:\BuildTools
|
||||
C:\Program Files (x86)\Windows Kits
|
||||
C:\Program Files\7-Zip
|
||||
C:\protoc
|
||||
key: ${{ runner.os }}-arm64-installs-v1
|
||||
restore-keys: |
|
||||
${{ runner.os }}-arm64-installs-
|
||||
- name: Install Git
|
||||
if: steps.cache-installs.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
Invoke-WebRequest -Uri "https://github.com/git-for-windows/git/releases/download/v2.44.0.windows.1/Git-2.44.0-64-bit.exe" -OutFile "git-installer.exe"
|
||||
Start-Process -FilePath "git-installer.exe" -ArgumentList "/VERYSILENT", "/NORESTART" -Wait
|
||||
shell: powershell
|
||||
- name: Add Git to PATH
|
||||
run: |
|
||||
Add-Content $env:GITHUB_PATH "C:\Program Files\Git\bin"
|
||||
$env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User")
|
||||
shell: powershell
|
||||
- name: Configure Git symlinks
|
||||
run: git config --global core.symlinks true
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.13"
|
||||
- name: Install Visual Studio Build Tools
|
||||
if: steps.cache-installs.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
Invoke-WebRequest -Uri "https://aka.ms/vs/17/release/vs_buildtools.exe" -OutFile "vs_buildtools.exe"
|
||||
Start-Process -FilePath "vs_buildtools.exe" -ArgumentList "--quiet", "--wait", "--norestart", "--nocache", `
|
||||
"--installPath", "C:\BuildTools", `
|
||||
"--add", "Microsoft.VisualStudio.Component.VC.Tools.ARM64", `
|
||||
"--add", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64", `
|
||||
"--add", "Microsoft.VisualStudio.Component.Windows11SDK.22621", `
|
||||
"--add", "Microsoft.VisualStudio.Component.VC.ATL", `
|
||||
"--add", "Microsoft.VisualStudio.Component.VC.ATLMFC", `
|
||||
"--add", "Microsoft.VisualStudio.Component.VC.Llvm.Clang" -Wait
|
||||
shell: powershell
|
||||
- name: Add Visual Studio Build Tools to PATH
|
||||
run: |
|
||||
$vsPath = "C:\BuildTools\VC\Tools\MSVC"
|
||||
$latestVersion = (Get-ChildItem $vsPath | Sort-Object {[version]$_.Name} -Descending)[0].Name
|
||||
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\MSVC\$latestVersion\bin\Hostx64\arm64"
|
||||
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\MSVC\$latestVersion\bin\Hostx64\x64"
|
||||
Add-Content $env:GITHUB_PATH "C:\Program Files (x86)\Windows Kits\10\bin\10.0.22621.0\arm64"
|
||||
Add-Content $env:GITHUB_PATH "C:\Program Files (x86)\Windows Kits\10\bin\10.0.22621.0\x64"
|
||||
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\Llvm\x64\bin"
|
||||
|
||||
# Add MSVC runtime libraries to LIB
|
||||
$env:LIB = "C:\BuildTools\VC\Tools\MSVC\$latestVersion\lib\arm64;" +
|
||||
"C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\um\arm64;" +
|
||||
"C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\ucrt\arm64"
|
||||
Add-Content $env:GITHUB_ENV "LIB=$env:LIB"
|
||||
|
||||
# Add INCLUDE paths
|
||||
$env:INCLUDE = "C:\BuildTools\VC\Tools\MSVC\$latestVersion\include;" +
|
||||
"C:\Program Files (x86)\Windows Kits\10\Include\10.0.22621.0\ucrt;" +
|
||||
"C:\Program Files (x86)\Windows Kits\10\Include\10.0.22621.0\um;" +
|
||||
"C:\Program Files (x86)\Windows Kits\10\Include\10.0.22621.0\shared"
|
||||
Add-Content $env:GITHUB_ENV "INCLUDE=$env:INCLUDE"
|
||||
shell: powershell
|
||||
- name: Install Rust
|
||||
run: |
|
||||
Invoke-WebRequest https://win.rustup.rs/x86_64 -OutFile rustup-init.exe
|
||||
.\rustup-init.exe -y --default-host aarch64-pc-windows-msvc
|
||||
shell: powershell
|
||||
- name: Add Rust to PATH
|
||||
run: |
|
||||
Add-Content $env:GITHUB_PATH "$env:USERPROFILE\.cargo\bin"
|
||||
shell: powershell
|
||||
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
with:
|
||||
workspaces: rust
|
||||
- name: Install 7-Zip ARM
|
||||
if: steps.cache-installs.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
New-Item -Path 'C:\7zip' -ItemType Directory
|
||||
Invoke-WebRequest https://7-zip.org/a/7z2408-arm64.exe -OutFile C:\7zip\7z-installer.exe
|
||||
Start-Process -FilePath C:\7zip\7z-installer.exe -ArgumentList '/S' -Wait
|
||||
shell: powershell
|
||||
- name: Add 7-Zip to PATH
|
||||
run: Add-Content $env:GITHUB_PATH "C:\Program Files\7-Zip"
|
||||
shell: powershell
|
||||
- name: Install Protoc v21.12
|
||||
if: steps.cache-installs.outputs.cache-hit != 'true'
|
||||
working-directory: C:\
|
||||
run: |
|
||||
if (Test-Path 'C:\protoc') {
|
||||
Write-Host "Protoc directory exists, skipping installation"
|
||||
return
|
||||
}
|
||||
New-Item -Path 'C:\protoc' -ItemType Directory
|
||||
Set-Location C:\protoc
|
||||
Invoke-WebRequest https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win64.zip -OutFile C:\protoc\protoc.zip
|
||||
& 'C:\Program Files\7-Zip\7z.exe' x protoc.zip
|
||||
shell: powershell
|
||||
- name: Add Protoc to PATH
|
||||
run: Add-Content $env:GITHUB_PATH "C:\protoc\bin"
|
||||
shell: powershell
|
||||
- name: Build Windows native node modules
|
||||
run: .\ci\build_windows_artifacts.ps1 aarch64-pc-windows-msvc
|
||||
- name: Upload Windows ARM64 Artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: node-native-windows-arm64
|
||||
path: |
|
||||
node/dist/*.node
|
||||
|
||||
nodejs-windows:
|
||||
name: lancedb ${{ matrix.target }}
|
||||
runs-on: windows-2022
|
||||
@@ -260,9 +380,119 @@ jobs:
|
||||
path: |
|
||||
nodejs/dist/*.node
|
||||
|
||||
nodejs-windows-arm64:
|
||||
name: lancedb win32-arm64-msvc
|
||||
runs-on: windows-4x-arm
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Cache installations
|
||||
id: cache-installs
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
C:\Program Files\Git
|
||||
C:\BuildTools
|
||||
C:\Program Files (x86)\Windows Kits
|
||||
C:\Program Files\7-Zip
|
||||
C:\protoc
|
||||
key: ${{ runner.os }}-arm64-installs-v1
|
||||
restore-keys: |
|
||||
${{ runner.os }}-arm64-installs-
|
||||
- name: Install Git
|
||||
if: steps.cache-installs.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
Invoke-WebRequest -Uri "https://github.com/git-for-windows/git/releases/download/v2.44.0.windows.1/Git-2.44.0-64-bit.exe" -OutFile "git-installer.exe"
|
||||
Start-Process -FilePath "git-installer.exe" -ArgumentList "/VERYSILENT", "/NORESTART" -Wait
|
||||
shell: powershell
|
||||
- name: Add Git to PATH
|
||||
run: |
|
||||
Add-Content $env:GITHUB_PATH "C:\Program Files\Git\bin"
|
||||
$env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User")
|
||||
shell: powershell
|
||||
- name: Configure Git symlinks
|
||||
run: git config --global core.symlinks true
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.13"
|
||||
- name: Install Visual Studio Build Tools
|
||||
if: steps.cache-installs.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
Invoke-WebRequest -Uri "https://aka.ms/vs/17/release/vs_buildtools.exe" -OutFile "vs_buildtools.exe"
|
||||
Start-Process -FilePath "vs_buildtools.exe" -ArgumentList "--quiet", "--wait", "--norestart", "--nocache", `
|
||||
"--installPath", "C:\BuildTools", `
|
||||
"--add", "Microsoft.VisualStudio.Component.VC.Tools.ARM64", `
|
||||
"--add", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64", `
|
||||
"--add", "Microsoft.VisualStudio.Component.Windows11SDK.22621", `
|
||||
"--add", "Microsoft.VisualStudio.Component.VC.ATL", `
|
||||
"--add", "Microsoft.VisualStudio.Component.VC.ATLMFC", `
|
||||
"--add", "Microsoft.VisualStudio.Component.VC.Llvm.Clang" -Wait
|
||||
shell: powershell
|
||||
- name: Add Visual Studio Build Tools to PATH
|
||||
run: |
|
||||
$vsPath = "C:\BuildTools\VC\Tools\MSVC"
|
||||
$latestVersion = (Get-ChildItem $vsPath | Sort-Object {[version]$_.Name} -Descending)[0].Name
|
||||
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\MSVC\$latestVersion\bin\Hostx64\arm64"
|
||||
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\MSVC\$latestVersion\bin\Hostx64\x64"
|
||||
Add-Content $env:GITHUB_PATH "C:\Program Files (x86)\Windows Kits\10\bin\10.0.22621.0\arm64"
|
||||
Add-Content $env:GITHUB_PATH "C:\Program Files (x86)\Windows Kits\10\bin\10.0.22621.0\x64"
|
||||
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\Llvm\x64\bin"
|
||||
|
||||
$env:LIB = ""
|
||||
Add-Content $env:GITHUB_ENV "LIB=C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\um\arm64;C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\ucrt\arm64"
|
||||
shell: powershell
|
||||
- name: Install Rust
|
||||
run: |
|
||||
Invoke-WebRequest https://win.rustup.rs/x86_64 -OutFile rustup-init.exe
|
||||
.\rustup-init.exe -y --default-host aarch64-pc-windows-msvc
|
||||
shell: powershell
|
||||
- name: Add Rust to PATH
|
||||
run: |
|
||||
Add-Content $env:GITHUB_PATH "$env:USERPROFILE\.cargo\bin"
|
||||
shell: powershell
|
||||
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
with:
|
||||
workspaces: rust
|
||||
- name: Install 7-Zip ARM
|
||||
if: steps.cache-installs.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
New-Item -Path 'C:\7zip' -ItemType Directory
|
||||
Invoke-WebRequest https://7-zip.org/a/7z2408-arm64.exe -OutFile C:\7zip\7z-installer.exe
|
||||
Start-Process -FilePath C:\7zip\7z-installer.exe -ArgumentList '/S' -Wait
|
||||
shell: powershell
|
||||
- name: Add 7-Zip to PATH
|
||||
run: Add-Content $env:GITHUB_PATH "C:\Program Files\7-Zip"
|
||||
shell: powershell
|
||||
- name: Install Protoc v21.12
|
||||
if: steps.cache-installs.outputs.cache-hit != 'true'
|
||||
working-directory: C:\
|
||||
run: |
|
||||
if (Test-Path 'C:\protoc') {
|
||||
Write-Host "Protoc directory exists, skipping installation"
|
||||
return
|
||||
}
|
||||
New-Item -Path 'C:\protoc' -ItemType Directory
|
||||
Set-Location C:\protoc
|
||||
Invoke-WebRequest https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win64.zip -OutFile C:\protoc\protoc.zip
|
||||
& 'C:\Program Files\7-Zip\7z.exe' x protoc.zip
|
||||
shell: powershell
|
||||
- name: Add Protoc to PATH
|
||||
run: Add-Content $env:GITHUB_PATH "C:\protoc\bin"
|
||||
shell: powershell
|
||||
- name: Build Windows native node modules
|
||||
run: .\ci\build_windows_artifacts_nodejs.ps1 aarch64-pc-windows-msvc
|
||||
- name: Upload Windows ARM64 Artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: nodejs-native-windows-arm64
|
||||
path: |
|
||||
nodejs/dist/*.node
|
||||
|
||||
release:
|
||||
name: vectordb NPM Publish
|
||||
needs: [node, node-macos, node-linux, node-windows]
|
||||
needs: [node, node-macos, node-linux, node-windows, node-windows-arm64]
|
||||
runs-on: ubuntu-latest
|
||||
# Only runs on tags that matches the make-release action
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
@@ -302,7 +532,7 @@ jobs:
|
||||
|
||||
release-nodejs:
|
||||
name: lancedb NPM Publish
|
||||
needs: [nodejs-macos, nodejs-linux, nodejs-windows]
|
||||
needs: [nodejs-macos, nodejs-linux, nodejs-windows, nodejs-windows-arm64]
|
||||
runs-on: ubuntu-latest
|
||||
# Only runs on tags that matches the make-release action
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
|
||||
183
.github/workflows/rust.yml
vendored
183
.github/workflows/rust.yml
vendored
@@ -35,21 +35,21 @@ jobs:
|
||||
CC: clang-18
|
||||
CXX: clang++-18
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
lfs: true
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
with:
|
||||
workspaces: rust
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
with:
|
||||
workspaces: rust
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt update
|
||||
sudo apt install -y protobuf-compiler libssl-dev
|
||||
- name: Run format
|
||||
run: cargo fmt --all -- --check
|
||||
- name: Run clippy
|
||||
run: cargo clippy --workspace --tests --all-features -- -D warnings
|
||||
- name: Run format
|
||||
run: cargo fmt --all -- --check
|
||||
- name: Run clippy
|
||||
run: cargo clippy --workspace --tests --all-features -- -D warnings
|
||||
linux:
|
||||
timeout-minutes: 30
|
||||
# To build all features, we need more disk space than is available
|
||||
@@ -65,37 +65,37 @@ jobs:
|
||||
CC: clang-18
|
||||
CXX: clang++-18
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
lfs: true
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
with:
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
with:
|
||||
workspaces: rust
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt update
|
||||
sudo apt install -y protobuf-compiler libssl-dev
|
||||
- name: Make Swap
|
||||
run: |
|
||||
sudo fallocate -l 16G /swapfile
|
||||
sudo chmod 600 /swapfile
|
||||
sudo mkswap /swapfile
|
||||
sudo swapon /swapfile
|
||||
- name: Start S3 integration test environment
|
||||
working-directory: .
|
||||
run: docker compose up --detach --wait
|
||||
- name: Build
|
||||
run: cargo build --all-features
|
||||
- name: Run tests
|
||||
run: cargo test --all-features
|
||||
- name: Run examples
|
||||
run: cargo run --example simple
|
||||
- name: Make Swap
|
||||
run: |
|
||||
sudo fallocate -l 16G /swapfile
|
||||
sudo chmod 600 /swapfile
|
||||
sudo mkswap /swapfile
|
||||
sudo swapon /swapfile
|
||||
- name: Start S3 integration test environment
|
||||
working-directory: .
|
||||
run: docker compose up --detach --wait
|
||||
- name: Build
|
||||
run: cargo build --all-features
|
||||
- name: Run tests
|
||||
run: cargo test --all-features
|
||||
- name: Run examples
|
||||
run: cargo run --example simple
|
||||
macos:
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
matrix:
|
||||
mac-runner: [ "macos-13", "macos-14" ]
|
||||
mac-runner: ["macos-13", "macos-14"]
|
||||
runs-on: "${{ matrix.mac-runner }}"
|
||||
defaults:
|
||||
run:
|
||||
@@ -104,8 +104,8 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
lfs: true
|
||||
fetch-depth: 0
|
||||
lfs: true
|
||||
- name: CPU features
|
||||
run: sysctl -a | grep cpu
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
@@ -139,3 +139,116 @@ jobs:
|
||||
$env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT
|
||||
cargo build
|
||||
cargo test
|
||||
windows-arm64:
|
||||
runs-on: windows-4x-arm
|
||||
steps:
|
||||
- name: Cache installations
|
||||
id: cache-installs
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
C:\Program Files\Git
|
||||
C:\BuildTools
|
||||
C:\Program Files (x86)\Windows Kits
|
||||
C:\Program Files\7-Zip
|
||||
C:\protoc
|
||||
key: ${{ runner.os }}-arm64-installs-v1
|
||||
restore-keys: |
|
||||
${{ runner.os }}-arm64-installs-
|
||||
- name: Install Git
|
||||
if: steps.cache-installs.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
Invoke-WebRequest -Uri "https://github.com/git-for-windows/git/releases/download/v2.44.0.windows.1/Git-2.44.0-64-bit.exe" -OutFile "git-installer.exe"
|
||||
Start-Process -FilePath "git-installer.exe" -ArgumentList "/VERYSILENT", "/NORESTART" -Wait
|
||||
shell: powershell
|
||||
- name: Add Git to PATH
|
||||
run: |
|
||||
Add-Content $env:GITHUB_PATH "C:\Program Files\Git\bin"
|
||||
$env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User")
|
||||
shell: powershell
|
||||
- name: Configure Git symlinks
|
||||
run: git config --global core.symlinks true
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.13"
|
||||
- name: Install Visual Studio Build Tools
|
||||
if: steps.cache-installs.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
Invoke-WebRequest -Uri "https://aka.ms/vs/17/release/vs_buildtools.exe" -OutFile "vs_buildtools.exe"
|
||||
Start-Process -FilePath "vs_buildtools.exe" -ArgumentList "--quiet", "--wait", "--norestart", "--nocache", `
|
||||
"--installPath", "C:\BuildTools", `
|
||||
"--add", "Microsoft.VisualStudio.Component.VC.Tools.ARM64", `
|
||||
"--add", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64", `
|
||||
"--add", "Microsoft.VisualStudio.Component.Windows11SDK.22621", `
|
||||
"--add", "Microsoft.VisualStudio.Component.VC.ATL", `
|
||||
"--add", "Microsoft.VisualStudio.Component.VC.ATLMFC", `
|
||||
"--add", "Microsoft.VisualStudio.Component.VC.Llvm.Clang" -Wait
|
||||
shell: powershell
|
||||
- name: Add Visual Studio Build Tools to PATH
|
||||
run: |
|
||||
$vsPath = "C:\BuildTools\VC\Tools\MSVC"
|
||||
$latestVersion = (Get-ChildItem $vsPath | Sort-Object {[version]$_.Name} -Descending)[0].Name
|
||||
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\MSVC\$latestVersion\bin\Hostx64\arm64"
|
||||
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\MSVC\$latestVersion\bin\Hostx64\x64"
|
||||
Add-Content $env:GITHUB_PATH "C:\Program Files (x86)\Windows Kits\10\bin\10.0.22621.0\arm64"
|
||||
Add-Content $env:GITHUB_PATH "C:\Program Files (x86)\Windows Kits\10\bin\10.0.22621.0\x64"
|
||||
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\Llvm\x64\bin"
|
||||
|
||||
# Add MSVC runtime libraries to LIB
|
||||
$env:LIB = "C:\BuildTools\VC\Tools\MSVC\$latestVersion\lib\arm64;" +
|
||||
"C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\um\arm64;" +
|
||||
"C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\ucrt\arm64"
|
||||
Add-Content $env:GITHUB_ENV "LIB=$env:LIB"
|
||||
|
||||
# Add INCLUDE paths
|
||||
$env:INCLUDE = "C:\BuildTools\VC\Tools\MSVC\$latestVersion\include;" +
|
||||
"C:\Program Files (x86)\Windows Kits\10\Include\10.0.22621.0\ucrt;" +
|
||||
"C:\Program Files (x86)\Windows Kits\10\Include\10.0.22621.0\um;" +
|
||||
"C:\Program Files (x86)\Windows Kits\10\Include\10.0.22621.0\shared"
|
||||
Add-Content $env:GITHUB_ENV "INCLUDE=$env:INCLUDE"
|
||||
shell: powershell
|
||||
- name: Install Rust
|
||||
run: |
|
||||
Invoke-WebRequest https://win.rustup.rs/x86_64 -OutFile rustup-init.exe
|
||||
.\rustup-init.exe -y --default-host aarch64-pc-windows-msvc
|
||||
shell: powershell
|
||||
- name: Add Rust to PATH
|
||||
run: |
|
||||
Add-Content $env:GITHUB_PATH "$env:USERPROFILE\.cargo\bin"
|
||||
shell: powershell
|
||||
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
with:
|
||||
workspaces: rust
|
||||
- name: Install 7-Zip ARM
|
||||
if: steps.cache-installs.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
New-Item -Path 'C:\7zip' -ItemType Directory
|
||||
Invoke-WebRequest https://7-zip.org/a/7z2408-arm64.exe -OutFile C:\7zip\7z-installer.exe
|
||||
Start-Process -FilePath C:\7zip\7z-installer.exe -ArgumentList '/S' -Wait
|
||||
shell: powershell
|
||||
- name: Add 7-Zip to PATH
|
||||
run: Add-Content $env:GITHUB_PATH "C:\Program Files\7-Zip"
|
||||
shell: powershell
|
||||
- name: Install Protoc v21.12
|
||||
if: steps.cache-installs.outputs.cache-hit != 'true'
|
||||
working-directory: C:\
|
||||
run: |
|
||||
if (Test-Path 'C:\protoc') {
|
||||
Write-Host "Protoc directory exists, skipping installation"
|
||||
return
|
||||
}
|
||||
New-Item -Path 'C:\protoc' -ItemType Directory
|
||||
Set-Location C:\protoc
|
||||
Invoke-WebRequest https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win64.zip -OutFile C:\protoc\protoc.zip
|
||||
& 'C:\Program Files\7-Zip\7z.exe' x protoc.zip
|
||||
shell: powershell
|
||||
- name: Add Protoc to PATH
|
||||
run: Add-Content $env:GITHUB_PATH "C:\protoc\bin"
|
||||
shell: powershell
|
||||
- name: Run tests
|
||||
run: |
|
||||
$env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT
|
||||
cargo build --target aarch64-pc-windows-msvc
|
||||
cargo test --target aarch64-pc-windows-msvc
|
||||
|
||||
14
Cargo.toml
14
Cargo.toml
@@ -21,13 +21,13 @@ categories = ["database-implementations"]
|
||||
rust-version = "1.80.0" # TODO: lower this once we upgrade Lance again.
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.19.1", "features" = ["dynamodb"] }
|
||||
lance-index = { "version" = "=0.19.1" }
|
||||
lance-linalg = { "version" = "=0.19.1" }
|
||||
lance-table = { "version" = "=0.19.1" }
|
||||
lance-testing = { "version" = "=0.19.1" }
|
||||
lance-datafusion = { "version" = "=0.19.1" }
|
||||
lance-encoding = { "version" = "=0.19.1" }
|
||||
lance = { "version" = "=0.19.2", "features" = ["dynamodb"], path = "../lance/rust/lance"}
|
||||
lance-index = { "version" = "=0.19.2", path = "../lance/rust/lance-index"}
|
||||
lance-linalg = { "version" = "=0.19.2", path = "../lance/rust/lance-linalg"}
|
||||
lance-testing = { "version" = "=0.19.2", path = "../lance/rust/lance-testing"}
|
||||
lance-datafusion = { "version" = "=0.19.2", path = "../lance/rust/lance-datafusion"}
|
||||
lance-encoding = { "version" = "=0.19.2", path = "../lance/rust/lance-encoding"}
|
||||
lance-table = { "version" = "=0.19.2", path = "../lance/rust/lance-table"}
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "52.2", optional = false }
|
||||
arrow-array = "52.2"
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
[](https://blog.lancedb.com/)
|
||||
[](https://discord.gg/zMM32dvNtd)
|
||||
[](https://twitter.com/lancedb)
|
||||
[](https://gurubase.io/g/lancedb)
|
||||
|
||||
</p>
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
# Targets supported:
|
||||
# - x86_64-pc-windows-msvc
|
||||
# - i686-pc-windows-msvc
|
||||
# - aarch64-pc-windows-msvc
|
||||
|
||||
function Prebuild-Rust {
|
||||
param (
|
||||
@@ -31,7 +32,7 @@ function Build-NodeBinaries {
|
||||
|
||||
$targets = $args[0]
|
||||
if (-not $targets) {
|
||||
$targets = "x86_64-pc-windows-msvc"
|
||||
$targets = "x86_64-pc-windows-msvc", "aarch64-pc-windows-msvc"
|
||||
}
|
||||
|
||||
Write-Host "Building artifacts for targets: $targets"
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
# Targets supported:
|
||||
# - x86_64-pc-windows-msvc
|
||||
# - i686-pc-windows-msvc
|
||||
# - aarch64-pc-windows-msvc
|
||||
|
||||
function Prebuild-Rust {
|
||||
param (
|
||||
@@ -31,7 +32,7 @@ function Build-NodeBinaries {
|
||||
|
||||
$targets = $args[0]
|
||||
if (-not $targets) {
|
||||
$targets = "x86_64-pc-windows-msvc"
|
||||
$targets = "x86_64-pc-windows-msvc", "aarch64-pc-windows-msvc"
|
||||
}
|
||||
|
||||
Write-Host "Building artifacts for targets: $targets"
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
# VoyageAI Embeddings
|
||||
|
||||
Voyage AI provides cutting-edge embedding and rerankers.
|
||||
|
||||
|
||||
Using voyageai API requires voyageai package, which can be installed using `pip install voyageai`. Voyage AI embeddings are used to generate embeddings for text data. The embeddings can be used for various tasks like semantic search, clustering, and classification.
|
||||
You also need to set the `VOYAGE_API_KEY` environment variable to use the VoyageAI API.
|
||||
|
||||
Supported models are:
|
||||
|
||||
- voyage-3
|
||||
- voyage-3-lite
|
||||
- voyage-finance-2
|
||||
- voyage-multilingual-2
|
||||
- voyage-law-2
|
||||
- voyage-code-2
|
||||
|
||||
|
||||
Supported parameters (to be passed in `create` method) are:
|
||||
|
||||
| Parameter | Type | Default Value | Description |
|
||||
|---|---|--------|---------|
|
||||
| `name` | `str` | `"voyage-3"` | The model ID of the model to use. Supported base models for Text Embeddings: voyage-3, voyage-3-lite, voyage-finance-2, voyage-multilingual-2, voyage-law-2, voyage-code-2 |
|
||||
| `input_type` | `str` | `None` | Type of the input text. Default to None. Other options: query, document. |
|
||||
| `truncation` | `bool` | `True` | Whether to truncate the input texts to fit within the context length. |
|
||||
|
||||
|
||||
Usage Example:
|
||||
|
||||
```python
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
|
||||
voyageai = EmbeddingFunctionRegistry
|
||||
.get_instance()
|
||||
.get("voyageai")
|
||||
.create(name="voyage-3")
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = voyageai.SourceField()
|
||||
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
||||
|
||||
data = [ { "text": "hello world" },
|
||||
{ "text": "goodbye world" }]
|
||||
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(data)
|
||||
```
|
||||
77
docs/src/reranking/voyageai.md
Normal file
77
docs/src/reranking/voyageai.md
Normal file
@@ -0,0 +1,77 @@
|
||||
# Voyage AI Reranker
|
||||
|
||||
Voyage AI provides cutting-edge embedding and rerankers.
|
||||
|
||||
This re-ranker uses the [VoyageAI](https://docs.voyageai.com/docs/) API to rerank the search results. You can use this re-ranker by passing `VoyageAIReranker()` to the `rerank()` method. Note that you'll either need to set the `VOYAGE_API_KEY` environment variable or pass the `api_key` argument to use this re-ranker.
|
||||
|
||||
|
||||
!!! note
|
||||
Supported Query Types: Hybrid, Vector, FTS
|
||||
|
||||
|
||||
```python
|
||||
import numpy
|
||||
import lancedb
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.rerankers import VoyageAIReranker
|
||||
|
||||
embedder = get_registry().get("sentence-transformers").create()
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
|
||||
class Schema(LanceModel):
|
||||
text: str = embedder.SourceField()
|
||||
vector: Vector(embedder.ndims()) = embedder.VectorField()
|
||||
|
||||
data = [
|
||||
{"text": "hello world"},
|
||||
{"text": "goodbye world"}
|
||||
]
|
||||
tbl = db.create_table("test", schema=Schema, mode="overwrite")
|
||||
tbl.add(data)
|
||||
reranker = VoyageAIReranker(model_name="rerank-2")
|
||||
|
||||
# Run vector search with a reranker
|
||||
result = tbl.search("hello").rerank(reranker=reranker).to_list()
|
||||
|
||||
# Run FTS search with a reranker
|
||||
result = tbl.search("hello", query_type="fts").rerank(reranker=reranker).to_list()
|
||||
|
||||
# Run hybrid search with a reranker
|
||||
tbl.create_fts_index("text", replace=True)
|
||||
result = tbl.search("hello", query_type="hybrid").rerank(reranker=reranker).to_list()
|
||||
|
||||
```
|
||||
|
||||
Accepted Arguments
|
||||
----------------
|
||||
| Argument | Type | Default | Description |
|
||||
| --- | --- | --- | --- |
|
||||
| `model_name` | `str` | `None` | The name of the reranker model to use. Available models are: rerank-2, rerank-2-lite |
|
||||
| `column` | `str` | `"text"` | The name of the column to use as input to the cross encoder model. |
|
||||
| `top_n` | `str` | `None` | The number of results to return. If None, will return all results. |
|
||||
| `api_key` | `str` | `None` | The API key for the Voyage AI API. If not provided, the `VOYAGE_API_KEY` environment variable is used. |
|
||||
| `return_score` | str | `"relevance"` | Options are "relevance" or "all". The type of score to return. If "relevance", will return only the `_relevance_score. If "all" is supported, will return relevance score along with the vector and/or fts scores depending on query type |
|
||||
| `truncation` | `bool` | `None` | Whether to truncate the input to satisfy the "context length limit" on the query and the documents. |
|
||||
|
||||
|
||||
## Supported Scores for each query type
|
||||
You can specify the type of scores you want the reranker to return. The following are the supported scores for each query type:
|
||||
|
||||
### Hybrid Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ❌ Not Supported | Returns have vector(`_distance`) and FTS(`score`) along with Hybrid Search score(`_relevance_score`) |
|
||||
|
||||
### Vector Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ✅ Supported | Returns have vector(`_distance`) along with Hybrid Search score(`_relevance_score`) |
|
||||
|
||||
### FTS Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ✅ Supported | Returns have FTS(`score`) along with Hybrid Search score(`_relevance_score`) |
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.11.1-beta.1</version>
|
||||
<version>0.13.0-beta.1</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.11.1-beta.1</version>
|
||||
<version>0.13.0-beta.1</version>
|
||||
<packaging>pom</packaging>
|
||||
|
||||
<name>LanceDB Parent</name>
|
||||
|
||||
50
node/package-lock.json
generated
50
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.11.1-beta.1",
|
||||
"version": "0.13.0-beta.1",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "vectordb",
|
||||
"version": "0.11.1-beta.1",
|
||||
"version": "0.13.0-beta.1",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
@@ -52,11 +52,12 @@
|
||||
"uuid": "^9.0.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.11.1-beta.1",
|
||||
"@lancedb/vectordb-darwin-x64": "0.11.1-beta.1",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.11.1-beta.1",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.11.1-beta.1",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.11.1-beta.1"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.13.0-beta.1",
|
||||
"@lancedb/vectordb-darwin-x64": "0.13.0-beta.1",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.13.0-beta.1",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.13.0-beta.1",
|
||||
"@lancedb/vectordb-win32-arm64-msvc": "0.13.0-beta.1",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.13.0-beta.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@apache-arrow/ts": "^14.0.2",
|
||||
@@ -327,65 +328,60 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
||||
"version": "0.11.1-beta.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.11.1-beta.1.tgz",
|
||||
"integrity": "sha512-q9jcCbmcz45UHmjgecL6zK82WaqUJsARfniwXXPcnd8ooISVhPkgN+RVKv6edwI9T0PV+xVRYq+LQLlZu5fyxw==",
|
||||
"version": "0.13.0-beta.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.13.0-beta.1.tgz",
|
||||
"integrity": "sha512-beOrf6selCzzhLgDG8Nibma4nO/CSnA1wUKRmlJHEPtGcg7PW18z6MP/nfwQMpMR/FLRfTo8pPTbpzss47MiQQ==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-x64": {
|
||||
"version": "0.11.1-beta.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.11.1-beta.1.tgz",
|
||||
"integrity": "sha512-E5tCTS5TaTkssTPa+gdnFxZJ1f60jnSIJXhqufNFZk4s+IMViwR1BPqaqE++WY5c1uBI55ef1862CROKDKX4gg==",
|
||||
"version": "0.13.0-beta.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.13.0-beta.1.tgz",
|
||||
"integrity": "sha512-YdraGRF/RbJRkKh0v3xT03LUhq47T2GtCvJ5gZp8wKlh4pHa8LuhLU0DIdvmG/DT5vuQA+td8HDkBm/e3EOdNg==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
||||
"version": "0.11.1-beta.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.11.1-beta.1.tgz",
|
||||
"integrity": "sha512-Obohy6TH31Uq+fp6ZisHR7iAsvgVPqBExrycVcIJqrLZnIe88N9OWUwBXkmfMAw/2hNJFwD4tU7+4U2FcBWX4w==",
|
||||
"version": "0.13.0-beta.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.13.0-beta.1.tgz",
|
||||
"integrity": "sha512-Pp0O/uhEqof1oLaWrNbv+Ym+q8kBkiCqaA5+2eAZ6a3e9U+Ozkvb0FQrHuyi9adJ5wKQ4NabyQE9BMf2bYpOnQ==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
||||
"version": "0.11.1-beta.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.11.1-beta.1.tgz",
|
||||
"integrity": "sha512-3Meu0dgrzNrnBVVQhxkUSAOhQNmgtKHvOvmrRLUicV+X19hd33udihgxVpZZb9mpXenJ8lZsS+Jq6R0hWqntag==",
|
||||
"version": "0.13.0-beta.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.13.0-beta.1.tgz",
|
||||
"integrity": "sha512-y8nxOye4egfWF5FGED9EfkmZ1O5HnRLU4a61B8m5JSpkivO9v2epTcbYN0yt/7ZFCgtqMfJ8VW4Mi7qQcz3KDA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
||||
"version": "0.11.1-beta.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.11.1-beta.1.tgz",
|
||||
"integrity": "sha512-BafZ9OJPQXsS7JW0weAl12wC+827AiRjfUrE5tvrYWZah2OwCF2U2g6uJ3x4pxfwEGsv5xcHFqgxlS7ttFkh+Q==",
|
||||
"version": "0.13.0-beta.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.13.0-beta.1.tgz",
|
||||
"integrity": "sha512-STMDP9dp0TBLkB3ro+16pKcGy6bmbhRuEZZZ1Tp5P75yTPeVh4zIgWkidMdU1qBbEYM7xacnsp9QAwgLnMU/Ow==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"win32"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.11.1-beta.1",
|
||||
"version": "0.13.0-beta.1",
|
||||
"description": " Serverless, low-latency vector database for AI applications",
|
||||
"main": "dist/index.js",
|
||||
"types": "dist/index.d.ts",
|
||||
@@ -84,14 +84,16 @@
|
||||
"aarch64-apple-darwin": "@lancedb/vectordb-darwin-arm64",
|
||||
"x86_64-unknown-linux-gnu": "@lancedb/vectordb-linux-x64-gnu",
|
||||
"aarch64-unknown-linux-gnu": "@lancedb/vectordb-linux-arm64-gnu",
|
||||
"x86_64-pc-windows-msvc": "@lancedb/vectordb-win32-x64-msvc"
|
||||
"x86_64-pc-windows-msvc": "@lancedb/vectordb-win32-x64-msvc",
|
||||
"aarch64-pc-windows-msvc": "@lancedb/vectordb-win32-arm64-msvc"
|
||||
}
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.11.1-beta.1",
|
||||
"@lancedb/vectordb-darwin-x64": "0.11.1-beta.1",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.11.1-beta.1",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.11.1-beta.1",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.11.1-beta.1"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.13.0-beta.1",
|
||||
"@lancedb/vectordb-darwin-x64": "0.13.0-beta.1",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.13.0-beta.1",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.13.0-beta.1",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.13.0-beta.1",
|
||||
"@lancedb/vectordb-win32-arm64-msvc": "0.13.0-beta.1"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import axios, { type AxiosResponse, type ResponseType } from 'axios'
|
||||
import axios, { type AxiosError, type AxiosResponse, type ResponseType } from 'axios'
|
||||
|
||||
import { tableFromIPC, type Table as ArrowTable } from 'apache-arrow'
|
||||
|
||||
@@ -197,7 +197,7 @@ export class HttpLancedbClient {
|
||||
response = await callWithMiddlewares(req, this._middlewares)
|
||||
return response
|
||||
} catch (err: any) {
|
||||
console.error('error: ', err)
|
||||
console.error(serializeErrorAsJson(err))
|
||||
if (err.response === undefined) {
|
||||
throw new Error(`Network Error: ${err.message as string}`)
|
||||
}
|
||||
@@ -247,7 +247,8 @@ export class HttpLancedbClient {
|
||||
|
||||
// return response
|
||||
} catch (err: any) {
|
||||
console.error('error: ', err)
|
||||
console.error(serializeErrorAsJson(err))
|
||||
|
||||
if (err.response === undefined) {
|
||||
throw new Error(`Network Error: ${err.message as string}`)
|
||||
}
|
||||
@@ -287,3 +288,15 @@ export class HttpLancedbClient {
|
||||
return clone
|
||||
}
|
||||
}
|
||||
|
||||
function serializeErrorAsJson(err: AxiosError) {
|
||||
const error = JSON.parse(JSON.stringify(err, Object.getOwnPropertyNames(err)))
|
||||
error.response = err.response != null
|
||||
? JSON.parse(JSON.stringify(
|
||||
err.response,
|
||||
// config contains the request data, too noisy
|
||||
Object.getOwnPropertyNames(err.response).filter(prop => prop !== 'config')
|
||||
))
|
||||
: null
|
||||
return JSON.stringify({ error })
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "lancedb-nodejs"
|
||||
edition.workspace = true
|
||||
version = "0.11.1-beta.1"
|
||||
version = "0.13.0-beta.1"
|
||||
license.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
@@ -18,7 +18,7 @@ futures.workspace = true
|
||||
lancedb = { path = "../rust/lancedb", features = ["remote"] }
|
||||
napi = { version = "2.16.8", default-features = false, features = [
|
||||
"napi9",
|
||||
"async",
|
||||
"async"
|
||||
] }
|
||||
napi-derive = "2.16.4"
|
||||
# Prevent dynamic linking of lzma, which comes from datafusion
|
||||
|
||||
@@ -402,6 +402,40 @@ describe("When creating an index", () => {
|
||||
expect(rst.numRows).toBe(1);
|
||||
});
|
||||
|
||||
it("should be able to query unindexed data", async () => {
|
||||
await tbl.createIndex("vec");
|
||||
await tbl.add([
|
||||
{
|
||||
id: 300,
|
||||
vec: Array(32)
|
||||
.fill(1)
|
||||
.map(() => Math.random()),
|
||||
tags: [],
|
||||
},
|
||||
]);
|
||||
|
||||
const plan1 = await tbl.query().nearestTo(queryVec).explainPlan(true);
|
||||
expect(plan1).toMatch("LanceScan");
|
||||
|
||||
const plan2 = await tbl
|
||||
.query()
|
||||
.nearestTo(queryVec)
|
||||
.fastSearch()
|
||||
.explainPlan(true);
|
||||
expect(plan2).not.toMatch("LanceScan");
|
||||
});
|
||||
|
||||
it("should be able to query with row id", async () => {
|
||||
const results = await tbl
|
||||
.query()
|
||||
.nearestTo(queryVec)
|
||||
.withRowId()
|
||||
.limit(1)
|
||||
.toArray();
|
||||
expect(results.length).toBe(1);
|
||||
expect(results[0]).toHaveProperty("_rowid");
|
||||
});
|
||||
|
||||
it("should allow parameters to be specified", async () => {
|
||||
await tbl.createIndex("vec", {
|
||||
config: Index.ivfPq({
|
||||
|
||||
@@ -239,6 +239,29 @@ export class QueryBase<NativeQueryType extends NativeQuery | NativeVectorQuery>
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Skip searching un-indexed data. This can make search faster, but will miss
|
||||
* any data that is not yet indexed.
|
||||
*
|
||||
* Use {@link lancedb.Table#optimize} to index all un-indexed data.
|
||||
*/
|
||||
fastSearch(): this {
|
||||
this.doCall((inner: NativeQueryType) => inner.fastSearch());
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Whether to return the row id in the results.
|
||||
*
|
||||
* This column can be used to match results between different queries. For
|
||||
* example, to match results from a full text search and a vector search in
|
||||
* order to perform hybrid search.
|
||||
*/
|
||||
withRowId(): this {
|
||||
this.doCall((inner: NativeQueryType) => inner.withRowId());
|
||||
return this;
|
||||
}
|
||||
|
||||
protected nativeExecute(
|
||||
options?: Partial<QueryExecutionOptions>,
|
||||
): Promise<NativeBatchIterator> {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-arm64",
|
||||
"version": "0.11.1-beta.1",
|
||||
"version": "0.13.0-beta.1",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.darwin-arm64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-x64",
|
||||
"version": "0.11.1-beta.1",
|
||||
"version": "0.13.0-beta.1",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.darwin-x64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||
"version": "0.11.1-beta.1",
|
||||
"version": "0.13.0-beta.1",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||
"version": "0.11.1-beta.1",
|
||||
"version": "0.13.0-beta.1",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-gnu.node",
|
||||
|
||||
3
nodejs/npm/win32-arm64-msvc/README.md
Normal file
3
nodejs/npm/win32-arm64-msvc/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# `@lancedb/lancedb-win32-arm64-msvc`
|
||||
|
||||
This is the **aarch64-pc-windows-msvc** binary for `@lancedb/lancedb`
|
||||
18
nodejs/npm/win32-arm64-msvc/package.json
Normal file
18
nodejs/npm/win32-arm64-msvc/package.json
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||
"version": "0.13.0-beta.1",
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"main": "lancedb.win32-arm64-msvc.node",
|
||||
"files": [
|
||||
"lancedb.win32-arm64-msvc.node"
|
||||
],
|
||||
"license": "Apache 2.0",
|
||||
"engines": {
|
||||
"node": ">= 18"
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||
"version": "0.11.1-beta.1",
|
||||
"version": "0.13.0-beta.1",
|
||||
"os": ["win32"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.win32-x64-msvc.node",
|
||||
|
||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.11.1-beta.1",
|
||||
"version": "0.12.0",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.11.1-beta.1",
|
||||
"version": "0.12.0",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
"vector database",
|
||||
"ann"
|
||||
],
|
||||
"version": "0.11.1-beta.1",
|
||||
"version": "0.13.0-beta.1",
|
||||
"main": "dist/index.js",
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
|
||||
@@ -82,7 +82,7 @@ pub struct OpenTableOptions {
|
||||
#[napi::module_init]
|
||||
fn init() {
|
||||
let env = Env::new()
|
||||
.filter_or("LANCEDB_LOG", "trace")
|
||||
.filter_or("LANCEDB_LOG", "warn")
|
||||
.write_style("LANCEDB_LOG_STYLE");
|
||||
env_logger::init_from_env(env);
|
||||
}
|
||||
|
||||
@@ -80,6 +80,16 @@ impl Query {
|
||||
Ok(VectorQuery { inner })
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn fast_search(&mut self) {
|
||||
self.inner = self.inner.clone().fast_search();
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn with_row_id(&mut self) {
|
||||
self.inner = self.inner.clone().with_row_id();
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn execute(
|
||||
&self,
|
||||
@@ -183,6 +193,16 @@ impl VectorQuery {
|
||||
self.inner = self.inner.clone().offset(offset as usize);
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn fast_search(&mut self) {
|
||||
self.inner = self.inner.clone().fast_search();
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn with_row_id(&mut self) {
|
||||
self.inner = self.inner.clone().with_row_id();
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn execute(
|
||||
&self,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.15.0"
|
||||
current_version = "0.16.0-beta.0"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.15.0"
|
||||
version = "0.16.0-beta.0"
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
license.workspace = true
|
||||
|
||||
@@ -3,13 +3,11 @@ name = "lancedb"
|
||||
# version in Cargo.toml
|
||||
dependencies = [
|
||||
"deprecation",
|
||||
"pylance==0.19.1",
|
||||
"requests>=2.31.0",
|
||||
"nest-asyncio~=1.0",
|
||||
"pylance==0.19.2-beta.3",
|
||||
"tqdm>=4.27.0",
|
||||
"pydantic>=1.10",
|
||||
"attrs>=21.3.0",
|
||||
"packaging",
|
||||
"cachetools",
|
||||
"overrides>=0.7",
|
||||
]
|
||||
description = "lancedb"
|
||||
@@ -61,6 +59,7 @@ dev = ["ruff", "pre-commit"]
|
||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||
clip = ["torch", "pillow", "open-clip"]
|
||||
embeddings = [
|
||||
"requests>=2.31.0",
|
||||
"openai>=1.6.1",
|
||||
"sentence-transformers",
|
||||
"torch",
|
||||
|
||||
@@ -19,12 +19,10 @@ from typing import Dict, Optional, Union, Any
|
||||
|
||||
__version__ = importlib.metadata.version("lancedb")
|
||||
|
||||
from lancedb.remote import ClientConfig
|
||||
|
||||
from ._lancedb import connect as lancedb_connect
|
||||
from .common import URI, sanitize_uri
|
||||
from .db import AsyncConnection, DBConnection, LanceDBConnection
|
||||
from .remote.db import RemoteDBConnection
|
||||
from .remote import ClientConfig
|
||||
from .schema import vector
|
||||
from .table import AsyncTable
|
||||
|
||||
@@ -37,6 +35,7 @@ def connect(
|
||||
host_override: Optional[str] = None,
|
||||
read_consistency_interval: Optional[timedelta] = None,
|
||||
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
|
||||
client_config: Union[ClientConfig, Dict[str, Any], None] = None,
|
||||
**kwargs: Any,
|
||||
) -> DBConnection:
|
||||
"""Connect to a LanceDB database.
|
||||
@@ -64,14 +63,10 @@ def connect(
|
||||
the last check, then the table will be checked for updates. Note: this
|
||||
consistency only applies to read operations. Write operations are
|
||||
always consistent.
|
||||
request_thread_pool: int or ThreadPoolExecutor, optional
|
||||
The thread pool to use for making batch requests to the LanceDB Cloud API.
|
||||
If an integer, then a ThreadPoolExecutor will be created with that
|
||||
number of threads. If None, then a ThreadPoolExecutor will be created
|
||||
with the default number of threads. If a ThreadPoolExecutor, then that
|
||||
executor will be used for making requests. This is for LanceDB Cloud
|
||||
only and is only used when making batch requests (i.e., passing in
|
||||
multiple queries to the search method at once).
|
||||
client_config: ClientConfig or dict, optional
|
||||
Configuration options for the LanceDB Cloud HTTP client. If a dict, then
|
||||
the keys are the attributes of the ClientConfig class. If None, then the
|
||||
default configuration is used.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -94,6 +89,8 @@ def connect(
|
||||
conn : DBConnection
|
||||
A connection to a LanceDB database.
|
||||
"""
|
||||
from .remote.db import RemoteDBConnection
|
||||
|
||||
if isinstance(uri, str) and uri.startswith("db://"):
|
||||
if api_key is None:
|
||||
api_key = os.environ.get("LANCEDB_API_KEY")
|
||||
@@ -106,7 +103,9 @@ def connect(
|
||||
api_key,
|
||||
region,
|
||||
host_override,
|
||||
# TODO: remove this (deprecation warning downstream)
|
||||
request_thread_pool=request_thread_pool,
|
||||
client_config=client_config,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -36,6 +36,8 @@ class Connection(object):
|
||||
data_storage_version: Optional[str] = None,
|
||||
enable_v2_manifest_paths: Optional[bool] = None,
|
||||
) -> Table: ...
|
||||
async def rename_table(self, old_name: str, new_name: str) -> None: ...
|
||||
async def drop_table(self, name: str) -> None: ...
|
||||
|
||||
class Table:
|
||||
def name(self) -> str: ...
|
||||
|
||||
@@ -817,6 +817,18 @@ class AsyncConnection(object):
|
||||
table = await self._inner.open_table(name, storage_options, index_cache_size)
|
||||
return AsyncTable(table)
|
||||
|
||||
async def rename_table(self, old_name: str, new_name: str):
|
||||
"""Rename a table in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
old_name: str
|
||||
The current name of the table.
|
||||
new_name: str
|
||||
The new name of the table.
|
||||
"""
|
||||
await self._inner.rename_table(old_name, new_name)
|
||||
|
||||
async def drop_table(self, name: str):
|
||||
"""Drop a table from the database.
|
||||
|
||||
|
||||
@@ -27,3 +27,4 @@ from .imagebind import ImageBindEmbeddings
|
||||
from .utils import with_embeddings
|
||||
from .jinaai import JinaEmbeddings
|
||||
from .watsonx import WatsonxEmbeddings
|
||||
from .voyageai import VoyageAIEmbeddingFunction
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
|
||||
import os
|
||||
import io
|
||||
import requests
|
||||
import base64
|
||||
from urllib.parse import urlparse
|
||||
from pathlib import Path
|
||||
@@ -226,6 +225,8 @@ class JinaEmbeddings(EmbeddingFunction):
|
||||
return [result["embedding"] for result in sorted_embeddings]
|
||||
|
||||
def _init_client(self):
|
||||
import requests
|
||||
|
||||
if JinaEmbeddings._session is None:
|
||||
if self.api_key is None and os.environ.get("JINA_API_KEY") is None:
|
||||
api_key_not_found_help("jina")
|
||||
|
||||
127
python/python/lancedb/embeddings/voyageai.py
Normal file
127
python/python/lancedb/embeddings/voyageai.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import ClassVar, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import api_key_not_found_help, TEXT
|
||||
|
||||
|
||||
@register("voyageai")
|
||||
class VoyageAIEmbeddingFunction(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the VoyageAI API
|
||||
|
||||
https://docs.voyageai.com/docs/embeddings
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the model to use. List of acceptable models:
|
||||
|
||||
* voyage-3
|
||||
* voyage-3-lite
|
||||
* voyage-finance-2
|
||||
* voyage-multilingual-2
|
||||
* voyage-law-2
|
||||
* voyage-code-2
|
||||
|
||||
|
||||
Examples
|
||||
--------
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
|
||||
voyageai = EmbeddingFunctionRegistry
|
||||
.get_instance()
|
||||
.get("voyageai")
|
||||
.create(name="voyage-3")
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = voyageai.SourceField()
|
||||
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
||||
|
||||
data = [ { "text": "hello world" },
|
||||
{ "text": "goodbye world" }]
|
||||
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(data)
|
||||
|
||||
"""
|
||||
|
||||
name: str
|
||||
client: ClassVar = None
|
||||
|
||||
def ndims(self):
|
||||
if self.name == "voyage-3-lite":
|
||||
return 512
|
||||
elif self.name == "voyage-code-2":
|
||||
return 1536
|
||||
elif self.name in [
|
||||
"voyage-3",
|
||||
"voyage-finance-2",
|
||||
"voyage-multilingual-2",
|
||||
"voyage-law-2",
|
||||
]:
|
||||
return 1024
|
||||
else:
|
||||
raise ValueError(f"Model {self.name} not supported")
|
||||
|
||||
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||
return self.compute_source_embeddings(query, input_type="query")
|
||||
|
||||
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
texts = self.sanitize_input(texts)
|
||||
input_type = (
|
||||
kwargs.get("input_type") or "document"
|
||||
) # assume source input type if not passed by `compute_query_embeddings`
|
||||
return self.generate_embeddings(texts, input_type=input_type)
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray], *args, **kwargs
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
input_type: Optional[str]
|
||||
|
||||
truncation: Optional[bool]
|
||||
"""
|
||||
VoyageAIEmbeddingFunction._init_client()
|
||||
rs = VoyageAIEmbeddingFunction.client.embed(
|
||||
texts=texts, model=self.name, **kwargs
|
||||
)
|
||||
|
||||
return [emb for emb in rs.embeddings]
|
||||
|
||||
@staticmethod
|
||||
def _init_client():
|
||||
if VoyageAIEmbeddingFunction.client is None:
|
||||
voyageai = attempt_import_or_raise("voyageai")
|
||||
if os.environ.get("VOYAGE_API_KEY") is None:
|
||||
api_key_not_found_help("voyageai")
|
||||
VoyageAIEmbeddingFunction.client = voyageai.Client(
|
||||
os.environ["VOYAGE_API_KEY"]
|
||||
)
|
||||
@@ -110,7 +110,16 @@ class FTS:
|
||||
remove_stop_words: bool = False,
|
||||
ascii_folding: bool = False,
|
||||
):
|
||||
self._inner = LanceDbIndex.fts(with_position=with_position)
|
||||
self._inner = LanceDbIndex.fts(
|
||||
with_position=with_position,
|
||||
base_tokenizer=base_tokenizer,
|
||||
language=language,
|
||||
max_token_length=max_token_length,
|
||||
lower_case=lower_case,
|
||||
stem=stem,
|
||||
remove_stop_words=remove_stop_words,
|
||||
ascii_folding=ascii_folding,
|
||||
)
|
||||
|
||||
|
||||
class HnswPq:
|
||||
@@ -467,6 +476,8 @@ class IvfPq:
|
||||
|
||||
The default value is 256.
|
||||
"""
|
||||
if distance_type is not None:
|
||||
distance_type = distance_type.lower()
|
||||
self._inner = LanceDbIndex.ivf_pq(
|
||||
distance_type=distance_type,
|
||||
num_partitions=num_partitions,
|
||||
|
||||
@@ -481,6 +481,7 @@ class LanceQueryBuilder(ABC):
|
||||
>>> plan = table.search(query).explain_plan(True)
|
||||
>>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
|
||||
ProjectionExec: expr=[vector@0 as vector, _distance@2 as _distance]
|
||||
GlobalLimitExec: skip=0, fetch=10
|
||||
FilterExec: _distance@2 IS NOT NULL
|
||||
SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], preserve_partitioning=[false]
|
||||
KNNVectorDistance: metric=l2
|
||||
@@ -500,7 +501,16 @@ class LanceQueryBuilder(ABC):
|
||||
nearest={
|
||||
"column": self._vector_column,
|
||||
"q": self._query,
|
||||
"k": self._limit,
|
||||
"metric": self._metric,
|
||||
"nprobes": self._nprobes,
|
||||
"refine_factor": self._refine_factor,
|
||||
},
|
||||
prefilter=self._prefilter,
|
||||
filter=self._str_query,
|
||||
limit=self._limit,
|
||||
with_row_id=self._with_row_id,
|
||||
offset=self._offset,
|
||||
).explain_plan(verbose)
|
||||
|
||||
def vector(self, vector: Union[np.ndarray, list]) -> LanceQueryBuilder:
|
||||
@@ -1315,6 +1325,48 @@ class AsyncQueryBase(object):
|
||||
self._inner.offset(offset)
|
||||
return self
|
||||
|
||||
def fast_search(self) -> AsyncQuery:
|
||||
"""
|
||||
Skip searching un-indexed data.
|
||||
|
||||
This can make queries faster, but will miss any data that has not been
|
||||
indexed.
|
||||
|
||||
!!! tip
|
||||
You can add new data into an existing index by calling
|
||||
[AsyncTable.optimize][lancedb.table.AsyncTable.optimize].
|
||||
"""
|
||||
self._inner.fast_search()
|
||||
return self
|
||||
|
||||
def with_row_id(self) -> AsyncQuery:
|
||||
"""
|
||||
Include the _rowid column in the results.
|
||||
"""
|
||||
self._inner.with_row_id()
|
||||
return self
|
||||
|
||||
def postfilter(self) -> AsyncQuery:
|
||||
"""
|
||||
If this is called then filtering will happen after the search instead of
|
||||
before.
|
||||
By default filtering will be performed before the search. This is how
|
||||
filtering is typically understood to work. This prefilter step does add some
|
||||
additional latency. Creating a scalar index on the filter column(s) can
|
||||
often improve this latency. However, sometimes a filter is too complex or
|
||||
scalar indices cannot be applied to the column. In these cases postfiltering
|
||||
can be used instead of prefiltering to improve latency.
|
||||
Post filtering applies the filter to the results of the search. This
|
||||
means we only run the filter on a much smaller set of data. However, it can
|
||||
cause the query to return fewer than `limit` results (or even no results) if
|
||||
none of the nearest results match the filter.
|
||||
Post filtering happens during the "refine stage" (described in more detail in
|
||||
@see {@link VectorQuery#refineFactor}). This means that setting a higher refine
|
||||
factor can often help restore some of the results lost by post filtering.
|
||||
"""
|
||||
self._inner.postfilter()
|
||||
return self
|
||||
|
||||
async def to_batches(
|
||||
self, *, max_batch_length: Optional[int] = None
|
||||
) -> AsyncRecordBatchReader:
|
||||
@@ -1618,30 +1670,6 @@ class AsyncVectorQuery(AsyncQueryBase):
|
||||
self._inner.distance_type(distance_type)
|
||||
return self
|
||||
|
||||
def postfilter(self) -> AsyncVectorQuery:
|
||||
"""
|
||||
If this is called then filtering will happen after the vector search instead of
|
||||
before.
|
||||
|
||||
By default filtering will be performed before the vector search. This is how
|
||||
filtering is typically understood to work. This prefilter step does add some
|
||||
additional latency. Creating a scalar index on the filter column(s) can
|
||||
often improve this latency. However, sometimes a filter is too complex or
|
||||
scalar indices cannot be applied to the column. In these cases postfiltering
|
||||
can be used instead of prefiltering to improve latency.
|
||||
|
||||
Post filtering applies the filter to the results of the vector search. This
|
||||
means we only run the filter on a much smaller set of data. However, it can
|
||||
cause the query to return fewer than `limit` results (or even no results) if
|
||||
none of the nearest results match the filter.
|
||||
|
||||
Post filtering happens during the "refine stage" (described in more detail in
|
||||
@see {@link VectorQuery#refineFactor}). This means that setting a higher refine
|
||||
factor can often help restore some of the results lost by post filtering.
|
||||
"""
|
||||
self._inner.postfilter()
|
||||
return self
|
||||
|
||||
def bypass_vector_index(self) -> AsyncVectorQuery:
|
||||
"""
|
||||
If this is called then any vector index is skipped
|
||||
|
||||
@@ -11,62 +11,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import timedelta
|
||||
from typing import List, Optional
|
||||
|
||||
import attrs
|
||||
from lancedb import __version__
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from lancedb.common import VECTOR_COLUMN_NAME
|
||||
|
||||
__all__ = ["LanceDBClient", "VectorQuery", "VectorQueryResult"]
|
||||
|
||||
|
||||
class VectorQuery(BaseModel):
|
||||
# vector to search for
|
||||
vector: List[float]
|
||||
|
||||
# sql filter to refine the query with
|
||||
filter: Optional[str] = None
|
||||
|
||||
# top k results to return
|
||||
k: int
|
||||
|
||||
# # metrics
|
||||
_metric: str = "L2"
|
||||
|
||||
# which columns to return in the results
|
||||
columns: Optional[List[str]] = None
|
||||
|
||||
# optional query parameters for tuning the results,
|
||||
# e.g. `{"nprobes": "10", "refine_factor": "10"}`
|
||||
nprobes: int = 10
|
||||
|
||||
refine_factor: Optional[int] = None
|
||||
|
||||
vector_column: str = VECTOR_COLUMN_NAME
|
||||
|
||||
fast_search: bool = False
|
||||
|
||||
|
||||
@attrs.define
|
||||
class VectorQueryResult:
|
||||
# for now the response is directly seralized into a pandas dataframe
|
||||
tbl: pa.Table
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
return self.tbl
|
||||
|
||||
|
||||
class LanceDBClient(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
"""Query the LanceDB server for the given table and query."""
|
||||
pass
|
||||
__all__ = ["TimeoutConfig", "RetryConfig", "ClientConfig"]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -165,8 +116,8 @@ class RetryConfig:
|
||||
@dataclass
|
||||
class ClientConfig:
|
||||
user_agent: str = f"LanceDB-Python-Client/{__version__}"
|
||||
retry_config: Optional[RetryConfig] = None
|
||||
timeout_config: Optional[TimeoutConfig] = None
|
||||
retry_config: RetryConfig = field(default_factory=RetryConfig)
|
||||
timeout_config: Optional[TimeoutConfig] = field(default_factory=TimeoutConfig)
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.retry_config, dict):
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Iterable, Union
|
||||
import pyarrow as pa
|
||||
|
||||
|
||||
def to_ipc_binary(table: Union[pa.Table, Iterable[pa.RecordBatch]]) -> bytes:
|
||||
"""Serialize a PyArrow Table to IPC binary."""
|
||||
sink = pa.BufferOutputStream()
|
||||
if isinstance(table, Iterable):
|
||||
table = pa.Table.from_batches(table)
|
||||
with pa.ipc.new_stream(sink, table.schema) as writer:
|
||||
writer.write_table(table)
|
||||
return sink.getvalue().to_pybytes()
|
||||
@@ -1,269 +0,0 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import attrs
|
||||
import pyarrow as pa
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3 import Retry
|
||||
|
||||
from lancedb.common import Credential
|
||||
from lancedb.remote import VectorQuery, VectorQueryResult
|
||||
from lancedb.remote.connection_timeout import LanceDBClientHTTPAdapterFactory
|
||||
from lancedb.remote.errors import LanceDBClientError
|
||||
|
||||
ARROW_STREAM_CONTENT_TYPE = "application/vnd.apache.arrow.stream"
|
||||
|
||||
|
||||
def _check_not_closed(f):
|
||||
@functools.wraps(f)
|
||||
def wrapped(self, *args, **kwargs):
|
||||
if self.closed:
|
||||
raise ValueError("Connection is closed")
|
||||
return f(self, *args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def _read_ipc(resp: requests.Response) -> pa.Table:
|
||||
resp_body = resp.content
|
||||
with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader:
|
||||
return reader.read_all()
|
||||
|
||||
|
||||
@attrs.define(slots=False)
|
||||
class RestfulLanceDBClient:
|
||||
db_name: str
|
||||
region: str
|
||||
api_key: Credential
|
||||
host_override: Optional[str] = attrs.field(default=None)
|
||||
|
||||
closed: bool = attrs.field(default=False, init=False)
|
||||
|
||||
connection_timeout: float = attrs.field(default=120.0, kw_only=True)
|
||||
read_timeout: float = attrs.field(default=300.0, kw_only=True)
|
||||
|
||||
@functools.cached_property
|
||||
def session(self) -> requests.Session:
|
||||
sess = requests.Session()
|
||||
|
||||
retry_adapter_instance = retry_adapter(retry_adapter_options())
|
||||
sess.mount(urljoin(self.url, "/v1/table/"), retry_adapter_instance)
|
||||
|
||||
adapter_class = LanceDBClientHTTPAdapterFactory()
|
||||
sess.mount("https://", adapter_class())
|
||||
return sess
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return (
|
||||
self.host_override
|
||||
or f"https://{self.db_name}.{self.region}.api.lancedb.com"
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.close()
|
||||
return False # Do not suppress exceptions
|
||||
|
||||
def close(self):
|
||||
self.session.close()
|
||||
self.closed = True
|
||||
|
||||
@functools.cached_property
|
||||
def headers(self) -> Dict[str, str]:
|
||||
headers = {
|
||||
"x-api-key": self.api_key,
|
||||
}
|
||||
if self.region == "local": # Local test mode
|
||||
headers["Host"] = f"{self.db_name}.{self.region}.api.lancedb.com"
|
||||
if self.host_override:
|
||||
headers["x-lancedb-database"] = self.db_name
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
def _check_status(resp: requests.Response):
|
||||
# Leaving request id empty for now, as we'll be replacing this impl
|
||||
# with the Rust one shortly.
|
||||
if resp.status_code == 404:
|
||||
raise LanceDBClientError(
|
||||
f"Not found: {resp.text}", request_id="", status_code=404
|
||||
)
|
||||
elif 400 <= resp.status_code < 500:
|
||||
raise LanceDBClientError(
|
||||
f"Bad Request: {resp.status_code}, error: {resp.text}",
|
||||
request_id="",
|
||||
status_code=resp.status_code,
|
||||
)
|
||||
elif 500 <= resp.status_code < 600:
|
||||
raise LanceDBClientError(
|
||||
f"Internal Server Error: {resp.status_code}, error: {resp.text}",
|
||||
request_id="",
|
||||
status_code=resp.status_code,
|
||||
)
|
||||
elif resp.status_code != 200:
|
||||
raise LanceDBClientError(
|
||||
f"Unknown Error: {resp.status_code}, error: {resp.text}",
|
||||
request_id="",
|
||||
status_code=resp.status_code,
|
||||
)
|
||||
|
||||
@_check_not_closed
|
||||
def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None):
|
||||
"""Send a GET request and returns the deserialized response payload."""
|
||||
if isinstance(params, BaseModel):
|
||||
params: Dict[str, Any] = params.dict(exclude_none=True)
|
||||
with self.session.get(
|
||||
urljoin(self.url, uri),
|
||||
params=params,
|
||||
headers=self.headers,
|
||||
timeout=(self.connection_timeout, self.read_timeout),
|
||||
) as resp:
|
||||
self._check_status(resp)
|
||||
return resp.json()
|
||||
|
||||
@_check_not_closed
|
||||
def post(
|
||||
self,
|
||||
uri: str,
|
||||
data: Optional[Union[Dict[str, Any], BaseModel, bytes]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
content_type: Optional[str] = None,
|
||||
deserialize: Callable = lambda resp: resp.json(),
|
||||
request_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Send a POST request and returns the deserialized response payload.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
uri : str
|
||||
The uri to send the POST request to.
|
||||
data: Union[Dict[str, Any], BaseModel]
|
||||
request_id: Optional[str]
|
||||
Optional client side request id to be sent in the request headers.
|
||||
|
||||
"""
|
||||
if isinstance(data, BaseModel):
|
||||
data: Dict[str, Any] = data.dict(exclude_none=True)
|
||||
if isinstance(data, bytes):
|
||||
req_kwargs = {"data": data}
|
||||
else:
|
||||
req_kwargs = {"json": data}
|
||||
|
||||
headers = self.headers.copy()
|
||||
if content_type is not None:
|
||||
headers["content-type"] = content_type
|
||||
if request_id is not None:
|
||||
headers["x-request-id"] = request_id
|
||||
with self.session.post(
|
||||
urljoin(self.url, uri),
|
||||
headers=headers,
|
||||
params=params,
|
||||
timeout=(self.connection_timeout, self.read_timeout),
|
||||
**req_kwargs,
|
||||
) as resp:
|
||||
self._check_status(resp)
|
||||
return deserialize(resp)
|
||||
|
||||
@_check_not_closed
|
||||
def list_tables(self, limit: int, page_token: Optional[str] = None) -> List[str]:
|
||||
"""List all tables in the database."""
|
||||
if page_token is None:
|
||||
page_token = ""
|
||||
json = self.get("/v1/table/", {"limit": limit, "page_token": page_token})
|
||||
return json["tables"]
|
||||
|
||||
@_check_not_closed
|
||||
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
"""Query a table."""
|
||||
tbl = self.post(f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc)
|
||||
return VectorQueryResult(tbl)
|
||||
|
||||
def mount_retry_adapter_for_table(self, table_name: str) -> None:
|
||||
"""
|
||||
Adds an http adapter to session that will retry retryable requests to the table.
|
||||
"""
|
||||
retry_options = retry_adapter_options(methods=["GET", "POST"])
|
||||
retry_adapter_instance = retry_adapter(retry_options)
|
||||
session = self.session
|
||||
|
||||
session.mount(
|
||||
urljoin(self.url, f"/v1/table/{table_name}/query/"), retry_adapter_instance
|
||||
)
|
||||
session.mount(
|
||||
urljoin(self.url, f"/v1/table/{table_name}/describe/"),
|
||||
retry_adapter_instance,
|
||||
)
|
||||
session.mount(
|
||||
urljoin(self.url, f"/v1/table/{table_name}/index/list/"),
|
||||
retry_adapter_instance,
|
||||
)
|
||||
|
||||
|
||||
def retry_adapter_options(methods=["GET"]) -> Dict[str, Any]:
|
||||
return {
|
||||
"retries": int(os.environ.get("LANCE_CLIENT_MAX_RETRIES", "3")),
|
||||
"connect_retries": int(os.environ.get("LANCE_CLIENT_CONNECT_RETRIES", "3")),
|
||||
"read_retries": int(os.environ.get("LANCE_CLIENT_READ_RETRIES", "3")),
|
||||
"backoff_factor": float(
|
||||
os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_FACTOR", "0.25")
|
||||
),
|
||||
"backoff_jitter": float(
|
||||
os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_JITTER", "0.25")
|
||||
),
|
||||
"statuses": [
|
||||
int(i.strip())
|
||||
for i in os.environ.get(
|
||||
"LANCE_CLIENT_RETRY_STATUSES", "429, 500, 502, 503"
|
||||
).split(",")
|
||||
],
|
||||
"methods": methods,
|
||||
}
|
||||
|
||||
|
||||
def retry_adapter(options: Dict[str, Any]) -> HTTPAdapter:
|
||||
total_retries = options["retries"]
|
||||
connect_retries = options["connect_retries"]
|
||||
read_retries = options["read_retries"]
|
||||
backoff_factor = options["backoff_factor"]
|
||||
backoff_jitter = options["backoff_jitter"]
|
||||
statuses = options["statuses"]
|
||||
methods = frozenset(options["methods"])
|
||||
logging.debug(
|
||||
f"Setting up retry adapter with {total_retries} retries," # noqa G003
|
||||
+ f"connect retries {connect_retries}, read retries {read_retries},"
|
||||
+ f"backoff factor {backoff_factor}, statuses {statuses}, "
|
||||
+ f"methods {methods}"
|
||||
)
|
||||
|
||||
return HTTPAdapter(
|
||||
max_retries=Retry(
|
||||
total=total_retries,
|
||||
connect=connect_retries,
|
||||
read=read_retries,
|
||||
backoff_factor=backoff_factor,
|
||||
backoff_jitter=backoff_jitter,
|
||||
status_forcelist=statuses,
|
||||
allowed_methods=methods,
|
||||
)
|
||||
)
|
||||
@@ -1,115 +0,0 @@
|
||||
# Copyright 2024 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This module contains an adapter that will close connections if they have not been
|
||||
# used before a certain timeout. This is necessary because some load balancers will
|
||||
# close connections after a certain amount of time, but the request module may not yet
|
||||
# have received the FIN/ACK and will try to reuse the connection.
|
||||
#
|
||||
# TODO some of the code here can be simplified if/when this PR is merged:
|
||||
# https://github.com/urllib3/urllib3/pull/3275
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3.connection import HTTPSConnection
|
||||
from urllib3.connectionpool import HTTPSConnectionPool
|
||||
from urllib3.poolmanager import PoolManager
|
||||
|
||||
|
||||
def get_client_connection_timeout() -> int:
|
||||
return int(os.environ.get("LANCE_CLIENT_CONNECTION_TIMEOUT", "300"))
|
||||
|
||||
|
||||
class LanceDBHTTPSConnection(HTTPSConnection):
|
||||
"""
|
||||
HTTPSConnection that tracks the last time it was used.
|
||||
"""
|
||||
|
||||
idle_timeout: datetime.timedelta
|
||||
last_activity: datetime.datetime
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.last_activity = datetime.datetime.now()
|
||||
|
||||
def request(self, *args, **kwargs):
|
||||
self.last_activity = datetime.datetime.now()
|
||||
super().request(*args, **kwargs)
|
||||
|
||||
def is_expired(self):
|
||||
return datetime.datetime.now() - self.last_activity > self.idle_timeout
|
||||
|
||||
|
||||
def LanceDBHTTPSConnectionPoolFactory(client_idle_timeout: int):
|
||||
"""
|
||||
Creates a connection pool class that can be used to close idle connections.
|
||||
"""
|
||||
|
||||
class LanceDBHTTPSConnectionPool(HTTPSConnectionPool):
|
||||
# override the connection class
|
||||
ConnectionCls = LanceDBHTTPSConnection
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _get_conn(self, timeout: float | None = None):
|
||||
logging.debug("Getting https connection")
|
||||
conn = super()._get_conn(timeout)
|
||||
if conn.is_expired():
|
||||
logging.debug("Closing expired connection")
|
||||
conn.close()
|
||||
|
||||
return conn
|
||||
|
||||
def _new_conn(self):
|
||||
conn = super()._new_conn()
|
||||
conn.idle_timeout = datetime.timedelta(seconds=client_idle_timeout)
|
||||
return conn
|
||||
|
||||
return LanceDBHTTPSConnectionPool
|
||||
|
||||
|
||||
class LanceDBClientPoolManager(PoolManager):
|
||||
def __init__(
|
||||
self, client_idle_timeout: int, num_pools: int, maxsize: int, **kwargs
|
||||
):
|
||||
super().__init__(num_pools=num_pools, maxsize=maxsize, **kwargs)
|
||||
# inject our connection pool impl
|
||||
connection_pool_class = LanceDBHTTPSConnectionPoolFactory(
|
||||
client_idle_timeout=client_idle_timeout
|
||||
)
|
||||
self.pool_classes_by_scheme["https"] = connection_pool_class
|
||||
|
||||
|
||||
def LanceDBClientHTTPAdapterFactory():
|
||||
"""
|
||||
Creates an HTTPAdapter class that can be used to close idle connections
|
||||
"""
|
||||
|
||||
# closure over the timeout
|
||||
client_idle_timeout = get_client_connection_timeout()
|
||||
|
||||
class LanceDBClientRequestHTTPAdapter(HTTPAdapter):
|
||||
def init_poolmanager(self, connections, maxsize, block=False):
|
||||
# inject our pool manager impl
|
||||
self.poolmanager = LanceDBClientPoolManager(
|
||||
client_idle_timeout=client_idle_timeout,
|
||||
num_pools=connections,
|
||||
maxsize=maxsize,
|
||||
block=block,
|
||||
)
|
||||
|
||||
return LanceDBClientRequestHTTPAdapter
|
||||
@@ -11,13 +11,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Iterable, List, Optional, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
import warnings
|
||||
|
||||
from cachetools import TTLCache
|
||||
from lancedb import connect_async
|
||||
from lancedb.remote import ClientConfig
|
||||
import pyarrow as pa
|
||||
from overrides import override
|
||||
|
||||
@@ -25,10 +28,8 @@ from ..common import DATA
|
||||
from ..db import DBConnection
|
||||
from ..embeddings import EmbeddingFunctionConfig
|
||||
from ..pydantic import LanceModel
|
||||
from ..table import Table, sanitize_create_table
|
||||
from ..table import Table
|
||||
from ..util import validate_table_name
|
||||
from .arrow import to_ipc_binary
|
||||
from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient
|
||||
|
||||
|
||||
class RemoteDBConnection(DBConnection):
|
||||
@@ -41,26 +42,70 @@ class RemoteDBConnection(DBConnection):
|
||||
region: str,
|
||||
host_override: Optional[str] = None,
|
||||
request_thread_pool: Optional[ThreadPoolExecutor] = None,
|
||||
connection_timeout: float = 120.0,
|
||||
read_timeout: float = 300.0,
|
||||
client_config: Union[ClientConfig, Dict[str, Any], None] = None,
|
||||
connection_timeout: Optional[float] = None,
|
||||
read_timeout: Optional[float] = None,
|
||||
):
|
||||
"""Connect to a remote LanceDB database."""
|
||||
|
||||
if isinstance(client_config, dict):
|
||||
client_config = ClientConfig(**client_config)
|
||||
elif client_config is None:
|
||||
client_config = ClientConfig()
|
||||
|
||||
# These are legacy options from the old Python-based client. We keep them
|
||||
# here for backwards compatibility, but will remove them in a future release.
|
||||
if request_thread_pool is not None:
|
||||
warnings.warn(
|
||||
"request_thread_pool is no longer used and will be removed in "
|
||||
"a future release.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
if connection_timeout is not None:
|
||||
warnings.warn(
|
||||
"connection_timeout is deprecated and will be removed in a future "
|
||||
"release. Please use client_config.timeout_config.connect_timeout "
|
||||
"instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
client_config.timeout_config.connect_timeout = timedelta(
|
||||
seconds=connection_timeout
|
||||
)
|
||||
|
||||
if read_timeout is not None:
|
||||
warnings.warn(
|
||||
"read_timeout is deprecated and will be removed in a future release. "
|
||||
"Please use client_config.timeout_config.read_timeout instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
client_config.timeout_config.read_timeout = timedelta(seconds=read_timeout)
|
||||
|
||||
parsed = urlparse(db_url)
|
||||
if parsed.scheme != "db":
|
||||
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
|
||||
self._uri = str(db_url)
|
||||
self.db_name = parsed.netloc
|
||||
self.api_key = api_key
|
||||
self._client = RestfulLanceDBClient(
|
||||
self.db_name,
|
||||
region,
|
||||
api_key,
|
||||
host_override,
|
||||
connection_timeout=connection_timeout,
|
||||
read_timeout=read_timeout,
|
||||
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply()
|
||||
try:
|
||||
self._loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
|
||||
self.client_config = client_config
|
||||
|
||||
self._conn = self._loop.run_until_complete(
|
||||
connect_async(
|
||||
db_url,
|
||||
api_key=api_key,
|
||||
region=region,
|
||||
host_override=host_override,
|
||||
client_config=client_config,
|
||||
)
|
||||
)
|
||||
self._request_thread_pool = request_thread_pool
|
||||
self._table_cache = TTLCache(maxsize=10000, ttl=300)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"RemoteConnect(name={self.db_name})"
|
||||
@@ -82,16 +127,9 @@ class RemoteDBConnection(DBConnection):
|
||||
-------
|
||||
An iterator of table names.
|
||||
"""
|
||||
while True:
|
||||
result = self._client.list_tables(limit, page_token)
|
||||
|
||||
if len(result) > 0:
|
||||
page_token = result[len(result) - 1]
|
||||
else:
|
||||
break
|
||||
for item in result:
|
||||
self._table_cache[item] = True
|
||||
yield item
|
||||
return self._loop.run_until_complete(
|
||||
self._conn.table_names(start_after=page_token, limit=limit)
|
||||
)
|
||||
|
||||
@override
|
||||
def open_table(self, name: str, *, index_cache_size: Optional[int] = None) -> Table:
|
||||
@@ -108,20 +146,14 @@ class RemoteDBConnection(DBConnection):
|
||||
"""
|
||||
from .table import RemoteTable
|
||||
|
||||
self._client.mount_retry_adapter_for_table(name)
|
||||
|
||||
if index_cache_size is not None:
|
||||
logging.info(
|
||||
"index_cache_size is ignored in LanceDb Cloud"
|
||||
" (there is no local cache to configure)"
|
||||
)
|
||||
|
||||
# check if table exists
|
||||
if self._table_cache.get(name) is None:
|
||||
self._client.post(f"/v1/table/{name}/describe/")
|
||||
self._table_cache[name] = True
|
||||
|
||||
return RemoteTable(self, name)
|
||||
table = self._loop.run_until_complete(self._conn.open_table(name))
|
||||
return RemoteTable(table, self.db_name, self._loop)
|
||||
|
||||
@override
|
||||
def create_table(
|
||||
@@ -233,27 +265,20 @@ class RemoteDBConnection(DBConnection):
|
||||
"Please vote https://github.com/lancedb/lancedb/issues/626 "
|
||||
"for this feature."
|
||||
)
|
||||
if mode is not None:
|
||||
logging.warning("mode is not yet supported on LanceDB Cloud.")
|
||||
|
||||
data, schema = sanitize_create_table(
|
||||
data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||
)
|
||||
|
||||
from .table import RemoteTable
|
||||
|
||||
data = to_ipc_binary(data)
|
||||
request_id = uuid.uuid4().hex
|
||||
|
||||
self._client.post(
|
||||
f"/v1/table/{name}/create/",
|
||||
data=data,
|
||||
request_id=request_id,
|
||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||
table = self._loop.run_until_complete(
|
||||
self._conn.create_table(
|
||||
name,
|
||||
data,
|
||||
mode=mode,
|
||||
schema=schema,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
)
|
||||
|
||||
self._table_cache[name] = True
|
||||
return RemoteTable(self, name)
|
||||
return RemoteTable(table, self.db_name, self._loop)
|
||||
|
||||
@override
|
||||
def drop_table(self, name: str):
|
||||
@@ -264,11 +289,7 @@ class RemoteDBConnection(DBConnection):
|
||||
name: str
|
||||
The name of the table.
|
||||
"""
|
||||
|
||||
self._client.post(
|
||||
f"/v1/table/{name}/drop/",
|
||||
)
|
||||
self._table_cache.pop(name, default=None)
|
||||
self._loop.run_until_complete(self._conn.drop_table(name))
|
||||
|
||||
@override
|
||||
def rename_table(self, cur_name: str, new_name: str):
|
||||
@@ -281,12 +302,7 @@ class RemoteDBConnection(DBConnection):
|
||||
new_name: str
|
||||
The new name of the table.
|
||||
"""
|
||||
self._client.post(
|
||||
f"/v1/table/{cur_name}/rename/",
|
||||
data={"new_table_name": new_name},
|
||||
)
|
||||
self._table_cache.pop(cur_name, default=None)
|
||||
self._table_cache[new_name] = True
|
||||
self._loop.run_until_complete(self._conn.rename_table(cur_name, new_name))
|
||||
|
||||
async def close(self):
|
||||
"""Close the connection to the database."""
|
||||
|
||||
@@ -11,53 +11,57 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from datetime import timedelta
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from concurrent.futures import Future
|
||||
from functools import cached_property
|
||||
from typing import Dict, Iterable, List, Optional, Union, Literal
|
||||
|
||||
from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfPq, LabelList
|
||||
import pyarrow as pa
|
||||
from lance import json_to_schema
|
||||
|
||||
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from lancedb.merge import LanceMergeInsertBuilder
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
|
||||
from ..query import LanceVectorQueryBuilder, LanceQueryBuilder
|
||||
from ..table import Query, Table, _sanitize_data
|
||||
from ..util import value_to_sql, infer_vector_column_name
|
||||
from .arrow import to_ipc_binary
|
||||
from .client import ARROW_STREAM_CONTENT_TYPE
|
||||
from .db import RemoteDBConnection
|
||||
from ..table import AsyncTable, Query, Table
|
||||
|
||||
|
||||
class RemoteTable(Table):
|
||||
def __init__(self, conn: RemoteDBConnection, name: str):
|
||||
self._conn = conn
|
||||
self.name = name
|
||||
def __init__(
|
||||
self,
|
||||
table: AsyncTable,
|
||||
db_name: str,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
):
|
||||
self._loop = loop
|
||||
self._table = table
|
||||
self.db_name = db_name
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""The name of the table"""
|
||||
return self._table.name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"RemoteTable({self._conn.db_name}.{self.name})"
|
||||
return f"RemoteTable({self.db_name}.{self.name})"
|
||||
|
||||
def __len__(self) -> int:
|
||||
self.count_rows(None)
|
||||
|
||||
@cached_property
|
||||
@property
|
||||
def schema(self) -> pa.Schema:
|
||||
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#)
|
||||
of this Table
|
||||
|
||||
"""
|
||||
resp = self._conn._client.post(f"/v1/table/{self.name}/describe/")
|
||||
schema = json_to_schema(resp["schema"])
|
||||
return schema
|
||||
return self._loop.run_until_complete(self._table.schema())
|
||||
|
||||
@property
|
||||
def version(self) -> int:
|
||||
"""Get the current version of the table"""
|
||||
resp = self._conn._client.post(f"/v1/table/{self.name}/describe/")
|
||||
return resp["version"]
|
||||
return self._loop.run_until_complete(self._table.version())
|
||||
|
||||
@cached_property
|
||||
def embedding_functions(self) -> dict:
|
||||
@@ -84,20 +88,18 @@ class RemoteTable(Table):
|
||||
|
||||
def list_indices(self):
|
||||
"""List all the indices on the table"""
|
||||
resp = self._conn._client.post(f"/v1/table/{self.name}/index/list/")
|
||||
return resp
|
||||
return self._loop.run_until_complete(self._table.list_indices())
|
||||
|
||||
def index_stats(self, index_uuid: str):
|
||||
"""List all the stats of a specified index"""
|
||||
resp = self._conn._client.post(
|
||||
f"/v1/table/{self.name}/index/{index_uuid}/stats/"
|
||||
)
|
||||
return resp
|
||||
return self._loop.run_until_complete(self._table.index_stats(index_uuid))
|
||||
|
||||
def create_scalar_index(
|
||||
self,
|
||||
column: str,
|
||||
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST", "scalar"] = "scalar",
|
||||
*,
|
||||
replace: bool = False,
|
||||
):
|
||||
"""Creates a scalar index
|
||||
Parameters
|
||||
@@ -107,36 +109,51 @@ class RemoteTable(Table):
|
||||
or string column.
|
||||
index_type : str
|
||||
The index type of the scalar index. Must be "scalar" (BTREE),
|
||||
"BTREE", "BITMAP", or "LABEL_LIST"
|
||||
"BTREE", "BITMAP", or "LABEL_LIST",
|
||||
replace : bool
|
||||
If True, replace the existing index with the new one.
|
||||
"""
|
||||
if index_type == "scalar" or index_type == "BTREE":
|
||||
config = BTree()
|
||||
elif index_type == "BITMAP":
|
||||
config = Bitmap()
|
||||
elif index_type == "LABEL_LIST":
|
||||
config = LabelList()
|
||||
else:
|
||||
raise ValueError(f"Unknown index type: {index_type}")
|
||||
|
||||
data = {
|
||||
"column": column,
|
||||
"index_type": index_type,
|
||||
"replace": True,
|
||||
}
|
||||
resp = self._conn._client.post(
|
||||
f"/v1/table/{self.name}/create_scalar_index/", data=data
|
||||
self._loop.run_until_complete(
|
||||
self._table.create_index(column, config=config, replace=replace)
|
||||
)
|
||||
|
||||
return resp
|
||||
|
||||
def create_fts_index(
|
||||
self,
|
||||
column: str,
|
||||
*,
|
||||
replace: bool = False,
|
||||
with_position: bool = True,
|
||||
# tokenizer configs:
|
||||
base_tokenizer: str = "simple",
|
||||
language: str = "English",
|
||||
max_token_length: Optional[int] = 40,
|
||||
lower_case: bool = True,
|
||||
stem: bool = False,
|
||||
remove_stop_words: bool = False,
|
||||
ascii_folding: bool = False,
|
||||
):
|
||||
data = {
|
||||
"column": column,
|
||||
"index_type": "FTS",
|
||||
"replace": replace,
|
||||
}
|
||||
resp = self._conn._client.post(
|
||||
f"/v1/table/{self.name}/create_index/", data=data
|
||||
config = FTS(
|
||||
with_position=with_position,
|
||||
base_tokenizer=base_tokenizer,
|
||||
language=language,
|
||||
max_token_length=max_token_length,
|
||||
lower_case=lower_case,
|
||||
stem=stem,
|
||||
remove_stop_words=remove_stop_words,
|
||||
ascii_folding=ascii_folding,
|
||||
)
|
||||
self._loop.run_until_complete(
|
||||
self._table.create_index(column, config=config, replace=replace)
|
||||
)
|
||||
return resp
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
@@ -204,17 +221,22 @@ class RemoteTable(Table):
|
||||
"Existing indexes will always be replaced."
|
||||
)
|
||||
|
||||
data = {
|
||||
"column": vector_column_name,
|
||||
"index_type": index_type,
|
||||
"metric_type": metric,
|
||||
"index_cache_size": index_cache_size,
|
||||
}
|
||||
resp = self._conn._client.post(
|
||||
f"/v1/table/{self.name}/create_index/", data=data
|
||||
)
|
||||
index_type = index_type.upper()
|
||||
if index_type == "VECTOR" or index_type == "IVF_PQ":
|
||||
config = IvfPq(distance_type=metric)
|
||||
elif index_type == "IVF_HNSW_PQ":
|
||||
config = HnswPq(distance_type=metric)
|
||||
elif index_type == "IVF_HNSW_SQ":
|
||||
config = HnswSq(distance_type=metric)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown vector index type: {index_type}. Valid options are"
|
||||
" 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
|
||||
)
|
||||
|
||||
return resp
|
||||
self._loop.run_until_complete(
|
||||
self._table.create_index(vector_column_name, config=config)
|
||||
)
|
||||
|
||||
def add(
|
||||
self,
|
||||
@@ -246,22 +268,10 @@ class RemoteTable(Table):
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
|
||||
"""
|
||||
data, _ = _sanitize_data(
|
||||
data,
|
||||
self.schema,
|
||||
metadata=self.schema.metadata,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
payload = to_ipc_binary(data)
|
||||
|
||||
request_id = uuid.uuid4().hex
|
||||
|
||||
self._conn._client.post(
|
||||
f"/v1/table/{self.name}/insert/",
|
||||
data=payload,
|
||||
params={"request_id": request_id, "mode": mode},
|
||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||
self._loop.run_until_complete(
|
||||
self._table.add(
|
||||
data, mode=mode, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||
)
|
||||
)
|
||||
|
||||
def search(
|
||||
@@ -337,12 +347,6 @@ class RemoteTable(Table):
|
||||
# empty query builder is not supported in saas, raise error
|
||||
if query is None and query_type != "hybrid":
|
||||
raise ValueError("Empty query is not supported")
|
||||
vector_column_name = infer_vector_column_name(
|
||||
schema=self.schema,
|
||||
query_type=query_type,
|
||||
query=query,
|
||||
vector_column_name=vector_column_name,
|
||||
)
|
||||
|
||||
return LanceQueryBuilder.create(
|
||||
self,
|
||||
@@ -356,37 +360,9 @@ class RemoteTable(Table):
|
||||
def _execute_query(
|
||||
self, query: Query, batch_size: Optional[int] = None
|
||||
) -> pa.RecordBatchReader:
|
||||
if (
|
||||
query.vector is not None
|
||||
and len(query.vector) > 0
|
||||
and not isinstance(query.vector[0], float)
|
||||
):
|
||||
if self._conn._request_thread_pool is None:
|
||||
|
||||
def submit(name, q):
|
||||
f = Future()
|
||||
f.set_result(self._conn._client.query(name, q))
|
||||
return f
|
||||
|
||||
else:
|
||||
|
||||
def submit(name, q):
|
||||
return self._conn._request_thread_pool.submit(
|
||||
self._conn._client.query, name, q
|
||||
)
|
||||
|
||||
results = []
|
||||
for v in query.vector:
|
||||
v = list(v)
|
||||
q = query.copy()
|
||||
q.vector = v
|
||||
results.append(submit(self.name, q))
|
||||
return pa.concat_tables(
|
||||
[add_index(r.result().to_arrow(), i) for i, r in enumerate(results)]
|
||||
).to_reader()
|
||||
else:
|
||||
result = self._conn._client.query(self.name, query)
|
||||
return result.to_arrow().to_reader()
|
||||
return self._loop.run_until_complete(
|
||||
self._table._execute_query(query, batch_size=batch_size)
|
||||
)
|
||||
|
||||
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
||||
"""Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder]
|
||||
@@ -403,42 +379,8 @@ class RemoteTable(Table):
|
||||
on_bad_vectors: str,
|
||||
fill_value: float,
|
||||
):
|
||||
data, _ = _sanitize_data(
|
||||
new_data,
|
||||
self.schema,
|
||||
metadata=None,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
payload = to_ipc_binary(data)
|
||||
|
||||
params = {}
|
||||
if len(merge._on) != 1:
|
||||
raise ValueError(
|
||||
"RemoteTable only supports a single on key in merge_insert"
|
||||
)
|
||||
params["on"] = merge._on[0]
|
||||
params["when_matched_update_all"] = str(merge._when_matched_update_all).lower()
|
||||
if merge._when_matched_update_all_condition is not None:
|
||||
params["when_matched_update_all_filt"] = (
|
||||
merge._when_matched_update_all_condition
|
||||
)
|
||||
params["when_not_matched_insert_all"] = str(
|
||||
merge._when_not_matched_insert_all
|
||||
).lower()
|
||||
params["when_not_matched_by_source_delete"] = str(
|
||||
merge._when_not_matched_by_source_delete
|
||||
).lower()
|
||||
if merge._when_not_matched_by_source_condition is not None:
|
||||
params["when_not_matched_by_source_delete_filt"] = (
|
||||
merge._when_not_matched_by_source_condition
|
||||
)
|
||||
|
||||
self._conn._client.post(
|
||||
f"/v1/table/{self.name}/merge_insert/",
|
||||
data=payload,
|
||||
params=params,
|
||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||
self._loop.run_until_complete(
|
||||
self._table._do_merge(merge, new_data, on_bad_vectors, fill_value)
|
||||
)
|
||||
|
||||
def delete(self, predicate: str):
|
||||
@@ -488,8 +430,7 @@ class RemoteTable(Table):
|
||||
x vector _distance # doctest: +SKIP
|
||||
0 2 [3.0, 4.0] 85.0 # doctest: +SKIP
|
||||
"""
|
||||
payload = {"predicate": predicate}
|
||||
self._conn._client.post(f"/v1/table/{self.name}/delete/", data=payload)
|
||||
self._loop.run_until_complete(self._table.delete(predicate))
|
||||
|
||||
def update(
|
||||
self,
|
||||
@@ -539,18 +480,9 @@ class RemoteTable(Table):
|
||||
2 2 [10.0, 10.0] # doctest: +SKIP
|
||||
|
||||
"""
|
||||
if values is not None and values_sql is not None:
|
||||
raise ValueError("Only one of values or values_sql can be provided")
|
||||
if values is None and values_sql is None:
|
||||
raise ValueError("Either values or values_sql must be provided")
|
||||
|
||||
if values is not None:
|
||||
updates = [[k, value_to_sql(v)] for k, v in values.items()]
|
||||
else:
|
||||
updates = [[k, v] for k, v in values_sql.items()]
|
||||
|
||||
payload = {"predicate": where, "updates": updates}
|
||||
self._conn._client.post(f"/v1/table/{self.name}/update/", data=payload)
|
||||
self._loop.run_until_complete(
|
||||
self._table.update(where=where, updates=values, updates_sql=values_sql)
|
||||
)
|
||||
|
||||
def cleanup_old_versions(self, *_):
|
||||
"""cleanup_old_versions() is not supported on the LanceDB cloud"""
|
||||
@@ -564,12 +496,21 @@ class RemoteTable(Table):
|
||||
"compact_files() is not supported on the LanceDB cloud"
|
||||
)
|
||||
|
||||
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
payload = {"predicate": filter}
|
||||
resp = self._conn._client.post(
|
||||
f"/v1/table/{self.name}/count_rows/", data=payload
|
||||
def optimize(
|
||||
self,
|
||||
*,
|
||||
cleanup_older_than: Optional[timedelta] = None,
|
||||
delete_unverified: bool = False,
|
||||
):
|
||||
"""optimize() is not supported on the LanceDB cloud.
|
||||
Indices are optimized automatically."""
|
||||
raise NotImplementedError(
|
||||
"optimize() is not supported on the LanceDB cloud. "
|
||||
"Indices are optimized automatically."
|
||||
)
|
||||
return resp
|
||||
|
||||
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
return self._loop.run_until_complete(self._table.count_rows(filter))
|
||||
|
||||
def add_columns(self, transforms: Dict[str, str]):
|
||||
raise NotImplementedError(
|
||||
|
||||
@@ -7,6 +7,7 @@ from .openai import OpenaiReranker
|
||||
from .jinaai import JinaReranker
|
||||
from .rrf import RRFReranker
|
||||
from .answerdotai import AnswerdotaiRerankers
|
||||
from .voyageai import VoyageAIReranker
|
||||
|
||||
__all__ = [
|
||||
"Reranker",
|
||||
@@ -18,4 +19,5 @@ __all__ = [
|
||||
"JinaReranker",
|
||||
"RRFReranker",
|
||||
"AnswerdotaiRerankers",
|
||||
"VoyageAIReranker",
|
||||
]
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import requests
|
||||
from functools import cached_property
|
||||
from typing import Union
|
||||
|
||||
@@ -57,6 +56,8 @@ class JinaReranker(Reranker):
|
||||
|
||||
@cached_property
|
||||
def _client(self):
|
||||
import requests
|
||||
|
||||
if os.environ.get("JINA_API_KEY") is None and self.api_key is None:
|
||||
raise ValueError(
|
||||
"JINA_API_KEY not set. Either set it in your environment or \
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from numpy import NaN
|
||||
from numpy import nan
|
||||
import pyarrow as pa
|
||||
|
||||
from .base import Reranker
|
||||
@@ -71,7 +71,7 @@ class LinearCombinationReranker(Reranker):
|
||||
elif self.score == "all":
|
||||
results = results.append_column(
|
||||
"_distance",
|
||||
pa.array([NaN] * len(fts_results), type=pa.float32()),
|
||||
pa.array([nan] * len(fts_results), type=pa.float32()),
|
||||
)
|
||||
return results
|
||||
|
||||
@@ -92,7 +92,7 @@ class LinearCombinationReranker(Reranker):
|
||||
elif self.score == "all":
|
||||
results = results.append_column(
|
||||
"_score",
|
||||
pa.array([NaN] * len(vector_results), type=pa.float32()),
|
||||
pa.array([nan] * len(vector_results), type=pa.float32()),
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
133
python/python/lancedb/rerankers/voyageai.py
Normal file
133
python/python/lancedb/rerankers/voyageai.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import Union, Optional
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import Reranker
|
||||
|
||||
|
||||
class VoyageAIReranker(Reranker):
|
||||
"""
|
||||
Reranks the results using the VoyageAI Rerank API.
|
||||
https://docs.voyageai.com/docs/reranker
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name : str, default "rerank-english-v2.0"
|
||||
The name of the cross encoder model to use. Available voyageai models are:
|
||||
- rerank-2
|
||||
- rerank-2-lite
|
||||
column : str, default "text"
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
top_n : int, default None
|
||||
The number of results to return. If None, will return all results.
|
||||
return_score : str, default "relevance"
|
||||
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||
api_key : str, default None
|
||||
The API key to use. If None, will use the OPENAI_API_KEY environment variable.
|
||||
truncation : Optional[bool], default None
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
column: str = "text",
|
||||
top_n: Optional[int] = None,
|
||||
return_score="relevance",
|
||||
api_key: Optional[str] = None,
|
||||
truncation: Optional[bool] = True,
|
||||
):
|
||||
super().__init__(return_score)
|
||||
self.model_name = model_name
|
||||
self.column = column
|
||||
self.top_n = top_n
|
||||
self.api_key = api_key
|
||||
self.truncation = truncation
|
||||
|
||||
@cached_property
|
||||
def _client(self):
|
||||
voyageai = attempt_import_or_raise("voyageai")
|
||||
if os.environ.get("VOYAGE_API_KEY") is None and self.api_key is None:
|
||||
raise ValueError(
|
||||
"VOYAGE_API_KEY not set. Either set it in your environment or \
|
||||
pass it as `api_key` argument to the VoyageAIReranker."
|
||||
)
|
||||
return voyageai.Client(
|
||||
api_key=os.environ.get("VOYAGE_API_KEY") or self.api_key,
|
||||
)
|
||||
|
||||
def _rerank(self, result_set: pa.Table, query: str):
|
||||
docs = result_set[self.column].to_pylist()
|
||||
response = self._client.rerank(
|
||||
query=query,
|
||||
documents=docs,
|
||||
top_k=self.top_n,
|
||||
model=self.model_name,
|
||||
truncation=self.truncation,
|
||||
)
|
||||
results = (
|
||||
response.results
|
||||
) # returns list (text, idx, relevance) attributes sorted descending by score
|
||||
indices, scores = list(
|
||||
zip(*[(result.index, result.relevance_score) for result in results])
|
||||
) # tuples
|
||||
result_set = result_set.take(list(indices))
|
||||
# add the scores
|
||||
result_set = result_set.append_column(
|
||||
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||
)
|
||||
|
||||
return result_set
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
combined_results = self._rerank(combined_results, query)
|
||||
if self.score == "relevance":
|
||||
combined_results = self._keep_relevance_score(combined_results)
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"return_score='all' not implemented for voyageai reranker"
|
||||
)
|
||||
return combined_results
|
||||
|
||||
def rerank_vector(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
):
|
||||
result_set = self._rerank(vector_results, query)
|
||||
if self.score == "relevance":
|
||||
result_set = result_set.drop_columns(["_distance"])
|
||||
|
||||
return result_set
|
||||
|
||||
def rerank_fts(
|
||||
self,
|
||||
query: str,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
result_set = self._rerank(fts_results, query)
|
||||
if self.score == "relevance":
|
||||
result_set = result_set.drop_columns(["_score"])
|
||||
|
||||
return result_set
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -32,7 +33,7 @@ import pyarrow.fs as pa_fs
|
||||
from lance import LanceDataset
|
||||
from lance.dependencies import _check_for_hugging_face
|
||||
|
||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME, sanitize_uri
|
||||
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
from .merge import LanceMergeInsertBuilder
|
||||
from .pydantic import LanceModel, model_to_dict
|
||||
@@ -57,12 +58,14 @@ from .util import (
|
||||
)
|
||||
from .index import lang_mapping
|
||||
|
||||
from ._lancedb import connect as lancedb_connect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import PIL
|
||||
from lance.dataset import CleanupStats, ReaderLike
|
||||
from ._lancedb import Table as LanceDBTable, OptimizeStats
|
||||
from .db import LanceDBConnection
|
||||
from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList, FTS
|
||||
from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList, FTS, HnswPq, HnswSq
|
||||
|
||||
pd = safe_import_pandas()
|
||||
pl = safe_import_polars()
|
||||
@@ -893,6 +896,55 @@ class Table(ABC):
|
||||
For most cases, the default should be fine.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def optimize(
|
||||
self,
|
||||
*,
|
||||
cleanup_older_than: Optional[timedelta] = None,
|
||||
delete_unverified: bool = False,
|
||||
):
|
||||
"""
|
||||
Optimize the on-disk data and indices for better performance.
|
||||
|
||||
Modeled after ``VACUUM`` in PostgreSQL.
|
||||
|
||||
Optimization covers three operations:
|
||||
|
||||
* Compaction: Merges small files into larger ones
|
||||
* Prune: Removes old versions of the dataset
|
||||
* Index: Optimizes the indices, adding new data to existing indices
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cleanup_older_than: timedelta, optional default 7 days
|
||||
All files belonging to versions older than this will be removed. Set
|
||||
to 0 days to remove all versions except the latest. The latest version
|
||||
is never removed.
|
||||
delete_unverified: bool, default False
|
||||
Files leftover from a failed transaction may appear to be part of an
|
||||
in-progress operation (e.g. appending new data) and these files will not
|
||||
be deleted unless they are at least 7 days old. If delete_unverified is True
|
||||
then these files will be deleted regardless of their age.
|
||||
|
||||
Experimental API
|
||||
----------------
|
||||
|
||||
The optimization process is undergoing active development and may change.
|
||||
Our goal with these changes is to improve the performance of optimization and
|
||||
reduce the complexity.
|
||||
|
||||
That being said, it is essential today to run optimize if you want the best
|
||||
performance. It should be stable and safe to use in production, but it our
|
||||
hope that the API may be simplified (or not even need to be called) in the
|
||||
future.
|
||||
|
||||
The frequency an application shoudl call optimize is based on the frequency of
|
||||
data modifications. If data is frequently added, deleted, or updated then
|
||||
optimize should be run frequently. A good rule of thumb is to run optimize if
|
||||
you have added or modified 100,000 or more records or run more than 20 data
|
||||
modification operations.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_columns(self, transforms: Dict[str, str]):
|
||||
"""
|
||||
@@ -948,7 +1000,9 @@ class Table(ABC):
|
||||
return _table_uri(self._conn.uri, self.name)
|
||||
|
||||
def _get_fts_index_path(self) -> Tuple[str, pa_fs.FileSystem, bool]:
|
||||
if get_uri_scheme(self._dataset_uri) != "file":
|
||||
from .remote.table import RemoteTable
|
||||
|
||||
if isinstance(self, RemoteTable) or get_uri_scheme(self._dataset_uri) != "file":
|
||||
return ("", None, False)
|
||||
path = join_uri(self._dataset_uri, "_indices", "fts")
|
||||
fs, path = fs_from_uri(path)
|
||||
@@ -1969,6 +2023,83 @@ class LanceTable(Table):
|
||||
"""
|
||||
return self.to_lance().optimize.compact_files(*args, **kwargs)
|
||||
|
||||
def optimize(
|
||||
self,
|
||||
*,
|
||||
cleanup_older_than: Optional[timedelta] = None,
|
||||
delete_unverified: bool = False,
|
||||
):
|
||||
"""
|
||||
Optimize the on-disk data and indices for better performance.
|
||||
|
||||
Modeled after ``VACUUM`` in PostgreSQL.
|
||||
|
||||
Optimization covers three operations:
|
||||
|
||||
* Compaction: Merges small files into larger ones
|
||||
* Prune: Removes old versions of the dataset
|
||||
* Index: Optimizes the indices, adding new data to existing indices
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cleanup_older_than: timedelta, optional default 7 days
|
||||
All files belonging to versions older than this will be removed. Set
|
||||
to 0 days to remove all versions except the latest. The latest version
|
||||
is never removed.
|
||||
delete_unverified: bool, default False
|
||||
Files leftover from a failed transaction may appear to be part of an
|
||||
in-progress operation (e.g. appending new data) and these files will not
|
||||
be deleted unless they are at least 7 days old. If delete_unverified is True
|
||||
then these files will be deleted regardless of their age.
|
||||
|
||||
Experimental API
|
||||
----------------
|
||||
|
||||
The optimization process is undergoing active development and may change.
|
||||
Our goal with these changes is to improve the performance of optimization and
|
||||
reduce the complexity.
|
||||
|
||||
That being said, it is essential today to run optimize if you want the best
|
||||
performance. It should be stable and safe to use in production, but it our
|
||||
hope that the API may be simplified (or not even need to be called) in the
|
||||
future.
|
||||
|
||||
The frequency an application shoudl call optimize is based on the frequency of
|
||||
data modifications. If data is frequently added, deleted, or updated then
|
||||
optimize should be run frequently. A good rule of thumb is to run optimize if
|
||||
you have added or modified 100,000 or more records or run more than 20 data
|
||||
modification operations.
|
||||
"""
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
raise AssertionError(
|
||||
"Synchronous method called in asynchronous context. "
|
||||
"If you are writing an asynchronous application "
|
||||
"then please use the asynchronous APIs"
|
||||
)
|
||||
|
||||
except RuntimeError:
|
||||
asyncio.run(
|
||||
self._async_optimize(
|
||||
cleanup_older_than=cleanup_older_than,
|
||||
delete_unverified=delete_unverified,
|
||||
)
|
||||
)
|
||||
self.checkout_latest()
|
||||
|
||||
async def _async_optimize(
|
||||
self,
|
||||
cleanup_older_than: Optional[timedelta] = None,
|
||||
delete_unverified: bool = False,
|
||||
):
|
||||
conn = await lancedb_connect(
|
||||
sanitize_uri(self._conn.uri),
|
||||
)
|
||||
table = AsyncTable(await conn.open_table(self.name))
|
||||
return await table.optimize(
|
||||
cleanup_older_than=cleanup_older_than, delete_unverified=delete_unverified
|
||||
)
|
||||
|
||||
def add_columns(self, transforms: Dict[str, str]):
|
||||
self._dataset_mut.add_columns(transforms)
|
||||
|
||||
@@ -2382,7 +2513,9 @@ class AsyncTable:
|
||||
column: str,
|
||||
*,
|
||||
replace: Optional[bool] = None,
|
||||
config: Optional[Union[IvfPq, BTree, Bitmap, LabelList, FTS]] = None,
|
||||
config: Optional[
|
||||
Union[IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS]
|
||||
] = None,
|
||||
):
|
||||
"""Create an index to speed up queries
|
||||
|
||||
@@ -2535,7 +2668,44 @@ class AsyncTable:
|
||||
async def _execute_query(
|
||||
self, query: Query, batch_size: Optional[int] = None
|
||||
) -> pa.RecordBatchReader:
|
||||
pass
|
||||
# The sync remote table calls into this method, so we need to map the
|
||||
# query to the async version of the query and run that here. This is only
|
||||
# used for that code path right now.
|
||||
async_query = self.query().limit(query.k)
|
||||
if query.offset > 0:
|
||||
async_query = async_query.offset(query.offset)
|
||||
if query.columns:
|
||||
async_query = async_query.select(query.columns)
|
||||
if query.filter:
|
||||
async_query = async_query.where(query.filter)
|
||||
if query.fast_search:
|
||||
async_query = async_query.fast_search()
|
||||
if query.with_row_id:
|
||||
async_query = async_query.with_row_id()
|
||||
|
||||
if query.vector:
|
||||
async_query = (
|
||||
async_query.nearest_to(query.vector)
|
||||
.distance_type(query.metric)
|
||||
.nprobes(query.nprobes)
|
||||
)
|
||||
if query.refine_factor:
|
||||
async_query = async_query.refine_factor(query.refine_factor)
|
||||
if query.vector_column:
|
||||
async_query = async_query.column(query.vector_column)
|
||||
|
||||
if not query.prefilter:
|
||||
async_query = async_query.postfilter()
|
||||
|
||||
if isinstance(query.full_text_query, str):
|
||||
async_query = async_query.nearest_to_text(query.full_text_query)
|
||||
elif isinstance(query.full_text_query, dict):
|
||||
fts_query = query.full_text_query["query"]
|
||||
fts_columns = query.full_text_query.get("columns", []) or []
|
||||
async_query = async_query.nearest_to_text(fts_query, columns=fts_columns)
|
||||
|
||||
table = await async_query.to_arrow()
|
||||
return table.to_reader()
|
||||
|
||||
async def _do_merge(
|
||||
self,
|
||||
@@ -2781,7 +2951,7 @@ class AsyncTable:
|
||||
cleanup_older_than = round(cleanup_older_than.total_seconds() * 1000)
|
||||
return await self._inner.optimize(cleanup_older_than, delete_unverified)
|
||||
|
||||
async def list_indices(self) -> IndexConfig:
|
||||
async def list_indices(self) -> Iterable[IndexConfig]:
|
||||
"""
|
||||
List all indices that have been created with Self::create_index
|
||||
"""
|
||||
@@ -2865,3 +3035,8 @@ class IndexStatistics:
|
||||
]
|
||||
distance_type: Optional[Literal["l2", "cosine", "dot"]] = None
|
||||
num_indices: Optional[int] = None
|
||||
|
||||
# This exists for backwards compatibility with an older API, which returned
|
||||
# a dictionary instead of a class.
|
||||
def __getitem__(self, key):
|
||||
return getattr(self, key)
|
||||
|
||||
@@ -196,6 +196,7 @@ def test_add_optional_vector(tmp_path):
|
||||
"ollama",
|
||||
"cohere",
|
||||
"instructor",
|
||||
"voyageai",
|
||||
],
|
||||
)
|
||||
def test_embedding_function_safe_model_dump(embedding_type):
|
||||
|
||||
@@ -18,7 +18,6 @@ import lancedb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import requests
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
@@ -108,6 +107,7 @@ def test_basic_text_embeddings(alias, tmp_path):
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_openclip(tmp_path):
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
@@ -481,3 +481,22 @@ def test_ollama_embedding(tmp_path):
|
||||
json.dumps(dumped_model)
|
||||
except TypeError:
|
||||
pytest.fail("Failed to JSON serialize the dumped model")
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||
)
|
||||
def test_voyageai_embedding_function():
|
||||
voyageai = get_registry().get("voyageai").create(name="voyage-3", max_retries=0)
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = voyageai.SourceField()
|
||||
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
||||
|
||||
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||
db = lancedb.connect("~/lancedb")
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
|
||||
|
||||
@@ -235,6 +235,29 @@ async def test_search_fts_async(async_table):
|
||||
results = await async_table.query().nearest_to_text("puppy").limit(5).to_list()
|
||||
assert len(results) == 5
|
||||
|
||||
expected_count = await async_table.count_rows(
|
||||
"count > 5000 and contains(text, 'puppy')"
|
||||
)
|
||||
expected_count = min(expected_count, 10)
|
||||
|
||||
limited_results_pre_filter = await (
|
||||
async_table.query()
|
||||
.nearest_to_text("puppy")
|
||||
.where("count > 5000")
|
||||
.limit(10)
|
||||
.to_list()
|
||||
)
|
||||
assert len(limited_results_pre_filter) == expected_count
|
||||
limited_results_post_filter = await (
|
||||
async_table.query()
|
||||
.nearest_to_text("puppy")
|
||||
.where("count > 5000")
|
||||
.limit(10)
|
||||
.postfilter()
|
||||
.to_list()
|
||||
)
|
||||
assert len(limited_results_post_filter) <= expected_count
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_fts_specify_column_async(async_table):
|
||||
|
||||
@@ -49,7 +49,7 @@ async def test_create_scalar_index(some_table: AsyncTable):
|
||||
# Can recreate if replace=True
|
||||
await some_table.create_index("id", replace=True)
|
||||
indices = await some_table.list_indices()
|
||||
assert str(indices) == '[Index(BTree, columns=["id"])]'
|
||||
assert str(indices) == '[Index(BTree, columns=["id"], name="id_idx")]'
|
||||
assert len(indices) == 1
|
||||
assert indices[0].index_type == "BTree"
|
||||
assert indices[0].columns == ["id"]
|
||||
@@ -64,7 +64,7 @@ async def test_create_scalar_index(some_table: AsyncTable):
|
||||
async def test_create_bitmap_index(some_table: AsyncTable):
|
||||
await some_table.create_index("id", config=Bitmap())
|
||||
indices = await some_table.list_indices()
|
||||
assert str(indices) == '[Index(Bitmap, columns=["id"])]'
|
||||
assert str(indices) == '[Index(Bitmap, columns=["id"], name="id_idx")]'
|
||||
indices = await some_table.list_indices()
|
||||
assert len(indices) == 1
|
||||
index_name = indices[0].name
|
||||
@@ -80,7 +80,7 @@ async def test_create_bitmap_index(some_table: AsyncTable):
|
||||
async def test_create_label_list_index(some_table: AsyncTable):
|
||||
await some_table.create_index("tags", config=LabelList())
|
||||
indices = await some_table.list_indices()
|
||||
assert str(indices) == '[Index(LabelList, columns=["tags"])]'
|
||||
assert str(indices) == '[Index(LabelList, columns=["tags"], name="tags_idx")]'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -17,6 +17,7 @@ from typing import Optional
|
||||
|
||||
import lance
|
||||
import lancedb
|
||||
from lancedb.index import IvfPq
|
||||
import numpy as np
|
||||
import pandas.testing as tm
|
||||
import pyarrow as pa
|
||||
@@ -330,6 +331,12 @@ async def test_query_async(table_async: AsyncTable):
|
||||
# Also check an empty query
|
||||
await check_query(table_async.query().where("id < 0"), expected_num_rows=0)
|
||||
|
||||
# with row id
|
||||
await check_query(
|
||||
table_async.query().select(["id", "vector"]).with_row_id(),
|
||||
expected_columns=["id", "vector", "_rowid"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_to_arrow_async(table_async: AsyncTable):
|
||||
@@ -358,6 +365,25 @@ async def test_query_to_pandas_async(table_async: AsyncTable):
|
||||
assert df.shape == (0, 4)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fast_search_async(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
vectors = pa.FixedShapeTensorArray.from_numpy_ndarray(
|
||||
np.random.rand(256, 32)
|
||||
).storage
|
||||
table = await db.create_table("test", pa.table({"vector": vectors}))
|
||||
await table.create_index(
|
||||
"vector", config=IvfPq(num_partitions=1, num_sub_vectors=1)
|
||||
)
|
||||
await table.add(pa.table({"vector": vectors}))
|
||||
|
||||
q = [1.0] * 32
|
||||
plan = await table.query().nearest_to(q).explain_plan(True)
|
||||
assert "LanceScan" in plan
|
||||
plan = await table.query().nearest_to(q).fast_search().explain_plan(True)
|
||||
assert "LanceScan" not in plan
|
||||
|
||||
|
||||
def test_explain_plan(table):
|
||||
q = LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||
plan = q.explain_plan(verbose=True)
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import attrs
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from lancedb.remote.client import RestfulLanceDBClient, VectorQuery
|
||||
|
||||
|
||||
@attrs.define
|
||||
class MockLanceDBServer:
|
||||
runner: web.AppRunner = attrs.field(init=False)
|
||||
site: web.TCPSite = attrs.field(init=False)
|
||||
|
||||
async def query_handler(self, request: web.Request) -> web.Response:
|
||||
table_name = request.match_info["table_name"]
|
||||
assert table_name == "test_table"
|
||||
|
||||
await request.json()
|
||||
# TODO: do some matching
|
||||
|
||||
vecs = pd.Series([np.random.rand(128) for x in range(10)], name="vector")
|
||||
ids = pd.Series(range(10), name="id")
|
||||
df = pd.DataFrame([vecs, ids]).T
|
||||
|
||||
batch = pa.RecordBatch.from_pandas(
|
||||
df,
|
||||
schema=pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), 128)),
|
||||
pa.field("id", pa.int64()),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
sink = pa.BufferOutputStream()
|
||||
with pa.ipc.new_file(sink, batch.schema) as writer:
|
||||
writer.write_batch(batch)
|
||||
|
||||
return web.Response(body=sink.getvalue().to_pybytes())
|
||||
|
||||
async def setup(self):
|
||||
app = web.Application()
|
||||
app.add_routes([web.post("/table/{table_name}", self.query_handler)])
|
||||
self.runner = web.AppRunner(app)
|
||||
await self.runner.setup()
|
||||
self.site = web.TCPSite(self.runner, "localhost", 8111)
|
||||
|
||||
async def start(self):
|
||||
await self.site.start()
|
||||
|
||||
async def stop(self):
|
||||
await self.runner.cleanup()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="flaky somehow, fix later")
|
||||
@pytest.mark.asyncio
|
||||
async def test_e2e_with_mock_server():
|
||||
mock_server = MockLanceDBServer()
|
||||
await mock_server.setup()
|
||||
await mock_server.start()
|
||||
|
||||
try:
|
||||
with RestfulLanceDBClient("lancedb+http://localhost:8111") as client:
|
||||
df = (
|
||||
await client.query(
|
||||
"test_table",
|
||||
VectorQuery(
|
||||
vector=np.random.rand(128).tolist(),
|
||||
k=10,
|
||||
_metric="L2",
|
||||
columns=["id", "vector"],
|
||||
),
|
||||
)
|
||||
).to_pandas()
|
||||
|
||||
assert "vector" in df.columns
|
||||
assert "id" in df.columns
|
||||
|
||||
assert client.closed
|
||||
finally:
|
||||
# make sure we don't leak resources
|
||||
await mock_server.stop()
|
||||
@@ -2,91 +2,19 @@
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import contextlib
|
||||
from datetime import timedelta
|
||||
import http.server
|
||||
import json
|
||||
import threading
|
||||
from unittest.mock import MagicMock
|
||||
import uuid
|
||||
|
||||
import lancedb
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.remote import ClientConfig
|
||||
from lancedb.remote.errors import HttpError, RetryError
|
||||
import pyarrow as pa
|
||||
from lancedb.remote.client import VectorQuery, VectorQueryResult
|
||||
import pytest
|
||||
|
||||
|
||||
class FakeLanceDBClient:
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
assert table_name == "test"
|
||||
t = pa.schema([]).empty_table()
|
||||
return VectorQueryResult(t)
|
||||
|
||||
def post(self, path: str):
|
||||
pass
|
||||
|
||||
def mount_retry_adapter_for_table(self, table_name: str):
|
||||
pass
|
||||
|
||||
|
||||
def test_remote_db():
|
||||
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
|
||||
setattr(conn, "_client", FakeLanceDBClient())
|
||||
|
||||
table = conn["test"]
|
||||
table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
|
||||
table.search([1.0, 2.0]).to_pandas()
|
||||
|
||||
|
||||
def test_create_empty_table():
|
||||
client = MagicMock()
|
||||
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
|
||||
|
||||
conn._client = client
|
||||
|
||||
schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
|
||||
|
||||
client.post.return_value = {"status": "ok"}
|
||||
table = conn.create_table("test", schema=schema)
|
||||
assert table.name == "test"
|
||||
assert client.post.call_args[0][0] == "/v1/table/test/create/"
|
||||
|
||||
json_schema = {
|
||||
"fields": [
|
||||
{
|
||||
"name": "vector",
|
||||
"nullable": True,
|
||||
"type": {
|
||||
"type": "fixed_size_list",
|
||||
"fields": [
|
||||
{"name": "item", "nullable": True, "type": {"type": "float"}}
|
||||
],
|
||||
"length": 2,
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
client.post.return_value = {"schema": json_schema}
|
||||
assert table.schema == schema
|
||||
assert client.post.call_args[0][0] == "/v1/table/test/describe/"
|
||||
|
||||
client.post.return_value = 0
|
||||
assert table.count_rows(None) == 0
|
||||
|
||||
|
||||
def test_create_table_with_recordbatches():
|
||||
client = MagicMock()
|
||||
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
|
||||
|
||||
conn._client = client
|
||||
|
||||
batch = pa.RecordBatch.from_arrays([pa.array([[1.0, 2.0], [3.0, 4.0]])], ["vector"])
|
||||
|
||||
client.post.return_value = {"status": "ok"}
|
||||
table = conn.create_table("test", [batch], schema=batch.schema)
|
||||
assert table.name == "test"
|
||||
assert client.post.call_args[0][0] == "/v1/table/test/create/"
|
||||
import pyarrow as pa
|
||||
|
||||
|
||||
def make_mock_http_handler(handler):
|
||||
@@ -100,8 +28,35 @@ def make_mock_http_handler(handler):
|
||||
return MockLanceDBHandler
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def mock_lancedb_connection(handler):
|
||||
with http.server.HTTPServer(
|
||||
("localhost", 8080), make_mock_http_handler(handler)
|
||||
) as server:
|
||||
handle = threading.Thread(target=server.serve_forever)
|
||||
handle.start()
|
||||
|
||||
db = lancedb.connect(
|
||||
"db://dev",
|
||||
api_key="fake",
|
||||
host_override="http://localhost:8080",
|
||||
client_config={
|
||||
"retry_config": {"retries": 2},
|
||||
"timeout_config": {
|
||||
"connect_timeout": 1,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
server.shutdown()
|
||||
handle.join()
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def mock_lancedb_connection(handler):
|
||||
async def mock_lancedb_connection_async(handler):
|
||||
with http.server.HTTPServer(
|
||||
("localhost", 8080), make_mock_http_handler(handler)
|
||||
) as server:
|
||||
@@ -143,7 +98,7 @@ async def test_async_remote_db():
|
||||
request.end_headers()
|
||||
request.wfile.write(b'{"tables": []}')
|
||||
|
||||
async with mock_lancedb_connection(handler) as db:
|
||||
async with mock_lancedb_connection_async(handler) as db:
|
||||
table_names = await db.table_names()
|
||||
assert table_names == []
|
||||
|
||||
@@ -159,12 +114,12 @@ async def test_http_error():
|
||||
request.end_headers()
|
||||
request.wfile.write(b"Internal Server Error")
|
||||
|
||||
async with mock_lancedb_connection(handler) as db:
|
||||
with pytest.raises(HttpError, match="Internal Server Error") as exc_info:
|
||||
async with mock_lancedb_connection_async(handler) as db:
|
||||
with pytest.raises(HttpError) as exc_info:
|
||||
await db.table_names()
|
||||
|
||||
assert exc_info.value.request_id == request_id_holder["request_id"]
|
||||
assert exc_info.value.status_code == 507
|
||||
assert "Internal Server Error" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -178,15 +133,225 @@ async def test_retry_error():
|
||||
request.end_headers()
|
||||
request.wfile.write(b"Try again later")
|
||||
|
||||
async with mock_lancedb_connection(handler) as db:
|
||||
with pytest.raises(RetryError, match="Hit retry limit") as exc_info:
|
||||
async with mock_lancedb_connection_async(handler) as db:
|
||||
with pytest.raises(RetryError) as exc_info:
|
||||
await db.table_names()
|
||||
|
||||
assert exc_info.value.request_id == request_id_holder["request_id"]
|
||||
assert exc_info.value.status_code == 429
|
||||
|
||||
cause = exc_info.value.__cause__
|
||||
assert isinstance(cause, HttpError)
|
||||
assert "Try again later" in str(cause)
|
||||
assert cause.request_id == request_id_holder["request_id"]
|
||||
assert cause.status_code == 429
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def query_test_table(query_handler):
|
||||
def handler(request):
|
||||
if request.path == "/v1/table/test/describe/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b"{}")
|
||||
elif request.path == "/v1/table/test/query/":
|
||||
content_len = int(request.headers.get("Content-Length"))
|
||||
body = request.rfile.read(content_len)
|
||||
body = json.loads(body)
|
||||
|
||||
data = query_handler(body)
|
||||
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/vnd.apache.arrow.file")
|
||||
request.end_headers()
|
||||
|
||||
with pa.ipc.new_file(request.wfile, schema=data.schema) as f:
|
||||
f.write_table(data)
|
||||
else:
|
||||
request.send_response(404)
|
||||
request.end_headers()
|
||||
|
||||
with mock_lancedb_connection(handler) as db:
|
||||
assert repr(db) == "RemoteConnect(name=dev)"
|
||||
table = db.open_table("test")
|
||||
assert repr(table) == "RemoteTable(dev.test)"
|
||||
yield table
|
||||
|
||||
|
||||
def test_query_sync_minimal():
|
||||
def handler(body):
|
||||
assert body == {
|
||||
"distance_type": "l2",
|
||||
"k": 10,
|
||||
"prefilter": False,
|
||||
"refine_factor": None,
|
||||
"vector": [1.0, 2.0, 3.0],
|
||||
"nprobes": 20,
|
||||
}
|
||||
|
||||
return pa.table({"id": [1, 2, 3]})
|
||||
|
||||
with query_test_table(handler) as table:
|
||||
data = table.search([1, 2, 3]).to_list()
|
||||
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
|
||||
assert data == expected
|
||||
|
||||
|
||||
def test_query_sync_maximal():
|
||||
def handler(body):
|
||||
assert body == {
|
||||
"distance_type": "cosine",
|
||||
"k": 42,
|
||||
"prefilter": True,
|
||||
"refine_factor": 10,
|
||||
"vector": [1.0, 2.0, 3.0],
|
||||
"nprobes": 5,
|
||||
"filter": "id > 0",
|
||||
"columns": ["id", "name"],
|
||||
"vector_column": "vector2",
|
||||
"fast_search": True,
|
||||
"with_row_id": True,
|
||||
}
|
||||
|
||||
return pa.table({"id": [1, 2, 3], "name": ["a", "b", "c"]})
|
||||
|
||||
with query_test_table(handler) as table:
|
||||
(
|
||||
table.search([1, 2, 3], vector_column_name="vector2", fast_search=True)
|
||||
.metric("cosine")
|
||||
.limit(42)
|
||||
.refine_factor(10)
|
||||
.nprobes(5)
|
||||
.where("id > 0", prefilter=True)
|
||||
.with_row_id(True)
|
||||
.select(["id", "name"])
|
||||
.to_list()
|
||||
)
|
||||
|
||||
|
||||
def test_query_sync_fts():
|
||||
def handler(body):
|
||||
assert body == {
|
||||
"full_text_query": {
|
||||
"query": "puppy",
|
||||
"columns": [],
|
||||
},
|
||||
"k": 10,
|
||||
"vector": [],
|
||||
}
|
||||
|
||||
return pa.table({"id": [1, 2, 3]})
|
||||
|
||||
with query_test_table(handler) as table:
|
||||
(table.search("puppy", query_type="fts").to_list())
|
||||
|
||||
def handler(body):
|
||||
assert body == {
|
||||
"full_text_query": {
|
||||
"query": "puppy",
|
||||
"columns": ["name", "description"],
|
||||
},
|
||||
"k": 42,
|
||||
"vector": [],
|
||||
"with_row_id": True,
|
||||
}
|
||||
|
||||
return pa.table({"id": [1, 2, 3]})
|
||||
|
||||
with query_test_table(handler) as table:
|
||||
(
|
||||
table.search("puppy", query_type="fts", fts_columns=["name", "description"])
|
||||
.with_row_id(True)
|
||||
.limit(42)
|
||||
.to_list()
|
||||
)
|
||||
|
||||
|
||||
def test_query_sync_hybrid():
|
||||
def handler(body):
|
||||
if "full_text_query" in body:
|
||||
# FTS query
|
||||
assert body == {
|
||||
"full_text_query": {
|
||||
"query": "puppy",
|
||||
"columns": [],
|
||||
},
|
||||
"k": 42,
|
||||
"vector": [],
|
||||
"with_row_id": True,
|
||||
}
|
||||
return pa.table({"_rowid": [1, 2, 3], "_score": [0.1, 0.2, 0.3]})
|
||||
else:
|
||||
# Vector query
|
||||
assert body == {
|
||||
"distance_type": "l2",
|
||||
"k": 42,
|
||||
"prefilter": False,
|
||||
"refine_factor": None,
|
||||
"vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
"nprobes": 20,
|
||||
"with_row_id": True,
|
||||
}
|
||||
return pa.table({"_rowid": [1, 2, 3], "_distance": [0.1, 0.2, 0.3]})
|
||||
|
||||
with query_test_table(handler) as table:
|
||||
embedding_func = MockTextEmbeddingFunction()
|
||||
embedding_config = MagicMock()
|
||||
embedding_config.function = embedding_func
|
||||
|
||||
embedding_funcs = MagicMock()
|
||||
embedding_funcs.get = MagicMock(return_value=embedding_config)
|
||||
table.embedding_functions = embedding_funcs
|
||||
|
||||
(table.search("puppy", query_type="hybrid").limit(42).to_list())
|
||||
|
||||
|
||||
def test_create_client():
|
||||
mandatory_args = {
|
||||
"uri": "db://dev",
|
||||
"api_key": "fake-api-key",
|
||||
"region": "us-east-1",
|
||||
}
|
||||
|
||||
db = lancedb.connect(**mandatory_args)
|
||||
assert isinstance(db.client_config, ClientConfig)
|
||||
|
||||
db = lancedb.connect(**mandatory_args, client_config={})
|
||||
assert isinstance(db.client_config, ClientConfig)
|
||||
|
||||
db = lancedb.connect(
|
||||
**mandatory_args,
|
||||
client_config=ClientConfig(timeout_config={"connect_timeout": 42}),
|
||||
)
|
||||
assert isinstance(db.client_config, ClientConfig)
|
||||
assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42)
|
||||
|
||||
db = lancedb.connect(
|
||||
**mandatory_args,
|
||||
client_config={"timeout_config": {"connect_timeout": timedelta(seconds=42)}},
|
||||
)
|
||||
assert isinstance(db.client_config, ClientConfig)
|
||||
assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42)
|
||||
|
||||
db = lancedb.connect(
|
||||
**mandatory_args, client_config=ClientConfig(retry_config={"retries": 42})
|
||||
)
|
||||
assert isinstance(db.client_config, ClientConfig)
|
||||
assert db.client_config.retry_config.retries == 42
|
||||
|
||||
db = lancedb.connect(
|
||||
**mandatory_args, client_config={"retry_config": {"retries": 42}}
|
||||
)
|
||||
assert isinstance(db.client_config, ClientConfig)
|
||||
assert db.client_config.retry_config.retries == 42
|
||||
|
||||
with pytest.warns(DeprecationWarning):
|
||||
db = lancedb.connect(**mandatory_args, connection_timeout=42)
|
||||
assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42)
|
||||
|
||||
with pytest.warns(DeprecationWarning):
|
||||
db = lancedb.connect(**mandatory_args, read_timeout=42)
|
||||
assert db.client_config.timeout_config.read_timeout == timedelta(seconds=42)
|
||||
|
||||
with pytest.warns(DeprecationWarning):
|
||||
lancedb.connect(**mandatory_args, request_thread_pool=10)
|
||||
|
||||
@@ -16,6 +16,7 @@ from lancedb.rerankers import (
|
||||
OpenaiReranker,
|
||||
JinaReranker,
|
||||
AnswerdotaiRerankers,
|
||||
VoyageAIReranker,
|
||||
)
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
@@ -344,3 +345,14 @@ def test_jina_reranker(tmp_path, use_tantivy):
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
reranker = JinaReranker()
|
||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||
)
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_voyageai_reranker(tmp_path, use_tantivy):
|
||||
pytest.importorskip("voyageai")
|
||||
reranker = VoyageAIReranker(model_name="rerank-2")
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||
|
||||
@@ -1223,6 +1223,54 @@ async def test_time_travel(db_async: AsyncConnection):
|
||||
await table.restore()
|
||||
|
||||
|
||||
def test_sync_optimize(db):
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"test",
|
||||
data=[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
],
|
||||
)
|
||||
|
||||
table.create_scalar_index("price", index_type="BTREE")
|
||||
stats = table.to_lance().stats.index_stats("price_idx")
|
||||
assert stats["num_indexed_rows"] == 2
|
||||
|
||||
table.add([{"vector": [2.0, 2.0], "item": "baz", "price": 30.0}])
|
||||
assert table.count_rows() == 3
|
||||
table.optimize()
|
||||
stats = table.to_lance().stats.index_stats("price_idx")
|
||||
assert stats["num_indexed_rows"] == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_optimize_in_async(db):
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"test",
|
||||
data=[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
],
|
||||
)
|
||||
|
||||
table.create_scalar_index("price", index_type="BTREE")
|
||||
stats = table.to_lance().stats.index_stats("price_idx")
|
||||
assert stats["num_indexed_rows"] == 2
|
||||
|
||||
table.add([{"vector": [2.0, 2.0], "item": "baz", "price": 30.0}])
|
||||
assert table.count_rows() == 3
|
||||
try:
|
||||
table.optimize()
|
||||
except Exception as e:
|
||||
assert (
|
||||
"Synchronous method called in asynchronous context. "
|
||||
"If you are writing an asynchronous application "
|
||||
"then please use the asynchronous APIs" in str(e)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_optimize(db_async: AsyncConnection):
|
||||
table = await db_async.create_table(
|
||||
|
||||
@@ -170,6 +170,17 @@ impl Connection {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn rename_table(
|
||||
self_: PyRef<'_, Self>,
|
||||
old_name: String,
|
||||
new_name: String,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner.rename_table(old_name, new_name).await.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn drop_table(self_: PyRef<'_, Self>, name: String) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
|
||||
@@ -24,8 +24,8 @@ use lancedb::{
|
||||
DistanceType,
|
||||
};
|
||||
use pyo3::{
|
||||
exceptions::{PyRuntimeError, PyValueError},
|
||||
pyclass, pymethods, PyResult,
|
||||
exceptions::{PyKeyError, PyRuntimeError, PyValueError},
|
||||
pyclass, pymethods, IntoPy, PyObject, PyResult, Python,
|
||||
};
|
||||
|
||||
use crate::util::parse_distance_type;
|
||||
@@ -236,7 +236,21 @@ pub struct IndexConfig {
|
||||
#[pymethods]
|
||||
impl IndexConfig {
|
||||
pub fn __repr__(&self) -> String {
|
||||
format!("Index({}, columns={:?})", self.index_type, self.columns)
|
||||
format!(
|
||||
"Index({}, columns={:?}, name=\"{}\")",
|
||||
self.index_type, self.columns, self.name
|
||||
)
|
||||
}
|
||||
|
||||
// For backwards-compatibility with the old sync SDK, we also support getting
|
||||
// attributes via __getitem__.
|
||||
pub fn __getitem__(&self, key: String, py: Python<'_>) -> PyResult<PyObject> {
|
||||
match key.as_str() {
|
||||
"index_type" => Ok(self.index_type.clone().into_py(py)),
|
||||
"columns" => Ok(self.columns.clone().into_py(py)),
|
||||
"name" | "index_name" => Ok(self.name.clone().into_py(py)),
|
||||
_ => Err(PyKeyError::new_err(format!("Invalid key: {}", key))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -68,6 +68,18 @@ impl Query {
|
||||
self.inner = self.inner.clone().offset(offset as usize);
|
||||
}
|
||||
|
||||
pub fn fast_search(&mut self) {
|
||||
self.inner = self.inner.clone().fast_search();
|
||||
}
|
||||
|
||||
pub fn with_row_id(&mut self) {
|
||||
self.inner = self.inner.clone().with_row_id();
|
||||
}
|
||||
|
||||
pub fn postfilter(&mut self) {
|
||||
self.inner = self.inner.clone().postfilter();
|
||||
}
|
||||
|
||||
pub fn nearest_to(&mut self, vector: Bound<'_, PyAny>) -> PyResult<VectorQuery> {
|
||||
let data: ArrayData = ArrayData::from_pyarrow_bound(&vector)?;
|
||||
let array = make_array(data);
|
||||
@@ -146,6 +158,14 @@ impl VectorQuery {
|
||||
self.inner = self.inner.clone().offset(offset as usize);
|
||||
}
|
||||
|
||||
pub fn fast_search(&mut self) {
|
||||
self.inner = self.inner.clone().fast_search();
|
||||
}
|
||||
|
||||
pub fn with_row_id(&mut self) {
|
||||
self.inner = self.inner.clone().with_row_id();
|
||||
}
|
||||
|
||||
pub fn column(&mut self, column: String) {
|
||||
self.inner = self.inner.clone().column(&column);
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-node"
|
||||
version = "0.11.1-beta.1"
|
||||
version = "0.13.0-beta.1"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb"
|
||||
version = "0.11.1-beta.1"
|
||||
version = "0.13.0-beta.1"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
|
||||
@@ -39,9 +39,6 @@ use crate::utils::validate_table_name;
|
||||
use crate::Table;
|
||||
pub use lance_encoding::version::LanceFileVersion;
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
use log::warn;
|
||||
|
||||
pub const LANCE_FILE_EXTENSION: &str = "lance";
|
||||
|
||||
pub type TableBuilderCallback = Box<dyn FnOnce(OpenTableBuilder) -> OpenTableBuilder + Send>;
|
||||
@@ -719,8 +716,7 @@ impl ConnectBuilder {
|
||||
let api_key = self.api_key.ok_or_else(|| Error::InvalidInput {
|
||||
message: "An api_key is required when connecting to LanceDb Cloud".to_string(),
|
||||
})?;
|
||||
// TODO: remove this warning when the remote client is ready
|
||||
warn!("The rust implementation of the remote client is not yet ready for use.");
|
||||
|
||||
let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new(
|
||||
&self.uri,
|
||||
&api_key,
|
||||
|
||||
@@ -29,6 +29,7 @@ pub mod scalar;
|
||||
pub mod vector;
|
||||
|
||||
/// Supported index types.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Index {
|
||||
Auto,
|
||||
/// A `BTree` index is an sorted index on scalar columns.
|
||||
@@ -119,6 +120,7 @@ pub enum IndexType {
|
||||
#[serde(alias = "LABEL_LIST")]
|
||||
LabelList,
|
||||
// FTS
|
||||
#[serde(alias = "INVERTED", alias = "Inverted")]
|
||||
FTS,
|
||||
}
|
||||
|
||||
|
||||
@@ -403,6 +403,26 @@ pub trait QueryBase {
|
||||
/// By default, it is false.
|
||||
fn fast_search(self) -> Self;
|
||||
|
||||
/// If this is called then filtering will happen after the vector search instead of
|
||||
/// before.
|
||||
///
|
||||
/// By default filtering will be performed before the vector search. This is how
|
||||
/// filtering is typically understood to work. This prefilter step does add some
|
||||
/// additional latency. Creating a scalar index on the filter column(s) can
|
||||
/// often improve this latency. However, sometimes a filter is too complex or scalar
|
||||
/// indices cannot be applied to the column. In these cases postfiltering can be
|
||||
/// used instead of prefiltering to improve latency.
|
||||
///
|
||||
/// Post filtering applies the filter to the results of the vector search. This means
|
||||
/// we only run the filter on a much smaller set of data. However, it can cause the
|
||||
/// query to return fewer than `limit` results (or even no results) if none of the nearest
|
||||
/// results match the filter.
|
||||
///
|
||||
/// Post filtering happens during the "refine stage" (described in more detail in
|
||||
/// [`Self::refine_factor`]). This means that setting a higher refine factor can often
|
||||
/// help restore some of the results lost by post filtering.
|
||||
fn postfilter(self) -> Self;
|
||||
|
||||
/// Return the `_rowid` meta column from the Table.
|
||||
fn with_row_id(self) -> Self;
|
||||
}
|
||||
@@ -442,6 +462,11 @@ impl<T: HasQuery> QueryBase for T {
|
||||
self
|
||||
}
|
||||
|
||||
fn postfilter(mut self) -> Self {
|
||||
self.mut_query().prefilter = false;
|
||||
self
|
||||
}
|
||||
|
||||
fn with_row_id(mut self) -> Self {
|
||||
self.mut_query().with_row_id = true;
|
||||
self
|
||||
@@ -561,6 +586,9 @@ pub struct Query {
|
||||
///
|
||||
/// By default, this is false.
|
||||
pub(crate) with_row_id: bool,
|
||||
|
||||
/// If set to false, the filter will be applied after the vector search.
|
||||
pub(crate) prefilter: bool,
|
||||
}
|
||||
|
||||
impl Query {
|
||||
@@ -574,6 +602,7 @@ impl Query {
|
||||
select: Select::All,
|
||||
fast_search: false,
|
||||
with_row_id: false,
|
||||
prefilter: true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -678,8 +707,6 @@ pub struct VectorQuery {
|
||||
pub(crate) distance_type: Option<DistanceType>,
|
||||
/// Default is true. Set to false to enforce a brute force search.
|
||||
pub(crate) use_index: bool,
|
||||
/// Apply filter before ANN search/
|
||||
pub(crate) prefilter: bool,
|
||||
}
|
||||
|
||||
impl VectorQuery {
|
||||
@@ -692,7 +719,6 @@ impl VectorQuery {
|
||||
refine_factor: None,
|
||||
distance_type: None,
|
||||
use_index: true,
|
||||
prefilter: true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -782,29 +808,6 @@ impl VectorQuery {
|
||||
self
|
||||
}
|
||||
|
||||
/// If this is called then filtering will happen after the vector search instead of
|
||||
/// before.
|
||||
///
|
||||
/// By default filtering will be performed before the vector search. This is how
|
||||
/// filtering is typically understood to work. This prefilter step does add some
|
||||
/// additional latency. Creating a scalar index on the filter column(s) can
|
||||
/// often improve this latency. However, sometimes a filter is too complex or scalar
|
||||
/// indices cannot be applied to the column. In these cases postfiltering can be
|
||||
/// used instead of prefiltering to improve latency.
|
||||
///
|
||||
/// Post filtering applies the filter to the results of the vector search. This means
|
||||
/// we only run the filter on a much smaller set of data. However, it can cause the
|
||||
/// query to return fewer than `limit` results (or even no results) if none of the nearest
|
||||
/// results match the filter.
|
||||
///
|
||||
/// Post filtering happens during the "refine stage" (described in more detail in
|
||||
/// [`Self::refine_factor`]). This means that setting a higher refine factor can often
|
||||
/// help restore some of the results lost by post filtering.
|
||||
pub fn postfilter(mut self) -> Self {
|
||||
self.prefilter = false;
|
||||
self
|
||||
}
|
||||
|
||||
/// If this is called then any vector index is skipped
|
||||
///
|
||||
/// An exhaustive (flat) search will be performed. The query vector will
|
||||
|
||||
@@ -23,6 +23,8 @@ pub(crate) mod table;
|
||||
pub(crate) mod util;
|
||||
|
||||
const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream";
|
||||
#[cfg(test)]
|
||||
const ARROW_FILE_CONTENT_TYPE: &str = "application/vnd.apache.arrow.file";
|
||||
const JSON_CONTENT_TYPE: &str = "application/json";
|
||||
|
||||
pub use client::{ClientConfig, RetryConfig, TimeoutConfig};
|
||||
|
||||
@@ -341,7 +341,22 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
request_id
|
||||
};
|
||||
|
||||
debug!("Sending request_id={}: {:?}", request_id, &request);
|
||||
if log::log_enabled!(log::Level::Debug) {
|
||||
let content_type = request
|
||||
.headers()
|
||||
.get("content-type")
|
||||
.map(|v| v.to_str().unwrap());
|
||||
if content_type == Some("application/json") {
|
||||
let body = request.body().as_ref().unwrap().as_bytes().unwrap();
|
||||
let body = String::from_utf8_lossy(body);
|
||||
debug!(
|
||||
"Sending request_id={}: {:?} with body {}",
|
||||
request_id, request, body
|
||||
);
|
||||
} else {
|
||||
debug!("Sending request_id={}: {:?}", request_id, request);
|
||||
}
|
||||
}
|
||||
|
||||
if with_retry {
|
||||
self.send_with_retry_impl(client, request, request_id).await
|
||||
|
||||
@@ -161,7 +161,7 @@ impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
|
||||
if self.table_cache.get(&options.name).is_none() {
|
||||
let req = self
|
||||
.client
|
||||
.get(&format!("/v1/table/{}/describe/", options.name));
|
||||
.post(&format!("/v1/table/{}/describe/", options.name));
|
||||
let (request_id, resp) = self.client.send(req, true).await?;
|
||||
if resp.status() == StatusCode::NOT_FOUND {
|
||||
return Err(crate::Error::TableNotFound { name: options.name });
|
||||
@@ -301,7 +301,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_open_table() {
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::GET);
|
||||
assert_eq!(request.method(), &reqwest::Method::POST);
|
||||
assert_eq!(request.url().path(), "/v1/table/table1/describe/");
|
||||
assert_eq!(request.url().query(), None);
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use std::io::Cursor;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::index::Index;
|
||||
@@ -7,10 +8,9 @@ use crate::table::AddDataMode;
|
||||
use crate::utils::{supported_btree_data_type, supported_vector_data_type};
|
||||
use crate::Error;
|
||||
use arrow_array::RecordBatchReader;
|
||||
use arrow_ipc::reader::StreamReader;
|
||||
use arrow_ipc::reader::FileReader;
|
||||
use arrow_schema::{DataType, SchemaRef};
|
||||
use async_trait::async_trait;
|
||||
use bytes::Buf;
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
|
||||
use datafusion_physical_plan::{ExecutionPlan, SendableRecordBatchStream};
|
||||
@@ -115,39 +115,14 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
async fn read_arrow_stream(
|
||||
&self,
|
||||
request_id: &str,
|
||||
body: reqwest::Response,
|
||||
response: reqwest::Response,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
// Assert that the content type is correct
|
||||
let content_type = body
|
||||
.headers()
|
||||
.get(CONTENT_TYPE)
|
||||
.ok_or_else(|| Error::Http {
|
||||
source: "Missing content type".into(),
|
||||
request_id: request_id.to_string(),
|
||||
status_code: None,
|
||||
})?
|
||||
.to_str()
|
||||
.map_err(|e| Error::Http {
|
||||
source: format!("Failed to parse content type: {}", e).into(),
|
||||
request_id: request_id.to_string(),
|
||||
status_code: None,
|
||||
})?;
|
||||
if content_type != ARROW_STREAM_CONTENT_TYPE {
|
||||
return Err(Error::Http {
|
||||
source: format!(
|
||||
"Expected content type {}, got {}",
|
||||
ARROW_STREAM_CONTENT_TYPE, content_type
|
||||
)
|
||||
.into(),
|
||||
request_id: request_id.to_string(),
|
||||
status_code: None,
|
||||
});
|
||||
}
|
||||
let response = self.check_table_response(request_id, response).await?;
|
||||
|
||||
// There isn't a way to actually stream this data yet. I have an upstream issue:
|
||||
// https://github.com/apache/arrow-rs/issues/6420
|
||||
let body = body.bytes().await.err_to_http(request_id.into())?;
|
||||
let reader = StreamReader::try_new(body.reader(), None)?;
|
||||
let body = response.bytes().await.err_to_http(request_id.into())?;
|
||||
let reader = FileReader::try_new(Cursor::new(body), None)?;
|
||||
let schema = reader.schema();
|
||||
let stream = futures::stream::iter(reader).map_err(DataFusionError::from);
|
||||
Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
|
||||
@@ -192,6 +167,10 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
body["fast_search"] = serde_json::Value::Bool(true);
|
||||
}
|
||||
|
||||
if params.with_row_id {
|
||||
body["with_row_id"] = serde_json::Value::Bool(true);
|
||||
}
|
||||
|
||||
if let Some(full_text_search) = ¶ms.full_text_search {
|
||||
if full_text_search.wand_factor.is_some() {
|
||||
return Err(Error::NotSupported {
|
||||
@@ -277,7 +256,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
.post(&format!("/v1/table/{}/count_rows/", self.name));
|
||||
|
||||
if let Some(filter) = filter {
|
||||
request = request.json(&serde_json::json!({ "filter": filter }));
|
||||
request = request.json(&serde_json::json!({ "predicate": filter }));
|
||||
} else {
|
||||
request = request.json(&serde_json::json!({}));
|
||||
}
|
||||
@@ -330,13 +309,13 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
let mut body = serde_json::Value::Object(Default::default());
|
||||
Self::apply_query_params(&mut body, &query.base)?;
|
||||
|
||||
body["prefilter"] = query.prefilter.into();
|
||||
body["prefilter"] = query.base.prefilter.into();
|
||||
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
|
||||
body["nprobes"] = query.nprobes.into();
|
||||
body["refine_factor"] = query.refine_factor.into();
|
||||
|
||||
if let Some(vector) = query.query_vector.as_ref() {
|
||||
let vector: Vec<f32> = match vector.data_type() {
|
||||
let vector: Vec<f32> = if let Some(vector) = query.query_vector.as_ref() {
|
||||
match vector.data_type() {
|
||||
DataType::Float32 => vector
|
||||
.as_any()
|
||||
.downcast_ref::<arrow_array::Float32Array>()
|
||||
@@ -350,9 +329,12 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
message: "VectorQuery vector must be of type Float32".into(),
|
||||
})
|
||||
}
|
||||
};
|
||||
body["vector"] = serde_json::json!(vector);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Server takes empty vector, not null or undefined.
|
||||
Vec::new()
|
||||
};
|
||||
body["vector"] = serde_json::json!(vector);
|
||||
|
||||
if let Some(vector_column) = query.column.as_ref() {
|
||||
body["vector_column"] = serde_json::Value::String(vector_column.clone());
|
||||
@@ -383,6 +365,8 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
|
||||
let mut body = serde_json::Value::Object(Default::default());
|
||||
Self::apply_query_params(&mut body, query)?;
|
||||
// Empty vector can be passed if no vector search is performed.
|
||||
body["vector"] = serde_json::Value::Array(Vec::new());
|
||||
|
||||
let request = request.json(&body);
|
||||
|
||||
@@ -399,30 +383,19 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
|
||||
let mut updates = Vec::new();
|
||||
for (column, expression) in update.columns {
|
||||
updates.push(column);
|
||||
updates.push(expression);
|
||||
updates.push(vec![column, expression]);
|
||||
}
|
||||
|
||||
let request = request.json(&serde_json::json!({
|
||||
"updates": updates,
|
||||
"only_if": update.filter,
|
||||
"predicate": update.filter,
|
||||
}));
|
||||
|
||||
let (request_id, response) = self.client.send(request, false).await?;
|
||||
|
||||
let response = self.check_table_response(&request_id, response).await?;
|
||||
self.check_table_response(&request_id, response).await?;
|
||||
|
||||
let body = response.text().await.err_to_http(request_id.clone())?;
|
||||
|
||||
serde_json::from_str(&body).map_err(|e| Error::Http {
|
||||
source: format!(
|
||||
"Failed to parse updated rows result from response {}: {}",
|
||||
body, e
|
||||
)
|
||||
.into(),
|
||||
request_id,
|
||||
status_code: None,
|
||||
})
|
||||
Ok(0) // TODO: support returning number of modified rows once supported in SaaS.
|
||||
}
|
||||
async fn delete(&self, predicate: &str) -> Result<()> {
|
||||
let body = serde_json::json!({ "predicate": predicate });
|
||||
@@ -691,6 +664,7 @@ mod tests {
|
||||
use crate::{
|
||||
index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType},
|
||||
query::{ExecutableQuery, QueryBase},
|
||||
remote::ARROW_FILE_CONTENT_TYPE,
|
||||
DistanceType, Error, Table,
|
||||
};
|
||||
|
||||
@@ -804,7 +778,7 @@ mod tests {
|
||||
);
|
||||
assert_eq!(
|
||||
request.body().unwrap().as_bytes().unwrap(),
|
||||
br#"{"filter":"a > 10"}"#
|
||||
br#"{"predicate":"a > 10"}"#
|
||||
);
|
||||
|
||||
http::Response::builder().status(200).body("42").unwrap()
|
||||
@@ -839,6 +813,17 @@ mod tests {
|
||||
body
|
||||
}
|
||||
|
||||
fn write_ipc_file(data: &RecordBatch) -> Vec<u8> {
|
||||
let mut body = Vec::new();
|
||||
{
|
||||
let mut writer = arrow_ipc::writer::FileWriter::try_new(&mut body, &data.schema())
|
||||
.expect("Failed to create writer");
|
||||
writer.write(data).expect("Failed to write data");
|
||||
writer.finish().expect("Failed to finish");
|
||||
}
|
||||
body
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_add_append() {
|
||||
let data = RecordBatch::try_new(
|
||||
@@ -947,21 +932,27 @@ mod tests {
|
||||
let updates = value.get("updates").unwrap().as_array().unwrap();
|
||||
assert!(updates.len() == 2);
|
||||
|
||||
let col_name = updates[0].as_str().unwrap();
|
||||
let expression = updates[1].as_str().unwrap();
|
||||
let col_name = updates[0][0].as_str().unwrap();
|
||||
let expression = updates[0][1].as_str().unwrap();
|
||||
assert_eq!(col_name, "a");
|
||||
assert_eq!(expression, "a + 1");
|
||||
|
||||
let only_if = value.get("only_if").unwrap().as_str().unwrap();
|
||||
let col_name = updates[1][0].as_str().unwrap();
|
||||
let expression = updates[1][1].as_str().unwrap();
|
||||
assert_eq!(col_name, "b");
|
||||
assert_eq!(expression, "b - 1");
|
||||
|
||||
let only_if = value.get("predicate").unwrap().as_str().unwrap();
|
||||
assert_eq!(only_if, "b > 10");
|
||||
}
|
||||
|
||||
http::Response::builder().status(200).body("1").unwrap()
|
||||
http::Response::builder().status(200).body("{}").unwrap()
|
||||
});
|
||||
|
||||
table
|
||||
.update()
|
||||
.column("a", "a + 1")
|
||||
.column("b", "b - 1")
|
||||
.only_if("b > 10")
|
||||
.execute()
|
||||
.await
|
||||
@@ -1092,10 +1083,10 @@ mod tests {
|
||||
expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into();
|
||||
assert_eq!(body, expected_body);
|
||||
|
||||
let response_body = write_ipc_stream(&expected_data_ref);
|
||||
let response_body = write_ipc_file(&expected_data_ref);
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
|
||||
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
|
||||
.body(response_body)
|
||||
.unwrap()
|
||||
});
|
||||
@@ -1142,10 +1133,10 @@ mod tests {
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap();
|
||||
let response_body = write_ipc_stream(&data);
|
||||
let response_body = write_ipc_file(&data);
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
|
||||
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
|
||||
.body(response_body)
|
||||
.unwrap()
|
||||
});
|
||||
@@ -1185,6 +1176,8 @@ mod tests {
|
||||
"query": "hello world",
|
||||
},
|
||||
"k": 10,
|
||||
"vector": [],
|
||||
"with_row_id": true,
|
||||
});
|
||||
assert_eq!(body, expected_body);
|
||||
|
||||
@@ -1193,10 +1186,10 @@ mod tests {
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap();
|
||||
let response_body = write_ipc_stream(&data);
|
||||
let response_body = write_ipc_file(&data);
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
|
||||
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
|
||||
.body(response_body)
|
||||
.unwrap()
|
||||
});
|
||||
@@ -1207,6 +1200,7 @@ mod tests {
|
||||
FullTextSearchQuery::new("hello world".into())
|
||||
.columns(Some(vec!["a".into(), "b".into()])),
|
||||
)
|
||||
.with_row_id()
|
||||
.limit(10)
|
||||
.execute()
|
||||
.await
|
||||
|
||||
@@ -1842,7 +1842,7 @@ impl TableInternal for NativeTable {
|
||||
|
||||
scanner.nprobs(query.nprobes);
|
||||
scanner.use_index(query.use_index);
|
||||
scanner.prefilter(query.prefilter);
|
||||
scanner.prefilter(query.base.prefilter);
|
||||
match query.base.select {
|
||||
Select::Columns(ref columns) => {
|
||||
scanner.project(columns.as_slice())?;
|
||||
@@ -3123,6 +3123,12 @@ mod tests {
|
||||
assert_eq!(index.index_type, crate::index::IndexType::FTS);
|
||||
assert_eq!(index.columns, vec!["text".to_string()]);
|
||||
assert_eq!(index.name, "text_idx");
|
||||
|
||||
let stats = table.index_stats("text_idx").await.unwrap().unwrap();
|
||||
assert_eq!(stats.num_indexed_rows, num_rows);
|
||||
assert_eq!(stats.num_unindexed_rows, 0);
|
||||
assert_eq!(stats.index_type, crate::index::IndexType::FTS);
|
||||
assert_eq!(stats.distance_type, None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
Reference in New Issue
Block a user