mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 13:29:57 +00:00
Compare commits
37 Commits
python-v0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4ca0b15354 | ||
|
|
d8c217b47d | ||
|
|
b724b1a01f | ||
|
|
abd75e0ead | ||
|
|
0fd8a50bd7 | ||
|
|
9f228feb0e | ||
|
|
90e9c52d0a | ||
|
|
68974a4e06 | ||
|
|
4c9bab0d92 | ||
|
|
5117aecc38 | ||
|
|
729718cb09 | ||
|
|
b1c84e0bda | ||
|
|
cbbc07d0f5 | ||
|
|
21021f94ca | ||
|
|
0ed77fa990 | ||
|
|
4372c231cd | ||
|
|
fa9ca8f7a6 | ||
|
|
2a35d24ee6 | ||
|
|
dd9ce337e2 | ||
|
|
b9921d56cc | ||
|
|
0cfd9ed18e | ||
|
|
975398c3a8 | ||
|
|
08d5f93f34 | ||
|
|
91cab3b556 | ||
|
|
c61bfc3af8 | ||
|
|
4e8c7b0adf | ||
|
|
26f4a80e10 | ||
|
|
3604d20ad3 | ||
|
|
9708d829a9 | ||
|
|
059c9794b5 | ||
|
|
15ed7f75a0 | ||
|
|
96181ab421 | ||
|
|
f3fc339ef6 | ||
|
|
113cd6995b | ||
|
|
02535bdc88 | ||
|
|
facc7d61c0 | ||
|
|
f947259f16 |
@@ -1,5 +1,5 @@
|
|||||||
[tool.bumpversion]
|
[tool.bumpversion]
|
||||||
current_version = "0.11.1-beta.1"
|
current_version = "0.13.0-beta.1"
|
||||||
parse = """(?x)
|
parse = """(?x)
|
||||||
(?P<major>0|[1-9]\\d*)\\.
|
(?P<major>0|[1-9]\\d*)\\.
|
||||||
(?P<minor>0|[1-9]\\d*)\\.
|
(?P<minor>0|[1-9]\\d*)\\.
|
||||||
@@ -92,6 +92,11 @@ glob = "node/package.json"
|
|||||||
replace = "\"@lancedb/vectordb-win32-x64-msvc\": \"{new_version}\""
|
replace = "\"@lancedb/vectordb-win32-x64-msvc\": \"{new_version}\""
|
||||||
search = "\"@lancedb/vectordb-win32-x64-msvc\": \"{current_version}\""
|
search = "\"@lancedb/vectordb-win32-x64-msvc\": \"{current_version}\""
|
||||||
|
|
||||||
|
[[tool.bumpversion.files]]
|
||||||
|
glob = "node/package.json"
|
||||||
|
replace = "\"@lancedb/vectordb-win32-arm64-msvc\": \"{new_version}\""
|
||||||
|
search = "\"@lancedb/vectordb-win32-arm64-msvc\": \"{current_version}\""
|
||||||
|
|
||||||
# Cargo files
|
# Cargo files
|
||||||
# ------------
|
# ------------
|
||||||
[[tool.bumpversion.files]]
|
[[tool.bumpversion.files]]
|
||||||
|
|||||||
@@ -38,3 +38,7 @@ rustflags = ["-C", "target-cpu=apple-m1", "-C", "target-feature=+neon,+fp16,+fhm
|
|||||||
# not found errors on systems that are missing it.
|
# not found errors on systems that are missing it.
|
||||||
[target.x86_64-pc-windows-msvc]
|
[target.x86_64-pc-windows-msvc]
|
||||||
rustflags = ["-Ctarget-feature=+crt-static"]
|
rustflags = ["-Ctarget-feature=+crt-static"]
|
||||||
|
|
||||||
|
# Experimental target for Arm64 Windows
|
||||||
|
[target.aarch64-pc-windows-msvc]
|
||||||
|
rustflags = ["-Ctarget-feature=+crt-static"]
|
||||||
4
.github/workflows/docs.yml
vendored
4
.github/workflows/docs.yml
vendored
@@ -41,8 +41,8 @@ jobs:
|
|||||||
- name: Build Python
|
- name: Build Python
|
||||||
working-directory: python
|
working-directory: python
|
||||||
run: |
|
run: |
|
||||||
python -m pip install -e .
|
python -m pip install --extra-index-url https://pypi.fury.io/lancedb/ -e .
|
||||||
python -m pip install -r ../docs/requirements.txt
|
python -m pip install --extra-index-url https://pypi.fury.io/lancedb/ -r ../docs/requirements.txt
|
||||||
- name: Set up node
|
- name: Set up node
|
||||||
uses: actions/setup-node@v3
|
uses: actions/setup-node@v3
|
||||||
with:
|
with:
|
||||||
|
|||||||
2
.github/workflows/docs_test.yml
vendored
2
.github/workflows/docs_test.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
|||||||
- name: Build Python
|
- name: Build Python
|
||||||
working-directory: docs/test
|
working-directory: docs/test
|
||||||
run:
|
run:
|
||||||
python -m pip install -r requirements.txt
|
python -m pip install --extra-index-url https://pypi.fury.io/lancedb/ -r requirements.txt
|
||||||
- name: Create test files
|
- name: Create test files
|
||||||
run: |
|
run: |
|
||||||
cd docs/test
|
cd docs/test
|
||||||
|
|||||||
16
.github/workflows/nodejs.yml
vendored
16
.github/workflows/nodejs.yml
vendored
@@ -53,6 +53,9 @@ jobs:
|
|||||||
cargo clippy --all --all-features -- -D warnings
|
cargo clippy --all --all-features -- -D warnings
|
||||||
npm ci
|
npm ci
|
||||||
npm run lint-ci
|
npm run lint-ci
|
||||||
|
- name: Lint examples
|
||||||
|
working-directory: nodejs/examples
|
||||||
|
run: npm ci && npm run lint-ci
|
||||||
linux:
|
linux:
|
||||||
name: Linux (NodeJS ${{ matrix.node-version }})
|
name: Linux (NodeJS ${{ matrix.node-version }})
|
||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
@@ -91,6 +94,19 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
S3_TEST: "1"
|
S3_TEST: "1"
|
||||||
run: npm run test
|
run: npm run test
|
||||||
|
- name: Setup examples
|
||||||
|
working-directory: nodejs/examples
|
||||||
|
run: npm ci
|
||||||
|
- name: Test examples
|
||||||
|
working-directory: ./
|
||||||
|
env:
|
||||||
|
OPENAI_API_KEY: test
|
||||||
|
OPENAI_BASE_URL: http://0.0.0.0:8000
|
||||||
|
run: |
|
||||||
|
python ci/mock_openai.py &
|
||||||
|
ss -ltnp | grep :8000
|
||||||
|
cd nodejs/examples
|
||||||
|
npm test
|
||||||
macos:
|
macos:
|
||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
runs-on: "macos-14"
|
runs-on: "macos-14"
|
||||||
|
|||||||
200
.github/workflows/npm-publish.yml
vendored
200
.github/workflows/npm-publish.yml
vendored
@@ -226,6 +226,109 @@ jobs:
|
|||||||
path: |
|
path: |
|
||||||
node/dist/lancedb-vectordb-win32*.tgz
|
node/dist/lancedb-vectordb-win32*.tgz
|
||||||
|
|
||||||
|
node-windows-arm64:
|
||||||
|
name: vectordb win32-arm64-msvc
|
||||||
|
runs-on: windows-4x-arm
|
||||||
|
if: startsWith(github.ref, 'refs/tags/v')
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: Install Git
|
||||||
|
run: |
|
||||||
|
Invoke-WebRequest -Uri "https://github.com/git-for-windows/git/releases/download/v2.44.0.windows.1/Git-2.44.0-64-bit.exe" -OutFile "git-installer.exe"
|
||||||
|
Start-Process -FilePath "git-installer.exe" -ArgumentList "/VERYSILENT", "/NORESTART" -Wait
|
||||||
|
shell: powershell
|
||||||
|
- name: Add Git to PATH
|
||||||
|
run: |
|
||||||
|
Add-Content $env:GITHUB_PATH "C:\Program Files\Git\bin"
|
||||||
|
$env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User")
|
||||||
|
shell: powershell
|
||||||
|
- name: Configure Git symlinks
|
||||||
|
run: git config --global core.symlinks true
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.13"
|
||||||
|
- name: Install Visual Studio Build Tools
|
||||||
|
run: |
|
||||||
|
Invoke-WebRequest -Uri "https://aka.ms/vs/17/release/vs_buildtools.exe" -OutFile "vs_buildtools.exe"
|
||||||
|
Start-Process -FilePath "vs_buildtools.exe" -ArgumentList "--quiet", "--wait", "--norestart", "--nocache", `
|
||||||
|
"--installPath", "C:\BuildTools", `
|
||||||
|
"--add", "Microsoft.VisualStudio.Component.VC.Tools.ARM64", `
|
||||||
|
"--add", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64", `
|
||||||
|
"--add", "Microsoft.VisualStudio.Component.Windows11SDK.22621", `
|
||||||
|
"--add", "Microsoft.VisualStudio.Component.VC.ATL", `
|
||||||
|
"--add", "Microsoft.VisualStudio.Component.VC.ATLMFC", `
|
||||||
|
"--add", "Microsoft.VisualStudio.Component.VC.Llvm.Clang" -Wait
|
||||||
|
shell: powershell
|
||||||
|
- name: Add Visual Studio Build Tools to PATH
|
||||||
|
run: |
|
||||||
|
$vsPath = "C:\BuildTools\VC\Tools\MSVC"
|
||||||
|
$latestVersion = (Get-ChildItem $vsPath | Sort-Object {[version]$_.Name} -Descending)[0].Name
|
||||||
|
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\MSVC\$latestVersion\bin\Hostx64\arm64"
|
||||||
|
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\MSVC\$latestVersion\bin\Hostx64\x64"
|
||||||
|
Add-Content $env:GITHUB_PATH "C:\Program Files (x86)\Windows Kits\10\bin\10.0.22621.0\arm64"
|
||||||
|
Add-Content $env:GITHUB_PATH "C:\Program Files (x86)\Windows Kits\10\bin\10.0.22621.0\x64"
|
||||||
|
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\Llvm\x64\bin"
|
||||||
|
|
||||||
|
# Add MSVC runtime libraries to LIB
|
||||||
|
$env:LIB = "C:\BuildTools\VC\Tools\MSVC\$latestVersion\lib\arm64;" +
|
||||||
|
"C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\um\arm64;" +
|
||||||
|
"C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\ucrt\arm64"
|
||||||
|
Add-Content $env:GITHUB_ENV "LIB=$env:LIB"
|
||||||
|
|
||||||
|
# Add INCLUDE paths
|
||||||
|
$env:INCLUDE = "C:\BuildTools\VC\Tools\MSVC\$latestVersion\include;" +
|
||||||
|
"C:\Program Files (x86)\Windows Kits\10\Include\10.0.22621.0\ucrt;" +
|
||||||
|
"C:\Program Files (x86)\Windows Kits\10\Include\10.0.22621.0\um;" +
|
||||||
|
"C:\Program Files (x86)\Windows Kits\10\Include\10.0.22621.0\shared"
|
||||||
|
Add-Content $env:GITHUB_ENV "INCLUDE=$env:INCLUDE"
|
||||||
|
shell: powershell
|
||||||
|
- name: Install Rust
|
||||||
|
run: |
|
||||||
|
Invoke-WebRequest https://win.rustup.rs/x86_64 -OutFile rustup-init.exe
|
||||||
|
.\rustup-init.exe -y --default-host aarch64-pc-windows-msvc
|
||||||
|
shell: powershell
|
||||||
|
- name: Add Rust to PATH
|
||||||
|
run: |
|
||||||
|
Add-Content $env:GITHUB_PATH "$env:USERPROFILE\.cargo\bin"
|
||||||
|
shell: powershell
|
||||||
|
|
||||||
|
- uses: Swatinem/rust-cache@v2
|
||||||
|
with:
|
||||||
|
workspaces: rust
|
||||||
|
- name: Install 7-Zip ARM
|
||||||
|
run: |
|
||||||
|
New-Item -Path 'C:\7zip' -ItemType Directory
|
||||||
|
Invoke-WebRequest https://7-zip.org/a/7z2408-arm64.exe -OutFile C:\7zip\7z-installer.exe
|
||||||
|
Start-Process -FilePath C:\7zip\7z-installer.exe -ArgumentList '/S' -Wait
|
||||||
|
shell: powershell
|
||||||
|
- name: Add 7-Zip to PATH
|
||||||
|
run: Add-Content $env:GITHUB_PATH "C:\Program Files\7-Zip"
|
||||||
|
shell: powershell
|
||||||
|
- name: Install Protoc v21.12
|
||||||
|
working-directory: C:\
|
||||||
|
run: |
|
||||||
|
if (Test-Path 'C:\protoc') {
|
||||||
|
Write-Host "Protoc directory exists, skipping installation"
|
||||||
|
return
|
||||||
|
}
|
||||||
|
New-Item -Path 'C:\protoc' -ItemType Directory
|
||||||
|
Set-Location C:\protoc
|
||||||
|
Invoke-WebRequest https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win64.zip -OutFile C:\protoc\protoc.zip
|
||||||
|
& 'C:\Program Files\7-Zip\7z.exe' x protoc.zip
|
||||||
|
shell: powershell
|
||||||
|
- name: Add Protoc to PATH
|
||||||
|
run: Add-Content $env:GITHUB_PATH "C:\protoc\bin"
|
||||||
|
shell: powershell
|
||||||
|
- name: Build Windows native node modules
|
||||||
|
run: .\ci\build_windows_artifacts.ps1 aarch64-pc-windows-msvc
|
||||||
|
- name: Upload Windows ARM64 Artifacts
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: node-native-windows-arm64
|
||||||
|
path: |
|
||||||
|
node/dist/*.node
|
||||||
|
|
||||||
nodejs-windows:
|
nodejs-windows:
|
||||||
name: lancedb ${{ matrix.target }}
|
name: lancedb ${{ matrix.target }}
|
||||||
runs-on: windows-2022
|
runs-on: windows-2022
|
||||||
@@ -260,9 +363,102 @@ jobs:
|
|||||||
path: |
|
path: |
|
||||||
nodejs/dist/*.node
|
nodejs/dist/*.node
|
||||||
|
|
||||||
|
nodejs-windows-arm64:
|
||||||
|
name: lancedb win32-arm64-msvc
|
||||||
|
runs-on: windows-4x-arm
|
||||||
|
if: startsWith(github.ref, 'refs/tags/v')
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: Install Git
|
||||||
|
run: |
|
||||||
|
Invoke-WebRequest -Uri "https://github.com/git-for-windows/git/releases/download/v2.44.0.windows.1/Git-2.44.0-64-bit.exe" -OutFile "git-installer.exe"
|
||||||
|
Start-Process -FilePath "git-installer.exe" -ArgumentList "/VERYSILENT", "/NORESTART" -Wait
|
||||||
|
shell: powershell
|
||||||
|
- name: Add Git to PATH
|
||||||
|
run: |
|
||||||
|
Add-Content $env:GITHUB_PATH "C:\Program Files\Git\bin"
|
||||||
|
$env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User")
|
||||||
|
shell: powershell
|
||||||
|
- name: Configure Git symlinks
|
||||||
|
run: git config --global core.symlinks true
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.13"
|
||||||
|
- name: Install Visual Studio Build Tools
|
||||||
|
run: |
|
||||||
|
Invoke-WebRequest -Uri "https://aka.ms/vs/17/release/vs_buildtools.exe" -OutFile "vs_buildtools.exe"
|
||||||
|
Start-Process -FilePath "vs_buildtools.exe" -ArgumentList "--quiet", "--wait", "--norestart", "--nocache", `
|
||||||
|
"--installPath", "C:\BuildTools", `
|
||||||
|
"--add", "Microsoft.VisualStudio.Component.VC.Tools.ARM64", `
|
||||||
|
"--add", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64", `
|
||||||
|
"--add", "Microsoft.VisualStudio.Component.Windows11SDK.22621", `
|
||||||
|
"--add", "Microsoft.VisualStudio.Component.VC.ATL", `
|
||||||
|
"--add", "Microsoft.VisualStudio.Component.VC.ATLMFC", `
|
||||||
|
"--add", "Microsoft.VisualStudio.Component.VC.Llvm.Clang" -Wait
|
||||||
|
shell: powershell
|
||||||
|
- name: Add Visual Studio Build Tools to PATH
|
||||||
|
run: |
|
||||||
|
$vsPath = "C:\BuildTools\VC\Tools\MSVC"
|
||||||
|
$latestVersion = (Get-ChildItem $vsPath | Sort-Object {[version]$_.Name} -Descending)[0].Name
|
||||||
|
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\MSVC\$latestVersion\bin\Hostx64\arm64"
|
||||||
|
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\MSVC\$latestVersion\bin\Hostx64\x64"
|
||||||
|
Add-Content $env:GITHUB_PATH "C:\Program Files (x86)\Windows Kits\10\bin\10.0.22621.0\arm64"
|
||||||
|
Add-Content $env:GITHUB_PATH "C:\Program Files (x86)\Windows Kits\10\bin\10.0.22621.0\x64"
|
||||||
|
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\Llvm\x64\bin"
|
||||||
|
|
||||||
|
$env:LIB = ""
|
||||||
|
Add-Content $env:GITHUB_ENV "LIB=C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\um\arm64;C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\ucrt\arm64"
|
||||||
|
shell: powershell
|
||||||
|
- name: Install Rust
|
||||||
|
run: |
|
||||||
|
Invoke-WebRequest https://win.rustup.rs/x86_64 -OutFile rustup-init.exe
|
||||||
|
.\rustup-init.exe -y --default-host aarch64-pc-windows-msvc
|
||||||
|
shell: powershell
|
||||||
|
- name: Add Rust to PATH
|
||||||
|
run: |
|
||||||
|
Add-Content $env:GITHUB_PATH "$env:USERPROFILE\.cargo\bin"
|
||||||
|
shell: powershell
|
||||||
|
|
||||||
|
- uses: Swatinem/rust-cache@v2
|
||||||
|
with:
|
||||||
|
workspaces: rust
|
||||||
|
- name: Install 7-Zip ARM
|
||||||
|
run: |
|
||||||
|
New-Item -Path 'C:\7zip' -ItemType Directory
|
||||||
|
Invoke-WebRequest https://7-zip.org/a/7z2408-arm64.exe -OutFile C:\7zip\7z-installer.exe
|
||||||
|
Start-Process -FilePath C:\7zip\7z-installer.exe -ArgumentList '/S' -Wait
|
||||||
|
shell: powershell
|
||||||
|
- name: Add 7-Zip to PATH
|
||||||
|
run: Add-Content $env:GITHUB_PATH "C:\Program Files\7-Zip"
|
||||||
|
shell: powershell
|
||||||
|
- name: Install Protoc v21.12
|
||||||
|
working-directory: C:\
|
||||||
|
run: |
|
||||||
|
if (Test-Path 'C:\protoc') {
|
||||||
|
Write-Host "Protoc directory exists, skipping installation"
|
||||||
|
return
|
||||||
|
}
|
||||||
|
New-Item -Path 'C:\protoc' -ItemType Directory
|
||||||
|
Set-Location C:\protoc
|
||||||
|
Invoke-WebRequest https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win64.zip -OutFile C:\protoc\protoc.zip
|
||||||
|
& 'C:\Program Files\7-Zip\7z.exe' x protoc.zip
|
||||||
|
shell: powershell
|
||||||
|
- name: Add Protoc to PATH
|
||||||
|
run: Add-Content $env:GITHUB_PATH "C:\protoc\bin"
|
||||||
|
shell: powershell
|
||||||
|
- name: Build Windows native node modules
|
||||||
|
run: .\ci\build_windows_artifacts_nodejs.ps1 aarch64-pc-windows-msvc
|
||||||
|
- name: Upload Windows ARM64 Artifacts
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: nodejs-native-windows-arm64
|
||||||
|
path: |
|
||||||
|
nodejs/dist/*.node
|
||||||
|
|
||||||
release:
|
release:
|
||||||
name: vectordb NPM Publish
|
name: vectordb NPM Publish
|
||||||
needs: [node, node-macos, node-linux, node-windows]
|
needs: [node, node-macos, node-linux, node-windows, node-windows-arm64]
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
# Only runs on tags that matches the make-release action
|
# Only runs on tags that matches the make-release action
|
||||||
if: startsWith(github.ref, 'refs/tags/v')
|
if: startsWith(github.ref, 'refs/tags/v')
|
||||||
@@ -302,7 +498,7 @@ jobs:
|
|||||||
|
|
||||||
release-nodejs:
|
release-nodejs:
|
||||||
name: lancedb NPM Publish
|
name: lancedb NPM Publish
|
||||||
needs: [nodejs-macos, nodejs-linux, nodejs-windows]
|
needs: [nodejs-macos, nodejs-linux, nodejs-windows, nodejs-windows-arm64]
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
# Only runs on tags that matches the make-release action
|
# Only runs on tags that matches the make-release action
|
||||||
if: startsWith(github.ref, 'refs/tags/v')
|
if: startsWith(github.ref, 'refs/tags/v')
|
||||||
|
|||||||
2
.github/workflows/python.yml
vendored
2
.github/workflows/python.yml
vendored
@@ -138,7 +138,7 @@ jobs:
|
|||||||
run: rm -rf target/wheels
|
run: rm -rf target/wheels
|
||||||
windows:
|
windows:
|
||||||
name: "Windows: ${{ matrix.config.name }}"
|
name: "Windows: ${{ matrix.config.name }}"
|
||||||
timeout-minutes: 30
|
timeout-minutes: 60
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
config:
|
config:
|
||||||
|
|||||||
99
.github/workflows/rust.yml
vendored
99
.github/workflows/rust.yml
vendored
@@ -50,6 +50,7 @@ jobs:
|
|||||||
run: cargo fmt --all -- --check
|
run: cargo fmt --all -- --check
|
||||||
- name: Run clippy
|
- name: Run clippy
|
||||||
run: cargo clippy --workspace --tests --all-features -- -D warnings
|
run: cargo clippy --workspace --tests --all-features -- -D warnings
|
||||||
|
|
||||||
linux:
|
linux:
|
||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
# To build all features, we need more disk space than is available
|
# To build all features, we need more disk space than is available
|
||||||
@@ -91,6 +92,7 @@ jobs:
|
|||||||
run: cargo test --all-features
|
run: cargo test --all-features
|
||||||
- name: Run examples
|
- name: Run examples
|
||||||
run: cargo run --example simple
|
run: cargo run --example simple
|
||||||
|
|
||||||
macos:
|
macos:
|
||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
strategy:
|
strategy:
|
||||||
@@ -118,6 +120,7 @@ jobs:
|
|||||||
- name: Run tests
|
- name: Run tests
|
||||||
# Run with everything except the integration tests.
|
# Run with everything except the integration tests.
|
||||||
run: cargo test --features remote,fp16kernels
|
run: cargo test --features remote,fp16kernels
|
||||||
|
|
||||||
windows:
|
windows:
|
||||||
runs-on: windows-2022
|
runs-on: windows-2022
|
||||||
steps:
|
steps:
|
||||||
@@ -139,3 +142,99 @@ jobs:
|
|||||||
$env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT
|
$env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT
|
||||||
cargo build
|
cargo build
|
||||||
cargo test
|
cargo test
|
||||||
|
|
||||||
|
windows-arm64:
|
||||||
|
runs-on: windows-4x-arm
|
||||||
|
steps:
|
||||||
|
- name: Install Git
|
||||||
|
run: |
|
||||||
|
Invoke-WebRequest -Uri "https://github.com/git-for-windows/git/releases/download/v2.44.0.windows.1/Git-2.44.0-64-bit.exe" -OutFile "git-installer.exe"
|
||||||
|
Start-Process -FilePath "git-installer.exe" -ArgumentList "/VERYSILENT", "/NORESTART" -Wait
|
||||||
|
shell: powershell
|
||||||
|
- name: Add Git to PATH
|
||||||
|
run: |
|
||||||
|
Add-Content $env:GITHUB_PATH "C:\Program Files\Git\bin"
|
||||||
|
$env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User")
|
||||||
|
shell: powershell
|
||||||
|
- name: Configure Git symlinks
|
||||||
|
run: git config --global core.symlinks true
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.13"
|
||||||
|
- name: Install Visual Studio Build Tools
|
||||||
|
run: |
|
||||||
|
Invoke-WebRequest -Uri "https://aka.ms/vs/17/release/vs_buildtools.exe" -OutFile "vs_buildtools.exe"
|
||||||
|
Start-Process -FilePath "vs_buildtools.exe" -ArgumentList "--quiet", "--wait", "--norestart", "--nocache", `
|
||||||
|
"--installPath", "C:\BuildTools", `
|
||||||
|
"--add", "Microsoft.VisualStudio.Component.VC.Tools.ARM64", `
|
||||||
|
"--add", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64", `
|
||||||
|
"--add", "Microsoft.VisualStudio.Component.Windows11SDK.22621", `
|
||||||
|
"--add", "Microsoft.VisualStudio.Component.VC.ATL", `
|
||||||
|
"--add", "Microsoft.VisualStudio.Component.VC.ATLMFC", `
|
||||||
|
"--add", "Microsoft.VisualStudio.Component.VC.Llvm.Clang" -Wait
|
||||||
|
shell: powershell
|
||||||
|
- name: Add Visual Studio Build Tools to PATH
|
||||||
|
run: |
|
||||||
|
$vsPath = "C:\BuildTools\VC\Tools\MSVC"
|
||||||
|
$latestVersion = (Get-ChildItem $vsPath | Sort-Object {[version]$_.Name} -Descending)[0].Name
|
||||||
|
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\MSVC\$latestVersion\bin\Hostx64\arm64"
|
||||||
|
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\MSVC\$latestVersion\bin\Hostx64\x64"
|
||||||
|
Add-Content $env:GITHUB_PATH "C:\Program Files (x86)\Windows Kits\10\bin\10.0.22621.0\arm64"
|
||||||
|
Add-Content $env:GITHUB_PATH "C:\Program Files (x86)\Windows Kits\10\bin\10.0.22621.0\x64"
|
||||||
|
Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\Llvm\x64\bin"
|
||||||
|
|
||||||
|
# Add MSVC runtime libraries to LIB
|
||||||
|
$env:LIB = "C:\BuildTools\VC\Tools\MSVC\$latestVersion\lib\arm64;" +
|
||||||
|
"C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\um\arm64;" +
|
||||||
|
"C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\ucrt\arm64"
|
||||||
|
Add-Content $env:GITHUB_ENV "LIB=$env:LIB"
|
||||||
|
|
||||||
|
# Add INCLUDE paths
|
||||||
|
$env:INCLUDE = "C:\BuildTools\VC\Tools\MSVC\$latestVersion\include;" +
|
||||||
|
"C:\Program Files (x86)\Windows Kits\10\Include\10.0.22621.0\ucrt;" +
|
||||||
|
"C:\Program Files (x86)\Windows Kits\10\Include\10.0.22621.0\um;" +
|
||||||
|
"C:\Program Files (x86)\Windows Kits\10\Include\10.0.22621.0\shared"
|
||||||
|
Add-Content $env:GITHUB_ENV "INCLUDE=$env:INCLUDE"
|
||||||
|
shell: powershell
|
||||||
|
- name: Install Rust
|
||||||
|
run: |
|
||||||
|
Invoke-WebRequest https://win.rustup.rs/x86_64 -OutFile rustup-init.exe
|
||||||
|
.\rustup-init.exe -y --default-host aarch64-pc-windows-msvc
|
||||||
|
shell: powershell
|
||||||
|
- name: Add Rust to PATH
|
||||||
|
run: |
|
||||||
|
Add-Content $env:GITHUB_PATH "$env:USERPROFILE\.cargo\bin"
|
||||||
|
shell: powershell
|
||||||
|
- uses: Swatinem/rust-cache@v2
|
||||||
|
with:
|
||||||
|
workspaces: rust
|
||||||
|
- name: Install 7-Zip ARM
|
||||||
|
run: |
|
||||||
|
New-Item -Path 'C:\7zip' -ItemType Directory
|
||||||
|
Invoke-WebRequest https://7-zip.org/a/7z2408-arm64.exe -OutFile C:\7zip\7z-installer.exe
|
||||||
|
Start-Process -FilePath C:\7zip\7z-installer.exe -ArgumentList '/S' -Wait
|
||||||
|
shell: powershell
|
||||||
|
- name: Add 7-Zip to PATH
|
||||||
|
run: Add-Content $env:GITHUB_PATH "C:\Program Files\7-Zip"
|
||||||
|
shell: powershell
|
||||||
|
- name: Install Protoc v21.12
|
||||||
|
working-directory: C:\
|
||||||
|
run: |
|
||||||
|
if (Test-Path 'C:\protoc') {
|
||||||
|
Write-Host "Protoc directory exists, skipping installation"
|
||||||
|
return
|
||||||
|
}
|
||||||
|
New-Item -Path 'C:\protoc' -ItemType Directory
|
||||||
|
Set-Location C:\protoc
|
||||||
|
Invoke-WebRequest https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win64.zip -OutFile C:\protoc\protoc.zip
|
||||||
|
& 'C:\Program Files\7-Zip\7z.exe' x protoc.zip
|
||||||
|
shell: powershell
|
||||||
|
- name: Add Protoc to PATH
|
||||||
|
run: Add-Content $env:GITHUB_PATH "C:\protoc\bin"
|
||||||
|
shell: powershell
|
||||||
|
- name: Run tests
|
||||||
|
run: |
|
||||||
|
$env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT
|
||||||
|
cargo build --target aarch64-pc-windows-msvc
|
||||||
|
cargo test --target aarch64-pc-windows-msvc
|
||||||
|
|||||||
16
Cargo.toml
16
Cargo.toml
@@ -21,13 +21,15 @@ categories = ["database-implementations"]
|
|||||||
rust-version = "1.80.0" # TODO: lower this once we upgrade Lance again.
|
rust-version = "1.80.0" # TODO: lower this once we upgrade Lance again.
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
lance = { "version" = "=0.19.1", "features" = ["dynamodb"] }
|
lance = { "version" = "=0.19.2", "features" = [
|
||||||
lance-index = { "version" = "=0.19.1" }
|
"dynamodb",
|
||||||
lance-linalg = { "version" = "=0.19.1" }
|
], git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
|
||||||
lance-table = { "version" = "=0.19.1" }
|
lance-index = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
|
||||||
lance-testing = { "version" = "=0.19.1" }
|
lance-linalg = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
|
||||||
lance-datafusion = { "version" = "=0.19.1" }
|
lance-table = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
|
||||||
lance-encoding = { "version" = "=0.19.1" }
|
lance-testing = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
|
||||||
|
lance-datafusion = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
|
||||||
|
lance-encoding = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
|
||||||
# Note that this one does not include pyarrow
|
# Note that this one does not include pyarrow
|
||||||
arrow = { version = "52.2", optional = false }
|
arrow = { version = "52.2", optional = false }
|
||||||
arrow-array = "52.2"
|
arrow-array = "52.2"
|
||||||
|
|||||||
@@ -10,6 +10,7 @@
|
|||||||
[](https://blog.lancedb.com/)
|
[](https://blog.lancedb.com/)
|
||||||
[](https://discord.gg/zMM32dvNtd)
|
[](https://discord.gg/zMM32dvNtd)
|
||||||
[](https://twitter.com/lancedb)
|
[](https://twitter.com/lancedb)
|
||||||
|
[](https://gurubase.io/g/lancedb)
|
||||||
|
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
# Targets supported:
|
# Targets supported:
|
||||||
# - x86_64-pc-windows-msvc
|
# - x86_64-pc-windows-msvc
|
||||||
# - i686-pc-windows-msvc
|
# - i686-pc-windows-msvc
|
||||||
|
# - aarch64-pc-windows-msvc
|
||||||
|
|
||||||
function Prebuild-Rust {
|
function Prebuild-Rust {
|
||||||
param (
|
param (
|
||||||
@@ -31,7 +32,7 @@ function Build-NodeBinaries {
|
|||||||
|
|
||||||
$targets = $args[0]
|
$targets = $args[0]
|
||||||
if (-not $targets) {
|
if (-not $targets) {
|
||||||
$targets = "x86_64-pc-windows-msvc"
|
$targets = "x86_64-pc-windows-msvc", "aarch64-pc-windows-msvc"
|
||||||
}
|
}
|
||||||
|
|
||||||
Write-Host "Building artifacts for targets: $targets"
|
Write-Host "Building artifacts for targets: $targets"
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
# Targets supported:
|
# Targets supported:
|
||||||
# - x86_64-pc-windows-msvc
|
# - x86_64-pc-windows-msvc
|
||||||
# - i686-pc-windows-msvc
|
# - i686-pc-windows-msvc
|
||||||
|
# - aarch64-pc-windows-msvc
|
||||||
|
|
||||||
function Prebuild-Rust {
|
function Prebuild-Rust {
|
||||||
param (
|
param (
|
||||||
@@ -31,7 +32,7 @@ function Build-NodeBinaries {
|
|||||||
|
|
||||||
$targets = $args[0]
|
$targets = $args[0]
|
||||||
if (-not $targets) {
|
if (-not $targets) {
|
||||||
$targets = "x86_64-pc-windows-msvc"
|
$targets = "x86_64-pc-windows-msvc", "aarch64-pc-windows-msvc"
|
||||||
}
|
}
|
||||||
|
|
||||||
Write-Host "Building artifacts for targets: $targets"
|
Write-Host "Building artifacts for targets: $targets"
|
||||||
|
|||||||
57
ci/mock_openai.py
Normal file
57
ci/mock_openai.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
"""A zero-dependency mock OpenAI embeddings API endpoint for testing purposes."""
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import http.server
|
||||||
|
|
||||||
|
|
||||||
|
class MockOpenAIRequestHandler(http.server.BaseHTTPRequestHandler):
|
||||||
|
def do_POST(self):
|
||||||
|
content_length = int(self.headers["Content-Length"])
|
||||||
|
post_data = self.rfile.read(content_length)
|
||||||
|
post_data = json.loads(post_data.decode("utf-8"))
|
||||||
|
# See: https://platform.openai.com/docs/api-reference/embeddings/create
|
||||||
|
|
||||||
|
if isinstance(post_data["input"], str):
|
||||||
|
num_inputs = 1
|
||||||
|
else:
|
||||||
|
num_inputs = len(post_data["input"])
|
||||||
|
|
||||||
|
model = post_data.get("model", "text-embedding-ada-002")
|
||||||
|
|
||||||
|
data = []
|
||||||
|
for i in range(num_inputs):
|
||||||
|
data.append({
|
||||||
|
"object": "embedding",
|
||||||
|
"embedding": [0.1] * 1536,
|
||||||
|
"index": i,
|
||||||
|
})
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"object": "list",
|
||||||
|
"data": data,
|
||||||
|
"model": model,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"total_tokens": 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.send_response(200)
|
||||||
|
self.send_header("Content-type", "application/json")
|
||||||
|
self.end_headers()
|
||||||
|
self.wfile.write(json.dumps(response).encode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Mock OpenAI embeddings API endpoint")
|
||||||
|
parser.add_argument("--port", type=int, default=8000, help="Port to listen on")
|
||||||
|
args = parser.parse_args()
|
||||||
|
port = args.port
|
||||||
|
|
||||||
|
print(f"server started on port {port}. Press Ctrl-C to stop.")
|
||||||
|
print(f"To use, set OPENAI_BASE_URL=http://localhost:{port} in your environment.")
|
||||||
|
|
||||||
|
with http.server.HTTPServer(("0.0.0.0", port), MockOpenAIRequestHandler) as server:
|
||||||
|
server.serve_forever()
|
||||||
@@ -45,9 +45,9 @@ Lance supports `IVF_PQ` index type by default.
|
|||||||
Creating indexes is done via the [lancedb.Table.createIndex](../js/classes/Table.md/#createIndex) method.
|
Creating indexes is done via the [lancedb.Table.createIndex](../js/classes/Table.md/#createIndex) method.
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
--8<--- "nodejs/examples/ann_indexes.ts:import"
|
--8<--- "nodejs/examples/ann_indexes.test.ts:import"
|
||||||
|
|
||||||
--8<-- "nodejs/examples/ann_indexes.ts:ingest"
|
--8<-- "nodejs/examples/ann_indexes.test.ts:ingest"
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "vectordb (deprecated)"
|
=== "vectordb (deprecated)"
|
||||||
@@ -140,13 +140,15 @@ There are a couple of parameters that can be used to fine-tune the search:
|
|||||||
|
|
||||||
- **limit** (default: 10): The amount of results that will be returned
|
- **limit** (default: 10): The amount of results that will be returned
|
||||||
- **nprobes** (default: 20): The number of probes used. A higher number makes search more accurate but also slower.<br/>
|
- **nprobes** (default: 20): The number of probes used. A higher number makes search more accurate but also slower.<br/>
|
||||||
Most of the time, setting nprobes to cover 5-10% of the dataset should achieve high recall with low latency.<br/>
|
Most of the time, setting nprobes to cover 5-15% of the dataset should achieve high recall with low latency.<br/>
|
||||||
e.g., for 1M vectors divided up into 256 partitions, nprobes should be set to ~20-40.<br/>
|
- _For example_, For a dataset of 1 million vectors divided into 256 partitions, `nprobes` should be set to ~20-40. This value can be adjusted to achieve the optimal balance between search latency and search quality. <br/>
|
||||||
Note: nprobes is only applicable if an ANN index is present. If specified on a table without an ANN index, it is ignored.
|
|
||||||
- **refine_factor** (default: None): Refine the results by reading extra elements and re-ranking them in memory.<br/>
|
- **refine_factor** (default: None): Refine the results by reading extra elements and re-ranking them in memory.<br/>
|
||||||
A higher number makes search more accurate but also slower. If you find the recall is less than ideal, try refine_factor=10 to start.<br/>
|
A higher number makes search more accurate but also slower. If you find the recall is less than ideal, try refine_factor=10 to start.<br/>
|
||||||
e.g., for 1M vectors divided into 256 partitions, if you're looking for top 20, then refine_factor=200 reranks the whole partition.<br/>
|
- _For example_, For a dataset of 1 million vectors divided into 256 partitions, setting the `refine_factor` to 200 will initially retrieve the top 4,000 candidates (top k * refine_factor) from all searched partitions. These candidates are then reranked to determine the final top 20 results.<br/>
|
||||||
Note: refine_factor is only applicable if an ANN index is present. If specified on a table without an ANN index, it is ignored.
|
!!! note
|
||||||
|
Both `nprobes` and `refine_factor` are only applicable if an ANN index is present. If specified on a table without an ANN index, those parameters are ignored.
|
||||||
|
|
||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
|
|
||||||
@@ -169,7 +171,7 @@ There are a couple of parameters that can be used to fine-tune the search:
|
|||||||
=== "@lancedb/lancedb"
|
=== "@lancedb/lancedb"
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
--8<-- "nodejs/examples/ann_indexes.ts:search1"
|
--8<-- "nodejs/examples/ann_indexes.test.ts:search1"
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "vectordb (deprecated)"
|
=== "vectordb (deprecated)"
|
||||||
@@ -203,7 +205,7 @@ You can further filter the elements returned by a search using a where clause.
|
|||||||
=== "@lancedb/lancedb"
|
=== "@lancedb/lancedb"
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
--8<-- "nodejs/examples/ann_indexes.ts:search2"
|
--8<-- "nodejs/examples/ann_indexes.test.ts:search2"
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "vectordb (deprecated)"
|
=== "vectordb (deprecated)"
|
||||||
@@ -235,7 +237,7 @@ You can select the columns returned by the query using a select clause.
|
|||||||
=== "@lancedb/lancedb"
|
=== "@lancedb/lancedb"
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
--8<-- "nodejs/examples/ann_indexes.ts:search3"
|
--8<-- "nodejs/examples/ann_indexes.test.ts:search3"
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "vectordb (deprecated)"
|
=== "vectordb (deprecated)"
|
||||||
|
|||||||
@@ -157,7 +157,7 @@ recommend switching to stable releases.
|
|||||||
import * as lancedb from "@lancedb/lancedb";
|
import * as lancedb from "@lancedb/lancedb";
|
||||||
import * as arrow from "apache-arrow";
|
import * as arrow from "apache-arrow";
|
||||||
|
|
||||||
--8<-- "nodejs/examples/basic.ts:connect"
|
--8<-- "nodejs/examples/basic.test.ts:connect"
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "vectordb (deprecated)"
|
=== "vectordb (deprecated)"
|
||||||
@@ -212,7 +212,7 @@ table.
|
|||||||
=== "@lancedb/lancedb"
|
=== "@lancedb/lancedb"
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
--8<-- "nodejs/examples/basic.ts:create_table"
|
--8<-- "nodejs/examples/basic.test.ts:create_table"
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "vectordb (deprecated)"
|
=== "vectordb (deprecated)"
|
||||||
@@ -268,7 +268,7 @@ similar to a `CREATE TABLE` statement in SQL.
|
|||||||
=== "@lancedb/lancedb"
|
=== "@lancedb/lancedb"
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
--8<-- "nodejs/examples/basic.ts:create_empty_table"
|
--8<-- "nodejs/examples/basic.test.ts:create_empty_table"
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "vectordb (deprecated)"
|
=== "vectordb (deprecated)"
|
||||||
@@ -298,7 +298,7 @@ Once created, you can open a table as follows:
|
|||||||
=== "@lancedb/lancedb"
|
=== "@lancedb/lancedb"
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
--8<-- "nodejs/examples/basic.ts:open_table"
|
--8<-- "nodejs/examples/basic.test.ts:open_table"
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "vectordb (deprecated)"
|
=== "vectordb (deprecated)"
|
||||||
@@ -327,7 +327,7 @@ If you forget the name of your table, you can always get a listing of all table
|
|||||||
=== "@lancedb/lancedb"
|
=== "@lancedb/lancedb"
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
--8<-- "nodejs/examples/basic.ts:table_names"
|
--8<-- "nodejs/examples/basic.test.ts:table_names"
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "vectordb (deprecated)"
|
=== "vectordb (deprecated)"
|
||||||
@@ -357,7 +357,7 @@ After a table has been created, you can always add more data to it as follows:
|
|||||||
=== "@lancedb/lancedb"
|
=== "@lancedb/lancedb"
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
--8<-- "nodejs/examples/basic.ts:add_data"
|
--8<-- "nodejs/examples/basic.test.ts:add_data"
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "vectordb (deprecated)"
|
=== "vectordb (deprecated)"
|
||||||
@@ -389,7 +389,7 @@ Once you've embedded the query, you can find its nearest neighbors as follows:
|
|||||||
=== "@lancedb/lancedb"
|
=== "@lancedb/lancedb"
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
--8<-- "nodejs/examples/basic.ts:vector_search"
|
--8<-- "nodejs/examples/basic.test.ts:vector_search"
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "vectordb (deprecated)"
|
=== "vectordb (deprecated)"
|
||||||
@@ -429,7 +429,7 @@ LanceDB allows you to create an ANN index on a table as follows:
|
|||||||
=== "@lancedb/lancedb"
|
=== "@lancedb/lancedb"
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
--8<-- "nodejs/examples/basic.ts:create_index"
|
--8<-- "nodejs/examples/basic.test.ts:create_index"
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "vectordb (deprecated)"
|
=== "vectordb (deprecated)"
|
||||||
@@ -469,7 +469,7 @@ This can delete any number of rows that match the filter.
|
|||||||
=== "@lancedb/lancedb"
|
=== "@lancedb/lancedb"
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
--8<-- "nodejs/examples/basic.ts:delete_rows"
|
--8<-- "nodejs/examples/basic.test.ts:delete_rows"
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "vectordb (deprecated)"
|
=== "vectordb (deprecated)"
|
||||||
@@ -527,7 +527,7 @@ Use the `drop_table()` method on the database to remove a table.
|
|||||||
=== "@lancedb/lancedb"
|
=== "@lancedb/lancedb"
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
--8<-- "nodejs/examples/basic.ts:drop_table"
|
--8<-- "nodejs/examples/basic.test.ts:drop_table"
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "vectordb (deprecated)"
|
=== "vectordb (deprecated)"
|
||||||
@@ -561,8 +561,8 @@ You can use the embedding API when working with embedding models. It automatical
|
|||||||
=== "@lancedb/lancedb"
|
=== "@lancedb/lancedb"
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
--8<-- "nodejs/examples/embedding.ts:imports"
|
--8<-- "nodejs/examples/embedding.test.ts:imports"
|
||||||
--8<-- "nodejs/examples/embedding.ts:openai_embeddings"
|
--8<-- "nodejs/examples/embedding.test.ts:openai_embeddings"
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "Rust"
|
=== "Rust"
|
||||||
|
|||||||
@@ -0,0 +1,51 @@
|
|||||||
|
# VoyageAI Embeddings
|
||||||
|
|
||||||
|
Voyage AI provides cutting-edge embedding and rerankers.
|
||||||
|
|
||||||
|
|
||||||
|
Using voyageai API requires voyageai package, which can be installed using `pip install voyageai`. Voyage AI embeddings are used to generate embeddings for text data. The embeddings can be used for various tasks like semantic search, clustering, and classification.
|
||||||
|
You also need to set the `VOYAGE_API_KEY` environment variable to use the VoyageAI API.
|
||||||
|
|
||||||
|
Supported models are:
|
||||||
|
|
||||||
|
- voyage-3
|
||||||
|
- voyage-3-lite
|
||||||
|
- voyage-finance-2
|
||||||
|
- voyage-multilingual-2
|
||||||
|
- voyage-law-2
|
||||||
|
- voyage-code-2
|
||||||
|
|
||||||
|
|
||||||
|
Supported parameters (to be passed in `create` method) are:
|
||||||
|
|
||||||
|
| Parameter | Type | Default Value | Description |
|
||||||
|
|---|---|--------|---------|
|
||||||
|
| `name` | `str` | `"voyage-3"` | The model ID of the model to use. Supported base models for Text Embeddings: voyage-3, voyage-3-lite, voyage-finance-2, voyage-multilingual-2, voyage-law-2, voyage-code-2 |
|
||||||
|
| `input_type` | `str` | `None` | Type of the input text. Default to None. Other options: query, document. |
|
||||||
|
| `truncation` | `bool` | `True` | Whether to truncate the input texts to fit within the context length. |
|
||||||
|
|
||||||
|
|
||||||
|
Usage Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import lancedb
|
||||||
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||||
|
|
||||||
|
voyageai = EmbeddingFunctionRegistry
|
||||||
|
.get_instance()
|
||||||
|
.get("voyageai")
|
||||||
|
.create(name="voyage-3")
|
||||||
|
|
||||||
|
class TextModel(LanceModel):
|
||||||
|
text: str = voyageai.SourceField()
|
||||||
|
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
||||||
|
|
||||||
|
data = [ { "text": "hello world" },
|
||||||
|
{ "text": "goodbye world" }]
|
||||||
|
|
||||||
|
db = lancedb.connect("~/.lancedb")
|
||||||
|
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||||
|
|
||||||
|
tbl.add(data)
|
||||||
|
```
|
||||||
@@ -47,9 +47,9 @@ Let's implement `SentenceTransformerEmbeddings` class. All you need to do is imp
|
|||||||
=== "TypeScript"
|
=== "TypeScript"
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
--8<--- "nodejs/examples/custom_embedding_function.ts:imports"
|
--8<--- "nodejs/examples/custom_embedding_function.test.ts:imports"
|
||||||
|
|
||||||
--8<--- "nodejs/examples/custom_embedding_function.ts:embedding_impl"
|
--8<--- "nodejs/examples/custom_embedding_function.test.ts:embedding_impl"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
@@ -78,7 +78,7 @@ Now you can use this embedding function to create your table schema and that's i
|
|||||||
=== "TypeScript"
|
=== "TypeScript"
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
--8<--- "nodejs/examples/custom_embedding_function.ts:call_custom_function"
|
--8<--- "nodejs/examples/custom_embedding_function.test.ts:call_custom_function"
|
||||||
```
|
```
|
||||||
|
|
||||||
!!! note
|
!!! note
|
||||||
|
|||||||
@@ -94,8 +94,8 @@ the embeddings at all:
|
|||||||
=== "@lancedb/lancedb"
|
=== "@lancedb/lancedb"
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
--8<-- "nodejs/examples/embedding.ts:imports"
|
--8<-- "nodejs/examples/embedding.test.ts:imports"
|
||||||
--8<-- "nodejs/examples/embedding.ts:embedding_function"
|
--8<-- "nodejs/examples/embedding.test.ts:embedding_function"
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "vectordb (deprecated)"
|
=== "vectordb (deprecated)"
|
||||||
@@ -150,7 +150,7 @@ need to worry about it when you query the table:
|
|||||||
.toArray()
|
.toArray()
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "vectordb (deprecated)
|
=== "vectordb (deprecated)"
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
const results = await table
|
const results = await table
|
||||||
|
|||||||
@@ -51,8 +51,8 @@ LanceDB registers the OpenAI embeddings function in the registry as `openai`. Yo
|
|||||||
=== "TypeScript"
|
=== "TypeScript"
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
--8<--- "nodejs/examples/embedding.ts:imports"
|
--8<--- "nodejs/examples/embedding.test.ts:imports"
|
||||||
--8<--- "nodejs/examples/embedding.ts:openai_embeddings"
|
--8<--- "nodejs/examples/embedding.test.ts:openai_embeddings"
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "Rust"
|
=== "Rust"
|
||||||
@@ -121,12 +121,10 @@ class Words(LanceModel):
|
|||||||
vector: Vector(func.ndims()) = func.VectorField()
|
vector: Vector(func.ndims()) = func.VectorField()
|
||||||
|
|
||||||
table = db.create_table("words", schema=Words)
|
table = db.create_table("words", schema=Words)
|
||||||
table.add(
|
table.add([
|
||||||
[
|
|
||||||
{"text": "hello world"},
|
{"text": "hello world"},
|
||||||
{"text": "goodbye world"}
|
{"text": "goodbye world"}
|
||||||
]
|
])
|
||||||
)
|
|
||||||
|
|
||||||
query = "greetings"
|
query = "greetings"
|
||||||
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||||
|
|||||||
@@ -85,13 +85,13 @@ Initialize a LanceDB connection and create a table
|
|||||||
|
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
--8<-- "nodejs/examples/basic.ts:create_table"
|
--8<-- "nodejs/examples/basic.test.ts:create_table"
|
||||||
```
|
```
|
||||||
|
|
||||||
This will infer the schema from the provided data. If you want to explicitly provide a schema, you can use `apache-arrow` to declare a schema
|
This will infer the schema from the provided data. If you want to explicitly provide a schema, you can use `apache-arrow` to declare a schema
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
--8<-- "nodejs/examples/basic.ts:create_table_with_schema"
|
--8<-- "nodejs/examples/basic.test.ts:create_table_with_schema"
|
||||||
```
|
```
|
||||||
|
|
||||||
!!! info "Note"
|
!!! info "Note"
|
||||||
@@ -100,14 +100,14 @@ Initialize a LanceDB connection and create a table
|
|||||||
passed in will NOT be appended to the table in that case.
|
passed in will NOT be appended to the table in that case.
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
--8<-- "nodejs/examples/basic.ts:create_table_exists_ok"
|
--8<-- "nodejs/examples/basic.test.ts:create_table_exists_ok"
|
||||||
```
|
```
|
||||||
|
|
||||||
Sometimes you want to make sure that you start fresh. If you want to
|
Sometimes you want to make sure that you start fresh. If you want to
|
||||||
overwrite the table, you can pass in mode: "overwrite" to the createTable function.
|
overwrite the table, you can pass in mode: "overwrite" to the createTable function.
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
--8<-- "nodejs/examples/basic.ts:create_table_overwrite"
|
--8<-- "nodejs/examples/basic.test.ts:create_table_overwrite"
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "vectordb (deprecated)"
|
=== "vectordb (deprecated)"
|
||||||
@@ -227,7 +227,7 @@ LanceDB supports float16 data type!
|
|||||||
=== "@lancedb/lancedb"
|
=== "@lancedb/lancedb"
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
--8<-- "nodejs/examples/basic.ts:create_f16_table"
|
--8<-- "nodejs/examples/basic.test.ts:create_f16_table"
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "vectordb (deprecated)"
|
=== "vectordb (deprecated)"
|
||||||
@@ -455,7 +455,7 @@ You can create an empty table for scenarios where you want to add data to the ta
|
|||||||
=== "@lancedb/lancedb"
|
=== "@lancedb/lancedb"
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
--8<-- "nodejs/examples/basic.ts:create_empty_table"
|
--8<-- "nodejs/examples/basic.test.ts:create_empty_table"
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "vectordb (deprecated)"
|
=== "vectordb (deprecated)"
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
77
docs/src/reranking/voyageai.md
Normal file
77
docs/src/reranking/voyageai.md
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
# Voyage AI Reranker
|
||||||
|
|
||||||
|
Voyage AI provides cutting-edge embedding and rerankers.
|
||||||
|
|
||||||
|
This re-ranker uses the [VoyageAI](https://docs.voyageai.com/docs/) API to rerank the search results. You can use this re-ranker by passing `VoyageAIReranker()` to the `rerank()` method. Note that you'll either need to set the `VOYAGE_API_KEY` environment variable or pass the `api_key` argument to use this re-ranker.
|
||||||
|
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
Supported Query Types: Hybrid, Vector, FTS
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
import numpy
|
||||||
|
import lancedb
|
||||||
|
from lancedb.embeddings import get_registry
|
||||||
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
from lancedb.rerankers import VoyageAIReranker
|
||||||
|
|
||||||
|
embedder = get_registry().get("sentence-transformers").create()
|
||||||
|
db = lancedb.connect("~/.lancedb")
|
||||||
|
|
||||||
|
class Schema(LanceModel):
|
||||||
|
text: str = embedder.SourceField()
|
||||||
|
vector: Vector(embedder.ndims()) = embedder.VectorField()
|
||||||
|
|
||||||
|
data = [
|
||||||
|
{"text": "hello world"},
|
||||||
|
{"text": "goodbye world"}
|
||||||
|
]
|
||||||
|
tbl = db.create_table("test", schema=Schema, mode="overwrite")
|
||||||
|
tbl.add(data)
|
||||||
|
reranker = VoyageAIReranker(model_name="rerank-2")
|
||||||
|
|
||||||
|
# Run vector search with a reranker
|
||||||
|
result = tbl.search("hello").rerank(reranker=reranker).to_list()
|
||||||
|
|
||||||
|
# Run FTS search with a reranker
|
||||||
|
result = tbl.search("hello", query_type="fts").rerank(reranker=reranker).to_list()
|
||||||
|
|
||||||
|
# Run hybrid search with a reranker
|
||||||
|
tbl.create_fts_index("text", replace=True)
|
||||||
|
result = tbl.search("hello", query_type="hybrid").rerank(reranker=reranker).to_list()
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
Accepted Arguments
|
||||||
|
----------------
|
||||||
|
| Argument | Type | Default | Description |
|
||||||
|
| --- | --- | --- | --- |
|
||||||
|
| `model_name` | `str` | `None` | The name of the reranker model to use. Available models are: rerank-2, rerank-2-lite |
|
||||||
|
| `column` | `str` | `"text"` | The name of the column to use as input to the cross encoder model. |
|
||||||
|
| `top_n` | `str` | `None` | The number of results to return. If None, will return all results. |
|
||||||
|
| `api_key` | `str` | `None` | The API key for the Voyage AI API. If not provided, the `VOYAGE_API_KEY` environment variable is used. |
|
||||||
|
| `return_score` | str | `"relevance"` | Options are "relevance" or "all". The type of score to return. If "relevance", will return only the `_relevance_score. If "all" is supported, will return relevance score along with the vector and/or fts scores depending on query type |
|
||||||
|
| `truncation` | `bool` | `None` | Whether to truncate the input to satisfy the "context length limit" on the query and the documents. |
|
||||||
|
|
||||||
|
|
||||||
|
## Supported Scores for each query type
|
||||||
|
You can specify the type of scores you want the reranker to return. The following are the supported scores for each query type:
|
||||||
|
|
||||||
|
### Hybrid Search
|
||||||
|
|`return_score`| Status | Description |
|
||||||
|
| --- | --- | --- |
|
||||||
|
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||||
|
| `all` | ❌ Not Supported | Returns have vector(`_distance`) and FTS(`score`) along with Hybrid Search score(`_relevance_score`) |
|
||||||
|
|
||||||
|
### Vector Search
|
||||||
|
|`return_score`| Status | Description |
|
||||||
|
| --- | --- | --- |
|
||||||
|
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||||
|
| `all` | ✅ Supported | Returns have vector(`_distance`) along with Hybrid Search score(`_relevance_score`) |
|
||||||
|
|
||||||
|
### FTS Search
|
||||||
|
|`return_score`| Status | Description |
|
||||||
|
| --- | --- | --- |
|
||||||
|
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||||
|
| `all` | ✅ Supported | Returns have FTS(`score`) along with Hybrid Search score(`_relevance_score`) |
|
||||||
@@ -58,9 +58,9 @@ db.create_table("my_vectors", data=data)
|
|||||||
=== "@lancedb/lancedb"
|
=== "@lancedb/lancedb"
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
--8<-- "nodejs/examples/search.ts:import"
|
--8<-- "nodejs/examples/search.test.ts:import"
|
||||||
|
|
||||||
--8<-- "nodejs/examples/search.ts:search1"
|
--8<-- "nodejs/examples/search.test.ts:search1"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
@@ -89,7 +89,7 @@ By default, `l2` will be used as metric type. You can specify the metric type as
|
|||||||
=== "@lancedb/lancedb"
|
=== "@lancedb/lancedb"
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
--8<-- "nodejs/examples/search.ts:search2"
|
--8<-- "nodejs/examples/search.test.ts:search2"
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "vectordb (deprecated)"
|
=== "vectordb (deprecated)"
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ const tbl = await db.createTable('myVectors', data)
|
|||||||
=== "@lancedb/lancedb"
|
=== "@lancedb/lancedb"
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
--8<-- "nodejs/examples/filtering.ts:search"
|
--8<-- "nodejs/examples/filtering.test.ts:search"
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "vectordb (deprecated)"
|
=== "vectordb (deprecated)"
|
||||||
@@ -91,7 +91,7 @@ For example, the following filter string is acceptable:
|
|||||||
=== "@lancedb/lancedb"
|
=== "@lancedb/lancedb"
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
--8<-- "nodejs/examples/filtering.ts:vec_search"
|
--8<-- "nodejs/examples/filtering.test.ts:vec_search"
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "vectordb (deprecated)"
|
=== "vectordb (deprecated)"
|
||||||
@@ -169,7 +169,7 @@ You can also filter your data without search.
|
|||||||
=== "@lancedb/lancedb"
|
=== "@lancedb/lancedb"
|
||||||
|
|
||||||
```ts
|
```ts
|
||||||
--8<-- "nodejs/examples/filtering.ts:sql_search"
|
--8<-- "nodejs/examples/filtering.test.ts:sql_search"
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "vectordb (deprecated)"
|
=== "vectordb (deprecated)"
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
<parent>
|
<parent>
|
||||||
<groupId>com.lancedb</groupId>
|
<groupId>com.lancedb</groupId>
|
||||||
<artifactId>lancedb-parent</artifactId>
|
<artifactId>lancedb-parent</artifactId>
|
||||||
<version>0.11.1-beta.1</version>
|
<version>0.13.0-beta.1</version>
|
||||||
<relativePath>../pom.xml</relativePath>
|
<relativePath>../pom.xml</relativePath>
|
||||||
</parent>
|
</parent>
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
<groupId>com.lancedb</groupId>
|
<groupId>com.lancedb</groupId>
|
||||||
<artifactId>lancedb-parent</artifactId>
|
<artifactId>lancedb-parent</artifactId>
|
||||||
<version>0.11.1-beta.1</version>
|
<version>0.13.0-beta.1</version>
|
||||||
<packaging>pom</packaging>
|
<packaging>pom</packaging>
|
||||||
|
|
||||||
<name>LanceDB Parent</name>
|
<name>LanceDB Parent</name>
|
||||||
|
|||||||
50
node/package-lock.json
generated
50
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.11.1-beta.1",
|
"version": "0.13.0-beta.1",
|
||||||
"lockfileVersion": 3,
|
"lockfileVersion": 3,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
"": {
|
"": {
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.11.1-beta.1",
|
"version": "0.13.0-beta.1",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64",
|
"x64",
|
||||||
"arm64"
|
"arm64"
|
||||||
@@ -52,11 +52,12 @@
|
|||||||
"uuid": "^9.0.0"
|
"uuid": "^9.0.0"
|
||||||
},
|
},
|
||||||
"optionalDependencies": {
|
"optionalDependencies": {
|
||||||
"@lancedb/vectordb-darwin-arm64": "0.11.1-beta.1",
|
"@lancedb/vectordb-darwin-arm64": "0.13.0-beta.1",
|
||||||
"@lancedb/vectordb-darwin-x64": "0.11.1-beta.1",
|
"@lancedb/vectordb-darwin-x64": "0.13.0-beta.1",
|
||||||
"@lancedb/vectordb-linux-arm64-gnu": "0.11.1-beta.1",
|
"@lancedb/vectordb-linux-arm64-gnu": "0.13.0-beta.1",
|
||||||
"@lancedb/vectordb-linux-x64-gnu": "0.11.1-beta.1",
|
"@lancedb/vectordb-linux-x64-gnu": "0.13.0-beta.1",
|
||||||
"@lancedb/vectordb-win32-x64-msvc": "0.11.1-beta.1"
|
"@lancedb/vectordb-win32-arm64-msvc": "0.13.0-beta.1",
|
||||||
|
"@lancedb/vectordb-win32-x64-msvc": "0.13.0-beta.1"
|
||||||
},
|
},
|
||||||
"peerDependencies": {
|
"peerDependencies": {
|
||||||
"@apache-arrow/ts": "^14.0.2",
|
"@apache-arrow/ts": "^14.0.2",
|
||||||
@@ -327,65 +328,60 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
||||||
"version": "0.11.1-beta.1",
|
"version": "0.13.0-beta.1",
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.11.1-beta.1.tgz",
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.13.0-beta.1.tgz",
|
||||||
"integrity": "sha512-q9jcCbmcz45UHmjgecL6zK82WaqUJsARfniwXXPcnd8ooISVhPkgN+RVKv6edwI9T0PV+xVRYq+LQLlZu5fyxw==",
|
"integrity": "sha512-beOrf6selCzzhLgDG8Nibma4nO/CSnA1wUKRmlJHEPtGcg7PW18z6MP/nfwQMpMR/FLRfTo8pPTbpzss47MiQQ==",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"arm64"
|
"arm64"
|
||||||
],
|
],
|
||||||
"license": "Apache-2.0",
|
|
||||||
"optional": true,
|
"optional": true,
|
||||||
"os": [
|
"os": [
|
||||||
"darwin"
|
"darwin"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"node_modules/@lancedb/vectordb-darwin-x64": {
|
"node_modules/@lancedb/vectordb-darwin-x64": {
|
||||||
"version": "0.11.1-beta.1",
|
"version": "0.13.0-beta.1",
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.11.1-beta.1.tgz",
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.13.0-beta.1.tgz",
|
||||||
"integrity": "sha512-E5tCTS5TaTkssTPa+gdnFxZJ1f60jnSIJXhqufNFZk4s+IMViwR1BPqaqE++WY5c1uBI55ef1862CROKDKX4gg==",
|
"integrity": "sha512-YdraGRF/RbJRkKh0v3xT03LUhq47T2GtCvJ5gZp8wKlh4pHa8LuhLU0DIdvmG/DT5vuQA+td8HDkBm/e3EOdNg==",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64"
|
"x64"
|
||||||
],
|
],
|
||||||
"license": "Apache-2.0",
|
|
||||||
"optional": true,
|
"optional": true,
|
||||||
"os": [
|
"os": [
|
||||||
"darwin"
|
"darwin"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
||||||
"version": "0.11.1-beta.1",
|
"version": "0.13.0-beta.1",
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.11.1-beta.1.tgz",
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.13.0-beta.1.tgz",
|
||||||
"integrity": "sha512-Obohy6TH31Uq+fp6ZisHR7iAsvgVPqBExrycVcIJqrLZnIe88N9OWUwBXkmfMAw/2hNJFwD4tU7+4U2FcBWX4w==",
|
"integrity": "sha512-Pp0O/uhEqof1oLaWrNbv+Ym+q8kBkiCqaA5+2eAZ6a3e9U+Ozkvb0FQrHuyi9adJ5wKQ4NabyQE9BMf2bYpOnQ==",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"arm64"
|
"arm64"
|
||||||
],
|
],
|
||||||
"license": "Apache-2.0",
|
|
||||||
"optional": true,
|
"optional": true,
|
||||||
"os": [
|
"os": [
|
||||||
"linux"
|
"linux"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
||||||
"version": "0.11.1-beta.1",
|
"version": "0.13.0-beta.1",
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.11.1-beta.1.tgz",
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.13.0-beta.1.tgz",
|
||||||
"integrity": "sha512-3Meu0dgrzNrnBVVQhxkUSAOhQNmgtKHvOvmrRLUicV+X19hd33udihgxVpZZb9mpXenJ8lZsS+Jq6R0hWqntag==",
|
"integrity": "sha512-y8nxOye4egfWF5FGED9EfkmZ1O5HnRLU4a61B8m5JSpkivO9v2epTcbYN0yt/7ZFCgtqMfJ8VW4Mi7qQcz3KDA==",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64"
|
"x64"
|
||||||
],
|
],
|
||||||
"license": "Apache-2.0",
|
|
||||||
"optional": true,
|
"optional": true,
|
||||||
"os": [
|
"os": [
|
||||||
"linux"
|
"linux"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
||||||
"version": "0.11.1-beta.1",
|
"version": "0.13.0-beta.1",
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.11.1-beta.1.tgz",
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.13.0-beta.1.tgz",
|
||||||
"integrity": "sha512-BafZ9OJPQXsS7JW0weAl12wC+827AiRjfUrE5tvrYWZah2OwCF2U2g6uJ3x4pxfwEGsv5xcHFqgxlS7ttFkh+Q==",
|
"integrity": "sha512-STMDP9dp0TBLkB3ro+16pKcGy6bmbhRuEZZZ1Tp5P75yTPeVh4zIgWkidMdU1qBbEYM7xacnsp9QAwgLnMU/Ow==",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64"
|
"x64"
|
||||||
],
|
],
|
||||||
"license": "Apache-2.0",
|
|
||||||
"optional": true,
|
"optional": true,
|
||||||
"os": [
|
"os": [
|
||||||
"win32"
|
"win32"
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.11.1-beta.1",
|
"version": "0.13.0-beta.1",
|
||||||
"description": " Serverless, low-latency vector database for AI applications",
|
"description": " Serverless, low-latency vector database for AI applications",
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
"types": "dist/index.d.ts",
|
"types": "dist/index.d.ts",
|
||||||
@@ -84,14 +84,16 @@
|
|||||||
"aarch64-apple-darwin": "@lancedb/vectordb-darwin-arm64",
|
"aarch64-apple-darwin": "@lancedb/vectordb-darwin-arm64",
|
||||||
"x86_64-unknown-linux-gnu": "@lancedb/vectordb-linux-x64-gnu",
|
"x86_64-unknown-linux-gnu": "@lancedb/vectordb-linux-x64-gnu",
|
||||||
"aarch64-unknown-linux-gnu": "@lancedb/vectordb-linux-arm64-gnu",
|
"aarch64-unknown-linux-gnu": "@lancedb/vectordb-linux-arm64-gnu",
|
||||||
"x86_64-pc-windows-msvc": "@lancedb/vectordb-win32-x64-msvc"
|
"x86_64-pc-windows-msvc": "@lancedb/vectordb-win32-x64-msvc",
|
||||||
|
"aarch64-pc-windows-msvc": "@lancedb/vectordb-win32-arm64-msvc"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"optionalDependencies": {
|
"optionalDependencies": {
|
||||||
"@lancedb/vectordb-darwin-arm64": "0.11.1-beta.1",
|
"@lancedb/vectordb-darwin-arm64": "0.13.0-beta.1",
|
||||||
"@lancedb/vectordb-darwin-x64": "0.11.1-beta.1",
|
"@lancedb/vectordb-darwin-x64": "0.13.0-beta.1",
|
||||||
"@lancedb/vectordb-linux-arm64-gnu": "0.11.1-beta.1",
|
"@lancedb/vectordb-linux-arm64-gnu": "0.13.0-beta.1",
|
||||||
"@lancedb/vectordb-linux-x64-gnu": "0.11.1-beta.1",
|
"@lancedb/vectordb-linux-x64-gnu": "0.13.0-beta.1",
|
||||||
"@lancedb/vectordb-win32-x64-msvc": "0.11.1-beta.1"
|
"@lancedb/vectordb-win32-x64-msvc": "0.13.0-beta.1",
|
||||||
|
"@lancedb/vectordb-win32-arm64-msvc": "0.13.0-beta.1"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
import axios, { type AxiosResponse, type ResponseType } from 'axios'
|
import axios, { type AxiosError, type AxiosResponse, type ResponseType } from 'axios'
|
||||||
|
|
||||||
import { tableFromIPC, type Table as ArrowTable } from 'apache-arrow'
|
import { tableFromIPC, type Table as ArrowTable } from 'apache-arrow'
|
||||||
|
|
||||||
@@ -197,7 +197,7 @@ export class HttpLancedbClient {
|
|||||||
response = await callWithMiddlewares(req, this._middlewares)
|
response = await callWithMiddlewares(req, this._middlewares)
|
||||||
return response
|
return response
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
console.error('error: ', err)
|
console.error(serializeErrorAsJson(err))
|
||||||
if (err.response === undefined) {
|
if (err.response === undefined) {
|
||||||
throw new Error(`Network Error: ${err.message as string}`)
|
throw new Error(`Network Error: ${err.message as string}`)
|
||||||
}
|
}
|
||||||
@@ -247,7 +247,8 @@ export class HttpLancedbClient {
|
|||||||
|
|
||||||
// return response
|
// return response
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
console.error('error: ', err)
|
console.error(serializeErrorAsJson(err))
|
||||||
|
|
||||||
if (err.response === undefined) {
|
if (err.response === undefined) {
|
||||||
throw new Error(`Network Error: ${err.message as string}`)
|
throw new Error(`Network Error: ${err.message as string}`)
|
||||||
}
|
}
|
||||||
@@ -287,3 +288,15 @@ export class HttpLancedbClient {
|
|||||||
return clone
|
return clone
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function serializeErrorAsJson(err: AxiosError) {
|
||||||
|
const error = JSON.parse(JSON.stringify(err, Object.getOwnPropertyNames(err)))
|
||||||
|
error.response = err.response != null
|
||||||
|
? JSON.parse(JSON.stringify(
|
||||||
|
err.response,
|
||||||
|
// config contains the request data, too noisy
|
||||||
|
Object.getOwnPropertyNames(err.response).filter(prop => prop !== 'config')
|
||||||
|
))
|
||||||
|
: null
|
||||||
|
return JSON.stringify({ error })
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb-nodejs"
|
name = "lancedb-nodejs"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
version = "0.11.1-beta.1"
|
version = "0.13.0-beta.1"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
description.workspace = true
|
description.workspace = true
|
||||||
repository.workspace = true
|
repository.workspace = true
|
||||||
@@ -18,7 +18,7 @@ futures.workspace = true
|
|||||||
lancedb = { path = "../rust/lancedb", features = ["remote"] }
|
lancedb = { path = "../rust/lancedb", features = ["remote"] }
|
||||||
napi = { version = "2.16.8", default-features = false, features = [
|
napi = { version = "2.16.8", default-features = false, features = [
|
||||||
"napi9",
|
"napi9",
|
||||||
"async",
|
"async"
|
||||||
] }
|
] }
|
||||||
napi-derive = "2.16.4"
|
napi-derive = "2.16.4"
|
||||||
# Prevent dynamic linking of lzma, which comes from datafusion
|
# Prevent dynamic linking of lzma, which comes from datafusion
|
||||||
|
|||||||
@@ -402,6 +402,40 @@ describe("When creating an index", () => {
|
|||||||
expect(rst.numRows).toBe(1);
|
expect(rst.numRows).toBe(1);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("should be able to query unindexed data", async () => {
|
||||||
|
await tbl.createIndex("vec");
|
||||||
|
await tbl.add([
|
||||||
|
{
|
||||||
|
id: 300,
|
||||||
|
vec: Array(32)
|
||||||
|
.fill(1)
|
||||||
|
.map(() => Math.random()),
|
||||||
|
tags: [],
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
|
||||||
|
const plan1 = await tbl.query().nearestTo(queryVec).explainPlan(true);
|
||||||
|
expect(plan1).toMatch("LanceScan");
|
||||||
|
|
||||||
|
const plan2 = await tbl
|
||||||
|
.query()
|
||||||
|
.nearestTo(queryVec)
|
||||||
|
.fastSearch()
|
||||||
|
.explainPlan(true);
|
||||||
|
expect(plan2).not.toMatch("LanceScan");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should be able to query with row id", async () => {
|
||||||
|
const results = await tbl
|
||||||
|
.query()
|
||||||
|
.nearestTo(queryVec)
|
||||||
|
.withRowId()
|
||||||
|
.limit(1)
|
||||||
|
.toArray();
|
||||||
|
expect(results.length).toBe(1);
|
||||||
|
expect(results[0]).toHaveProperty("_rowid");
|
||||||
|
});
|
||||||
|
|
||||||
it("should allow parameters to be specified", async () => {
|
it("should allow parameters to be specified", async () => {
|
||||||
await tbl.createIndex("vec", {
|
await tbl.createIndex("vec", {
|
||||||
config: Index.ivfPq({
|
config: Index.ivfPq({
|
||||||
@@ -964,4 +998,18 @@ describe("column name options", () => {
|
|||||||
const results = await table.query().where("`camelCase` = 1").toArray();
|
const results = await table.query().where("`camelCase` = 1").toArray();
|
||||||
expect(results[0].camelCase).toBe(1);
|
expect(results[0].camelCase).toBe(1);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test("can make multiple vector queries in one go", async () => {
|
||||||
|
const results = await table
|
||||||
|
.query()
|
||||||
|
.nearestTo([0.1, 0.2])
|
||||||
|
.addQueryVector([0.1, 0.2])
|
||||||
|
.limit(1)
|
||||||
|
.toArray();
|
||||||
|
console.log(results);
|
||||||
|
expect(results.length).toBe(2);
|
||||||
|
results.sort((a, b) => a.query_index - b.query_index);
|
||||||
|
expect(results[0].query_index).toBe(0);
|
||||||
|
expect(results[1].query_index).toBe(1);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -9,7 +9,8 @@
|
|||||||
"**/native.js",
|
"**/native.js",
|
||||||
"**/native.d.ts",
|
"**/native.d.ts",
|
||||||
"**/npm/**/*",
|
"**/npm/**/*",
|
||||||
"**/.vscode/**"
|
"**/.vscode/**",
|
||||||
|
"./examples/*"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"formatter": {
|
"formatter": {
|
||||||
|
|||||||
57
nodejs/examples/ann_indexes.test.ts
Normal file
57
nodejs/examples/ann_indexes.test.ts
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
import { expect, test } from "@jest/globals";
|
||||||
|
// --8<-- [start:import]
|
||||||
|
import * as lancedb from "@lancedb/lancedb";
|
||||||
|
import { VectorQuery } from "@lancedb/lancedb";
|
||||||
|
// --8<-- [end:import]
|
||||||
|
import { withTempDirectory } from "./util.ts";
|
||||||
|
|
||||||
|
test("ann index examples", async () => {
|
||||||
|
await withTempDirectory(async (databaseDir) => {
|
||||||
|
// --8<-- [start:ingest]
|
||||||
|
const db = await lancedb.connect(databaseDir);
|
||||||
|
|
||||||
|
const data = Array.from({ length: 5_000 }, (_, i) => ({
|
||||||
|
vector: Array(128).fill(i),
|
||||||
|
id: `${i}`,
|
||||||
|
content: "",
|
||||||
|
longId: `${i}`,
|
||||||
|
}));
|
||||||
|
|
||||||
|
const table = await db.createTable("my_vectors", data, {
|
||||||
|
mode: "overwrite",
|
||||||
|
});
|
||||||
|
await table.createIndex("vector", {
|
||||||
|
config: lancedb.Index.ivfPq({
|
||||||
|
numPartitions: 10,
|
||||||
|
numSubVectors: 16,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
// --8<-- [end:ingest]
|
||||||
|
|
||||||
|
// --8<-- [start:search1]
|
||||||
|
const search = table.search(Array(128).fill(1.2)).limit(2) as VectorQuery;
|
||||||
|
const results1 = await search.nprobes(20).refineFactor(10).toArray();
|
||||||
|
// --8<-- [end:search1]
|
||||||
|
expect(results1.length).toBe(2);
|
||||||
|
|
||||||
|
// --8<-- [start:search2]
|
||||||
|
const results2 = await table
|
||||||
|
.search(Array(128).fill(1.2))
|
||||||
|
.where("id != '1141'")
|
||||||
|
.limit(2)
|
||||||
|
.toArray();
|
||||||
|
// --8<-- [end:search2]
|
||||||
|
expect(results2.length).toBe(2);
|
||||||
|
|
||||||
|
// --8<-- [start:search3]
|
||||||
|
const results3 = await table
|
||||||
|
.search(Array(128).fill(1.2))
|
||||||
|
.select(["id"])
|
||||||
|
.limit(2)
|
||||||
|
.toArray();
|
||||||
|
// --8<-- [end:search3]
|
||||||
|
expect(results3.length).toBe(2);
|
||||||
|
});
|
||||||
|
}, 100_000);
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
// --8<-- [start:import]
|
|
||||||
import * as lancedb from "@lancedb/lancedb";
|
|
||||||
// --8<-- [end:import]
|
|
||||||
|
|
||||||
// --8<-- [start:ingest]
|
|
||||||
const db = await lancedb.connect("/tmp/lancedb/");
|
|
||||||
|
|
||||||
const data = Array.from({ length: 10_000 }, (_, i) => ({
|
|
||||||
vector: Array(1536).fill(i),
|
|
||||||
id: `${i}`,
|
|
||||||
content: "",
|
|
||||||
longId: `${i}`,
|
|
||||||
}));
|
|
||||||
|
|
||||||
const table = await db.createTable("my_vectors", data, { mode: "overwrite" });
|
|
||||||
await table.createIndex("vector", {
|
|
||||||
config: lancedb.Index.ivfPq({
|
|
||||||
numPartitions: 16,
|
|
||||||
numSubVectors: 48,
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
// --8<-- [end:ingest]
|
|
||||||
|
|
||||||
// --8<-- [start:search1]
|
|
||||||
const _results1 = await table
|
|
||||||
.search(Array(1536).fill(1.2))
|
|
||||||
.limit(2)
|
|
||||||
.nprobes(20)
|
|
||||||
.refineFactor(10)
|
|
||||||
.toArray();
|
|
||||||
// --8<-- [end:search1]
|
|
||||||
|
|
||||||
// --8<-- [start:search2]
|
|
||||||
const _results2 = await table
|
|
||||||
.search(Array(1536).fill(1.2))
|
|
||||||
.where("id != '1141'")
|
|
||||||
.limit(2)
|
|
||||||
.toArray();
|
|
||||||
// --8<-- [end:search2]
|
|
||||||
|
|
||||||
// --8<-- [start:search3]
|
|
||||||
const _results3 = await table
|
|
||||||
.search(Array(1536).fill(1.2))
|
|
||||||
.select(["id"])
|
|
||||||
.limit(2)
|
|
||||||
.toArray();
|
|
||||||
// --8<-- [end:search3]
|
|
||||||
|
|
||||||
console.log("Ann indexes: done");
|
|
||||||
175
nodejs/examples/basic.test.ts
Normal file
175
nodejs/examples/basic.test.ts
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
import { expect, test } from "@jest/globals";
|
||||||
|
// --8<-- [start:imports]
|
||||||
|
import * as lancedb from "@lancedb/lancedb";
|
||||||
|
import * as arrow from "apache-arrow";
|
||||||
|
import {
|
||||||
|
Field,
|
||||||
|
FixedSizeList,
|
||||||
|
Float16,
|
||||||
|
Int32,
|
||||||
|
Schema,
|
||||||
|
Utf8,
|
||||||
|
} from "apache-arrow";
|
||||||
|
// --8<-- [end:imports]
|
||||||
|
import { withTempDirectory } from "./util.ts";
|
||||||
|
|
||||||
|
test("basic table examples", async () => {
|
||||||
|
await withTempDirectory(async (databaseDir) => {
|
||||||
|
// --8<-- [start:connect]
|
||||||
|
const db = await lancedb.connect(databaseDir);
|
||||||
|
// --8<-- [end:connect]
|
||||||
|
{
|
||||||
|
// --8<-- [start:create_table]
|
||||||
|
const _tbl = await db.createTable(
|
||||||
|
"myTable",
|
||||||
|
[
|
||||||
|
{ vector: [3.1, 4.1], item: "foo", price: 10.0 },
|
||||||
|
{ vector: [5.9, 26.5], item: "bar", price: 20.0 },
|
||||||
|
],
|
||||||
|
{ mode: "overwrite" },
|
||||||
|
);
|
||||||
|
// --8<-- [end:create_table]
|
||||||
|
|
||||||
|
const data = [
|
||||||
|
{ vector: [3.1, 4.1], item: "foo", price: 10.0 },
|
||||||
|
{ vector: [5.9, 26.5], item: "bar", price: 20.0 },
|
||||||
|
];
|
||||||
|
|
||||||
|
{
|
||||||
|
// --8<-- [start:create_table_exists_ok]
|
||||||
|
const tbl = await db.createTable("myTable", data, {
|
||||||
|
existOk: true,
|
||||||
|
});
|
||||||
|
// --8<-- [end:create_table_exists_ok]
|
||||||
|
expect(await tbl.countRows()).toBe(2);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// --8<-- [start:create_table_overwrite]
|
||||||
|
const tbl = await db.createTable("myTable", data, {
|
||||||
|
mode: "overwrite",
|
||||||
|
});
|
||||||
|
// --8<-- [end:create_table_overwrite]
|
||||||
|
expect(await tbl.countRows()).toBe(2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await db.dropTable("myTable");
|
||||||
|
|
||||||
|
{
|
||||||
|
// --8<-- [start:create_table_with_schema]
|
||||||
|
const schema = new arrow.Schema([
|
||||||
|
new arrow.Field(
|
||||||
|
"vector",
|
||||||
|
new arrow.FixedSizeList(
|
||||||
|
2,
|
||||||
|
new arrow.Field("item", new arrow.Float32(), true),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
new arrow.Field("item", new arrow.Utf8(), true),
|
||||||
|
new arrow.Field("price", new arrow.Float32(), true),
|
||||||
|
]);
|
||||||
|
const data = [
|
||||||
|
{ vector: [3.1, 4.1], item: "foo", price: 10.0 },
|
||||||
|
{ vector: [5.9, 26.5], item: "bar", price: 20.0 },
|
||||||
|
];
|
||||||
|
const tbl = await db.createTable("myTable", data, {
|
||||||
|
schema,
|
||||||
|
});
|
||||||
|
// --8<-- [end:create_table_with_schema]
|
||||||
|
expect(await tbl.countRows()).toBe(2);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// --8<-- [start:create_empty_table]
|
||||||
|
|
||||||
|
const schema = new arrow.Schema([
|
||||||
|
new arrow.Field("id", new arrow.Int32()),
|
||||||
|
new arrow.Field("name", new arrow.Utf8()),
|
||||||
|
]);
|
||||||
|
|
||||||
|
const emptyTbl = await db.createEmptyTable("empty_table", schema);
|
||||||
|
// --8<-- [end:create_empty_table]
|
||||||
|
expect(await emptyTbl.countRows()).toBe(0);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// --8<-- [start:open_table]
|
||||||
|
const _tbl = await db.openTable("myTable");
|
||||||
|
// --8<-- [end:open_table]
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// --8<-- [start:table_names]
|
||||||
|
const tableNames = await db.tableNames();
|
||||||
|
// --8<-- [end:table_names]
|
||||||
|
expect(tableNames).toEqual(["empty_table", "myTable"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
const tbl = await db.openTable("myTable");
|
||||||
|
{
|
||||||
|
// --8<-- [start:add_data]
|
||||||
|
const data = [
|
||||||
|
{ vector: [1.3, 1.4], item: "fizz", price: 100.0 },
|
||||||
|
{ vector: [9.5, 56.2], item: "buzz", price: 200.0 },
|
||||||
|
];
|
||||||
|
await tbl.add(data);
|
||||||
|
// --8<-- [end:add_data]
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// --8<-- [start:vector_search]
|
||||||
|
const res = await tbl.search([100, 100]).limit(2).toArray();
|
||||||
|
// --8<-- [end:vector_search]
|
||||||
|
expect(res.length).toBe(2);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
const data = Array.from({ length: 1000 })
|
||||||
|
.fill(null)
|
||||||
|
.map(() => ({
|
||||||
|
vector: [Math.random(), Math.random()],
|
||||||
|
item: "autogen",
|
||||||
|
price: Math.round(Math.random() * 100),
|
||||||
|
}));
|
||||||
|
|
||||||
|
await tbl.add(data);
|
||||||
|
}
|
||||||
|
|
||||||
|
// --8<-- [start:create_index]
|
||||||
|
await tbl.createIndex("vector");
|
||||||
|
// --8<-- [end:create_index]
|
||||||
|
|
||||||
|
// --8<-- [start:delete_rows]
|
||||||
|
await tbl.delete('item = "fizz"');
|
||||||
|
// --8<-- [end:delete_rows]
|
||||||
|
|
||||||
|
// --8<-- [start:drop_table]
|
||||||
|
await db.dropTable("myTable");
|
||||||
|
// --8<-- [end:drop_table]
|
||||||
|
await db.dropTable("empty_table");
|
||||||
|
|
||||||
|
{
|
||||||
|
// --8<-- [start:create_f16_table]
|
||||||
|
const db = await lancedb.connect(databaseDir);
|
||||||
|
const dim = 16;
|
||||||
|
const total = 10;
|
||||||
|
const f16Schema = new Schema([
|
||||||
|
new Field("id", new Int32()),
|
||||||
|
new Field(
|
||||||
|
"vector",
|
||||||
|
new FixedSizeList(dim, new Field("item", new Float16(), true)),
|
||||||
|
false,
|
||||||
|
),
|
||||||
|
]);
|
||||||
|
const data = lancedb.makeArrowTable(
|
||||||
|
Array.from(Array(total), (_, i) => ({
|
||||||
|
id: i,
|
||||||
|
vector: Array.from(Array(dim), Math.random),
|
||||||
|
})),
|
||||||
|
{ schema: f16Schema },
|
||||||
|
);
|
||||||
|
const _table = await db.createTable("f16_tbl", data);
|
||||||
|
// --8<-- [end:create_f16_table]
|
||||||
|
await db.dropTable("f16_tbl");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -1,162 +0,0 @@
|
|||||||
// --8<-- [start:imports]
|
|
||||||
import * as lancedb from "@lancedb/lancedb";
|
|
||||||
import * as arrow from "apache-arrow";
|
|
||||||
import {
|
|
||||||
Field,
|
|
||||||
FixedSizeList,
|
|
||||||
Float16,
|
|
||||||
Int32,
|
|
||||||
Schema,
|
|
||||||
Utf8,
|
|
||||||
} from "apache-arrow";
|
|
||||||
|
|
||||||
// --8<-- [end:imports]
|
|
||||||
|
|
||||||
// --8<-- [start:connect]
|
|
||||||
const uri = "/tmp/lancedb/";
|
|
||||||
const db = await lancedb.connect(uri);
|
|
||||||
// --8<-- [end:connect]
|
|
||||||
{
|
|
||||||
// --8<-- [start:create_table]
|
|
||||||
const tbl = await db.createTable(
|
|
||||||
"myTable",
|
|
||||||
[
|
|
||||||
{ vector: [3.1, 4.1], item: "foo", price: 10.0 },
|
|
||||||
{ vector: [5.9, 26.5], item: "bar", price: 20.0 },
|
|
||||||
],
|
|
||||||
{ mode: "overwrite" },
|
|
||||||
);
|
|
||||||
// --8<-- [end:create_table]
|
|
||||||
|
|
||||||
const data = [
|
|
||||||
{ vector: [3.1, 4.1], item: "foo", price: 10.0 },
|
|
||||||
{ vector: [5.9, 26.5], item: "bar", price: 20.0 },
|
|
||||||
];
|
|
||||||
|
|
||||||
{
|
|
||||||
// --8<-- [start:create_table_exists_ok]
|
|
||||||
const tbl = await db.createTable("myTable", data, {
|
|
||||||
existsOk: true,
|
|
||||||
});
|
|
||||||
// --8<-- [end:create_table_exists_ok]
|
|
||||||
}
|
|
||||||
{
|
|
||||||
// --8<-- [start:create_table_overwrite]
|
|
||||||
const _tbl = await db.createTable("myTable", data, {
|
|
||||||
mode: "overwrite",
|
|
||||||
});
|
|
||||||
// --8<-- [end:create_table_overwrite]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
// --8<-- [start:create_table_with_schema]
|
|
||||||
const schema = new arrow.Schema([
|
|
||||||
new arrow.Field(
|
|
||||||
"vector",
|
|
||||||
new arrow.FixedSizeList(
|
|
||||||
2,
|
|
||||||
new arrow.Field("item", new arrow.Float32(), true),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
new arrow.Field("item", new arrow.Utf8(), true),
|
|
||||||
new arrow.Field("price", new arrow.Float32(), true),
|
|
||||||
]);
|
|
||||||
const data = [
|
|
||||||
{ vector: [3.1, 4.1], item: "foo", price: 10.0 },
|
|
||||||
{ vector: [5.9, 26.5], item: "bar", price: 20.0 },
|
|
||||||
];
|
|
||||||
const _tbl = await db.createTable("myTable", data, {
|
|
||||||
schema,
|
|
||||||
});
|
|
||||||
// --8<-- [end:create_table_with_schema]
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
// --8<-- [start:create_empty_table]
|
|
||||||
|
|
||||||
const schema = new arrow.Schema([
|
|
||||||
new arrow.Field("id", new arrow.Int32()),
|
|
||||||
new arrow.Field("name", new arrow.Utf8()),
|
|
||||||
]);
|
|
||||||
|
|
||||||
const empty_tbl = await db.createEmptyTable("empty_table", schema);
|
|
||||||
// --8<-- [end:create_empty_table]
|
|
||||||
}
|
|
||||||
{
|
|
||||||
// --8<-- [start:open_table]
|
|
||||||
const _tbl = await db.openTable("myTable");
|
|
||||||
// --8<-- [end:open_table]
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
// --8<-- [start:table_names]
|
|
||||||
const tableNames = await db.tableNames();
|
|
||||||
console.log(tableNames);
|
|
||||||
// --8<-- [end:table_names]
|
|
||||||
}
|
|
||||||
|
|
||||||
const tbl = await db.openTable("myTable");
|
|
||||||
{
|
|
||||||
// --8<-- [start:add_data]
|
|
||||||
const data = [
|
|
||||||
{ vector: [1.3, 1.4], item: "fizz", price: 100.0 },
|
|
||||||
{ vector: [9.5, 56.2], item: "buzz", price: 200.0 },
|
|
||||||
];
|
|
||||||
await tbl.add(data);
|
|
||||||
// --8<-- [end:add_data]
|
|
||||||
}
|
|
||||||
{
|
|
||||||
// --8<-- [start:vector_search]
|
|
||||||
const _res = tbl.search([100, 100]).limit(2).toArray();
|
|
||||||
// --8<-- [end:vector_search]
|
|
||||||
}
|
|
||||||
{
|
|
||||||
const data = Array.from({ length: 1000 })
|
|
||||||
.fill(null)
|
|
||||||
.map(() => ({
|
|
||||||
vector: [Math.random(), Math.random()],
|
|
||||||
item: "autogen",
|
|
||||||
price: Math.round(Math.random() * 100),
|
|
||||||
}));
|
|
||||||
|
|
||||||
await tbl.add(data);
|
|
||||||
}
|
|
||||||
|
|
||||||
// --8<-- [start:create_index]
|
|
||||||
await tbl.createIndex("vector");
|
|
||||||
// --8<-- [end:create_index]
|
|
||||||
|
|
||||||
// --8<-- [start:delete_rows]
|
|
||||||
await tbl.delete('item = "fizz"');
|
|
||||||
// --8<-- [end:delete_rows]
|
|
||||||
|
|
||||||
// --8<-- [start:drop_table]
|
|
||||||
await db.dropTable("myTable");
|
|
||||||
// --8<-- [end:drop_table]
|
|
||||||
await db.dropTable("empty_table");
|
|
||||||
|
|
||||||
{
|
|
||||||
// --8<-- [start:create_f16_table]
|
|
||||||
const db = await lancedb.connect("/tmp/lancedb");
|
|
||||||
const dim = 16;
|
|
||||||
const total = 10;
|
|
||||||
const f16Schema = new Schema([
|
|
||||||
new Field("id", new Int32()),
|
|
||||||
new Field(
|
|
||||||
"vector",
|
|
||||||
new FixedSizeList(dim, new Field("item", new Float16(), true)),
|
|
||||||
false,
|
|
||||||
),
|
|
||||||
]);
|
|
||||||
const data = lancedb.makeArrowTable(
|
|
||||||
Array.from(Array(total), (_, i) => ({
|
|
||||||
id: i,
|
|
||||||
vector: Array.from(Array(dim), Math.random),
|
|
||||||
})),
|
|
||||||
{ schema: f16Schema },
|
|
||||||
);
|
|
||||||
const _table = await db.createTable("f16_tbl", data);
|
|
||||||
// --8<-- [end:create_f16_table]
|
|
||||||
await db.dropTable("f16_tbl");
|
|
||||||
}
|
|
||||||
76
nodejs/examples/custom_embedding_function.test.ts
Normal file
76
nodejs/examples/custom_embedding_function.test.ts
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
import { FeatureExtractionPipeline, pipeline } from "@huggingface/transformers";
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
import { expect, test } from "@jest/globals";
|
||||||
|
// --8<-- [start:imports]
|
||||||
|
import * as lancedb from "@lancedb/lancedb";
|
||||||
|
import {
|
||||||
|
LanceSchema,
|
||||||
|
TextEmbeddingFunction,
|
||||||
|
getRegistry,
|
||||||
|
register,
|
||||||
|
} from "@lancedb/lancedb/embedding";
|
||||||
|
// --8<-- [end:imports]
|
||||||
|
import { withTempDirectory } from "./util.ts";
|
||||||
|
|
||||||
|
// --8<-- [start:embedding_impl]
|
||||||
|
@register("sentence-transformers")
|
||||||
|
class SentenceTransformersEmbeddings extends TextEmbeddingFunction {
|
||||||
|
name = "Xenova/all-miniLM-L6-v2";
|
||||||
|
#ndims!: number;
|
||||||
|
extractor!: FeatureExtractionPipeline;
|
||||||
|
|
||||||
|
async init() {
|
||||||
|
this.extractor = await pipeline("feature-extraction", this.name, {
|
||||||
|
dtype: "fp32",
|
||||||
|
});
|
||||||
|
this.#ndims = await this.generateEmbeddings(["hello"]).then(
|
||||||
|
(e) => e[0].length,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
ndims() {
|
||||||
|
return this.#ndims;
|
||||||
|
}
|
||||||
|
|
||||||
|
toJSON() {
|
||||||
|
return {
|
||||||
|
name: this.name,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
async generateEmbeddings(texts: string[]) {
|
||||||
|
const output = await this.extractor(texts, {
|
||||||
|
pooling: "mean",
|
||||||
|
normalize: true,
|
||||||
|
});
|
||||||
|
return output.tolist();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// -8<-- [end:embedding_impl]
|
||||||
|
|
||||||
|
test("Registry examples", async () => {
|
||||||
|
await withTempDirectory(async (databaseDir) => {
|
||||||
|
// --8<-- [start:call_custom_function]
|
||||||
|
const registry = getRegistry();
|
||||||
|
|
||||||
|
const sentenceTransformer = await registry
|
||||||
|
.get<SentenceTransformersEmbeddings>("sentence-transformers")!
|
||||||
|
.create();
|
||||||
|
|
||||||
|
const schema = LanceSchema({
|
||||||
|
vector: sentenceTransformer.vectorField(),
|
||||||
|
text: sentenceTransformer.sourceField(),
|
||||||
|
});
|
||||||
|
|
||||||
|
const db = await lancedb.connect(databaseDir);
|
||||||
|
const table = await db.createEmptyTable("table", schema, {
|
||||||
|
mode: "overwrite",
|
||||||
|
});
|
||||||
|
|
||||||
|
await table.add([{ text: "hello" }, { text: "world" }]);
|
||||||
|
|
||||||
|
const results = await table.search("greeting").limit(1).toArray();
|
||||||
|
// -8<-- [end:call_custom_function]
|
||||||
|
expect(results.length).toBe(1);
|
||||||
|
});
|
||||||
|
}, 100_000);
|
||||||
@@ -1,64 +0,0 @@
|
|||||||
// --8<-- [start:imports]
|
|
||||||
import * as lancedb from "@lancedb/lancedb";
|
|
||||||
import {
|
|
||||||
LanceSchema,
|
|
||||||
TextEmbeddingFunction,
|
|
||||||
getRegistry,
|
|
||||||
register,
|
|
||||||
} from "@lancedb/lancedb/embedding";
|
|
||||||
import { pipeline } from "@xenova/transformers";
|
|
||||||
// --8<-- [end:imports]
|
|
||||||
|
|
||||||
// --8<-- [start:embedding_impl]
|
|
||||||
@register("sentence-transformers")
|
|
||||||
class SentenceTransformersEmbeddings extends TextEmbeddingFunction {
|
|
||||||
name = "Xenova/all-miniLM-L6-v2";
|
|
||||||
#ndims!: number;
|
|
||||||
extractor: any;
|
|
||||||
|
|
||||||
async init() {
|
|
||||||
this.extractor = await pipeline("feature-extraction", this.name);
|
|
||||||
this.#ndims = await this.generateEmbeddings(["hello"]).then(
|
|
||||||
(e) => e[0].length,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
ndims() {
|
|
||||||
return this.#ndims;
|
|
||||||
}
|
|
||||||
|
|
||||||
toJSON() {
|
|
||||||
return {
|
|
||||||
name: this.name,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
async generateEmbeddings(texts: string[]) {
|
|
||||||
const output = await this.extractor(texts, {
|
|
||||||
pooling: "mean",
|
|
||||||
normalize: true,
|
|
||||||
});
|
|
||||||
return output.tolist();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// -8<-- [end:embedding_impl]
|
|
||||||
|
|
||||||
// --8<-- [start:call_custom_function]
|
|
||||||
const registry = getRegistry();
|
|
||||||
|
|
||||||
const sentenceTransformer = await registry
|
|
||||||
.get<SentenceTransformersEmbeddings>("sentence-transformers")!
|
|
||||||
.create();
|
|
||||||
|
|
||||||
const schema = LanceSchema({
|
|
||||||
vector: sentenceTransformer.vectorField(),
|
|
||||||
text: sentenceTransformer.sourceField(),
|
|
||||||
});
|
|
||||||
|
|
||||||
const db = await lancedb.connect("/tmp/db");
|
|
||||||
const table = await db.createEmptyTable("table", schema, { mode: "overwrite" });
|
|
||||||
|
|
||||||
await table.add([{ text: "hello" }, { text: "world" }]);
|
|
||||||
|
|
||||||
const results = await table.search("greeting").limit(1).toArray();
|
|
||||||
console.log(results[0].text);
|
|
||||||
// -8<-- [end:call_custom_function]
|
|
||||||
96
nodejs/examples/embedding.test.ts
Normal file
96
nodejs/examples/embedding.test.ts
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
import { expect, test } from "@jest/globals";
|
||||||
|
// --8<-- [start:imports]
|
||||||
|
import * as lancedb from "@lancedb/lancedb";
|
||||||
|
import "@lancedb/lancedb/embedding/openai";
|
||||||
|
import { LanceSchema, getRegistry, register } from "@lancedb/lancedb/embedding";
|
||||||
|
import { EmbeddingFunction } from "@lancedb/lancedb/embedding";
|
||||||
|
import { type Float, Float32, Utf8 } from "apache-arrow";
|
||||||
|
// --8<-- [end:imports]
|
||||||
|
import { withTempDirectory } from "./util.ts";
|
||||||
|
|
||||||
|
const openAiTest = process.env.OPENAI_API_KEY == null ? test.skip : test;
|
||||||
|
|
||||||
|
openAiTest("openai embeddings", async () => {
|
||||||
|
await withTempDirectory(async (databaseDir) => {
|
||||||
|
// --8<-- [start:openai_embeddings]
|
||||||
|
const db = await lancedb.connect(databaseDir);
|
||||||
|
const func = getRegistry()
|
||||||
|
.get("openai")
|
||||||
|
?.create({ model: "text-embedding-ada-002" }) as EmbeddingFunction;
|
||||||
|
|
||||||
|
const wordsSchema = LanceSchema({
|
||||||
|
text: func.sourceField(new Utf8()),
|
||||||
|
vector: func.vectorField(),
|
||||||
|
});
|
||||||
|
const tbl = await db.createEmptyTable("words", wordsSchema, {
|
||||||
|
mode: "overwrite",
|
||||||
|
});
|
||||||
|
await tbl.add([{ text: "hello world" }, { text: "goodbye world" }]);
|
||||||
|
|
||||||
|
const query = "greetings";
|
||||||
|
const actual = (await tbl.search(query).limit(1).toArray())[0];
|
||||||
|
// --8<-- [end:openai_embeddings]
|
||||||
|
expect(actual).toHaveProperty("text");
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("custom embedding function", async () => {
|
||||||
|
await withTempDirectory(async (databaseDir) => {
|
||||||
|
// --8<-- [start:embedding_function]
|
||||||
|
const db = await lancedb.connect(databaseDir);
|
||||||
|
|
||||||
|
@register("my_embedding")
|
||||||
|
class MyEmbeddingFunction extends EmbeddingFunction<string> {
|
||||||
|
toJSON(): object {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
ndims() {
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
embeddingDataType(): Float {
|
||||||
|
return new Float32();
|
||||||
|
}
|
||||||
|
async computeQueryEmbeddings(_data: string) {
|
||||||
|
// This is a placeholder for a real embedding function
|
||||||
|
return [1, 2, 3];
|
||||||
|
}
|
||||||
|
async computeSourceEmbeddings(data: string[]) {
|
||||||
|
// This is a placeholder for a real embedding function
|
||||||
|
return Array.from({ length: data.length }).fill([
|
||||||
|
1, 2, 3,
|
||||||
|
]) as number[][];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const func = new MyEmbeddingFunction();
|
||||||
|
|
||||||
|
const data = [{ text: "pepperoni" }, { text: "pineapple" }];
|
||||||
|
|
||||||
|
// Option 1: manually specify the embedding function
|
||||||
|
const table = await db.createTable("vectors", data, {
|
||||||
|
embeddingFunction: {
|
||||||
|
function: func,
|
||||||
|
sourceColumn: "text",
|
||||||
|
vectorColumn: "vector",
|
||||||
|
},
|
||||||
|
mode: "overwrite",
|
||||||
|
});
|
||||||
|
|
||||||
|
// Option 2: provide the embedding function through a schema
|
||||||
|
|
||||||
|
const schema = LanceSchema({
|
||||||
|
text: func.sourceField(new Utf8()),
|
||||||
|
vector: func.vectorField(),
|
||||||
|
});
|
||||||
|
|
||||||
|
const table2 = await db.createTable("vectors2", data, {
|
||||||
|
schema,
|
||||||
|
mode: "overwrite",
|
||||||
|
});
|
||||||
|
// --8<-- [end:embedding_function]
|
||||||
|
expect(await table.countRows()).toBe(2);
|
||||||
|
expect(await table2.countRows()).toBe(2);
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -1,83 +0,0 @@
|
|||||||
// --8<-- [start:imports]
|
|
||||||
import * as lancedb from "@lancedb/lancedb";
|
|
||||||
import { LanceSchema, getRegistry, register } from "@lancedb/lancedb/embedding";
|
|
||||||
import { EmbeddingFunction } from "@lancedb/lancedb/embedding";
|
|
||||||
import { type Float, Float32, Utf8 } from "apache-arrow";
|
|
||||||
// --8<-- [end:imports]
|
|
||||||
|
|
||||||
{
|
|
||||||
// --8<-- [start:openai_embeddings]
|
|
||||||
|
|
||||||
const db = await lancedb.connect("/tmp/db");
|
|
||||||
const func = getRegistry()
|
|
||||||
.get("openai")
|
|
||||||
?.create({ model: "text-embedding-ada-002" }) as EmbeddingFunction;
|
|
||||||
|
|
||||||
const wordsSchema = LanceSchema({
|
|
||||||
text: func.sourceField(new Utf8()),
|
|
||||||
vector: func.vectorField(),
|
|
||||||
});
|
|
||||||
const tbl = await db.createEmptyTable("words", wordsSchema, {
|
|
||||||
mode: "overwrite",
|
|
||||||
});
|
|
||||||
await tbl.add([{ text: "hello world" }, { text: "goodbye world" }]);
|
|
||||||
|
|
||||||
const query = "greetings";
|
|
||||||
const actual = (await (await tbl.search(query)).limit(1).toArray())[0];
|
|
||||||
|
|
||||||
// --8<-- [end:openai_embeddings]
|
|
||||||
console.log("result = ", actual.text);
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
// --8<-- [start:embedding_function]
|
|
||||||
const db = await lancedb.connect("/tmp/db");
|
|
||||||
|
|
||||||
@register("my_embedding")
|
|
||||||
class MyEmbeddingFunction extends EmbeddingFunction<string> {
|
|
||||||
toJSON(): object {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
ndims() {
|
|
||||||
return 3;
|
|
||||||
}
|
|
||||||
embeddingDataType(): Float {
|
|
||||||
return new Float32();
|
|
||||||
}
|
|
||||||
async computeQueryEmbeddings(_data: string) {
|
|
||||||
// This is a placeholder for a real embedding function
|
|
||||||
return [1, 2, 3];
|
|
||||||
}
|
|
||||||
async computeSourceEmbeddings(data: string[]) {
|
|
||||||
// This is a placeholder for a real embedding function
|
|
||||||
return Array.from({ length: data.length }).fill([1, 2, 3]) as number[][];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const func = new MyEmbeddingFunction();
|
|
||||||
|
|
||||||
const data = [{ text: "pepperoni" }, { text: "pineapple" }];
|
|
||||||
|
|
||||||
// Option 1: manually specify the embedding function
|
|
||||||
const table = await db.createTable("vectors", data, {
|
|
||||||
embeddingFunction: {
|
|
||||||
function: func,
|
|
||||||
sourceColumn: "text",
|
|
||||||
vectorColumn: "vector",
|
|
||||||
},
|
|
||||||
mode: "overwrite",
|
|
||||||
});
|
|
||||||
|
|
||||||
// Option 2: provide the embedding function through a schema
|
|
||||||
|
|
||||||
const schema = LanceSchema({
|
|
||||||
text: func.sourceField(new Utf8()),
|
|
||||||
vector: func.vectorField(),
|
|
||||||
});
|
|
||||||
|
|
||||||
const table2 = await db.createTable("vectors2", data, {
|
|
||||||
schema,
|
|
||||||
mode: "overwrite",
|
|
||||||
});
|
|
||||||
// --8<-- [end:embedding_function]
|
|
||||||
}
|
|
||||||
42
nodejs/examples/filtering.test.ts
Normal file
42
nodejs/examples/filtering.test.ts
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
import { expect, test } from "@jest/globals";
|
||||||
|
import * as lancedb from "@lancedb/lancedb";
|
||||||
|
import { withTempDirectory } from "./util.ts";
|
||||||
|
|
||||||
|
test("filtering examples", async () => {
|
||||||
|
await withTempDirectory(async (databaseDir) => {
|
||||||
|
const db = await lancedb.connect(databaseDir);
|
||||||
|
|
||||||
|
const data = Array.from({ length: 10_000 }, (_, i) => ({
|
||||||
|
vector: Array(1536).fill(i),
|
||||||
|
id: i,
|
||||||
|
item: `item ${i}`,
|
||||||
|
strId: `${i}`,
|
||||||
|
}));
|
||||||
|
|
||||||
|
const tbl = await db.createTable("myVectors", data, { mode: "overwrite" });
|
||||||
|
|
||||||
|
// --8<-- [start:search]
|
||||||
|
const _result = await tbl
|
||||||
|
.search(Array(1536).fill(0.5))
|
||||||
|
.limit(1)
|
||||||
|
.where("id = 10")
|
||||||
|
.toArray();
|
||||||
|
// --8<-- [end:search]
|
||||||
|
|
||||||
|
// --8<-- [start:vec_search]
|
||||||
|
const result = await (
|
||||||
|
tbl.search(Array(1536).fill(0)) as lancedb.VectorQuery
|
||||||
|
)
|
||||||
|
.where("(item IN ('item 0', 'item 2')) AND (id > 10)")
|
||||||
|
.postfilter()
|
||||||
|
.toArray();
|
||||||
|
// --8<-- [end:vec_search]
|
||||||
|
expect(result.length).toBe(0);
|
||||||
|
|
||||||
|
// --8<-- [start:sql_search]
|
||||||
|
await tbl.query().where("id = 10").limit(10).toArray();
|
||||||
|
// --8<-- [end:sql_search]
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -1,34 +0,0 @@
|
|||||||
import * as lancedb from "@lancedb/lancedb";
|
|
||||||
|
|
||||||
const db = await lancedb.connect("data/sample-lancedb");
|
|
||||||
|
|
||||||
const data = Array.from({ length: 10_000 }, (_, i) => ({
|
|
||||||
vector: Array(1536).fill(i),
|
|
||||||
id: i,
|
|
||||||
item: `item ${i}`,
|
|
||||||
strId: `${i}`,
|
|
||||||
}));
|
|
||||||
|
|
||||||
const tbl = await db.createTable("myVectors", data, { mode: "overwrite" });
|
|
||||||
|
|
||||||
// --8<-- [start:search]
|
|
||||||
const _result = await tbl
|
|
||||||
.search(Array(1536).fill(0.5))
|
|
||||||
.limit(1)
|
|
||||||
.where("id = 10")
|
|
||||||
.toArray();
|
|
||||||
// --8<-- [end:search]
|
|
||||||
|
|
||||||
// --8<-- [start:vec_search]
|
|
||||||
await tbl
|
|
||||||
.search(Array(1536).fill(0))
|
|
||||||
.where("(item IN ('item 0', 'item 2')) AND (id > 10)")
|
|
||||||
.postfilter()
|
|
||||||
.toArray();
|
|
||||||
// --8<-- [end:vec_search]
|
|
||||||
|
|
||||||
// --8<-- [start:sql_search]
|
|
||||||
await tbl.query().where("id = 10").limit(10).toArray();
|
|
||||||
// --8<-- [end:sql_search]
|
|
||||||
|
|
||||||
console.log("SQL search: done");
|
|
||||||
45
nodejs/examples/full_text_search.test.ts
Normal file
45
nodejs/examples/full_text_search.test.ts
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
import { expect, test } from "@jest/globals";
|
||||||
|
import * as lancedb from "@lancedb/lancedb";
|
||||||
|
import { withTempDirectory } from "./util.ts";
|
||||||
|
|
||||||
|
test("full text search", async () => {
|
||||||
|
await withTempDirectory(async (databaseDir) => {
|
||||||
|
const db = await lancedb.connect(databaseDir);
|
||||||
|
|
||||||
|
const words = [
|
||||||
|
"apple",
|
||||||
|
"banana",
|
||||||
|
"cherry",
|
||||||
|
"date",
|
||||||
|
"elderberry",
|
||||||
|
"fig",
|
||||||
|
"grape",
|
||||||
|
];
|
||||||
|
|
||||||
|
const data = Array.from({ length: 10_000 }, (_, i) => ({
|
||||||
|
vector: Array(1536).fill(i),
|
||||||
|
id: i,
|
||||||
|
item: `item ${i}`,
|
||||||
|
strId: `${i}`,
|
||||||
|
doc: words[i % words.length],
|
||||||
|
}));
|
||||||
|
|
||||||
|
const tbl = await db.createTable("myVectors", data, { mode: "overwrite" });
|
||||||
|
|
||||||
|
await tbl.createIndex("doc", {
|
||||||
|
config: lancedb.Index.fts(),
|
||||||
|
});
|
||||||
|
|
||||||
|
// --8<-- [start:full_text_search]
|
||||||
|
const result = await tbl
|
||||||
|
.query()
|
||||||
|
.nearestToText("apple")
|
||||||
|
.select(["id", "doc"])
|
||||||
|
.limit(10)
|
||||||
|
.toArray();
|
||||||
|
expect(result.length).toBe(10);
|
||||||
|
// --8<-- [end:full_text_search]
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -1,52 +0,0 @@
|
|||||||
// Copyright 2024 Lance Developers.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
import * as lancedb from "@lancedb/lancedb";
|
|
||||||
|
|
||||||
const db = await lancedb.connect("data/sample-lancedb");
|
|
||||||
|
|
||||||
const words = [
|
|
||||||
"apple",
|
|
||||||
"banana",
|
|
||||||
"cherry",
|
|
||||||
"date",
|
|
||||||
"elderberry",
|
|
||||||
"fig",
|
|
||||||
"grape",
|
|
||||||
];
|
|
||||||
|
|
||||||
const data = Array.from({ length: 10_000 }, (_, i) => ({
|
|
||||||
vector: Array(1536).fill(i),
|
|
||||||
id: i,
|
|
||||||
item: `item ${i}`,
|
|
||||||
strId: `${i}`,
|
|
||||||
doc: words[i % words.length],
|
|
||||||
}));
|
|
||||||
|
|
||||||
const tbl = await db.createTable("myVectors", data, { mode: "overwrite" });
|
|
||||||
|
|
||||||
await tbl.createIndex("doc", {
|
|
||||||
config: lancedb.Index.fts(),
|
|
||||||
});
|
|
||||||
|
|
||||||
// --8<-- [start:full_text_search]
|
|
||||||
let result = await tbl
|
|
||||||
.search("apple")
|
|
||||||
.select(["id", "doc"])
|
|
||||||
.limit(10)
|
|
||||||
.toArray();
|
|
||||||
console.log(result);
|
|
||||||
// --8<-- [end:full_text_search]
|
|
||||||
|
|
||||||
console.log("SQL search: done");
|
|
||||||
6
nodejs/examples/jest.config.cjs
Normal file
6
nodejs/examples/jest.config.cjs
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
/** @type {import('ts-jest').JestConfigWithTsJest} */
|
||||||
|
module.exports = {
|
||||||
|
preset: "ts-jest",
|
||||||
|
testEnvironment: "node",
|
||||||
|
testPathIgnorePatterns: ["./dist"],
|
||||||
|
};
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
{
|
|
||||||
"compilerOptions": {
|
|
||||||
// Enable latest features
|
|
||||||
"lib": ["ESNext", "DOM"],
|
|
||||||
"target": "ESNext",
|
|
||||||
"module": "ESNext",
|
|
||||||
"moduleDetection": "force",
|
|
||||||
"jsx": "react-jsx",
|
|
||||||
"allowJs": true,
|
|
||||||
|
|
||||||
// Bundler mode
|
|
||||||
"moduleResolution": "bundler",
|
|
||||||
"allowImportingTsExtensions": true,
|
|
||||||
"verbatimModuleSyntax": true,
|
|
||||||
"noEmit": true,
|
|
||||||
|
|
||||||
// Best practices
|
|
||||||
"strict": true,
|
|
||||||
"skipLibCheck": true,
|
|
||||||
"noFallthroughCasesInSwitch": true,
|
|
||||||
|
|
||||||
// Some stricter flags (disabled by default)
|
|
||||||
"noUnusedLocals": false,
|
|
||||||
"noUnusedParameters": false,
|
|
||||||
"noPropertyAccessFromIndexSignature": false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
4991
nodejs/examples/package-lock.json
generated
4991
nodejs/examples/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -5,24 +5,29 @@
|
|||||||
"main": "index.js",
|
"main": "index.js",
|
||||||
"type": "module",
|
"type": "module",
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"test": "echo \"Error: no test specified\" && exit 1"
|
"//1": "--experimental-vm-modules is needed to run jest with sentence-transformers",
|
||||||
|
"//2": "--testEnvironment is needed to run jest with sentence-transformers",
|
||||||
|
"//3": "See: https://github.com/huggingface/transformers.js/issues/57",
|
||||||
|
"test": "node --experimental-vm-modules node_modules/.bin/jest --testEnvironment jest-environment-node-single-context --verbose",
|
||||||
|
"lint": "biome check *.ts && biome format *.ts",
|
||||||
|
"lint-ci": "biome ci .",
|
||||||
|
"lint-fix": "biome check --write *.ts && npm run format",
|
||||||
|
"format": "biome format --write *.ts"
|
||||||
},
|
},
|
||||||
"author": "Lance Devs",
|
"author": "Lance Devs",
|
||||||
"license": "Apache-2.0",
|
"license": "Apache-2.0",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@lancedb/lancedb": "file:../",
|
"@huggingface/transformers": "^3.0.2",
|
||||||
"@xenova/transformers": "^2.17.2"
|
"@lancedb/lancedb": "file:../dist",
|
||||||
|
"openai": "^4.29.2",
|
||||||
|
"sharp": "^0.33.5"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
|
"@biomejs/biome": "^1.7.3",
|
||||||
|
"@jest/globals": "^29.7.0",
|
||||||
|
"jest": "^29.7.0",
|
||||||
|
"jest-environment-node-single-context": "^29.4.0",
|
||||||
|
"ts-jest": "^29.2.5",
|
||||||
"typescript": "^5.5.4"
|
"typescript": "^5.5.4"
|
||||||
},
|
|
||||||
"compilerOptions": {
|
|
||||||
"target": "ESNext",
|
|
||||||
"module": "ESNext",
|
|
||||||
"moduleResolution": "Node",
|
|
||||||
"strict": true,
|
|
||||||
"esModuleInterop": true,
|
|
||||||
"skipLibCheck": true,
|
|
||||||
"forceConsistentCasingInFileNames": true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
42
nodejs/examples/search.test.ts
Normal file
42
nodejs/examples/search.test.ts
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
import { expect, test } from "@jest/globals";
|
||||||
|
// --8<-- [start:import]
|
||||||
|
import * as lancedb from "@lancedb/lancedb";
|
||||||
|
// --8<-- [end:import]
|
||||||
|
import { withTempDirectory } from "./util.ts";
|
||||||
|
|
||||||
|
test("full text search", async () => {
|
||||||
|
await withTempDirectory(async (databaseDir) => {
|
||||||
|
{
|
||||||
|
const db = await lancedb.connect(databaseDir);
|
||||||
|
|
||||||
|
const data = Array.from({ length: 10_000 }, (_, i) => ({
|
||||||
|
vector: Array(128).fill(i),
|
||||||
|
id: `${i}`,
|
||||||
|
content: "",
|
||||||
|
longId: `${i}`,
|
||||||
|
}));
|
||||||
|
|
||||||
|
await db.createTable("my_vectors", data);
|
||||||
|
}
|
||||||
|
|
||||||
|
// --8<-- [start:search1]
|
||||||
|
const db = await lancedb.connect(databaseDir);
|
||||||
|
const tbl = await db.openTable("my_vectors");
|
||||||
|
|
||||||
|
const results1 = await tbl.search(Array(128).fill(1.2)).limit(10).toArray();
|
||||||
|
// --8<-- [end:search1]
|
||||||
|
expect(results1.length).toBe(10);
|
||||||
|
|
||||||
|
// --8<-- [start:search2]
|
||||||
|
const results2 = await (
|
||||||
|
tbl.search(Array(128).fill(1.2)) as lancedb.VectorQuery
|
||||||
|
)
|
||||||
|
.distanceType("cosine")
|
||||||
|
.limit(10)
|
||||||
|
.toArray();
|
||||||
|
// --8<-- [end:search2]
|
||||||
|
expect(results2.length).toBe(10);
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -1,38 +0,0 @@
|
|||||||
// --8<-- [end:import]
|
|
||||||
import * as fs from "node:fs";
|
|
||||||
// --8<-- [start:import]
|
|
||||||
import * as lancedb from "@lancedb/lancedb";
|
|
||||||
|
|
||||||
async function setup() {
|
|
||||||
fs.rmSync("data/sample-lancedb", { recursive: true, force: true });
|
|
||||||
const db = await lancedb.connect("data/sample-lancedb");
|
|
||||||
|
|
||||||
const data = Array.from({ length: 10_000 }, (_, i) => ({
|
|
||||||
vector: Array(1536).fill(i),
|
|
||||||
id: `${i}`,
|
|
||||||
content: "",
|
|
||||||
longId: `${i}`,
|
|
||||||
}));
|
|
||||||
|
|
||||||
await db.createTable("my_vectors", data);
|
|
||||||
}
|
|
||||||
|
|
||||||
await setup();
|
|
||||||
|
|
||||||
// --8<-- [start:search1]
|
|
||||||
const db = await lancedb.connect("data/sample-lancedb");
|
|
||||||
const tbl = await db.openTable("my_vectors");
|
|
||||||
|
|
||||||
const _results1 = await tbl.search(Array(1536).fill(1.2)).limit(10).toArray();
|
|
||||||
// --8<-- [end:search1]
|
|
||||||
|
|
||||||
// --8<-- [start:search2]
|
|
||||||
const _results2 = await tbl
|
|
||||||
.search(Array(1536).fill(1.2))
|
|
||||||
.distanceType("cosine")
|
|
||||||
.limit(10)
|
|
||||||
.toArray();
|
|
||||||
console.log(_results2);
|
|
||||||
// --8<-- [end:search2]
|
|
||||||
|
|
||||||
console.log("search: done");
|
|
||||||
@@ -1,50 +0,0 @@
|
|||||||
import * as lancedb from "@lancedb/lancedb";
|
|
||||||
|
|
||||||
import { LanceSchema, getRegistry } from "@lancedb/lancedb/embedding";
|
|
||||||
import { Utf8 } from "apache-arrow";
|
|
||||||
|
|
||||||
const db = await lancedb.connect("/tmp/db");
|
|
||||||
const func = await getRegistry().get("huggingface").create();
|
|
||||||
|
|
||||||
const facts = [
|
|
||||||
"Albert Einstein was a theoretical physicist.",
|
|
||||||
"The capital of France is Paris.",
|
|
||||||
"The Great Wall of China is one of the Seven Wonders of the World.",
|
|
||||||
"Python is a popular programming language.",
|
|
||||||
"Mount Everest is the highest mountain in the world.",
|
|
||||||
"Leonardo da Vinci painted the Mona Lisa.",
|
|
||||||
"Shakespeare wrote Hamlet.",
|
|
||||||
"The human body has 206 bones.",
|
|
||||||
"The speed of light is approximately 299,792 kilometers per second.",
|
|
||||||
"Water boils at 100 degrees Celsius.",
|
|
||||||
"The Earth orbits the Sun.",
|
|
||||||
"The Pyramids of Giza are located in Egypt.",
|
|
||||||
"Coffee is one of the most popular beverages in the world.",
|
|
||||||
"Tokyo is the capital city of Japan.",
|
|
||||||
"Photosynthesis is the process by which plants make their food.",
|
|
||||||
"The Pacific Ocean is the largest ocean on Earth.",
|
|
||||||
"Mozart was a prolific composer of classical music.",
|
|
||||||
"The Internet is a global network of computers.",
|
|
||||||
"Basketball is a sport played with a ball and a hoop.",
|
|
||||||
"The first computer virus was created in 1983.",
|
|
||||||
"Artificial neural networks are inspired by the human brain.",
|
|
||||||
"Deep learning is a subset of machine learning.",
|
|
||||||
"IBM's Watson won Jeopardy! in 2011.",
|
|
||||||
"The first computer programmer was Ada Lovelace.",
|
|
||||||
"The first chatbot was ELIZA, created in the 1960s.",
|
|
||||||
].map((text) => ({ text }));
|
|
||||||
|
|
||||||
const factsSchema = LanceSchema({
|
|
||||||
text: func.sourceField(new Utf8()),
|
|
||||||
vector: func.vectorField(),
|
|
||||||
});
|
|
||||||
|
|
||||||
const tbl = await db.createTable("facts", facts, {
|
|
||||||
mode: "overwrite",
|
|
||||||
schema: factsSchema,
|
|
||||||
});
|
|
||||||
|
|
||||||
const query = "How many bones are in the human body?";
|
|
||||||
const actual = await tbl.search(query).limit(1).toArray();
|
|
||||||
|
|
||||||
console.log("Answer: ", actual[0]["text"]);
|
|
||||||
59
nodejs/examples/sentence-transformers.test.ts
Normal file
59
nodejs/examples/sentence-transformers.test.ts
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
import { expect, test } from "@jest/globals";
|
||||||
|
import { withTempDirectory } from "./util.ts";
|
||||||
|
|
||||||
|
import * as lancedb from "@lancedb/lancedb";
|
||||||
|
import "@lancedb/lancedb/embedding/transformers";
|
||||||
|
import { LanceSchema, getRegistry } from "@lancedb/lancedb/embedding";
|
||||||
|
import { Utf8 } from "apache-arrow";
|
||||||
|
|
||||||
|
test("full text search", async () => {
|
||||||
|
await withTempDirectory(async (databaseDir) => {
|
||||||
|
const db = await lancedb.connect(databaseDir);
|
||||||
|
const func = await getRegistry().get("huggingface").create();
|
||||||
|
|
||||||
|
const facts = [
|
||||||
|
"Albert Einstein was a theoretical physicist.",
|
||||||
|
"The capital of France is Paris.",
|
||||||
|
"The Great Wall of China is one of the Seven Wonders of the World.",
|
||||||
|
"Python is a popular programming language.",
|
||||||
|
"Mount Everest is the highest mountain in the world.",
|
||||||
|
"Leonardo da Vinci painted the Mona Lisa.",
|
||||||
|
"Shakespeare wrote Hamlet.",
|
||||||
|
"The human body has 206 bones.",
|
||||||
|
"The speed of light is approximately 299,792 kilometers per second.",
|
||||||
|
"Water boils at 100 degrees Celsius.",
|
||||||
|
"The Earth orbits the Sun.",
|
||||||
|
"The Pyramids of Giza are located in Egypt.",
|
||||||
|
"Coffee is one of the most popular beverages in the world.",
|
||||||
|
"Tokyo is the capital city of Japan.",
|
||||||
|
"Photosynthesis is the process by which plants make their food.",
|
||||||
|
"The Pacific Ocean is the largest ocean on Earth.",
|
||||||
|
"Mozart was a prolific composer of classical music.",
|
||||||
|
"The Internet is a global network of computers.",
|
||||||
|
"Basketball is a sport played with a ball and a hoop.",
|
||||||
|
"The first computer virus was created in 1983.",
|
||||||
|
"Artificial neural networks are inspired by the human brain.",
|
||||||
|
"Deep learning is a subset of machine learning.",
|
||||||
|
"IBM's Watson won Jeopardy! in 2011.",
|
||||||
|
"The first computer programmer was Ada Lovelace.",
|
||||||
|
"The first chatbot was ELIZA, created in the 1960s.",
|
||||||
|
].map((text) => ({ text }));
|
||||||
|
|
||||||
|
const factsSchema = LanceSchema({
|
||||||
|
text: func.sourceField(new Utf8()),
|
||||||
|
vector: func.vectorField(),
|
||||||
|
});
|
||||||
|
|
||||||
|
const tbl = await db.createTable("facts", facts, {
|
||||||
|
mode: "overwrite",
|
||||||
|
schema: factsSchema,
|
||||||
|
});
|
||||||
|
|
||||||
|
const query = "How many bones are in the human body?";
|
||||||
|
const actual = await tbl.search(query).limit(1).toArray();
|
||||||
|
|
||||||
|
expect(actual[0]["text"]).toBe("The human body has 206 bones.");
|
||||||
|
});
|
||||||
|
});
|
||||||
17
nodejs/examples/tsconfig.json
Normal file
17
nodejs/examples/tsconfig.json
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
{
|
||||||
|
"include": ["*.test.ts"],
|
||||||
|
"compilerOptions": {
|
||||||
|
"target": "es2022",
|
||||||
|
"module": "NodeNext",
|
||||||
|
"declaration": true,
|
||||||
|
"outDir": "./dist",
|
||||||
|
"strict": true,
|
||||||
|
"allowJs": true,
|
||||||
|
"resolveJsonModule": true,
|
||||||
|
"emitDecoratorMetadata": true,
|
||||||
|
"experimentalDecorators": true,
|
||||||
|
"moduleResolution": "NodeNext",
|
||||||
|
"allowImportingTsExtensions": true,
|
||||||
|
"emitDeclarationOnly": true
|
||||||
|
}
|
||||||
|
}
|
||||||
16
nodejs/examples/util.ts
Normal file
16
nodejs/examples/util.ts
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
import * as fs from "fs";
|
||||||
|
import { tmpdir } from "os";
|
||||||
|
import * as path from "path";
|
||||||
|
|
||||||
|
export async function withTempDirectory(
|
||||||
|
fn: (tempDir: string) => Promise<void>,
|
||||||
|
) {
|
||||||
|
const tmpDirPath = fs.mkdtempSync(path.join(tmpdir(), "temp-dir-"));
|
||||||
|
try {
|
||||||
|
await fn(tmpDirPath);
|
||||||
|
} finally {
|
||||||
|
fs.rmSync(tmpDirPath, { recursive: true });
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,4 +4,5 @@ module.exports = {
|
|||||||
testEnvironment: "node",
|
testEnvironment: "node",
|
||||||
moduleDirectories: ["node_modules", "./dist"],
|
moduleDirectories: ["node_modules", "./dist"],
|
||||||
moduleFileExtensions: ["js", "ts"],
|
moduleFileExtensions: ["js", "ts"],
|
||||||
|
modulePathIgnorePatterns: ["<rootDir>/examples/"],
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -47,8 +47,8 @@ export class TransformersEmbeddingFunction extends EmbeddingFunction<
|
|||||||
string,
|
string,
|
||||||
Partial<XenovaTransformerOptions>
|
Partial<XenovaTransformerOptions>
|
||||||
> {
|
> {
|
||||||
#model?: import("@xenova/transformers").PreTrainedModel;
|
#model?: import("@huggingface/transformers").PreTrainedModel;
|
||||||
#tokenizer?: import("@xenova/transformers").PreTrainedTokenizer;
|
#tokenizer?: import("@huggingface/transformers").PreTrainedTokenizer;
|
||||||
#modelName: XenovaTransformerOptions["model"];
|
#modelName: XenovaTransformerOptions["model"];
|
||||||
#initialized = false;
|
#initialized = false;
|
||||||
#tokenizerOptions: XenovaTransformerOptions["tokenizerOptions"];
|
#tokenizerOptions: XenovaTransformerOptions["tokenizerOptions"];
|
||||||
@@ -92,18 +92,19 @@ export class TransformersEmbeddingFunction extends EmbeddingFunction<
|
|||||||
try {
|
try {
|
||||||
// SAFETY:
|
// SAFETY:
|
||||||
// since typescript transpiles `import` to `require`, we need to do this in an unsafe way
|
// since typescript transpiles `import` to `require`, we need to do this in an unsafe way
|
||||||
// We can't use `require` because `@xenova/transformers` is an ESM module
|
// We can't use `require` because `@huggingface/transformers` is an ESM module
|
||||||
// and we can't use `import` directly because typescript will transpile it to `require`.
|
// and we can't use `import` directly because typescript will transpile it to `require`.
|
||||||
// and we want to remain compatible with both ESM and CJS modules
|
// and we want to remain compatible with both ESM and CJS modules
|
||||||
// so we use `eval` to bypass typescript for this specific import.
|
// so we use `eval` to bypass typescript for this specific import.
|
||||||
transformers = await eval('import("@xenova/transformers")');
|
transformers = await eval('import("@huggingface/transformers")');
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
throw new Error(`error loading @xenova/transformers\nReason: ${e}`);
|
throw new Error(`error loading @huggingface/transformers\nReason: ${e}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
this.#model = await transformers.AutoModel.from_pretrained(
|
this.#model = await transformers.AutoModel.from_pretrained(
|
||||||
this.#modelName,
|
this.#modelName,
|
||||||
|
{ dtype: "fp32" },
|
||||||
);
|
);
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
@@ -128,7 +129,8 @@ export class TransformersEmbeddingFunction extends EmbeddingFunction<
|
|||||||
} else {
|
} else {
|
||||||
const config = this.#model!.config;
|
const config = this.#model!.config;
|
||||||
|
|
||||||
const ndims = config["hidden_size"];
|
// biome-ignore lint/style/useNamingConvention: we don't control this name.
|
||||||
|
const ndims = (config as unknown as { hidden_size: number }).hidden_size;
|
||||||
if (!ndims) {
|
if (!ndims) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
"hidden_size not found in model config, you may need to manually specify the embedding dimensions. ",
|
"hidden_size not found in model config, you may need to manually specify the embedding dimensions. ",
|
||||||
@@ -183,7 +185,7 @@ export class TransformersEmbeddingFunction extends EmbeddingFunction<
|
|||||||
}
|
}
|
||||||
|
|
||||||
const tensorDiv = (
|
const tensorDiv = (
|
||||||
src: import("@xenova/transformers").Tensor,
|
src: import("@huggingface/transformers").Tensor,
|
||||||
divBy: number,
|
divBy: number,
|
||||||
) => {
|
) => {
|
||||||
for (let i = 0; i < src.data.length; ++i) {
|
for (let i = 0; i < src.data.length; ++i) {
|
||||||
|
|||||||
@@ -239,6 +239,29 @@ export class QueryBase<NativeQueryType extends NativeQuery | NativeVectorQuery>
|
|||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Skip searching un-indexed data. This can make search faster, but will miss
|
||||||
|
* any data that is not yet indexed.
|
||||||
|
*
|
||||||
|
* Use {@link lancedb.Table#optimize} to index all un-indexed data.
|
||||||
|
*/
|
||||||
|
fastSearch(): this {
|
||||||
|
this.doCall((inner: NativeQueryType) => inner.fastSearch());
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Whether to return the row id in the results.
|
||||||
|
*
|
||||||
|
* This column can be used to match results between different queries. For
|
||||||
|
* example, to match results from a full text search and a vector search in
|
||||||
|
* order to perform hybrid search.
|
||||||
|
*/
|
||||||
|
withRowId(): this {
|
||||||
|
this.doCall((inner: NativeQueryType) => inner.withRowId());
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
protected nativeExecute(
|
protected nativeExecute(
|
||||||
options?: Partial<QueryExecutionOptions>,
|
options?: Partial<QueryExecutionOptions>,
|
||||||
): Promise<NativeBatchIterator> {
|
): Promise<NativeBatchIterator> {
|
||||||
@@ -469,6 +492,42 @@ export class VectorQuery extends QueryBase<NativeVectorQuery> {
|
|||||||
super.doCall((inner) => inner.bypassVectorIndex());
|
super.doCall((inner) => inner.bypassVectorIndex());
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Add a query vector to the search
|
||||||
|
*
|
||||||
|
* This method can be called multiple times to add multiple query vectors
|
||||||
|
* to the search. If multiple query vectors are added, then they will be searched
|
||||||
|
* in parallel, and the results will be concatenated. A column called `query_index`
|
||||||
|
* will be added to indicate the index of the query vector that produced the result.
|
||||||
|
*
|
||||||
|
* Performance wise, this is equivalent to running multiple queries concurrently.
|
||||||
|
*/
|
||||||
|
addQueryVector(vector: IntoVector): VectorQuery {
|
||||||
|
if (vector instanceof Promise) {
|
||||||
|
const res = (async () => {
|
||||||
|
try {
|
||||||
|
const v = await vector;
|
||||||
|
const arr = Float32Array.from(v);
|
||||||
|
//
|
||||||
|
// biome-ignore lint/suspicious/noExplicitAny: we need to get the `inner`, but js has no package scoping
|
||||||
|
const value: any = this.addQueryVector(arr);
|
||||||
|
const inner = value.inner as
|
||||||
|
| NativeVectorQuery
|
||||||
|
| Promise<NativeVectorQuery>;
|
||||||
|
return inner;
|
||||||
|
} catch (e) {
|
||||||
|
return Promise.reject(e);
|
||||||
|
}
|
||||||
|
})();
|
||||||
|
return new VectorQuery(res);
|
||||||
|
} else {
|
||||||
|
super.doCall((inner) => {
|
||||||
|
inner.addQueryVector(Float32Array.from(vector));
|
||||||
|
});
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** A builder for LanceDB queries. */
|
/** A builder for LanceDB queries. */
|
||||||
@@ -548,4 +607,9 @@ export class Query extends QueryBase<NativeQuery> {
|
|||||||
return new VectorQuery(vectorQuery);
|
return new VectorQuery(vectorQuery);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nearestToText(query: string, columns?: string[]): Query {
|
||||||
|
this.doCall((inner) => inner.fullTextSearch(query, columns));
|
||||||
|
return this;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-darwin-arm64",
|
"name": "@lancedb/lancedb-darwin-arm64",
|
||||||
"version": "0.11.1-beta.1",
|
"version": "0.13.0-beta.1",
|
||||||
"os": ["darwin"],
|
"os": ["darwin"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.darwin-arm64.node",
|
"main": "lancedb.darwin-arm64.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-darwin-x64",
|
"name": "@lancedb/lancedb-darwin-x64",
|
||||||
"version": "0.11.1-beta.1",
|
"version": "0.13.0-beta.1",
|
||||||
"os": ["darwin"],
|
"os": ["darwin"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.darwin-x64.node",
|
"main": "lancedb.darwin-x64.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||||
"version": "0.11.1-beta.1",
|
"version": "0.13.0-beta.1",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.linux-arm64-gnu.node",
|
"main": "lancedb.linux-arm64-gnu.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||||
"version": "0.11.1-beta.1",
|
"version": "0.13.0-beta.1",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.linux-x64-gnu.node",
|
"main": "lancedb.linux-x64-gnu.node",
|
||||||
|
|||||||
3
nodejs/npm/win32-arm64-msvc/README.md
Normal file
3
nodejs/npm/win32-arm64-msvc/README.md
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# `@lancedb/lancedb-win32-arm64-msvc`
|
||||||
|
|
||||||
|
This is the **aarch64-pc-windows-msvc** binary for `@lancedb/lancedb`
|
||||||
18
nodejs/npm/win32-arm64-msvc/package.json
Normal file
18
nodejs/npm/win32-arm64-msvc/package.json
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
{
|
||||||
|
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||||
|
"version": "0.13.0-beta.1",
|
||||||
|
"os": [
|
||||||
|
"win32"
|
||||||
|
],
|
||||||
|
"cpu": [
|
||||||
|
"arm64"
|
||||||
|
],
|
||||||
|
"main": "lancedb.win32-arm64-msvc.node",
|
||||||
|
"files": [
|
||||||
|
"lancedb.win32-arm64-msvc.node"
|
||||||
|
],
|
||||||
|
"license": "Apache 2.0",
|
||||||
|
"engines": {
|
||||||
|
"node": ">= 18"
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||||
"version": "0.11.1-beta.1",
|
"version": "0.13.0-beta.1",
|
||||||
"os": ["win32"],
|
"os": ["win32"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.win32-x64-msvc.node",
|
"main": "lancedb.win32-x64-msvc.node",
|
||||||
|
|||||||
1432
nodejs/package-lock.json
generated
1432
nodejs/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -10,7 +10,7 @@
|
|||||||
"vector database",
|
"vector database",
|
||||||
"ann"
|
"ann"
|
||||||
],
|
],
|
||||||
"version": "0.11.1-beta.1",
|
"version": "0.13.0-beta.1",
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
"exports": {
|
"exports": {
|
||||||
".": "./dist/index.js",
|
".": "./dist/index.js",
|
||||||
@@ -85,7 +85,7 @@
|
|||||||
"reflect-metadata": "^0.2.2"
|
"reflect-metadata": "^0.2.2"
|
||||||
},
|
},
|
||||||
"optionalDependencies": {
|
"optionalDependencies": {
|
||||||
"@xenova/transformers": ">=2.17 < 3",
|
"@huggingface/transformers": "^3.0.2",
|
||||||
"openai": "^4.29.2"
|
"openai": "^4.29.2"
|
||||||
},
|
},
|
||||||
"peerDependencies": {
|
"peerDependencies": {
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ pub struct OpenTableOptions {
|
|||||||
#[napi::module_init]
|
#[napi::module_init]
|
||||||
fn init() {
|
fn init() {
|
||||||
let env = Env::new()
|
let env = Env::new()
|
||||||
.filter_or("LANCEDB_LOG", "trace")
|
.filter_or("LANCEDB_LOG", "warn")
|
||||||
.write_style("LANCEDB_LOG_STYLE");
|
.write_style("LANCEDB_LOG_STYLE");
|
||||||
env_logger::init_from_env(env);
|
env_logger::init_from_env(env);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -80,6 +80,16 @@ impl Query {
|
|||||||
Ok(VectorQuery { inner })
|
Ok(VectorQuery { inner })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub fn fast_search(&mut self) {
|
||||||
|
self.inner = self.inner.clone().fast_search();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub fn with_row_id(&mut self) {
|
||||||
|
self.inner = self.inner.clone().with_row_id();
|
||||||
|
}
|
||||||
|
|
||||||
#[napi(catch_unwind)]
|
#[napi(catch_unwind)]
|
||||||
pub async fn execute(
|
pub async fn execute(
|
||||||
&self,
|
&self,
|
||||||
@@ -125,6 +135,16 @@ impl VectorQuery {
|
|||||||
self.inner = self.inner.clone().column(&column);
|
self.inner = self.inner.clone().column(&column);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub fn add_query_vector(&mut self, vector: Float32Array) -> Result<()> {
|
||||||
|
self.inner = self
|
||||||
|
.inner
|
||||||
|
.clone()
|
||||||
|
.add_query_vector(vector.as_ref())
|
||||||
|
.default_error()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
pub fn distance_type(&mut self, distance_type: String) -> napi::Result<()> {
|
pub fn distance_type(&mut self, distance_type: String) -> napi::Result<()> {
|
||||||
let distance_type = parse_distance_type(distance_type)?;
|
let distance_type = parse_distance_type(distance_type)?;
|
||||||
@@ -183,6 +203,16 @@ impl VectorQuery {
|
|||||||
self.inner = self.inner.clone().offset(offset as usize);
|
self.inner = self.inner.clone().offset(offset as usize);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub fn fast_search(&mut self) {
|
||||||
|
self.inner = self.inner.clone().fast_search();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub fn with_row_id(&mut self) {
|
||||||
|
self.inner = self.inner.clone().with_row_id();
|
||||||
|
}
|
||||||
|
|
||||||
#[napi(catch_unwind)]
|
#[napi(catch_unwind)]
|
||||||
pub async fn execute(
|
pub async fn execute(
|
||||||
&self,
|
&self,
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
"experimentalDecorators": true,
|
"experimentalDecorators": true,
|
||||||
"moduleResolution": "Node"
|
"moduleResolution": "Node"
|
||||||
},
|
},
|
||||||
"exclude": ["./dist/*"],
|
"exclude": ["./dist/*", "./examples/*"],
|
||||||
"typedocOptions": {
|
"typedocOptions": {
|
||||||
"entryPoints": ["lancedb/index.ts"],
|
"entryPoints": ["lancedb/index.ts"],
|
||||||
"out": "../docs/src/javascript/",
|
"out": "../docs/src/javascript/",
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[tool.bumpversion]
|
[tool.bumpversion]
|
||||||
current_version = "0.15.0"
|
current_version = "0.16.0-beta.1"
|
||||||
parse = """(?x)
|
parse = """(?x)
|
||||||
(?P<major>0|[1-9]\\d*)\\.
|
(?P<major>0|[1-9]\\d*)\\.
|
||||||
(?P<minor>0|[1-9]\\d*)\\.
|
(?P<minor>0|[1-9]\\d*)\\.
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb-python"
|
name = "lancedb-python"
|
||||||
version = "0.15.0"
|
version = "0.16.0-beta.1"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
description = "Python bindings for LanceDB"
|
description = "Python bindings for LanceDB"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
|||||||
@@ -3,13 +3,11 @@ name = "lancedb"
|
|||||||
# version in Cargo.toml
|
# version in Cargo.toml
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"deprecation",
|
"deprecation",
|
||||||
"pylance==0.19.1",
|
"nest-asyncio~=1.0",
|
||||||
"requests>=2.31.0",
|
"pylance==0.19.2",
|
||||||
"tqdm>=4.27.0",
|
"tqdm>=4.27.0",
|
||||||
"pydantic>=1.10",
|
"pydantic>=1.10",
|
||||||
"attrs>=21.3.0",
|
|
||||||
"packaging",
|
"packaging",
|
||||||
"cachetools",
|
|
||||||
"overrides>=0.7",
|
"overrides>=0.7",
|
||||||
]
|
]
|
||||||
description = "lancedb"
|
description = "lancedb"
|
||||||
@@ -61,6 +59,7 @@ dev = ["ruff", "pre-commit"]
|
|||||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||||
clip = ["torch", "pillow", "open-clip"]
|
clip = ["torch", "pillow", "open-clip"]
|
||||||
embeddings = [
|
embeddings = [
|
||||||
|
"requests>=2.31.0",
|
||||||
"openai>=1.6.1",
|
"openai>=1.6.1",
|
||||||
"sentence-transformers",
|
"sentence-transformers",
|
||||||
"torch",
|
"torch",
|
||||||
|
|||||||
@@ -19,12 +19,10 @@ from typing import Dict, Optional, Union, Any
|
|||||||
|
|
||||||
__version__ = importlib.metadata.version("lancedb")
|
__version__ = importlib.metadata.version("lancedb")
|
||||||
|
|
||||||
from lancedb.remote import ClientConfig
|
|
||||||
|
|
||||||
from ._lancedb import connect as lancedb_connect
|
from ._lancedb import connect as lancedb_connect
|
||||||
from .common import URI, sanitize_uri
|
from .common import URI, sanitize_uri
|
||||||
from .db import AsyncConnection, DBConnection, LanceDBConnection
|
from .db import AsyncConnection, DBConnection, LanceDBConnection
|
||||||
from .remote.db import RemoteDBConnection
|
from .remote import ClientConfig
|
||||||
from .schema import vector
|
from .schema import vector
|
||||||
from .table import AsyncTable
|
from .table import AsyncTable
|
||||||
|
|
||||||
@@ -37,6 +35,7 @@ def connect(
|
|||||||
host_override: Optional[str] = None,
|
host_override: Optional[str] = None,
|
||||||
read_consistency_interval: Optional[timedelta] = None,
|
read_consistency_interval: Optional[timedelta] = None,
|
||||||
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
|
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
|
||||||
|
client_config: Union[ClientConfig, Dict[str, Any], None] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> DBConnection:
|
) -> DBConnection:
|
||||||
"""Connect to a LanceDB database.
|
"""Connect to a LanceDB database.
|
||||||
@@ -64,14 +63,10 @@ def connect(
|
|||||||
the last check, then the table will be checked for updates. Note: this
|
the last check, then the table will be checked for updates. Note: this
|
||||||
consistency only applies to read operations. Write operations are
|
consistency only applies to read operations. Write operations are
|
||||||
always consistent.
|
always consistent.
|
||||||
request_thread_pool: int or ThreadPoolExecutor, optional
|
client_config: ClientConfig or dict, optional
|
||||||
The thread pool to use for making batch requests to the LanceDB Cloud API.
|
Configuration options for the LanceDB Cloud HTTP client. If a dict, then
|
||||||
If an integer, then a ThreadPoolExecutor will be created with that
|
the keys are the attributes of the ClientConfig class. If None, then the
|
||||||
number of threads. If None, then a ThreadPoolExecutor will be created
|
default configuration is used.
|
||||||
with the default number of threads. If a ThreadPoolExecutor, then that
|
|
||||||
executor will be used for making requests. This is for LanceDB Cloud
|
|
||||||
only and is only used when making batch requests (i.e., passing in
|
|
||||||
multiple queries to the search method at once).
|
|
||||||
|
|
||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
@@ -94,6 +89,8 @@ def connect(
|
|||||||
conn : DBConnection
|
conn : DBConnection
|
||||||
A connection to a LanceDB database.
|
A connection to a LanceDB database.
|
||||||
"""
|
"""
|
||||||
|
from .remote.db import RemoteDBConnection
|
||||||
|
|
||||||
if isinstance(uri, str) and uri.startswith("db://"):
|
if isinstance(uri, str) and uri.startswith("db://"):
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
api_key = os.environ.get("LANCEDB_API_KEY")
|
api_key = os.environ.get("LANCEDB_API_KEY")
|
||||||
@@ -106,7 +103,9 @@ def connect(
|
|||||||
api_key,
|
api_key,
|
||||||
region,
|
region,
|
||||||
host_override,
|
host_override,
|
||||||
|
# TODO: remove this (deprecation warning downstream)
|
||||||
request_thread_pool=request_thread_pool,
|
request_thread_pool=request_thread_pool,
|
||||||
|
client_config=client_config,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -36,6 +36,8 @@ class Connection(object):
|
|||||||
data_storage_version: Optional[str] = None,
|
data_storage_version: Optional[str] = None,
|
||||||
enable_v2_manifest_paths: Optional[bool] = None,
|
enable_v2_manifest_paths: Optional[bool] = None,
|
||||||
) -> Table: ...
|
) -> Table: ...
|
||||||
|
async def rename_table(self, old_name: str, new_name: str) -> None: ...
|
||||||
|
async def drop_table(self, name: str) -> None: ...
|
||||||
|
|
||||||
class Table:
|
class Table:
|
||||||
def name(self) -> str: ...
|
def name(self) -> str: ...
|
||||||
|
|||||||
@@ -817,6 +817,18 @@ class AsyncConnection(object):
|
|||||||
table = await self._inner.open_table(name, storage_options, index_cache_size)
|
table = await self._inner.open_table(name, storage_options, index_cache_size)
|
||||||
return AsyncTable(table)
|
return AsyncTable(table)
|
||||||
|
|
||||||
|
async def rename_table(self, old_name: str, new_name: str):
|
||||||
|
"""Rename a table in the database.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
old_name: str
|
||||||
|
The current name of the table.
|
||||||
|
new_name: str
|
||||||
|
The new name of the table.
|
||||||
|
"""
|
||||||
|
await self._inner.rename_table(old_name, new_name)
|
||||||
|
|
||||||
async def drop_table(self, name: str):
|
async def drop_table(self, name: str):
|
||||||
"""Drop a table from the database.
|
"""Drop a table from the database.
|
||||||
|
|
||||||
|
|||||||
@@ -27,3 +27,4 @@ from .imagebind import ImageBindEmbeddings
|
|||||||
from .utils import with_embeddings
|
from .utils import with_embeddings
|
||||||
from .jinaai import JinaEmbeddings
|
from .jinaai import JinaEmbeddings
|
||||||
from .watsonx import WatsonxEmbeddings
|
from .watsonx import WatsonxEmbeddings
|
||||||
|
from .voyageai import VoyageAIEmbeddingFunction
|
||||||
|
|||||||
@@ -13,7 +13,6 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import io
|
import io
|
||||||
import requests
|
|
||||||
import base64
|
import base64
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -226,6 +225,8 @@ class JinaEmbeddings(EmbeddingFunction):
|
|||||||
return [result["embedding"] for result in sorted_embeddings]
|
return [result["embedding"] for result in sorted_embeddings]
|
||||||
|
|
||||||
def _init_client(self):
|
def _init_client(self):
|
||||||
|
import requests
|
||||||
|
|
||||||
if JinaEmbeddings._session is None:
|
if JinaEmbeddings._session is None:
|
||||||
if self.api_key is None and os.environ.get("JINA_API_KEY") is None:
|
if self.api_key is None and os.environ.get("JINA_API_KEY") is None:
|
||||||
api_key_not_found_help("jina")
|
api_key_not_found_help("jina")
|
||||||
|
|||||||
@@ -1,15 +1,6 @@
|
|||||||
# Copyright (c) 2023. LanceDB Developers
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
#
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import json
|
import json
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
@@ -170,7 +161,7 @@ def register(name):
|
|||||||
return __REGISTRY__.get_instance().register(name)
|
return __REGISTRY__.get_instance().register(name)
|
||||||
|
|
||||||
|
|
||||||
def get_registry():
|
def get_registry() -> EmbeddingFunctionRegistry:
|
||||||
"""
|
"""
|
||||||
Utility function to get the global instance of the registry
|
Utility function to get the global instance of the registry
|
||||||
|
|
||||||
|
|||||||
127
python/python/lancedb/embeddings/voyageai.py
Normal file
127
python/python/lancedb/embeddings/voyageai.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
# Copyright (c) 2023. LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import ClassVar, List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ..util import attempt_import_or_raise
|
||||||
|
from .base import TextEmbeddingFunction
|
||||||
|
from .registry import register
|
||||||
|
from .utils import api_key_not_found_help, TEXT
|
||||||
|
|
||||||
|
|
||||||
|
@register("voyageai")
|
||||||
|
class VoyageAIEmbeddingFunction(TextEmbeddingFunction):
|
||||||
|
"""
|
||||||
|
An embedding function that uses the VoyageAI API
|
||||||
|
|
||||||
|
https://docs.voyageai.com/docs/embeddings
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name: str
|
||||||
|
The name of the model to use. List of acceptable models:
|
||||||
|
|
||||||
|
* voyage-3
|
||||||
|
* voyage-3-lite
|
||||||
|
* voyage-finance-2
|
||||||
|
* voyage-multilingual-2
|
||||||
|
* voyage-law-2
|
||||||
|
* voyage-code-2
|
||||||
|
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
import lancedb
|
||||||
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||||
|
|
||||||
|
voyageai = EmbeddingFunctionRegistry
|
||||||
|
.get_instance()
|
||||||
|
.get("voyageai")
|
||||||
|
.create(name="voyage-3")
|
||||||
|
|
||||||
|
class TextModel(LanceModel):
|
||||||
|
text: str = voyageai.SourceField()
|
||||||
|
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
||||||
|
|
||||||
|
data = [ { "text": "hello world" },
|
||||||
|
{ "text": "goodbye world" }]
|
||||||
|
|
||||||
|
db = lancedb.connect("~/.lancedb")
|
||||||
|
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||||
|
|
||||||
|
tbl.add(data)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
client: ClassVar = None
|
||||||
|
|
||||||
|
def ndims(self):
|
||||||
|
if self.name == "voyage-3-lite":
|
||||||
|
return 512
|
||||||
|
elif self.name == "voyage-code-2":
|
||||||
|
return 1536
|
||||||
|
elif self.name in [
|
||||||
|
"voyage-3",
|
||||||
|
"voyage-finance-2",
|
||||||
|
"voyage-multilingual-2",
|
||||||
|
"voyage-law-2",
|
||||||
|
]:
|
||||||
|
return 1024
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Model {self.name} not supported")
|
||||||
|
|
||||||
|
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||||
|
return self.compute_source_embeddings(query, input_type="query")
|
||||||
|
|
||||||
|
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||||
|
texts = self.sanitize_input(texts)
|
||||||
|
input_type = (
|
||||||
|
kwargs.get("input_type") or "document"
|
||||||
|
) # assume source input type if not passed by `compute_query_embeddings`
|
||||||
|
return self.generate_embeddings(texts, input_type=input_type)
|
||||||
|
|
||||||
|
def generate_embeddings(
|
||||||
|
self, texts: Union[List[str], np.ndarray], *args, **kwargs
|
||||||
|
) -> List[np.array]:
|
||||||
|
"""
|
||||||
|
Get the embeddings for the given texts
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
texts: list[str] or np.ndarray (of str)
|
||||||
|
The texts to embed
|
||||||
|
input_type: Optional[str]
|
||||||
|
|
||||||
|
truncation: Optional[bool]
|
||||||
|
"""
|
||||||
|
VoyageAIEmbeddingFunction._init_client()
|
||||||
|
rs = VoyageAIEmbeddingFunction.client.embed(
|
||||||
|
texts=texts, model=self.name, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return [emb for emb in rs.embeddings]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _init_client():
|
||||||
|
if VoyageAIEmbeddingFunction.client is None:
|
||||||
|
voyageai = attempt_import_or_raise("voyageai")
|
||||||
|
if os.environ.get("VOYAGE_API_KEY") is None:
|
||||||
|
api_key_not_found_help("voyageai")
|
||||||
|
VoyageAIEmbeddingFunction.client = voyageai.Client(
|
||||||
|
os.environ["VOYAGE_API_KEY"]
|
||||||
|
)
|
||||||
@@ -467,6 +467,8 @@ class IvfPq:
|
|||||||
|
|
||||||
The default value is 256.
|
The default value is 256.
|
||||||
"""
|
"""
|
||||||
|
if distance_type is not None:
|
||||||
|
distance_type = distance_type.lower()
|
||||||
self._inner = LanceDbIndex.ivf_pq(
|
self._inner = LanceDbIndex.ivf_pq(
|
||||||
distance_type=distance_type,
|
distance_type=distance_type,
|
||||||
num_partitions=num_partitions,
|
num_partitions=num_partitions,
|
||||||
|
|||||||
@@ -481,6 +481,7 @@ class LanceQueryBuilder(ABC):
|
|||||||
>>> plan = table.search(query).explain_plan(True)
|
>>> plan = table.search(query).explain_plan(True)
|
||||||
>>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
|
>>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
|
||||||
ProjectionExec: expr=[vector@0 as vector, _distance@2 as _distance]
|
ProjectionExec: expr=[vector@0 as vector, _distance@2 as _distance]
|
||||||
|
GlobalLimitExec: skip=0, fetch=10
|
||||||
FilterExec: _distance@2 IS NOT NULL
|
FilterExec: _distance@2 IS NOT NULL
|
||||||
SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], preserve_partitioning=[false]
|
SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], preserve_partitioning=[false]
|
||||||
KNNVectorDistance: metric=l2
|
KNNVectorDistance: metric=l2
|
||||||
@@ -500,7 +501,16 @@ class LanceQueryBuilder(ABC):
|
|||||||
nearest={
|
nearest={
|
||||||
"column": self._vector_column,
|
"column": self._vector_column,
|
||||||
"q": self._query,
|
"q": self._query,
|
||||||
|
"k": self._limit,
|
||||||
|
"metric": self._metric,
|
||||||
|
"nprobes": self._nprobes,
|
||||||
|
"refine_factor": self._refine_factor,
|
||||||
},
|
},
|
||||||
|
prefilter=self._prefilter,
|
||||||
|
filter=self._str_query,
|
||||||
|
limit=self._limit,
|
||||||
|
with_row_id=self._with_row_id,
|
||||||
|
offset=self._offset,
|
||||||
).explain_plan(verbose)
|
).explain_plan(verbose)
|
||||||
|
|
||||||
def vector(self, vector: Union[np.ndarray, list]) -> LanceQueryBuilder:
|
def vector(self, vector: Union[np.ndarray, list]) -> LanceQueryBuilder:
|
||||||
@@ -933,12 +943,16 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
|||||||
|
|
||||||
class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
||||||
def to_arrow(self) -> pa.Table:
|
def to_arrow(self) -> pa.Table:
|
||||||
ds = self._table.to_lance()
|
query = Query(
|
||||||
return ds.to_table(
|
|
||||||
columns=self._columns,
|
columns=self._columns,
|
||||||
filter=self._where,
|
filter=self._where,
|
||||||
limit=self._limit,
|
k=self._limit or 10,
|
||||||
|
with_row_id=self._with_row_id,
|
||||||
|
vector=[],
|
||||||
|
# not actually respected in remote query
|
||||||
|
offset=self._offset or 0,
|
||||||
)
|
)
|
||||||
|
return self._table._execute_query(query).read_all()
|
||||||
|
|
||||||
def rerank(self, reranker: Reranker) -> LanceEmptyQueryBuilder:
|
def rerank(self, reranker: Reranker) -> LanceEmptyQueryBuilder:
|
||||||
"""Rerank the results using the specified reranker.
|
"""Rerank the results using the specified reranker.
|
||||||
@@ -1315,6 +1329,48 @@ class AsyncQueryBase(object):
|
|||||||
self._inner.offset(offset)
|
self._inner.offset(offset)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def fast_search(self) -> AsyncQuery:
|
||||||
|
"""
|
||||||
|
Skip searching un-indexed data.
|
||||||
|
|
||||||
|
This can make queries faster, but will miss any data that has not been
|
||||||
|
indexed.
|
||||||
|
|
||||||
|
!!! tip
|
||||||
|
You can add new data into an existing index by calling
|
||||||
|
[AsyncTable.optimize][lancedb.table.AsyncTable.optimize].
|
||||||
|
"""
|
||||||
|
self._inner.fast_search()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_row_id(self) -> AsyncQuery:
|
||||||
|
"""
|
||||||
|
Include the _rowid column in the results.
|
||||||
|
"""
|
||||||
|
self._inner.with_row_id()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def postfilter(self) -> AsyncQuery:
|
||||||
|
"""
|
||||||
|
If this is called then filtering will happen after the search instead of
|
||||||
|
before.
|
||||||
|
By default filtering will be performed before the search. This is how
|
||||||
|
filtering is typically understood to work. This prefilter step does add some
|
||||||
|
additional latency. Creating a scalar index on the filter column(s) can
|
||||||
|
often improve this latency. However, sometimes a filter is too complex or
|
||||||
|
scalar indices cannot be applied to the column. In these cases postfiltering
|
||||||
|
can be used instead of prefiltering to improve latency.
|
||||||
|
Post filtering applies the filter to the results of the search. This
|
||||||
|
means we only run the filter on a much smaller set of data. However, it can
|
||||||
|
cause the query to return fewer than `limit` results (or even no results) if
|
||||||
|
none of the nearest results match the filter.
|
||||||
|
Post filtering happens during the "refine stage" (described in more detail in
|
||||||
|
@see {@link VectorQuery#refineFactor}). This means that setting a higher refine
|
||||||
|
factor can often help restore some of the results lost by post filtering.
|
||||||
|
"""
|
||||||
|
self._inner.postfilter()
|
||||||
|
return self
|
||||||
|
|
||||||
async def to_batches(
|
async def to_batches(
|
||||||
self, *, max_batch_length: Optional[int] = None
|
self, *, max_batch_length: Optional[int] = None
|
||||||
) -> AsyncRecordBatchReader:
|
) -> AsyncRecordBatchReader:
|
||||||
@@ -1439,7 +1495,7 @@ class AsyncQuery(AsyncQueryBase):
|
|||||||
return pa.array(vec)
|
return pa.array(vec)
|
||||||
|
|
||||||
def nearest_to(
|
def nearest_to(
|
||||||
self, query_vector: Optional[Union[VEC, Tuple]] = None
|
self, query_vector: Optional[Union[VEC, Tuple, List[VEC]]] = None
|
||||||
) -> AsyncVectorQuery:
|
) -> AsyncVectorQuery:
|
||||||
"""
|
"""
|
||||||
Find the nearest vectors to the given query vector.
|
Find the nearest vectors to the given query vector.
|
||||||
@@ -1477,7 +1533,27 @@ class AsyncQuery(AsyncQueryBase):
|
|||||||
|
|
||||||
Vector searches always have a [limit][]. If `limit` has not been called then
|
Vector searches always have a [limit][]. If `limit` has not been called then
|
||||||
a default `limit` of 10 will be used.
|
a default `limit` of 10 will be used.
|
||||||
|
|
||||||
|
Typically, a single vector is passed in as the query. However, you can also
|
||||||
|
pass in multiple vectors. This can be useful if you want to find the nearest
|
||||||
|
vectors to multiple query vectors. This is not expected to be faster than
|
||||||
|
making multiple queries concurrently; it is just a convenience method.
|
||||||
|
If multiple vectors are passed in then an additional column `query_index`
|
||||||
|
will be added to the results. This column will contain the index of the
|
||||||
|
query vector that the result is nearest to.
|
||||||
"""
|
"""
|
||||||
|
if (
|
||||||
|
isinstance(query_vector, list)
|
||||||
|
and len(query_vector) > 0
|
||||||
|
and not isinstance(query_vector[0], (float, int))
|
||||||
|
):
|
||||||
|
# multiple have been passed
|
||||||
|
query_vectors = [AsyncQuery._query_vec_to_array(v) for v in query_vector]
|
||||||
|
new_self = self._inner.nearest_to(query_vectors[0])
|
||||||
|
for v in query_vectors[1:]:
|
||||||
|
new_self.add_query_vector(v)
|
||||||
|
return AsyncVectorQuery(new_self)
|
||||||
|
else:
|
||||||
return AsyncVectorQuery(
|
return AsyncVectorQuery(
|
||||||
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector))
|
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector))
|
||||||
)
|
)
|
||||||
@@ -1618,30 +1694,6 @@ class AsyncVectorQuery(AsyncQueryBase):
|
|||||||
self._inner.distance_type(distance_type)
|
self._inner.distance_type(distance_type)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def postfilter(self) -> AsyncVectorQuery:
|
|
||||||
"""
|
|
||||||
If this is called then filtering will happen after the vector search instead of
|
|
||||||
before.
|
|
||||||
|
|
||||||
By default filtering will be performed before the vector search. This is how
|
|
||||||
filtering is typically understood to work. This prefilter step does add some
|
|
||||||
additional latency. Creating a scalar index on the filter column(s) can
|
|
||||||
often improve this latency. However, sometimes a filter is too complex or
|
|
||||||
scalar indices cannot be applied to the column. In these cases postfiltering
|
|
||||||
can be used instead of prefiltering to improve latency.
|
|
||||||
|
|
||||||
Post filtering applies the filter to the results of the vector search. This
|
|
||||||
means we only run the filter on a much smaller set of data. However, it can
|
|
||||||
cause the query to return fewer than `limit` results (or even no results) if
|
|
||||||
none of the nearest results match the filter.
|
|
||||||
|
|
||||||
Post filtering happens during the "refine stage" (described in more detail in
|
|
||||||
@see {@link VectorQuery#refineFactor}). This means that setting a higher refine
|
|
||||||
factor can often help restore some of the results lost by post filtering.
|
|
||||||
"""
|
|
||||||
self._inner.postfilter()
|
|
||||||
return self
|
|
||||||
|
|
||||||
def bypass_vector_index(self) -> AsyncVectorQuery:
|
def bypass_vector_index(self) -> AsyncVectorQuery:
|
||||||
"""
|
"""
|
||||||
If this is called then any vector index is skipped
|
If this is called then any vector index is skipped
|
||||||
|
|||||||
@@ -11,62 +11,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import abc
|
from dataclasses import dataclass, field
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import attrs
|
|
||||||
from lancedb import __version__
|
from lancedb import __version__
|
||||||
import pyarrow as pa
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from lancedb.common import VECTOR_COLUMN_NAME
|
__all__ = ["TimeoutConfig", "RetryConfig", "ClientConfig"]
|
||||||
|
|
||||||
__all__ = ["LanceDBClient", "VectorQuery", "VectorQueryResult"]
|
|
||||||
|
|
||||||
|
|
||||||
class VectorQuery(BaseModel):
|
|
||||||
# vector to search for
|
|
||||||
vector: List[float]
|
|
||||||
|
|
||||||
# sql filter to refine the query with
|
|
||||||
filter: Optional[str] = None
|
|
||||||
|
|
||||||
# top k results to return
|
|
||||||
k: int
|
|
||||||
|
|
||||||
# # metrics
|
|
||||||
_metric: str = "L2"
|
|
||||||
|
|
||||||
# which columns to return in the results
|
|
||||||
columns: Optional[List[str]] = None
|
|
||||||
|
|
||||||
# optional query parameters for tuning the results,
|
|
||||||
# e.g. `{"nprobes": "10", "refine_factor": "10"}`
|
|
||||||
nprobes: int = 10
|
|
||||||
|
|
||||||
refine_factor: Optional[int] = None
|
|
||||||
|
|
||||||
vector_column: str = VECTOR_COLUMN_NAME
|
|
||||||
|
|
||||||
fast_search: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
@attrs.define
|
|
||||||
class VectorQueryResult:
|
|
||||||
# for now the response is directly seralized into a pandas dataframe
|
|
||||||
tbl: pa.Table
|
|
||||||
|
|
||||||
def to_arrow(self) -> pa.Table:
|
|
||||||
return self.tbl
|
|
||||||
|
|
||||||
|
|
||||||
class LanceDBClient(abc.ABC):
|
|
||||||
@abc.abstractmethod
|
|
||||||
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
|
||||||
"""Query the LanceDB server for the given table and query."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -165,8 +116,8 @@ class RetryConfig:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ClientConfig:
|
class ClientConfig:
|
||||||
user_agent: str = f"LanceDB-Python-Client/{__version__}"
|
user_agent: str = f"LanceDB-Python-Client/{__version__}"
|
||||||
retry_config: Optional[RetryConfig] = None
|
retry_config: RetryConfig = field(default_factory=RetryConfig)
|
||||||
timeout_config: Optional[TimeoutConfig] = None
|
timeout_config: Optional[TimeoutConfig] = field(default_factory=TimeoutConfig)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if isinstance(self.retry_config, dict):
|
if isinstance(self.retry_config, dict):
|
||||||
|
|||||||
@@ -1,25 +0,0 @@
|
|||||||
# Copyright 2023 LanceDB Developers
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from typing import Iterable, Union
|
|
||||||
import pyarrow as pa
|
|
||||||
|
|
||||||
|
|
||||||
def to_ipc_binary(table: Union[pa.Table, Iterable[pa.RecordBatch]]) -> bytes:
|
|
||||||
"""Serialize a PyArrow Table to IPC binary."""
|
|
||||||
sink = pa.BufferOutputStream()
|
|
||||||
if isinstance(table, Iterable):
|
|
||||||
table = pa.Table.from_batches(table)
|
|
||||||
with pa.ipc.new_stream(sink, table.schema) as writer:
|
|
||||||
writer.write_table(table)
|
|
||||||
return sink.getvalue().to_pybytes()
|
|
||||||
@@ -1,269 +0,0 @@
|
|||||||
# Copyright 2023 LanceDB Developers
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
import functools
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
|
||||||
from urllib.parse import urljoin
|
|
||||||
|
|
||||||
import attrs
|
|
||||||
import pyarrow as pa
|
|
||||||
import requests
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from requests.adapters import HTTPAdapter
|
|
||||||
from urllib3 import Retry
|
|
||||||
|
|
||||||
from lancedb.common import Credential
|
|
||||||
from lancedb.remote import VectorQuery, VectorQueryResult
|
|
||||||
from lancedb.remote.connection_timeout import LanceDBClientHTTPAdapterFactory
|
|
||||||
from lancedb.remote.errors import LanceDBClientError
|
|
||||||
|
|
||||||
ARROW_STREAM_CONTENT_TYPE = "application/vnd.apache.arrow.stream"
|
|
||||||
|
|
||||||
|
|
||||||
def _check_not_closed(f):
|
|
||||||
@functools.wraps(f)
|
|
||||||
def wrapped(self, *args, **kwargs):
|
|
||||||
if self.closed:
|
|
||||||
raise ValueError("Connection is closed")
|
|
||||||
return f(self, *args, **kwargs)
|
|
||||||
|
|
||||||
return wrapped
|
|
||||||
|
|
||||||
|
|
||||||
def _read_ipc(resp: requests.Response) -> pa.Table:
|
|
||||||
resp_body = resp.content
|
|
||||||
with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader:
|
|
||||||
return reader.read_all()
|
|
||||||
|
|
||||||
|
|
||||||
@attrs.define(slots=False)
|
|
||||||
class RestfulLanceDBClient:
|
|
||||||
db_name: str
|
|
||||||
region: str
|
|
||||||
api_key: Credential
|
|
||||||
host_override: Optional[str] = attrs.field(default=None)
|
|
||||||
|
|
||||||
closed: bool = attrs.field(default=False, init=False)
|
|
||||||
|
|
||||||
connection_timeout: float = attrs.field(default=120.0, kw_only=True)
|
|
||||||
read_timeout: float = attrs.field(default=300.0, kw_only=True)
|
|
||||||
|
|
||||||
@functools.cached_property
|
|
||||||
def session(self) -> requests.Session:
|
|
||||||
sess = requests.Session()
|
|
||||||
|
|
||||||
retry_adapter_instance = retry_adapter(retry_adapter_options())
|
|
||||||
sess.mount(urljoin(self.url, "/v1/table/"), retry_adapter_instance)
|
|
||||||
|
|
||||||
adapter_class = LanceDBClientHTTPAdapterFactory()
|
|
||||||
sess.mount("https://", adapter_class())
|
|
||||||
return sess
|
|
||||||
|
|
||||||
@property
|
|
||||||
def url(self) -> str:
|
|
||||||
return (
|
|
||||||
self.host_override
|
|
||||||
or f"https://{self.db_name}.{self.region}.api.lancedb.com"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
|
||||||
self.close()
|
|
||||||
return False # Do not suppress exceptions
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
self.session.close()
|
|
||||||
self.closed = True
|
|
||||||
|
|
||||||
@functools.cached_property
|
|
||||||
def headers(self) -> Dict[str, str]:
|
|
||||||
headers = {
|
|
||||||
"x-api-key": self.api_key,
|
|
||||||
}
|
|
||||||
if self.region == "local": # Local test mode
|
|
||||||
headers["Host"] = f"{self.db_name}.{self.region}.api.lancedb.com"
|
|
||||||
if self.host_override:
|
|
||||||
headers["x-lancedb-database"] = self.db_name
|
|
||||||
return headers
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _check_status(resp: requests.Response):
|
|
||||||
# Leaving request id empty for now, as we'll be replacing this impl
|
|
||||||
# with the Rust one shortly.
|
|
||||||
if resp.status_code == 404:
|
|
||||||
raise LanceDBClientError(
|
|
||||||
f"Not found: {resp.text}", request_id="", status_code=404
|
|
||||||
)
|
|
||||||
elif 400 <= resp.status_code < 500:
|
|
||||||
raise LanceDBClientError(
|
|
||||||
f"Bad Request: {resp.status_code}, error: {resp.text}",
|
|
||||||
request_id="",
|
|
||||||
status_code=resp.status_code,
|
|
||||||
)
|
|
||||||
elif 500 <= resp.status_code < 600:
|
|
||||||
raise LanceDBClientError(
|
|
||||||
f"Internal Server Error: {resp.status_code}, error: {resp.text}",
|
|
||||||
request_id="",
|
|
||||||
status_code=resp.status_code,
|
|
||||||
)
|
|
||||||
elif resp.status_code != 200:
|
|
||||||
raise LanceDBClientError(
|
|
||||||
f"Unknown Error: {resp.status_code}, error: {resp.text}",
|
|
||||||
request_id="",
|
|
||||||
status_code=resp.status_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
@_check_not_closed
|
|
||||||
def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None):
|
|
||||||
"""Send a GET request and returns the deserialized response payload."""
|
|
||||||
if isinstance(params, BaseModel):
|
|
||||||
params: Dict[str, Any] = params.dict(exclude_none=True)
|
|
||||||
with self.session.get(
|
|
||||||
urljoin(self.url, uri),
|
|
||||||
params=params,
|
|
||||||
headers=self.headers,
|
|
||||||
timeout=(self.connection_timeout, self.read_timeout),
|
|
||||||
) as resp:
|
|
||||||
self._check_status(resp)
|
|
||||||
return resp.json()
|
|
||||||
|
|
||||||
@_check_not_closed
|
|
||||||
def post(
|
|
||||||
self,
|
|
||||||
uri: str,
|
|
||||||
data: Optional[Union[Dict[str, Any], BaseModel, bytes]] = None,
|
|
||||||
params: Optional[Dict[str, Any]] = None,
|
|
||||||
content_type: Optional[str] = None,
|
|
||||||
deserialize: Callable = lambda resp: resp.json(),
|
|
||||||
request_id: Optional[str] = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Send a POST request and returns the deserialized response payload.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
uri : str
|
|
||||||
The uri to send the POST request to.
|
|
||||||
data: Union[Dict[str, Any], BaseModel]
|
|
||||||
request_id: Optional[str]
|
|
||||||
Optional client side request id to be sent in the request headers.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if isinstance(data, BaseModel):
|
|
||||||
data: Dict[str, Any] = data.dict(exclude_none=True)
|
|
||||||
if isinstance(data, bytes):
|
|
||||||
req_kwargs = {"data": data}
|
|
||||||
else:
|
|
||||||
req_kwargs = {"json": data}
|
|
||||||
|
|
||||||
headers = self.headers.copy()
|
|
||||||
if content_type is not None:
|
|
||||||
headers["content-type"] = content_type
|
|
||||||
if request_id is not None:
|
|
||||||
headers["x-request-id"] = request_id
|
|
||||||
with self.session.post(
|
|
||||||
urljoin(self.url, uri),
|
|
||||||
headers=headers,
|
|
||||||
params=params,
|
|
||||||
timeout=(self.connection_timeout, self.read_timeout),
|
|
||||||
**req_kwargs,
|
|
||||||
) as resp:
|
|
||||||
self._check_status(resp)
|
|
||||||
return deserialize(resp)
|
|
||||||
|
|
||||||
@_check_not_closed
|
|
||||||
def list_tables(self, limit: int, page_token: Optional[str] = None) -> List[str]:
|
|
||||||
"""List all tables in the database."""
|
|
||||||
if page_token is None:
|
|
||||||
page_token = ""
|
|
||||||
json = self.get("/v1/table/", {"limit": limit, "page_token": page_token})
|
|
||||||
return json["tables"]
|
|
||||||
|
|
||||||
@_check_not_closed
|
|
||||||
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
|
||||||
"""Query a table."""
|
|
||||||
tbl = self.post(f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc)
|
|
||||||
return VectorQueryResult(tbl)
|
|
||||||
|
|
||||||
def mount_retry_adapter_for_table(self, table_name: str) -> None:
|
|
||||||
"""
|
|
||||||
Adds an http adapter to session that will retry retryable requests to the table.
|
|
||||||
"""
|
|
||||||
retry_options = retry_adapter_options(methods=["GET", "POST"])
|
|
||||||
retry_adapter_instance = retry_adapter(retry_options)
|
|
||||||
session = self.session
|
|
||||||
|
|
||||||
session.mount(
|
|
||||||
urljoin(self.url, f"/v1/table/{table_name}/query/"), retry_adapter_instance
|
|
||||||
)
|
|
||||||
session.mount(
|
|
||||||
urljoin(self.url, f"/v1/table/{table_name}/describe/"),
|
|
||||||
retry_adapter_instance,
|
|
||||||
)
|
|
||||||
session.mount(
|
|
||||||
urljoin(self.url, f"/v1/table/{table_name}/index/list/"),
|
|
||||||
retry_adapter_instance,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def retry_adapter_options(methods=["GET"]) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"retries": int(os.environ.get("LANCE_CLIENT_MAX_RETRIES", "3")),
|
|
||||||
"connect_retries": int(os.environ.get("LANCE_CLIENT_CONNECT_RETRIES", "3")),
|
|
||||||
"read_retries": int(os.environ.get("LANCE_CLIENT_READ_RETRIES", "3")),
|
|
||||||
"backoff_factor": float(
|
|
||||||
os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_FACTOR", "0.25")
|
|
||||||
),
|
|
||||||
"backoff_jitter": float(
|
|
||||||
os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_JITTER", "0.25")
|
|
||||||
),
|
|
||||||
"statuses": [
|
|
||||||
int(i.strip())
|
|
||||||
for i in os.environ.get(
|
|
||||||
"LANCE_CLIENT_RETRY_STATUSES", "429, 500, 502, 503"
|
|
||||||
).split(",")
|
|
||||||
],
|
|
||||||
"methods": methods,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def retry_adapter(options: Dict[str, Any]) -> HTTPAdapter:
|
|
||||||
total_retries = options["retries"]
|
|
||||||
connect_retries = options["connect_retries"]
|
|
||||||
read_retries = options["read_retries"]
|
|
||||||
backoff_factor = options["backoff_factor"]
|
|
||||||
backoff_jitter = options["backoff_jitter"]
|
|
||||||
statuses = options["statuses"]
|
|
||||||
methods = frozenset(options["methods"])
|
|
||||||
logging.debug(
|
|
||||||
f"Setting up retry adapter with {total_retries} retries," # noqa G003
|
|
||||||
+ f"connect retries {connect_retries}, read retries {read_retries},"
|
|
||||||
+ f"backoff factor {backoff_factor}, statuses {statuses}, "
|
|
||||||
+ f"methods {methods}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return HTTPAdapter(
|
|
||||||
max_retries=Retry(
|
|
||||||
total=total_retries,
|
|
||||||
connect=connect_retries,
|
|
||||||
read=read_retries,
|
|
||||||
backoff_factor=backoff_factor,
|
|
||||||
backoff_jitter=backoff_jitter,
|
|
||||||
status_forcelist=statuses,
|
|
||||||
allowed_methods=methods,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@@ -1,115 +0,0 @@
|
|||||||
# Copyright 2024 LanceDB Developers
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
# This module contains an adapter that will close connections if they have not been
|
|
||||||
# used before a certain timeout. This is necessary because some load balancers will
|
|
||||||
# close connections after a certain amount of time, but the request module may not yet
|
|
||||||
# have received the FIN/ACK and will try to reuse the connection.
|
|
||||||
#
|
|
||||||
# TODO some of the code here can be simplified if/when this PR is merged:
|
|
||||||
# https://github.com/urllib3/urllib3/pull/3275
|
|
||||||
|
|
||||||
import datetime
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
|
|
||||||
from requests.adapters import HTTPAdapter
|
|
||||||
from urllib3.connection import HTTPSConnection
|
|
||||||
from urllib3.connectionpool import HTTPSConnectionPool
|
|
||||||
from urllib3.poolmanager import PoolManager
|
|
||||||
|
|
||||||
|
|
||||||
def get_client_connection_timeout() -> int:
|
|
||||||
return int(os.environ.get("LANCE_CLIENT_CONNECTION_TIMEOUT", "300"))
|
|
||||||
|
|
||||||
|
|
||||||
class LanceDBHTTPSConnection(HTTPSConnection):
|
|
||||||
"""
|
|
||||||
HTTPSConnection that tracks the last time it was used.
|
|
||||||
"""
|
|
||||||
|
|
||||||
idle_timeout: datetime.timedelta
|
|
||||||
last_activity: datetime.datetime
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.last_activity = datetime.datetime.now()
|
|
||||||
|
|
||||||
def request(self, *args, **kwargs):
|
|
||||||
self.last_activity = datetime.datetime.now()
|
|
||||||
super().request(*args, **kwargs)
|
|
||||||
|
|
||||||
def is_expired(self):
|
|
||||||
return datetime.datetime.now() - self.last_activity > self.idle_timeout
|
|
||||||
|
|
||||||
|
|
||||||
def LanceDBHTTPSConnectionPoolFactory(client_idle_timeout: int):
|
|
||||||
"""
|
|
||||||
Creates a connection pool class that can be used to close idle connections.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class LanceDBHTTPSConnectionPool(HTTPSConnectionPool):
|
|
||||||
# override the connection class
|
|
||||||
ConnectionCls = LanceDBHTTPSConnection
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def _get_conn(self, timeout: float | None = None):
|
|
||||||
logging.debug("Getting https connection")
|
|
||||||
conn = super()._get_conn(timeout)
|
|
||||||
if conn.is_expired():
|
|
||||||
logging.debug("Closing expired connection")
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
return conn
|
|
||||||
|
|
||||||
def _new_conn(self):
|
|
||||||
conn = super()._new_conn()
|
|
||||||
conn.idle_timeout = datetime.timedelta(seconds=client_idle_timeout)
|
|
||||||
return conn
|
|
||||||
|
|
||||||
return LanceDBHTTPSConnectionPool
|
|
||||||
|
|
||||||
|
|
||||||
class LanceDBClientPoolManager(PoolManager):
|
|
||||||
def __init__(
|
|
||||||
self, client_idle_timeout: int, num_pools: int, maxsize: int, **kwargs
|
|
||||||
):
|
|
||||||
super().__init__(num_pools=num_pools, maxsize=maxsize, **kwargs)
|
|
||||||
# inject our connection pool impl
|
|
||||||
connection_pool_class = LanceDBHTTPSConnectionPoolFactory(
|
|
||||||
client_idle_timeout=client_idle_timeout
|
|
||||||
)
|
|
||||||
self.pool_classes_by_scheme["https"] = connection_pool_class
|
|
||||||
|
|
||||||
|
|
||||||
def LanceDBClientHTTPAdapterFactory():
|
|
||||||
"""
|
|
||||||
Creates an HTTPAdapter class that can be used to close idle connections
|
|
||||||
"""
|
|
||||||
|
|
||||||
# closure over the timeout
|
|
||||||
client_idle_timeout = get_client_connection_timeout()
|
|
||||||
|
|
||||||
class LanceDBClientRequestHTTPAdapter(HTTPAdapter):
|
|
||||||
def init_poolmanager(self, connections, maxsize, block=False):
|
|
||||||
# inject our pool manager impl
|
|
||||||
self.poolmanager = LanceDBClientPoolManager(
|
|
||||||
client_idle_timeout=client_idle_timeout,
|
|
||||||
num_pools=connections,
|
|
||||||
maxsize=maxsize,
|
|
||||||
block=block,
|
|
||||||
)
|
|
||||||
|
|
||||||
return LanceDBClientRequestHTTPAdapter
|
|
||||||
@@ -11,13 +11,16 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from datetime import timedelta
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Iterable, List, Optional, Union
|
from typing import Any, Dict, Iterable, List, Optional, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
import warnings
|
||||||
|
|
||||||
from cachetools import TTLCache
|
from lancedb import connect_async
|
||||||
|
from lancedb.remote import ClientConfig
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
from overrides import override
|
from overrides import override
|
||||||
|
|
||||||
@@ -25,10 +28,8 @@ from ..common import DATA
|
|||||||
from ..db import DBConnection
|
from ..db import DBConnection
|
||||||
from ..embeddings import EmbeddingFunctionConfig
|
from ..embeddings import EmbeddingFunctionConfig
|
||||||
from ..pydantic import LanceModel
|
from ..pydantic import LanceModel
|
||||||
from ..table import Table, sanitize_create_table
|
from ..table import Table
|
||||||
from ..util import validate_table_name
|
from ..util import validate_table_name
|
||||||
from .arrow import to_ipc_binary
|
|
||||||
from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient
|
|
||||||
|
|
||||||
|
|
||||||
class RemoteDBConnection(DBConnection):
|
class RemoteDBConnection(DBConnection):
|
||||||
@@ -41,26 +42,70 @@ class RemoteDBConnection(DBConnection):
|
|||||||
region: str,
|
region: str,
|
||||||
host_override: Optional[str] = None,
|
host_override: Optional[str] = None,
|
||||||
request_thread_pool: Optional[ThreadPoolExecutor] = None,
|
request_thread_pool: Optional[ThreadPoolExecutor] = None,
|
||||||
connection_timeout: float = 120.0,
|
client_config: Union[ClientConfig, Dict[str, Any], None] = None,
|
||||||
read_timeout: float = 300.0,
|
connection_timeout: Optional[float] = None,
|
||||||
|
read_timeout: Optional[float] = None,
|
||||||
):
|
):
|
||||||
"""Connect to a remote LanceDB database."""
|
"""Connect to a remote LanceDB database."""
|
||||||
|
|
||||||
|
if isinstance(client_config, dict):
|
||||||
|
client_config = ClientConfig(**client_config)
|
||||||
|
elif client_config is None:
|
||||||
|
client_config = ClientConfig()
|
||||||
|
|
||||||
|
# These are legacy options from the old Python-based client. We keep them
|
||||||
|
# here for backwards compatibility, but will remove them in a future release.
|
||||||
|
if request_thread_pool is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"request_thread_pool is no longer used and will be removed in "
|
||||||
|
"a future release.",
|
||||||
|
DeprecationWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
if connection_timeout is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"connection_timeout is deprecated and will be removed in a future "
|
||||||
|
"release. Please use client_config.timeout_config.connect_timeout "
|
||||||
|
"instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
)
|
||||||
|
client_config.timeout_config.connect_timeout = timedelta(
|
||||||
|
seconds=connection_timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
if read_timeout is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"read_timeout is deprecated and will be removed in a future release. "
|
||||||
|
"Please use client_config.timeout_config.read_timeout instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
)
|
||||||
|
client_config.timeout_config.read_timeout = timedelta(seconds=read_timeout)
|
||||||
|
|
||||||
parsed = urlparse(db_url)
|
parsed = urlparse(db_url)
|
||||||
if parsed.scheme != "db":
|
if parsed.scheme != "db":
|
||||||
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
|
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
|
||||||
self._uri = str(db_url)
|
|
||||||
self.db_name = parsed.netloc
|
self.db_name = parsed.netloc
|
||||||
self.api_key = api_key
|
|
||||||
self._client = RestfulLanceDBClient(
|
import nest_asyncio
|
||||||
self.db_name,
|
|
||||||
region,
|
nest_asyncio.apply()
|
||||||
api_key,
|
try:
|
||||||
host_override,
|
self._loop = asyncio.get_running_loop()
|
||||||
connection_timeout=connection_timeout,
|
except RuntimeError:
|
||||||
read_timeout=read_timeout,
|
self._loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(self._loop)
|
||||||
|
|
||||||
|
self.client_config = client_config
|
||||||
|
|
||||||
|
self._conn = self._loop.run_until_complete(
|
||||||
|
connect_async(
|
||||||
|
db_url,
|
||||||
|
api_key=api_key,
|
||||||
|
region=region,
|
||||||
|
host_override=host_override,
|
||||||
|
client_config=client_config,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self._request_thread_pool = request_thread_pool
|
|
||||||
self._table_cache = TTLCache(maxsize=10000, ttl=300)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"RemoteConnect(name={self.db_name})"
|
return f"RemoteConnect(name={self.db_name})"
|
||||||
@@ -82,16 +127,9 @@ class RemoteDBConnection(DBConnection):
|
|||||||
-------
|
-------
|
||||||
An iterator of table names.
|
An iterator of table names.
|
||||||
"""
|
"""
|
||||||
while True:
|
return self._loop.run_until_complete(
|
||||||
result = self._client.list_tables(limit, page_token)
|
self._conn.table_names(start_after=page_token, limit=limit)
|
||||||
|
)
|
||||||
if len(result) > 0:
|
|
||||||
page_token = result[len(result) - 1]
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
for item in result:
|
|
||||||
self._table_cache[item] = True
|
|
||||||
yield item
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def open_table(self, name: str, *, index_cache_size: Optional[int] = None) -> Table:
|
def open_table(self, name: str, *, index_cache_size: Optional[int] = None) -> Table:
|
||||||
@@ -108,20 +146,14 @@ class RemoteDBConnection(DBConnection):
|
|||||||
"""
|
"""
|
||||||
from .table import RemoteTable
|
from .table import RemoteTable
|
||||||
|
|
||||||
self._client.mount_retry_adapter_for_table(name)
|
|
||||||
|
|
||||||
if index_cache_size is not None:
|
if index_cache_size is not None:
|
||||||
logging.info(
|
logging.info(
|
||||||
"index_cache_size is ignored in LanceDb Cloud"
|
"index_cache_size is ignored in LanceDb Cloud"
|
||||||
" (there is no local cache to configure)"
|
" (there is no local cache to configure)"
|
||||||
)
|
)
|
||||||
|
|
||||||
# check if table exists
|
table = self._loop.run_until_complete(self._conn.open_table(name))
|
||||||
if self._table_cache.get(name) is None:
|
return RemoteTable(table, self.db_name, self._loop)
|
||||||
self._client.post(f"/v1/table/{name}/describe/")
|
|
||||||
self._table_cache[name] = True
|
|
||||||
|
|
||||||
return RemoteTable(self, name)
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def create_table(
|
def create_table(
|
||||||
@@ -233,27 +265,20 @@ class RemoteDBConnection(DBConnection):
|
|||||||
"Please vote https://github.com/lancedb/lancedb/issues/626 "
|
"Please vote https://github.com/lancedb/lancedb/issues/626 "
|
||||||
"for this feature."
|
"for this feature."
|
||||||
)
|
)
|
||||||
if mode is not None:
|
|
||||||
logging.warning("mode is not yet supported on LanceDB Cloud.")
|
|
||||||
|
|
||||||
data, schema = sanitize_create_table(
|
|
||||||
data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
|
||||||
)
|
|
||||||
|
|
||||||
from .table import RemoteTable
|
from .table import RemoteTable
|
||||||
|
|
||||||
data = to_ipc_binary(data)
|
table = self._loop.run_until_complete(
|
||||||
request_id = uuid.uuid4().hex
|
self._conn.create_table(
|
||||||
|
name,
|
||||||
self._client.post(
|
data,
|
||||||
f"/v1/table/{name}/create/",
|
mode=mode,
|
||||||
data=data,
|
schema=schema,
|
||||||
request_id=request_id,
|
on_bad_vectors=on_bad_vectors,
|
||||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
fill_value=fill_value,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
self._table_cache[name] = True
|
return RemoteTable(table, self.db_name, self._loop)
|
||||||
return RemoteTable(self, name)
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def drop_table(self, name: str):
|
def drop_table(self, name: str):
|
||||||
@@ -264,11 +289,7 @@ class RemoteDBConnection(DBConnection):
|
|||||||
name: str
|
name: str
|
||||||
The name of the table.
|
The name of the table.
|
||||||
"""
|
"""
|
||||||
|
self._loop.run_until_complete(self._conn.drop_table(name))
|
||||||
self._client.post(
|
|
||||||
f"/v1/table/{name}/drop/",
|
|
||||||
)
|
|
||||||
self._table_cache.pop(name, default=None)
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def rename_table(self, cur_name: str, new_name: str):
|
def rename_table(self, cur_name: str, new_name: str):
|
||||||
@@ -281,12 +302,7 @@ class RemoteDBConnection(DBConnection):
|
|||||||
new_name: str
|
new_name: str
|
||||||
The new name of the table.
|
The new name of the table.
|
||||||
"""
|
"""
|
||||||
self._client.post(
|
self._loop.run_until_complete(self._conn.rename_table(cur_name, new_name))
|
||||||
f"/v1/table/{cur_name}/rename/",
|
|
||||||
data={"new_table_name": new_name},
|
|
||||||
)
|
|
||||||
self._table_cache.pop(cur_name, default=None)
|
|
||||||
self._table_cache[new_name] = True
|
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
"""Close the connection to the database."""
|
"""Close the connection to the database."""
|
||||||
|
|||||||
@@ -11,53 +11,57 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from datetime import timedelta
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
|
||||||
from concurrent.futures import Future
|
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Dict, Iterable, List, Optional, Union, Literal
|
from typing import Dict, Iterable, List, Optional, Union, Literal
|
||||||
|
|
||||||
|
from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfPq, LabelList
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
from lance import json_to_schema
|
|
||||||
|
|
||||||
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
|
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||||
from lancedb.merge import LanceMergeInsertBuilder
|
from lancedb.merge import LanceMergeInsertBuilder
|
||||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||||
|
|
||||||
from ..query import LanceVectorQueryBuilder, LanceQueryBuilder
|
from ..query import LanceVectorQueryBuilder, LanceQueryBuilder
|
||||||
from ..table import Query, Table, _sanitize_data
|
from ..table import AsyncTable, Query, Table
|
||||||
from ..util import value_to_sql, infer_vector_column_name
|
|
||||||
from .arrow import to_ipc_binary
|
|
||||||
from .client import ARROW_STREAM_CONTENT_TYPE
|
|
||||||
from .db import RemoteDBConnection
|
|
||||||
|
|
||||||
|
|
||||||
class RemoteTable(Table):
|
class RemoteTable(Table):
|
||||||
def __init__(self, conn: RemoteDBConnection, name: str):
|
def __init__(
|
||||||
self._conn = conn
|
self,
|
||||||
self.name = name
|
table: AsyncTable,
|
||||||
|
db_name: str,
|
||||||
|
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||||
|
):
|
||||||
|
self._loop = loop
|
||||||
|
self._table = table
|
||||||
|
self.db_name = db_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""The name of the table"""
|
||||||
|
return self._table.name
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"RemoteTable({self._conn.db_name}.{self.name})"
|
return f"RemoteTable({self.db_name}.{self.name})"
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
self.count_rows(None)
|
self.count_rows(None)
|
||||||
|
|
||||||
@cached_property
|
@property
|
||||||
def schema(self) -> pa.Schema:
|
def schema(self) -> pa.Schema:
|
||||||
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#)
|
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#)
|
||||||
of this Table
|
of this Table
|
||||||
|
|
||||||
"""
|
"""
|
||||||
resp = self._conn._client.post(f"/v1/table/{self.name}/describe/")
|
return self._loop.run_until_complete(self._table.schema())
|
||||||
schema = json_to_schema(resp["schema"])
|
|
||||||
return schema
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def version(self) -> int:
|
def version(self) -> int:
|
||||||
"""Get the current version of the table"""
|
"""Get the current version of the table"""
|
||||||
resp = self._conn._client.post(f"/v1/table/{self.name}/describe/")
|
return self._loop.run_until_complete(self._table.version())
|
||||||
return resp["version"]
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def embedding_functions(self) -> dict:
|
def embedding_functions(self) -> dict:
|
||||||
@@ -84,20 +88,18 @@ class RemoteTable(Table):
|
|||||||
|
|
||||||
def list_indices(self):
|
def list_indices(self):
|
||||||
"""List all the indices on the table"""
|
"""List all the indices on the table"""
|
||||||
resp = self._conn._client.post(f"/v1/table/{self.name}/index/list/")
|
return self._loop.run_until_complete(self._table.list_indices())
|
||||||
return resp
|
|
||||||
|
|
||||||
def index_stats(self, index_uuid: str):
|
def index_stats(self, index_uuid: str):
|
||||||
"""List all the stats of a specified index"""
|
"""List all the stats of a specified index"""
|
||||||
resp = self._conn._client.post(
|
return self._loop.run_until_complete(self._table.index_stats(index_uuid))
|
||||||
f"/v1/table/{self.name}/index/{index_uuid}/stats/"
|
|
||||||
)
|
|
||||||
return resp
|
|
||||||
|
|
||||||
def create_scalar_index(
|
def create_scalar_index(
|
||||||
self,
|
self,
|
||||||
column: str,
|
column: str,
|
||||||
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST", "scalar"] = "scalar",
|
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST", "scalar"] = "scalar",
|
||||||
|
*,
|
||||||
|
replace: bool = False,
|
||||||
):
|
):
|
||||||
"""Creates a scalar index
|
"""Creates a scalar index
|
||||||
Parameters
|
Parameters
|
||||||
@@ -107,20 +109,23 @@ class RemoteTable(Table):
|
|||||||
or string column.
|
or string column.
|
||||||
index_type : str
|
index_type : str
|
||||||
The index type of the scalar index. Must be "scalar" (BTREE),
|
The index type of the scalar index. Must be "scalar" (BTREE),
|
||||||
"BTREE", "BITMAP", or "LABEL_LIST"
|
"BTREE", "BITMAP", or "LABEL_LIST",
|
||||||
|
replace : bool
|
||||||
|
If True, replace the existing index with the new one.
|
||||||
"""
|
"""
|
||||||
|
if index_type == "scalar" or index_type == "BTREE":
|
||||||
|
config = BTree()
|
||||||
|
elif index_type == "BITMAP":
|
||||||
|
config = Bitmap()
|
||||||
|
elif index_type == "LABEL_LIST":
|
||||||
|
config = LabelList()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown index type: {index_type}")
|
||||||
|
|
||||||
data = {
|
self._loop.run_until_complete(
|
||||||
"column": column,
|
self._table.create_index(column, config=config, replace=replace)
|
||||||
"index_type": index_type,
|
|
||||||
"replace": True,
|
|
||||||
}
|
|
||||||
resp = self._conn._client.post(
|
|
||||||
f"/v1/table/{self.name}/create_scalar_index/", data=data
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return resp
|
|
||||||
|
|
||||||
def create_fts_index(
|
def create_fts_index(
|
||||||
self,
|
self,
|
||||||
column: str,
|
column: str,
|
||||||
@@ -128,15 +133,10 @@ class RemoteTable(Table):
|
|||||||
replace: bool = False,
|
replace: bool = False,
|
||||||
with_position: bool = True,
|
with_position: bool = True,
|
||||||
):
|
):
|
||||||
data = {
|
config = FTS(with_position=with_position)
|
||||||
"column": column,
|
self._loop.run_until_complete(
|
||||||
"index_type": "FTS",
|
self._table.create_index(column, config=config, replace=replace)
|
||||||
"replace": replace,
|
|
||||||
}
|
|
||||||
resp = self._conn._client.post(
|
|
||||||
f"/v1/table/{self.name}/create_index/", data=data
|
|
||||||
)
|
)
|
||||||
return resp
|
|
||||||
|
|
||||||
def create_index(
|
def create_index(
|
||||||
self,
|
self,
|
||||||
@@ -204,17 +204,22 @@ class RemoteTable(Table):
|
|||||||
"Existing indexes will always be replaced."
|
"Existing indexes will always be replaced."
|
||||||
)
|
)
|
||||||
|
|
||||||
data = {
|
index_type = index_type.upper()
|
||||||
"column": vector_column_name,
|
if index_type == "VECTOR" or index_type == "IVF_PQ":
|
||||||
"index_type": index_type,
|
config = IvfPq(distance_type=metric)
|
||||||
"metric_type": metric,
|
elif index_type == "IVF_HNSW_PQ":
|
||||||
"index_cache_size": index_cache_size,
|
config = HnswPq(distance_type=metric)
|
||||||
}
|
elif index_type == "IVF_HNSW_SQ":
|
||||||
resp = self._conn._client.post(
|
config = HnswSq(distance_type=metric)
|
||||||
f"/v1/table/{self.name}/create_index/", data=data
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown vector index type: {index_type}. Valid options are"
|
||||||
|
" 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
|
||||||
)
|
)
|
||||||
|
|
||||||
return resp
|
self._loop.run_until_complete(
|
||||||
|
self._table.create_index(vector_column_name, config=config)
|
||||||
|
)
|
||||||
|
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
@@ -246,22 +251,10 @@ class RemoteTable(Table):
|
|||||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||||
|
|
||||||
"""
|
"""
|
||||||
data, _ = _sanitize_data(
|
self._loop.run_until_complete(
|
||||||
data,
|
self._table.add(
|
||||||
self.schema,
|
data, mode=mode, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||||
metadata=self.schema.metadata,
|
|
||||||
on_bad_vectors=on_bad_vectors,
|
|
||||||
fill_value=fill_value,
|
|
||||||
)
|
)
|
||||||
payload = to_ipc_binary(data)
|
|
||||||
|
|
||||||
request_id = uuid.uuid4().hex
|
|
||||||
|
|
||||||
self._conn._client.post(
|
|
||||||
f"/v1/table/{self.name}/insert/",
|
|
||||||
data=payload,
|
|
||||||
params={"request_id": request_id, "mode": mode},
|
|
||||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
@@ -334,16 +327,6 @@ class RemoteTable(Table):
|
|||||||
- and also the "_distance" column which is the distance between the query
|
- and also the "_distance" column which is the distance between the query
|
||||||
vector and the returned vector.
|
vector and the returned vector.
|
||||||
"""
|
"""
|
||||||
# empty query builder is not supported in saas, raise error
|
|
||||||
if query is None and query_type != "hybrid":
|
|
||||||
raise ValueError("Empty query is not supported")
|
|
||||||
vector_column_name = infer_vector_column_name(
|
|
||||||
schema=self.schema,
|
|
||||||
query_type=query_type,
|
|
||||||
query=query,
|
|
||||||
vector_column_name=vector_column_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
return LanceQueryBuilder.create(
|
return LanceQueryBuilder.create(
|
||||||
self,
|
self,
|
||||||
query,
|
query,
|
||||||
@@ -356,38 +339,10 @@ class RemoteTable(Table):
|
|||||||
def _execute_query(
|
def _execute_query(
|
||||||
self, query: Query, batch_size: Optional[int] = None
|
self, query: Query, batch_size: Optional[int] = None
|
||||||
) -> pa.RecordBatchReader:
|
) -> pa.RecordBatchReader:
|
||||||
if (
|
return self._loop.run_until_complete(
|
||||||
query.vector is not None
|
self._table._execute_query(query, batch_size=batch_size)
|
||||||
and len(query.vector) > 0
|
|
||||||
and not isinstance(query.vector[0], float)
|
|
||||||
):
|
|
||||||
if self._conn._request_thread_pool is None:
|
|
||||||
|
|
||||||
def submit(name, q):
|
|
||||||
f = Future()
|
|
||||||
f.set_result(self._conn._client.query(name, q))
|
|
||||||
return f
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
def submit(name, q):
|
|
||||||
return self._conn._request_thread_pool.submit(
|
|
||||||
self._conn._client.query, name, q
|
|
||||||
)
|
)
|
||||||
|
|
||||||
results = []
|
|
||||||
for v in query.vector:
|
|
||||||
v = list(v)
|
|
||||||
q = query.copy()
|
|
||||||
q.vector = v
|
|
||||||
results.append(submit(self.name, q))
|
|
||||||
return pa.concat_tables(
|
|
||||||
[add_index(r.result().to_arrow(), i) for i, r in enumerate(results)]
|
|
||||||
).to_reader()
|
|
||||||
else:
|
|
||||||
result = self._conn._client.query(self.name, query)
|
|
||||||
return result.to_arrow().to_reader()
|
|
||||||
|
|
||||||
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
||||||
"""Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder]
|
"""Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder]
|
||||||
that can be used to create a "merge insert" operation.
|
that can be used to create a "merge insert" operation.
|
||||||
@@ -403,42 +358,8 @@ class RemoteTable(Table):
|
|||||||
on_bad_vectors: str,
|
on_bad_vectors: str,
|
||||||
fill_value: float,
|
fill_value: float,
|
||||||
):
|
):
|
||||||
data, _ = _sanitize_data(
|
self._loop.run_until_complete(
|
||||||
new_data,
|
self._table._do_merge(merge, new_data, on_bad_vectors, fill_value)
|
||||||
self.schema,
|
|
||||||
metadata=None,
|
|
||||||
on_bad_vectors=on_bad_vectors,
|
|
||||||
fill_value=fill_value,
|
|
||||||
)
|
|
||||||
payload = to_ipc_binary(data)
|
|
||||||
|
|
||||||
params = {}
|
|
||||||
if len(merge._on) != 1:
|
|
||||||
raise ValueError(
|
|
||||||
"RemoteTable only supports a single on key in merge_insert"
|
|
||||||
)
|
|
||||||
params["on"] = merge._on[0]
|
|
||||||
params["when_matched_update_all"] = str(merge._when_matched_update_all).lower()
|
|
||||||
if merge._when_matched_update_all_condition is not None:
|
|
||||||
params["when_matched_update_all_filt"] = (
|
|
||||||
merge._when_matched_update_all_condition
|
|
||||||
)
|
|
||||||
params["when_not_matched_insert_all"] = str(
|
|
||||||
merge._when_not_matched_insert_all
|
|
||||||
).lower()
|
|
||||||
params["when_not_matched_by_source_delete"] = str(
|
|
||||||
merge._when_not_matched_by_source_delete
|
|
||||||
).lower()
|
|
||||||
if merge._when_not_matched_by_source_condition is not None:
|
|
||||||
params["when_not_matched_by_source_delete_filt"] = (
|
|
||||||
merge._when_not_matched_by_source_condition
|
|
||||||
)
|
|
||||||
|
|
||||||
self._conn._client.post(
|
|
||||||
f"/v1/table/{self.name}/merge_insert/",
|
|
||||||
data=payload,
|
|
||||||
params=params,
|
|
||||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete(self, predicate: str):
|
def delete(self, predicate: str):
|
||||||
@@ -488,8 +409,7 @@ class RemoteTable(Table):
|
|||||||
x vector _distance # doctest: +SKIP
|
x vector _distance # doctest: +SKIP
|
||||||
0 2 [3.0, 4.0] 85.0 # doctest: +SKIP
|
0 2 [3.0, 4.0] 85.0 # doctest: +SKIP
|
||||||
"""
|
"""
|
||||||
payload = {"predicate": predicate}
|
self._loop.run_until_complete(self._table.delete(predicate))
|
||||||
self._conn._client.post(f"/v1/table/{self.name}/delete/", data=payload)
|
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
@@ -539,18 +459,9 @@ class RemoteTable(Table):
|
|||||||
2 2 [10.0, 10.0] # doctest: +SKIP
|
2 2 [10.0, 10.0] # doctest: +SKIP
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if values is not None and values_sql is not None:
|
self._loop.run_until_complete(
|
||||||
raise ValueError("Only one of values or values_sql can be provided")
|
self._table.update(where=where, updates=values, updates_sql=values_sql)
|
||||||
if values is None and values_sql is None:
|
)
|
||||||
raise ValueError("Either values or values_sql must be provided")
|
|
||||||
|
|
||||||
if values is not None:
|
|
||||||
updates = [[k, value_to_sql(v)] for k, v in values.items()]
|
|
||||||
else:
|
|
||||||
updates = [[k, v] for k, v in values_sql.items()]
|
|
||||||
|
|
||||||
payload = {"predicate": where, "updates": updates}
|
|
||||||
self._conn._client.post(f"/v1/table/{self.name}/update/", data=payload)
|
|
||||||
|
|
||||||
def cleanup_old_versions(self, *_):
|
def cleanup_old_versions(self, *_):
|
||||||
"""cleanup_old_versions() is not supported on the LanceDB cloud"""
|
"""cleanup_old_versions() is not supported on the LanceDB cloud"""
|
||||||
@@ -564,12 +475,21 @@ class RemoteTable(Table):
|
|||||||
"compact_files() is not supported on the LanceDB cloud"
|
"compact_files() is not supported on the LanceDB cloud"
|
||||||
)
|
)
|
||||||
|
|
||||||
def count_rows(self, filter: Optional[str] = None) -> int:
|
def optimize(
|
||||||
payload = {"predicate": filter}
|
self,
|
||||||
resp = self._conn._client.post(
|
*,
|
||||||
f"/v1/table/{self.name}/count_rows/", data=payload
|
cleanup_older_than: Optional[timedelta] = None,
|
||||||
|
delete_unverified: bool = False,
|
||||||
|
):
|
||||||
|
"""optimize() is not supported on the LanceDB cloud.
|
||||||
|
Indices are optimized automatically."""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"optimize() is not supported on the LanceDB cloud. "
|
||||||
|
"Indices are optimized automatically."
|
||||||
)
|
)
|
||||||
return resp
|
|
||||||
|
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||||
|
return self._loop.run_until_complete(self._table.count_rows(filter))
|
||||||
|
|
||||||
def add_columns(self, transforms: Dict[str, str]):
|
def add_columns(self, transforms: Dict[str, str]):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from .openai import OpenaiReranker
|
|||||||
from .jinaai import JinaReranker
|
from .jinaai import JinaReranker
|
||||||
from .rrf import RRFReranker
|
from .rrf import RRFReranker
|
||||||
from .answerdotai import AnswerdotaiRerankers
|
from .answerdotai import AnswerdotaiRerankers
|
||||||
|
from .voyageai import VoyageAIReranker
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Reranker",
|
"Reranker",
|
||||||
@@ -18,4 +19,5 @@ __all__ = [
|
|||||||
"JinaReranker",
|
"JinaReranker",
|
||||||
"RRFReranker",
|
"RRFReranker",
|
||||||
"AnswerdotaiRerankers",
|
"AnswerdotaiRerankers",
|
||||||
|
"VoyageAIReranker",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -12,7 +12,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import requests
|
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@@ -57,6 +56,8 @@ class JinaReranker(Reranker):
|
|||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def _client(self):
|
def _client(self):
|
||||||
|
import requests
|
||||||
|
|
||||||
if os.environ.get("JINA_API_KEY") is None and self.api_key is None:
|
if os.environ.get("JINA_API_KEY") is None and self.api_key is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"JINA_API_KEY not set. Either set it in your environment or \
|
"JINA_API_KEY not set. Either set it in your environment or \
|
||||||
|
|||||||
@@ -11,7 +11,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from numpy import NaN
|
from numpy import nan
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
|
||||||
from .base import Reranker
|
from .base import Reranker
|
||||||
@@ -71,7 +71,7 @@ class LinearCombinationReranker(Reranker):
|
|||||||
elif self.score == "all":
|
elif self.score == "all":
|
||||||
results = results.append_column(
|
results = results.append_column(
|
||||||
"_distance",
|
"_distance",
|
||||||
pa.array([NaN] * len(fts_results), type=pa.float32()),
|
pa.array([nan] * len(fts_results), type=pa.float32()),
|
||||||
)
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@@ -92,7 +92,7 @@ class LinearCombinationReranker(Reranker):
|
|||||||
elif self.score == "all":
|
elif self.score == "all":
|
||||||
results = results.append_column(
|
results = results.append_column(
|
||||||
"_score",
|
"_score",
|
||||||
pa.array([NaN] * len(vector_results), type=pa.float32()),
|
pa.array([nan] * len(vector_results), type=pa.float32()),
|
||||||
)
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|||||||
133
python/python/lancedb/rerankers/voyageai.py
Normal file
133
python/python/lancedb/rerankers/voyageai.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
# Copyright (c) 2023. LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from functools import cached_property
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
from ..util import attempt_import_or_raise
|
||||||
|
from .base import Reranker
|
||||||
|
|
||||||
|
|
||||||
|
class VoyageAIReranker(Reranker):
|
||||||
|
"""
|
||||||
|
Reranks the results using the VoyageAI Rerank API.
|
||||||
|
https://docs.voyageai.com/docs/reranker
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model_name : str, default "rerank-english-v2.0"
|
||||||
|
The name of the cross encoder model to use. Available voyageai models are:
|
||||||
|
- rerank-2
|
||||||
|
- rerank-2-lite
|
||||||
|
column : str, default "text"
|
||||||
|
The name of the column to use as input to the cross encoder model.
|
||||||
|
top_n : int, default None
|
||||||
|
The number of results to return. If None, will return all results.
|
||||||
|
return_score : str, default "relevance"
|
||||||
|
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||||
|
api_key : str, default None
|
||||||
|
The API key to use. If None, will use the OPENAI_API_KEY environment variable.
|
||||||
|
truncation : Optional[bool], default None
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
column: str = "text",
|
||||||
|
top_n: Optional[int] = None,
|
||||||
|
return_score="relevance",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
truncation: Optional[bool] = True,
|
||||||
|
):
|
||||||
|
super().__init__(return_score)
|
||||||
|
self.model_name = model_name
|
||||||
|
self.column = column
|
||||||
|
self.top_n = top_n
|
||||||
|
self.api_key = api_key
|
||||||
|
self.truncation = truncation
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _client(self):
|
||||||
|
voyageai = attempt_import_or_raise("voyageai")
|
||||||
|
if os.environ.get("VOYAGE_API_KEY") is None and self.api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"VOYAGE_API_KEY not set. Either set it in your environment or \
|
||||||
|
pass it as `api_key` argument to the VoyageAIReranker."
|
||||||
|
)
|
||||||
|
return voyageai.Client(
|
||||||
|
api_key=os.environ.get("VOYAGE_API_KEY") or self.api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _rerank(self, result_set: pa.Table, query: str):
|
||||||
|
docs = result_set[self.column].to_pylist()
|
||||||
|
response = self._client.rerank(
|
||||||
|
query=query,
|
||||||
|
documents=docs,
|
||||||
|
top_k=self.top_n,
|
||||||
|
model=self.model_name,
|
||||||
|
truncation=self.truncation,
|
||||||
|
)
|
||||||
|
results = (
|
||||||
|
response.results
|
||||||
|
) # returns list (text, idx, relevance) attributes sorted descending by score
|
||||||
|
indices, scores = list(
|
||||||
|
zip(*[(result.index, result.relevance_score) for result in results])
|
||||||
|
) # tuples
|
||||||
|
result_set = result_set.take(list(indices))
|
||||||
|
# add the scores
|
||||||
|
result_set = result_set.append_column(
|
||||||
|
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||||
|
)
|
||||||
|
|
||||||
|
return result_set
|
||||||
|
|
||||||
|
def rerank_hybrid(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
vector_results: pa.Table,
|
||||||
|
fts_results: pa.Table,
|
||||||
|
):
|
||||||
|
combined_results = self.merge_results(vector_results, fts_results)
|
||||||
|
combined_results = self._rerank(combined_results, query)
|
||||||
|
if self.score == "relevance":
|
||||||
|
combined_results = self._keep_relevance_score(combined_results)
|
||||||
|
elif self.score == "all":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"return_score='all' not implemented for voyageai reranker"
|
||||||
|
)
|
||||||
|
return combined_results
|
||||||
|
|
||||||
|
def rerank_vector(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
vector_results: pa.Table,
|
||||||
|
):
|
||||||
|
result_set = self._rerank(vector_results, query)
|
||||||
|
if self.score == "relevance":
|
||||||
|
result_set = result_set.drop_columns(["_distance"])
|
||||||
|
|
||||||
|
return result_set
|
||||||
|
|
||||||
|
def rerank_fts(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
fts_results: pa.Table,
|
||||||
|
):
|
||||||
|
result_set = self._rerank(fts_results, query)
|
||||||
|
if self.score == "relevance":
|
||||||
|
result_set = result_set.drop_columns(["_score"])
|
||||||
|
|
||||||
|
return result_set
|
||||||
@@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@@ -32,7 +33,7 @@ import pyarrow.fs as pa_fs
|
|||||||
from lance import LanceDataset
|
from lance import LanceDataset
|
||||||
from lance.dependencies import _check_for_hugging_face
|
from lance.dependencies import _check_for_hugging_face
|
||||||
|
|
||||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
from .common import DATA, VEC, VECTOR_COLUMN_NAME, sanitize_uri
|
||||||
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||||
from .merge import LanceMergeInsertBuilder
|
from .merge import LanceMergeInsertBuilder
|
||||||
from .pydantic import LanceModel, model_to_dict
|
from .pydantic import LanceModel, model_to_dict
|
||||||
@@ -57,12 +58,14 @@ from .util import (
|
|||||||
)
|
)
|
||||||
from .index import lang_mapping
|
from .index import lang_mapping
|
||||||
|
|
||||||
|
from ._lancedb import connect as lancedb_connect
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import PIL
|
import PIL
|
||||||
from lance.dataset import CleanupStats, ReaderLike
|
from lance.dataset import CleanupStats, ReaderLike
|
||||||
from ._lancedb import Table as LanceDBTable, OptimizeStats
|
from ._lancedb import Table as LanceDBTable, OptimizeStats
|
||||||
from .db import LanceDBConnection
|
from .db import LanceDBConnection
|
||||||
from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList, FTS
|
from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList, FTS, HnswPq, HnswSq
|
||||||
|
|
||||||
pd = safe_import_pandas()
|
pd = safe_import_pandas()
|
||||||
pl = safe_import_polars()
|
pl = safe_import_polars()
|
||||||
@@ -70,6 +73,21 @@ pl = safe_import_polars()
|
|||||||
QueryType = Literal["vector", "fts", "hybrid", "auto"]
|
QueryType = Literal["vector", "fts", "hybrid", "auto"]
|
||||||
|
|
||||||
|
|
||||||
|
def _pd_schema_without_embedding_funcs(
|
||||||
|
schema: Optional[pa.Schema], columns: List[str]
|
||||||
|
) -> Optional[pa.Schema]:
|
||||||
|
"""Return a schema without any embedding function columns"""
|
||||||
|
if schema is None:
|
||||||
|
return None
|
||||||
|
embedding_functions = EmbeddingFunctionRegistry.get_instance().parse_functions(
|
||||||
|
schema.metadata
|
||||||
|
)
|
||||||
|
if not embedding_functions:
|
||||||
|
return schema
|
||||||
|
columns = set(columns)
|
||||||
|
return pa.schema([field for field in schema if field.name in columns])
|
||||||
|
|
||||||
|
|
||||||
def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
|
def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
|
||||||
if _check_for_hugging_face(data):
|
if _check_for_hugging_face(data):
|
||||||
# Huggingface datasets
|
# Huggingface datasets
|
||||||
@@ -100,10 +118,10 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
|
|||||||
elif isinstance(data[0], pa.RecordBatch):
|
elif isinstance(data[0], pa.RecordBatch):
|
||||||
return pa.Table.from_batches(data, schema=schema)
|
return pa.Table.from_batches(data, schema=schema)
|
||||||
else:
|
else:
|
||||||
return pa.Table.from_pylist(data)
|
return pa.Table.from_pylist(data, schema=schema)
|
||||||
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame):
|
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame):
|
||||||
# Do not add schema here, since schema may contains the vector column
|
raw_schema = _pd_schema_without_embedding_funcs(schema, data.columns.to_list())
|
||||||
table = pa.Table.from_pandas(data, preserve_index=False)
|
table = pa.Table.from_pandas(data, preserve_index=False, schema=raw_schema)
|
||||||
# Do not serialize Pandas metadata
|
# Do not serialize Pandas metadata
|
||||||
meta = table.schema.metadata if table.schema.metadata is not None else {}
|
meta = table.schema.metadata if table.schema.metadata is not None else {}
|
||||||
meta = {k: v for k, v in meta.items() if k != b"pandas"}
|
meta = {k: v for k, v in meta.items() if k != b"pandas"}
|
||||||
@@ -169,6 +187,8 @@ def sanitize_create_table(
|
|||||||
schema = schema.to_arrow_schema()
|
schema = schema.to_arrow_schema()
|
||||||
|
|
||||||
if data is not None:
|
if data is not None:
|
||||||
|
if metadata is None and schema is not None:
|
||||||
|
metadata = schema.metadata
|
||||||
data, schema = _sanitize_data(
|
data, schema = _sanitize_data(
|
||||||
data,
|
data,
|
||||||
schema,
|
schema,
|
||||||
@@ -893,6 +913,55 @@ class Table(ABC):
|
|||||||
For most cases, the default should be fine.
|
For most cases, the default should be fine.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def optimize(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
cleanup_older_than: Optional[timedelta] = None,
|
||||||
|
delete_unverified: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Optimize the on-disk data and indices for better performance.
|
||||||
|
|
||||||
|
Modeled after ``VACUUM`` in PostgreSQL.
|
||||||
|
|
||||||
|
Optimization covers three operations:
|
||||||
|
|
||||||
|
* Compaction: Merges small files into larger ones
|
||||||
|
* Prune: Removes old versions of the dataset
|
||||||
|
* Index: Optimizes the indices, adding new data to existing indices
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
cleanup_older_than: timedelta, optional default 7 days
|
||||||
|
All files belonging to versions older than this will be removed. Set
|
||||||
|
to 0 days to remove all versions except the latest. The latest version
|
||||||
|
is never removed.
|
||||||
|
delete_unverified: bool, default False
|
||||||
|
Files leftover from a failed transaction may appear to be part of an
|
||||||
|
in-progress operation (e.g. appending new data) and these files will not
|
||||||
|
be deleted unless they are at least 7 days old. If delete_unverified is True
|
||||||
|
then these files will be deleted regardless of their age.
|
||||||
|
|
||||||
|
Experimental API
|
||||||
|
----------------
|
||||||
|
|
||||||
|
The optimization process is undergoing active development and may change.
|
||||||
|
Our goal with these changes is to improve the performance of optimization and
|
||||||
|
reduce the complexity.
|
||||||
|
|
||||||
|
That being said, it is essential today to run optimize if you want the best
|
||||||
|
performance. It should be stable and safe to use in production, but it our
|
||||||
|
hope that the API may be simplified (or not even need to be called) in the
|
||||||
|
future.
|
||||||
|
|
||||||
|
The frequency an application shoudl call optimize is based on the frequency of
|
||||||
|
data modifications. If data is frequently added, deleted, or updated then
|
||||||
|
optimize should be run frequently. A good rule of thumb is to run optimize if
|
||||||
|
you have added or modified 100,000 or more records or run more than 20 data
|
||||||
|
modification operations.
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def add_columns(self, transforms: Dict[str, str]):
|
def add_columns(self, transforms: Dict[str, str]):
|
||||||
"""
|
"""
|
||||||
@@ -948,7 +1017,9 @@ class Table(ABC):
|
|||||||
return _table_uri(self._conn.uri, self.name)
|
return _table_uri(self._conn.uri, self.name)
|
||||||
|
|
||||||
def _get_fts_index_path(self) -> Tuple[str, pa_fs.FileSystem, bool]:
|
def _get_fts_index_path(self) -> Tuple[str, pa_fs.FileSystem, bool]:
|
||||||
if get_uri_scheme(self._dataset_uri) != "file":
|
from .remote.table import RemoteTable
|
||||||
|
|
||||||
|
if isinstance(self, RemoteTable) or get_uri_scheme(self._dataset_uri) != "file":
|
||||||
return ("", None, False)
|
return ("", None, False)
|
||||||
path = join_uri(self._dataset_uri, "_indices", "fts")
|
path = join_uri(self._dataset_uri, "_indices", "fts")
|
||||||
fs, path = fs_from_uri(path)
|
fs, path = fs_from_uri(path)
|
||||||
@@ -1969,6 +2040,83 @@ class LanceTable(Table):
|
|||||||
"""
|
"""
|
||||||
return self.to_lance().optimize.compact_files(*args, **kwargs)
|
return self.to_lance().optimize.compact_files(*args, **kwargs)
|
||||||
|
|
||||||
|
def optimize(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
cleanup_older_than: Optional[timedelta] = None,
|
||||||
|
delete_unverified: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Optimize the on-disk data and indices for better performance.
|
||||||
|
|
||||||
|
Modeled after ``VACUUM`` in PostgreSQL.
|
||||||
|
|
||||||
|
Optimization covers three operations:
|
||||||
|
|
||||||
|
* Compaction: Merges small files into larger ones
|
||||||
|
* Prune: Removes old versions of the dataset
|
||||||
|
* Index: Optimizes the indices, adding new data to existing indices
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
cleanup_older_than: timedelta, optional default 7 days
|
||||||
|
All files belonging to versions older than this will be removed. Set
|
||||||
|
to 0 days to remove all versions except the latest. The latest version
|
||||||
|
is never removed.
|
||||||
|
delete_unverified: bool, default False
|
||||||
|
Files leftover from a failed transaction may appear to be part of an
|
||||||
|
in-progress operation (e.g. appending new data) and these files will not
|
||||||
|
be deleted unless they are at least 7 days old. If delete_unverified is True
|
||||||
|
then these files will be deleted regardless of their age.
|
||||||
|
|
||||||
|
Experimental API
|
||||||
|
----------------
|
||||||
|
|
||||||
|
The optimization process is undergoing active development and may change.
|
||||||
|
Our goal with these changes is to improve the performance of optimization and
|
||||||
|
reduce the complexity.
|
||||||
|
|
||||||
|
That being said, it is essential today to run optimize if you want the best
|
||||||
|
performance. It should be stable and safe to use in production, but it our
|
||||||
|
hope that the API may be simplified (or not even need to be called) in the
|
||||||
|
future.
|
||||||
|
|
||||||
|
The frequency an application shoudl call optimize is based on the frequency of
|
||||||
|
data modifications. If data is frequently added, deleted, or updated then
|
||||||
|
optimize should be run frequently. A good rule of thumb is to run optimize if
|
||||||
|
you have added or modified 100,000 or more records or run more than 20 data
|
||||||
|
modification operations.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
asyncio.get_running_loop()
|
||||||
|
raise AssertionError(
|
||||||
|
"Synchronous method called in asynchronous context. "
|
||||||
|
"If you are writing an asynchronous application "
|
||||||
|
"then please use the asynchronous APIs"
|
||||||
|
)
|
||||||
|
|
||||||
|
except RuntimeError:
|
||||||
|
asyncio.run(
|
||||||
|
self._async_optimize(
|
||||||
|
cleanup_older_than=cleanup_older_than,
|
||||||
|
delete_unverified=delete_unverified,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.checkout_latest()
|
||||||
|
|
||||||
|
async def _async_optimize(
|
||||||
|
self,
|
||||||
|
cleanup_older_than: Optional[timedelta] = None,
|
||||||
|
delete_unverified: bool = False,
|
||||||
|
):
|
||||||
|
conn = await lancedb_connect(
|
||||||
|
sanitize_uri(self._conn.uri),
|
||||||
|
)
|
||||||
|
table = AsyncTable(await conn.open_table(self.name))
|
||||||
|
return await table.optimize(
|
||||||
|
cleanup_older_than=cleanup_older_than, delete_unverified=delete_unverified
|
||||||
|
)
|
||||||
|
|
||||||
def add_columns(self, transforms: Dict[str, str]):
|
def add_columns(self, transforms: Dict[str, str]):
|
||||||
self._dataset_mut.add_columns(transforms)
|
self._dataset_mut.add_columns(transforms)
|
||||||
|
|
||||||
@@ -2382,7 +2530,9 @@ class AsyncTable:
|
|||||||
column: str,
|
column: str,
|
||||||
*,
|
*,
|
||||||
replace: Optional[bool] = None,
|
replace: Optional[bool] = None,
|
||||||
config: Optional[Union[IvfPq, BTree, Bitmap, LabelList, FTS]] = None,
|
config: Optional[
|
||||||
|
Union[IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS]
|
||||||
|
] = None,
|
||||||
):
|
):
|
||||||
"""Create an index to speed up queries
|
"""Create an index to speed up queries
|
||||||
|
|
||||||
@@ -2535,7 +2685,44 @@ class AsyncTable:
|
|||||||
async def _execute_query(
|
async def _execute_query(
|
||||||
self, query: Query, batch_size: Optional[int] = None
|
self, query: Query, batch_size: Optional[int] = None
|
||||||
) -> pa.RecordBatchReader:
|
) -> pa.RecordBatchReader:
|
||||||
pass
|
# The sync remote table calls into this method, so we need to map the
|
||||||
|
# query to the async version of the query and run that here. This is only
|
||||||
|
# used for that code path right now.
|
||||||
|
async_query = self.query().limit(query.k)
|
||||||
|
if query.offset > 0:
|
||||||
|
async_query = async_query.offset(query.offset)
|
||||||
|
if query.columns:
|
||||||
|
async_query = async_query.select(query.columns)
|
||||||
|
if query.filter:
|
||||||
|
async_query = async_query.where(query.filter)
|
||||||
|
if query.fast_search:
|
||||||
|
async_query = async_query.fast_search()
|
||||||
|
if query.with_row_id:
|
||||||
|
async_query = async_query.with_row_id()
|
||||||
|
|
||||||
|
if query.vector:
|
||||||
|
async_query = (
|
||||||
|
async_query.nearest_to(query.vector)
|
||||||
|
.distance_type(query.metric)
|
||||||
|
.nprobes(query.nprobes)
|
||||||
|
)
|
||||||
|
if query.refine_factor:
|
||||||
|
async_query = async_query.refine_factor(query.refine_factor)
|
||||||
|
if query.vector_column:
|
||||||
|
async_query = async_query.column(query.vector_column)
|
||||||
|
|
||||||
|
if not query.prefilter:
|
||||||
|
async_query = async_query.postfilter()
|
||||||
|
|
||||||
|
if isinstance(query.full_text_query, str):
|
||||||
|
async_query = async_query.nearest_to_text(query.full_text_query)
|
||||||
|
elif isinstance(query.full_text_query, dict):
|
||||||
|
fts_query = query.full_text_query["query"]
|
||||||
|
fts_columns = query.full_text_query.get("columns", []) or []
|
||||||
|
async_query = async_query.nearest_to_text(fts_query, columns=fts_columns)
|
||||||
|
|
||||||
|
table = await async_query.to_arrow()
|
||||||
|
return table.to_reader()
|
||||||
|
|
||||||
async def _do_merge(
|
async def _do_merge(
|
||||||
self,
|
self,
|
||||||
@@ -2781,7 +2968,7 @@ class AsyncTable:
|
|||||||
cleanup_older_than = round(cleanup_older_than.total_seconds() * 1000)
|
cleanup_older_than = round(cleanup_older_than.total_seconds() * 1000)
|
||||||
return await self._inner.optimize(cleanup_older_than, delete_unverified)
|
return await self._inner.optimize(cleanup_older_than, delete_unverified)
|
||||||
|
|
||||||
async def list_indices(self) -> IndexConfig:
|
async def list_indices(self) -> Iterable[IndexConfig]:
|
||||||
"""
|
"""
|
||||||
List all indices that have been created with Self::create_index
|
List all indices that have been created with Self::create_index
|
||||||
"""
|
"""
|
||||||
@@ -2865,3 +3052,8 @@ class IndexStatistics:
|
|||||||
]
|
]
|
||||||
distance_type: Optional[Literal["l2", "cosine", "dot"]] = None
|
distance_type: Optional[Literal["l2", "cosine", "dot"]] = None
|
||||||
num_indices: Optional[int] = None
|
num_indices: Optional[int] = None
|
||||||
|
|
||||||
|
# This exists for backwards compatibility with an older API, which returned
|
||||||
|
# a dictionary instead of a class.
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return getattr(self, key)
|
||||||
|
|||||||
@@ -1,15 +1,6 @@
|
|||||||
# Copyright 2023 LanceDB Developers
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
#
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
@@ -18,6 +9,7 @@ import lancedb
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import pytest
|
import pytest
|
||||||
|
import pandas as pd
|
||||||
from lancedb.conftest import MockTextEmbeddingFunction
|
from lancedb.conftest import MockTextEmbeddingFunction
|
||||||
from lancedb.embeddings import (
|
from lancedb.embeddings import (
|
||||||
EmbeddingFunctionConfig,
|
EmbeddingFunctionConfig,
|
||||||
@@ -129,6 +121,142 @@ def test_embedding_with_bad_results(tmp_path):
|
|||||||
# assert tbl["vector"].null_count == 1
|
# assert tbl["vector"].null_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_with_existing_vectors(tmp_path):
|
||||||
|
@register("mock-embedding")
|
||||||
|
class MockEmbeddingFunction(TextEmbeddingFunction):
|
||||||
|
def ndims(self):
|
||||||
|
return 128
|
||||||
|
|
||||||
|
def generate_embeddings(
|
||||||
|
self, texts: Union[List[str], np.ndarray]
|
||||||
|
) -> List[np.array]:
|
||||||
|
return [np.random.randn(self.ndims()).tolist() for _ in range(len(texts))]
|
||||||
|
|
||||||
|
registry = get_registry()
|
||||||
|
model = registry.get("mock-embedding").create()
|
||||||
|
|
||||||
|
class Schema(LanceModel):
|
||||||
|
text: str = model.SourceField()
|
||||||
|
vector: Vector(model.ndims()) = model.VectorField()
|
||||||
|
|
||||||
|
db = lancedb.connect(tmp_path)
|
||||||
|
tbl = db.create_table("test", schema=Schema, mode="overwrite")
|
||||||
|
tbl.add([{"text": "hello world", "vector": np.zeros(128).tolist()}])
|
||||||
|
|
||||||
|
embeddings = tbl.to_arrow()["vector"].to_pylist()
|
||||||
|
assert not np.any(embeddings), "all zeros"
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_function_with_pandas(tmp_path):
|
||||||
|
@register("mock-embedding")
|
||||||
|
class _MockEmbeddingFunction(TextEmbeddingFunction):
|
||||||
|
def ndims(self):
|
||||||
|
return 128
|
||||||
|
|
||||||
|
def generate_embeddings(
|
||||||
|
self, texts: Union[List[str], np.ndarray]
|
||||||
|
) -> List[np.array]:
|
||||||
|
return [np.random.randn(self.ndims()).tolist() for _ in range(len(texts))]
|
||||||
|
|
||||||
|
registery = get_registry()
|
||||||
|
func = registery.get("mock-embedding").create()
|
||||||
|
|
||||||
|
class TestSchema(LanceModel):
|
||||||
|
text: str = func.SourceField()
|
||||||
|
val: int
|
||||||
|
vector: Vector(func.ndims()) = func.VectorField()
|
||||||
|
|
||||||
|
df = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"text": ["hello world", "goodbye world"],
|
||||||
|
"val": [1, 2],
|
||||||
|
"not-used": ["s1", "s3"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
db = lancedb.connect(tmp_path)
|
||||||
|
tbl = db.create_table("test", schema=TestSchema, mode="overwrite", data=df)
|
||||||
|
schema = tbl.schema
|
||||||
|
assert schema.field("text").type == pa.string()
|
||||||
|
assert schema.field("val").type == pa.int64()
|
||||||
|
assert schema.field("vector").type == pa.list_(pa.float32(), 128)
|
||||||
|
|
||||||
|
df = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"text": ["extra", "more"],
|
||||||
|
"val": [4, 5],
|
||||||
|
"misc-col": ["s1", "s3"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tbl.add(df)
|
||||||
|
|
||||||
|
assert tbl.count_rows() == 4
|
||||||
|
embeddings = tbl.to_arrow()["vector"]
|
||||||
|
assert embeddings.null_count == 0
|
||||||
|
|
||||||
|
df = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"text": ["with", "embeddings"],
|
||||||
|
"val": [6, 7],
|
||||||
|
"vector": [np.zeros(128).tolist(), np.zeros(128).tolist()],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tbl.add(df)
|
||||||
|
|
||||||
|
embeddings = tbl.search().where("val > 5").to_arrow()["vector"].to_pylist()
|
||||||
|
assert not np.any(embeddings), "all zeros"
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_embeddings_for_pandas(tmp_path):
|
||||||
|
@register("mock-embedding")
|
||||||
|
class MockFunc1(TextEmbeddingFunction):
|
||||||
|
def ndims(self):
|
||||||
|
return 128
|
||||||
|
|
||||||
|
def generate_embeddings(
|
||||||
|
self, texts: Union[List[str], np.ndarray]
|
||||||
|
) -> List[np.array]:
|
||||||
|
return [np.random.randn(self.ndims()).tolist() for _ in range(len(texts))]
|
||||||
|
|
||||||
|
@register("mock-embedding2")
|
||||||
|
class MockFunc2(TextEmbeddingFunction):
|
||||||
|
def ndims(self):
|
||||||
|
return 512
|
||||||
|
|
||||||
|
def generate_embeddings(
|
||||||
|
self, texts: Union[List[str], np.ndarray]
|
||||||
|
) -> List[np.array]:
|
||||||
|
return [np.random.randn(self.ndims()).tolist() for _ in range(len(texts))]
|
||||||
|
|
||||||
|
registery = get_registry()
|
||||||
|
func1 = registery.get("mock-embedding").create()
|
||||||
|
func2 = registery.get("mock-embedding2").create()
|
||||||
|
|
||||||
|
class TestSchema(LanceModel):
|
||||||
|
text: str = func1.SourceField()
|
||||||
|
val: int
|
||||||
|
vec1: Vector(func1.ndims()) = func1.VectorField()
|
||||||
|
prompt: str = func2.SourceField()
|
||||||
|
vec2: Vector(func2.ndims()) = func2.VectorField()
|
||||||
|
|
||||||
|
df = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"text": ["hello world", "goodbye world"],
|
||||||
|
"val": [1, 2],
|
||||||
|
"prompt": ["hello", "goodbye"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
db = lancedb.connect(tmp_path)
|
||||||
|
tbl = db.create_table("test", schema=TestSchema, mode="overwrite", data=df)
|
||||||
|
|
||||||
|
schema = tbl.schema
|
||||||
|
assert schema.field("text").type == pa.string()
|
||||||
|
assert schema.field("val").type == pa.int64()
|
||||||
|
assert schema.field("vec1").type == pa.list_(pa.float32(), 128)
|
||||||
|
assert schema.field("prompt").type == pa.string()
|
||||||
|
assert schema.field("vec2").type == pa.list_(pa.float32(), 512)
|
||||||
|
assert tbl.count_rows() == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
def test_embedding_function_rate_limit(tmp_path):
|
def test_embedding_function_rate_limit(tmp_path):
|
||||||
def _get_schema_from_model(model):
|
def _get_schema_from_model(model):
|
||||||
@@ -196,6 +324,7 @@ def test_add_optional_vector(tmp_path):
|
|||||||
"ollama",
|
"ollama",
|
||||||
"cohere",
|
"cohere",
|
||||||
"instructor",
|
"instructor",
|
||||||
|
"voyageai",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_embedding_function_safe_model_dump(embedding_type):
|
def test_embedding_function_safe_model_dump(embedding_type):
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ import lancedb
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
|
||||||
from lancedb.embeddings import get_registry
|
from lancedb.embeddings import get_registry
|
||||||
from lancedb.pydantic import LanceModel, Vector
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
|
||||||
@@ -108,6 +107,7 @@ def test_basic_text_embeddings(alias, tmp_path):
|
|||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
def test_openclip(tmp_path):
|
def test_openclip(tmp_path):
|
||||||
|
import requests
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
db = lancedb.connect(tmp_path)
|
db = lancedb.connect(tmp_path)
|
||||||
@@ -481,3 +481,22 @@ def test_ollama_embedding(tmp_path):
|
|||||||
json.dumps(dumped_model)
|
json.dumps(dumped_model)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
pytest.fail("Failed to JSON serialize the dumped model")
|
pytest.fail("Failed to JSON serialize the dumped model")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||||
|
)
|
||||||
|
def test_voyageai_embedding_function():
|
||||||
|
voyageai = get_registry().get("voyageai").create(name="voyage-3", max_retries=0)
|
||||||
|
|
||||||
|
class TextModel(LanceModel):
|
||||||
|
text: str = voyageai.SourceField()
|
||||||
|
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
||||||
|
|
||||||
|
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||||
|
db = lancedb.connect("~/lancedb")
|
||||||
|
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||||
|
|
||||||
|
tbl.add(df)
|
||||||
|
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
|
||||||
|
|||||||
@@ -235,6 +235,29 @@ async def test_search_fts_async(async_table):
|
|||||||
results = await async_table.query().nearest_to_text("puppy").limit(5).to_list()
|
results = await async_table.query().nearest_to_text("puppy").limit(5).to_list()
|
||||||
assert len(results) == 5
|
assert len(results) == 5
|
||||||
|
|
||||||
|
expected_count = await async_table.count_rows(
|
||||||
|
"count > 5000 and contains(text, 'puppy')"
|
||||||
|
)
|
||||||
|
expected_count = min(expected_count, 10)
|
||||||
|
|
||||||
|
limited_results_pre_filter = await (
|
||||||
|
async_table.query()
|
||||||
|
.nearest_to_text("puppy")
|
||||||
|
.where("count > 5000")
|
||||||
|
.limit(10)
|
||||||
|
.to_list()
|
||||||
|
)
|
||||||
|
assert len(limited_results_pre_filter) == expected_count
|
||||||
|
limited_results_post_filter = await (
|
||||||
|
async_table.query()
|
||||||
|
.nearest_to_text("puppy")
|
||||||
|
.where("count > 5000")
|
||||||
|
.limit(10)
|
||||||
|
.postfilter()
|
||||||
|
.to_list()
|
||||||
|
)
|
||||||
|
assert len(limited_results_post_filter) <= expected_count
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_search_fts_specify_column_async(async_table):
|
async def test_search_fts_specify_column_async(async_table):
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ async def test_create_scalar_index(some_table: AsyncTable):
|
|||||||
# Can recreate if replace=True
|
# Can recreate if replace=True
|
||||||
await some_table.create_index("id", replace=True)
|
await some_table.create_index("id", replace=True)
|
||||||
indices = await some_table.list_indices()
|
indices = await some_table.list_indices()
|
||||||
assert str(indices) == '[Index(BTree, columns=["id"])]'
|
assert str(indices) == '[Index(BTree, columns=["id"], name="id_idx")]'
|
||||||
assert len(indices) == 1
|
assert len(indices) == 1
|
||||||
assert indices[0].index_type == "BTree"
|
assert indices[0].index_type == "BTree"
|
||||||
assert indices[0].columns == ["id"]
|
assert indices[0].columns == ["id"]
|
||||||
@@ -64,7 +64,7 @@ async def test_create_scalar_index(some_table: AsyncTable):
|
|||||||
async def test_create_bitmap_index(some_table: AsyncTable):
|
async def test_create_bitmap_index(some_table: AsyncTable):
|
||||||
await some_table.create_index("id", config=Bitmap())
|
await some_table.create_index("id", config=Bitmap())
|
||||||
indices = await some_table.list_indices()
|
indices = await some_table.list_indices()
|
||||||
assert str(indices) == '[Index(Bitmap, columns=["id"])]'
|
assert str(indices) == '[Index(Bitmap, columns=["id"], name="id_idx")]'
|
||||||
indices = await some_table.list_indices()
|
indices = await some_table.list_indices()
|
||||||
assert len(indices) == 1
|
assert len(indices) == 1
|
||||||
index_name = indices[0].name
|
index_name = indices[0].name
|
||||||
@@ -80,7 +80,7 @@ async def test_create_bitmap_index(some_table: AsyncTable):
|
|||||||
async def test_create_label_list_index(some_table: AsyncTable):
|
async def test_create_label_list_index(some_table: AsyncTable):
|
||||||
await some_table.create_index("tags", config=LabelList())
|
await some_table.create_index("tags", config=LabelList())
|
||||||
indices = await some_table.list_indices()
|
indices = await some_table.list_indices()
|
||||||
assert str(indices) == '[Index(LabelList, columns=["tags"])]'
|
assert str(indices) == '[Index(LabelList, columns=["tags"], name="tags_idx")]'
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import lance
|
import lance
|
||||||
import lancedb
|
import lancedb
|
||||||
|
from lancedb.index import IvfPq
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas.testing as tm
|
import pandas.testing as tm
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
@@ -330,6 +331,12 @@ async def test_query_async(table_async: AsyncTable):
|
|||||||
# Also check an empty query
|
# Also check an empty query
|
||||||
await check_query(table_async.query().where("id < 0"), expected_num_rows=0)
|
await check_query(table_async.query().where("id < 0"), expected_num_rows=0)
|
||||||
|
|
||||||
|
# with row id
|
||||||
|
await check_query(
|
||||||
|
table_async.query().select(["id", "vector"]).with_row_id(),
|
||||||
|
expected_columns=["id", "vector", "_rowid"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_to_arrow_async(table_async: AsyncTable):
|
async def test_query_to_arrow_async(table_async: AsyncTable):
|
||||||
@@ -358,6 +365,25 @@ async def test_query_to_pandas_async(table_async: AsyncTable):
|
|||||||
assert df.shape == (0, 4)
|
assert df.shape == (0, 4)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fast_search_async(tmp_path):
|
||||||
|
db = await lancedb.connect_async(tmp_path)
|
||||||
|
vectors = pa.FixedShapeTensorArray.from_numpy_ndarray(
|
||||||
|
np.random.rand(256, 32)
|
||||||
|
).storage
|
||||||
|
table = await db.create_table("test", pa.table({"vector": vectors}))
|
||||||
|
await table.create_index(
|
||||||
|
"vector", config=IvfPq(num_partitions=1, num_sub_vectors=1)
|
||||||
|
)
|
||||||
|
await table.add(pa.table({"vector": vectors}))
|
||||||
|
|
||||||
|
q = [1.0] * 32
|
||||||
|
plan = await table.query().nearest_to(q).explain_plan(True)
|
||||||
|
assert "LanceScan" in plan
|
||||||
|
plan = await table.query().nearest_to(q).fast_search().explain_plan(True)
|
||||||
|
assert "LanceScan" not in plan
|
||||||
|
|
||||||
|
|
||||||
def test_explain_plan(table):
|
def test_explain_plan(table):
|
||||||
q = LanceVectorQueryBuilder(table, [0, 0], "vector")
|
q = LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||||
plan = q.explain_plan(verbose=True)
|
plan = q.explain_plan(verbose=True)
|
||||||
|
|||||||
@@ -1,96 +0,0 @@
|
|||||||
# Copyright 2023 LanceDB Developers
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import attrs
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
import pyarrow as pa
|
|
||||||
import pytest
|
|
||||||
from aiohttp import web
|
|
||||||
from lancedb.remote.client import RestfulLanceDBClient, VectorQuery
|
|
||||||
|
|
||||||
|
|
||||||
@attrs.define
|
|
||||||
class MockLanceDBServer:
|
|
||||||
runner: web.AppRunner = attrs.field(init=False)
|
|
||||||
site: web.TCPSite = attrs.field(init=False)
|
|
||||||
|
|
||||||
async def query_handler(self, request: web.Request) -> web.Response:
|
|
||||||
table_name = request.match_info["table_name"]
|
|
||||||
assert table_name == "test_table"
|
|
||||||
|
|
||||||
await request.json()
|
|
||||||
# TODO: do some matching
|
|
||||||
|
|
||||||
vecs = pd.Series([np.random.rand(128) for x in range(10)], name="vector")
|
|
||||||
ids = pd.Series(range(10), name="id")
|
|
||||||
df = pd.DataFrame([vecs, ids]).T
|
|
||||||
|
|
||||||
batch = pa.RecordBatch.from_pandas(
|
|
||||||
df,
|
|
||||||
schema=pa.schema(
|
|
||||||
[
|
|
||||||
pa.field("vector", pa.list_(pa.float32(), 128)),
|
|
||||||
pa.field("id", pa.int64()),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
sink = pa.BufferOutputStream()
|
|
||||||
with pa.ipc.new_file(sink, batch.schema) as writer:
|
|
||||||
writer.write_batch(batch)
|
|
||||||
|
|
||||||
return web.Response(body=sink.getvalue().to_pybytes())
|
|
||||||
|
|
||||||
async def setup(self):
|
|
||||||
app = web.Application()
|
|
||||||
app.add_routes([web.post("/table/{table_name}", self.query_handler)])
|
|
||||||
self.runner = web.AppRunner(app)
|
|
||||||
await self.runner.setup()
|
|
||||||
self.site = web.TCPSite(self.runner, "localhost", 8111)
|
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
await self.site.start()
|
|
||||||
|
|
||||||
async def stop(self):
|
|
||||||
await self.runner.cleanup()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="flaky somehow, fix later")
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_e2e_with_mock_server():
|
|
||||||
mock_server = MockLanceDBServer()
|
|
||||||
await mock_server.setup()
|
|
||||||
await mock_server.start()
|
|
||||||
|
|
||||||
try:
|
|
||||||
with RestfulLanceDBClient("lancedb+http://localhost:8111") as client:
|
|
||||||
df = (
|
|
||||||
await client.query(
|
|
||||||
"test_table",
|
|
||||||
VectorQuery(
|
|
||||||
vector=np.random.rand(128).tolist(),
|
|
||||||
k=10,
|
|
||||||
_metric="L2",
|
|
||||||
columns=["id", "vector"],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
).to_pandas()
|
|
||||||
|
|
||||||
assert "vector" in df.columns
|
|
||||||
assert "id" in df.columns
|
|
||||||
|
|
||||||
assert client.closed
|
|
||||||
finally:
|
|
||||||
# make sure we don't leak resources
|
|
||||||
await mock_server.stop()
|
|
||||||
@@ -2,91 +2,19 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
from datetime import timedelta
|
||||||
import http.server
|
import http.server
|
||||||
|
import json
|
||||||
import threading
|
import threading
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import lancedb
|
import lancedb
|
||||||
|
from lancedb.conftest import MockTextEmbeddingFunction
|
||||||
|
from lancedb.remote import ClientConfig
|
||||||
from lancedb.remote.errors import HttpError, RetryError
|
from lancedb.remote.errors import HttpError, RetryError
|
||||||
import pyarrow as pa
|
|
||||||
from lancedb.remote.client import VectorQuery, VectorQueryResult
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
class FakeLanceDBClient:
|
|
||||||
def close(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
|
||||||
assert table_name == "test"
|
|
||||||
t = pa.schema([]).empty_table()
|
|
||||||
return VectorQueryResult(t)
|
|
||||||
|
|
||||||
def post(self, path: str):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def mount_retry_adapter_for_table(self, table_name: str):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def test_remote_db():
|
|
||||||
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
|
|
||||||
setattr(conn, "_client", FakeLanceDBClient())
|
|
||||||
|
|
||||||
table = conn["test"]
|
|
||||||
table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
|
|
||||||
table.search([1.0, 2.0]).to_pandas()
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_empty_table():
|
|
||||||
client = MagicMock()
|
|
||||||
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
|
|
||||||
|
|
||||||
conn._client = client
|
|
||||||
|
|
||||||
schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
|
|
||||||
|
|
||||||
client.post.return_value = {"status": "ok"}
|
|
||||||
table = conn.create_table("test", schema=schema)
|
|
||||||
assert table.name == "test"
|
|
||||||
assert client.post.call_args[0][0] == "/v1/table/test/create/"
|
|
||||||
|
|
||||||
json_schema = {
|
|
||||||
"fields": [
|
|
||||||
{
|
|
||||||
"name": "vector",
|
|
||||||
"nullable": True,
|
|
||||||
"type": {
|
|
||||||
"type": "fixed_size_list",
|
|
||||||
"fields": [
|
|
||||||
{"name": "item", "nullable": True, "type": {"type": "float"}}
|
|
||||||
],
|
|
||||||
"length": 2,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
client.post.return_value = {"schema": json_schema}
|
|
||||||
assert table.schema == schema
|
|
||||||
assert client.post.call_args[0][0] == "/v1/table/test/describe/"
|
|
||||||
|
|
||||||
client.post.return_value = 0
|
|
||||||
assert table.count_rows(None) == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_table_with_recordbatches():
|
|
||||||
client = MagicMock()
|
|
||||||
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
|
|
||||||
|
|
||||||
conn._client = client
|
|
||||||
|
|
||||||
batch = pa.RecordBatch.from_arrays([pa.array([[1.0, 2.0], [3.0, 4.0]])], ["vector"])
|
|
||||||
|
|
||||||
client.post.return_value = {"status": "ok"}
|
|
||||||
table = conn.create_table("test", [batch], schema=batch.schema)
|
|
||||||
assert table.name == "test"
|
|
||||||
assert client.post.call_args[0][0] == "/v1/table/test/create/"
|
|
||||||
|
|
||||||
|
|
||||||
def make_mock_http_handler(handler):
|
def make_mock_http_handler(handler):
|
||||||
@@ -100,8 +28,35 @@ def make_mock_http_handler(handler):
|
|||||||
return MockLanceDBHandler
|
return MockLanceDBHandler
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def mock_lancedb_connection(handler):
|
||||||
|
with http.server.HTTPServer(
|
||||||
|
("localhost", 8080), make_mock_http_handler(handler)
|
||||||
|
) as server:
|
||||||
|
handle = threading.Thread(target=server.serve_forever)
|
||||||
|
handle.start()
|
||||||
|
|
||||||
|
db = lancedb.connect(
|
||||||
|
"db://dev",
|
||||||
|
api_key="fake",
|
||||||
|
host_override="http://localhost:8080",
|
||||||
|
client_config={
|
||||||
|
"retry_config": {"retries": 2},
|
||||||
|
"timeout_config": {
|
||||||
|
"connect_timeout": 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
server.shutdown()
|
||||||
|
handle.join()
|
||||||
|
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def mock_lancedb_connection(handler):
|
async def mock_lancedb_connection_async(handler):
|
||||||
with http.server.HTTPServer(
|
with http.server.HTTPServer(
|
||||||
("localhost", 8080), make_mock_http_handler(handler)
|
("localhost", 8080), make_mock_http_handler(handler)
|
||||||
) as server:
|
) as server:
|
||||||
@@ -143,7 +98,7 @@ async def test_async_remote_db():
|
|||||||
request.end_headers()
|
request.end_headers()
|
||||||
request.wfile.write(b'{"tables": []}')
|
request.wfile.write(b'{"tables": []}')
|
||||||
|
|
||||||
async with mock_lancedb_connection(handler) as db:
|
async with mock_lancedb_connection_async(handler) as db:
|
||||||
table_names = await db.table_names()
|
table_names = await db.table_names()
|
||||||
assert table_names == []
|
assert table_names == []
|
||||||
|
|
||||||
@@ -159,12 +114,12 @@ async def test_http_error():
|
|||||||
request.end_headers()
|
request.end_headers()
|
||||||
request.wfile.write(b"Internal Server Error")
|
request.wfile.write(b"Internal Server Error")
|
||||||
|
|
||||||
async with mock_lancedb_connection(handler) as db:
|
async with mock_lancedb_connection_async(handler) as db:
|
||||||
with pytest.raises(HttpError, match="Internal Server Error") as exc_info:
|
with pytest.raises(HttpError) as exc_info:
|
||||||
await db.table_names()
|
await db.table_names()
|
||||||
|
|
||||||
assert exc_info.value.request_id == request_id_holder["request_id"]
|
assert exc_info.value.request_id == request_id_holder["request_id"]
|
||||||
assert exc_info.value.status_code == 507
|
assert "Internal Server Error" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -178,15 +133,253 @@ async def test_retry_error():
|
|||||||
request.end_headers()
|
request.end_headers()
|
||||||
request.wfile.write(b"Try again later")
|
request.wfile.write(b"Try again later")
|
||||||
|
|
||||||
async with mock_lancedb_connection(handler) as db:
|
async with mock_lancedb_connection_async(handler) as db:
|
||||||
with pytest.raises(RetryError, match="Hit retry limit") as exc_info:
|
with pytest.raises(RetryError) as exc_info:
|
||||||
await db.table_names()
|
await db.table_names()
|
||||||
|
|
||||||
assert exc_info.value.request_id == request_id_holder["request_id"]
|
assert exc_info.value.request_id == request_id_holder["request_id"]
|
||||||
assert exc_info.value.status_code == 429
|
|
||||||
|
|
||||||
cause = exc_info.value.__cause__
|
cause = exc_info.value.__cause__
|
||||||
assert isinstance(cause, HttpError)
|
assert isinstance(cause, HttpError)
|
||||||
assert "Try again later" in str(cause)
|
assert "Try again later" in str(cause)
|
||||||
assert cause.request_id == request_id_holder["request_id"]
|
assert cause.request_id == request_id_holder["request_id"]
|
||||||
assert cause.status_code == 429
|
assert cause.status_code == 429
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def query_test_table(query_handler):
|
||||||
|
def handler(request):
|
||||||
|
if request.path == "/v1/table/test/describe/":
|
||||||
|
request.send_response(200)
|
||||||
|
request.send_header("Content-Type", "application/json")
|
||||||
|
request.end_headers()
|
||||||
|
request.wfile.write(b"{}")
|
||||||
|
elif request.path == "/v1/table/test/query/":
|
||||||
|
content_len = int(request.headers.get("Content-Length"))
|
||||||
|
body = request.rfile.read(content_len)
|
||||||
|
body = json.loads(body)
|
||||||
|
|
||||||
|
data = query_handler(body)
|
||||||
|
|
||||||
|
request.send_response(200)
|
||||||
|
request.send_header("Content-Type", "application/vnd.apache.arrow.file")
|
||||||
|
request.end_headers()
|
||||||
|
|
||||||
|
with pa.ipc.new_file(request.wfile, schema=data.schema) as f:
|
||||||
|
f.write_table(data)
|
||||||
|
else:
|
||||||
|
request.send_response(404)
|
||||||
|
request.end_headers()
|
||||||
|
|
||||||
|
with mock_lancedb_connection(handler) as db:
|
||||||
|
assert repr(db) == "RemoteConnect(name=dev)"
|
||||||
|
table = db.open_table("test")
|
||||||
|
assert repr(table) == "RemoteTable(dev.test)"
|
||||||
|
yield table
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_sync_minimal():
|
||||||
|
def handler(body):
|
||||||
|
assert body == {
|
||||||
|
"distance_type": "l2",
|
||||||
|
"k": 10,
|
||||||
|
"prefilter": False,
|
||||||
|
"refine_factor": None,
|
||||||
|
"vector": [1.0, 2.0, 3.0],
|
||||||
|
"nprobes": 20,
|
||||||
|
}
|
||||||
|
|
||||||
|
return pa.table({"id": [1, 2, 3]})
|
||||||
|
|
||||||
|
with query_test_table(handler) as table:
|
||||||
|
data = table.search([1, 2, 3]).to_list()
|
||||||
|
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
|
||||||
|
assert data == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_sync_empty_query():
|
||||||
|
def handler(body):
|
||||||
|
assert body == {
|
||||||
|
"k": 10,
|
||||||
|
"filter": "true",
|
||||||
|
"vector": [],
|
||||||
|
"columns": ["id"],
|
||||||
|
}
|
||||||
|
|
||||||
|
return pa.table({"id": [1, 2, 3]})
|
||||||
|
|
||||||
|
with query_test_table(handler) as table:
|
||||||
|
data = table.search(None).where("true").select(["id"]).limit(10).to_list()
|
||||||
|
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
|
||||||
|
assert data == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_sync_maximal():
|
||||||
|
def handler(body):
|
||||||
|
assert body == {
|
||||||
|
"distance_type": "cosine",
|
||||||
|
"k": 42,
|
||||||
|
"prefilter": True,
|
||||||
|
"refine_factor": 10,
|
||||||
|
"vector": [1.0, 2.0, 3.0],
|
||||||
|
"nprobes": 5,
|
||||||
|
"filter": "id > 0",
|
||||||
|
"columns": ["id", "name"],
|
||||||
|
"vector_column": "vector2",
|
||||||
|
"fast_search": True,
|
||||||
|
"with_row_id": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
return pa.table({"id": [1, 2, 3], "name": ["a", "b", "c"]})
|
||||||
|
|
||||||
|
with query_test_table(handler) as table:
|
||||||
|
(
|
||||||
|
table.search([1, 2, 3], vector_column_name="vector2", fast_search=True)
|
||||||
|
.metric("cosine")
|
||||||
|
.limit(42)
|
||||||
|
.refine_factor(10)
|
||||||
|
.nprobes(5)
|
||||||
|
.where("id > 0", prefilter=True)
|
||||||
|
.with_row_id(True)
|
||||||
|
.select(["id", "name"])
|
||||||
|
.to_list()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_sync_multiple_vectors():
|
||||||
|
def handler(_body):
|
||||||
|
return pa.table({"id": [1]})
|
||||||
|
|
||||||
|
with query_test_table(handler) as table:
|
||||||
|
results = table.search([[1, 2, 3], [4, 5, 6]]).limit(1).to_list()
|
||||||
|
assert len(results) == 2
|
||||||
|
results.sort(key=lambda x: x["query_index"])
|
||||||
|
assert results == [{"id": 1, "query_index": 0}, {"id": 1, "query_index": 1}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_sync_fts():
|
||||||
|
def handler(body):
|
||||||
|
assert body == {
|
||||||
|
"full_text_query": {
|
||||||
|
"query": "puppy",
|
||||||
|
"columns": [],
|
||||||
|
},
|
||||||
|
"k": 10,
|
||||||
|
"vector": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
return pa.table({"id": [1, 2, 3]})
|
||||||
|
|
||||||
|
with query_test_table(handler) as table:
|
||||||
|
(table.search("puppy", query_type="fts").to_list())
|
||||||
|
|
||||||
|
def handler(body):
|
||||||
|
assert body == {
|
||||||
|
"full_text_query": {
|
||||||
|
"query": "puppy",
|
||||||
|
"columns": ["name", "description"],
|
||||||
|
},
|
||||||
|
"k": 42,
|
||||||
|
"vector": [],
|
||||||
|
"with_row_id": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
return pa.table({"id": [1, 2, 3]})
|
||||||
|
|
||||||
|
with query_test_table(handler) as table:
|
||||||
|
(
|
||||||
|
table.search("puppy", query_type="fts", fts_columns=["name", "description"])
|
||||||
|
.with_row_id(True)
|
||||||
|
.limit(42)
|
||||||
|
.to_list()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_sync_hybrid():
|
||||||
|
def handler(body):
|
||||||
|
if "full_text_query" in body:
|
||||||
|
# FTS query
|
||||||
|
assert body == {
|
||||||
|
"full_text_query": {
|
||||||
|
"query": "puppy",
|
||||||
|
"columns": [],
|
||||||
|
},
|
||||||
|
"k": 42,
|
||||||
|
"vector": [],
|
||||||
|
"with_row_id": True,
|
||||||
|
}
|
||||||
|
return pa.table({"_rowid": [1, 2, 3], "_score": [0.1, 0.2, 0.3]})
|
||||||
|
else:
|
||||||
|
# Vector query
|
||||||
|
assert body == {
|
||||||
|
"distance_type": "l2",
|
||||||
|
"k": 42,
|
||||||
|
"prefilter": False,
|
||||||
|
"refine_factor": None,
|
||||||
|
"vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||||
|
"nprobes": 20,
|
||||||
|
"with_row_id": True,
|
||||||
|
}
|
||||||
|
return pa.table({"_rowid": [1, 2, 3], "_distance": [0.1, 0.2, 0.3]})
|
||||||
|
|
||||||
|
with query_test_table(handler) as table:
|
||||||
|
embedding_func = MockTextEmbeddingFunction()
|
||||||
|
embedding_config = MagicMock()
|
||||||
|
embedding_config.function = embedding_func
|
||||||
|
|
||||||
|
embedding_funcs = MagicMock()
|
||||||
|
embedding_funcs.get = MagicMock(return_value=embedding_config)
|
||||||
|
table.embedding_functions = embedding_funcs
|
||||||
|
|
||||||
|
(table.search("puppy", query_type="hybrid").limit(42).to_list())
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_client():
|
||||||
|
mandatory_args = {
|
||||||
|
"uri": "db://dev",
|
||||||
|
"api_key": "fake-api-key",
|
||||||
|
"region": "us-east-1",
|
||||||
|
}
|
||||||
|
|
||||||
|
db = lancedb.connect(**mandatory_args)
|
||||||
|
assert isinstance(db.client_config, ClientConfig)
|
||||||
|
|
||||||
|
db = lancedb.connect(**mandatory_args, client_config={})
|
||||||
|
assert isinstance(db.client_config, ClientConfig)
|
||||||
|
|
||||||
|
db = lancedb.connect(
|
||||||
|
**mandatory_args,
|
||||||
|
client_config=ClientConfig(timeout_config={"connect_timeout": 42}),
|
||||||
|
)
|
||||||
|
assert isinstance(db.client_config, ClientConfig)
|
||||||
|
assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42)
|
||||||
|
|
||||||
|
db = lancedb.connect(
|
||||||
|
**mandatory_args,
|
||||||
|
client_config={"timeout_config": {"connect_timeout": timedelta(seconds=42)}},
|
||||||
|
)
|
||||||
|
assert isinstance(db.client_config, ClientConfig)
|
||||||
|
assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42)
|
||||||
|
|
||||||
|
db = lancedb.connect(
|
||||||
|
**mandatory_args, client_config=ClientConfig(retry_config={"retries": 42})
|
||||||
|
)
|
||||||
|
assert isinstance(db.client_config, ClientConfig)
|
||||||
|
assert db.client_config.retry_config.retries == 42
|
||||||
|
|
||||||
|
db = lancedb.connect(
|
||||||
|
**mandatory_args, client_config={"retry_config": {"retries": 42}}
|
||||||
|
)
|
||||||
|
assert isinstance(db.client_config, ClientConfig)
|
||||||
|
assert db.client_config.retry_config.retries == 42
|
||||||
|
|
||||||
|
with pytest.warns(DeprecationWarning):
|
||||||
|
db = lancedb.connect(**mandatory_args, connection_timeout=42)
|
||||||
|
assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42)
|
||||||
|
|
||||||
|
with pytest.warns(DeprecationWarning):
|
||||||
|
db = lancedb.connect(**mandatory_args, read_timeout=42)
|
||||||
|
assert db.client_config.timeout_config.read_timeout == timedelta(seconds=42)
|
||||||
|
|
||||||
|
with pytest.warns(DeprecationWarning):
|
||||||
|
lancedb.connect(**mandatory_args, request_thread_pool=10)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from lancedb.rerankers import (
|
|||||||
OpenaiReranker,
|
OpenaiReranker,
|
||||||
JinaReranker,
|
JinaReranker,
|
||||||
AnswerdotaiRerankers,
|
AnswerdotaiRerankers,
|
||||||
|
VoyageAIReranker,
|
||||||
)
|
)
|
||||||
from lancedb.table import LanceTable
|
from lancedb.table import LanceTable
|
||||||
|
|
||||||
@@ -344,3 +345,14 @@ def test_jina_reranker(tmp_path, use_tantivy):
|
|||||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||||
reranker = JinaReranker()
|
reranker = JinaReranker()
|
||||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||||
|
def test_voyageai_reranker(tmp_path, use_tantivy):
|
||||||
|
pytest.importorskip("voyageai")
|
||||||
|
reranker = VoyageAIReranker(model_name="rerank-2")
|
||||||
|
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||||
|
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user