diff --git a/.bumpversion.toml b/.bumpversion.toml index 57a76fbd..2caae397 100644 --- a/.bumpversion.toml +++ b/.bumpversion.toml @@ -1,5 +1,5 @@ [tool.bumpversion] -current_version = "0.12.0" +current_version = "0.13.0-beta.1" parse = """(?x) (?P0|[1-9]\\d*)\\. (?P0|[1-9]\\d*)\\. @@ -92,6 +92,11 @@ glob = "node/package.json" replace = "\"@lancedb/vectordb-win32-x64-msvc\": \"{new_version}\"" search = "\"@lancedb/vectordb-win32-x64-msvc\": \"{current_version}\"" +[[tool.bumpversion.files]] +glob = "node/package.json" +replace = "\"@lancedb/vectordb-win32-arm64-msvc\": \"{new_version}\"" +search = "\"@lancedb/vectordb-win32-arm64-msvc\": \"{current_version}\"" + # Cargo files # ------------ [[tool.bumpversion.files]] diff --git a/.cargo/config.toml b/.cargo/config.toml index ec0369f8..7a5e31a8 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -38,3 +38,7 @@ rustflags = ["-C", "target-cpu=apple-m1", "-C", "target-feature=+neon,+fp16,+fhm # not found errors on systems that are missing it. [target.x86_64-pc-windows-msvc] rustflags = ["-Ctarget-feature=+crt-static"] + +# Experimental target for Arm64 Windows +[target.aarch64-pc-windows-msvc] +rustflags = ["-Ctarget-feature=+crt-static"] \ No newline at end of file diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 25a941e8..ab02b499 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -35,21 +35,21 @@ jobs: CC: clang-18 CXX: clang++-18 steps: - - uses: actions/checkout@v4 - with: + - uses: actions/checkout@v4 + with: fetch-depth: 0 lfs: true - - uses: Swatinem/rust-cache@v2 - with: - workspaces: rust - - name: Install dependencies - run: | + - uses: Swatinem/rust-cache@v2 + with: + workspaces: rust + - name: Install dependencies + run: | sudo apt update sudo apt install -y protobuf-compiler libssl-dev - - name: Run format - run: cargo fmt --all -- --check - - name: Run clippy - run: cargo clippy --workspace --tests --all-features -- -D warnings + - name: Run format + run: cargo fmt --all -- --check + - name: Run clippy + run: cargo clippy --workspace --tests --all-features -- -D warnings linux: timeout-minutes: 30 # To build all features, we need more disk space than is available @@ -65,37 +65,37 @@ jobs: CC: clang-18 CXX: clang++-18 steps: - - uses: actions/checkout@v4 - with: + - uses: actions/checkout@v4 + with: fetch-depth: 0 lfs: true - - uses: Swatinem/rust-cache@v2 - with: + - uses: Swatinem/rust-cache@v2 + with: workspaces: rust - - name: Install dependencies - run: | + - name: Install dependencies + run: | sudo apt update sudo apt install -y protobuf-compiler libssl-dev - - name: Make Swap - run: | - sudo fallocate -l 16G /swapfile - sudo chmod 600 /swapfile - sudo mkswap /swapfile - sudo swapon /swapfile - - name: Start S3 integration test environment - working-directory: . - run: docker compose up --detach --wait - - name: Build - run: cargo build --all-features - - name: Run tests - run: cargo test --all-features - - name: Run examples - run: cargo run --example simple + - name: Make Swap + run: | + sudo fallocate -l 16G /swapfile + sudo chmod 600 /swapfile + sudo mkswap /swapfile + sudo swapon /swapfile + - name: Start S3 integration test environment + working-directory: . + run: docker compose up --detach --wait + - name: Build + run: cargo build --all-features + - name: Run tests + run: cargo test --all-features + - name: Run examples + run: cargo run --example simple macos: timeout-minutes: 30 strategy: matrix: - mac-runner: [ "macos-13", "macos-14" ] + mac-runner: ["macos-13", "macos-14"] runs-on: "${{ matrix.mac-runner }}" defaults: run: @@ -104,8 +104,8 @@ jobs: steps: - uses: actions/checkout@v4 with: - fetch-depth: 0 - lfs: true + fetch-depth: 0 + lfs: true - name: CPU features run: sysctl -a | grep cpu - uses: Swatinem/rust-cache@v2 @@ -139,3 +139,102 @@ jobs: $env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT cargo build cargo test + windows-arm64: + runs-on: windows-4x-arm + steps: + - name: Cache installations + id: cache-installs + uses: actions/cache@v4 + with: + path: | + C:\Program Files\Git + C:\BuildTools + C:\Program Files (x86)\Windows Kits + C:\Program Files\7-Zip + C:\protoc + key: ${{ runner.os }}-arm64-installs-v1 + restore-keys: | + ${{ runner.os }}-arm64-installs- + - name: Install Git + if: steps.cache-installs.outputs.cache-hit != 'true' + run: | + Invoke-WebRequest -Uri "https://github.com/git-for-windows/git/releases/download/v2.44.0.windows.1/Git-2.44.0-64-bit.exe" -OutFile "git-installer.exe" + Start-Process -FilePath "git-installer.exe" -ArgumentList "/VERYSILENT", "/NORESTART" -Wait + shell: powershell + - name: Add Git to PATH + run: | + Add-Content $env:GITHUB_PATH "C:\Program Files\Git\bin" + $env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User") + shell: powershell + - name: Configure Git symlinks + run: git config --global core.symlinks true + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.13" + - name: Install Visual Studio Build Tools + if: steps.cache-installs.outputs.cache-hit != 'true' + run: | + Invoke-WebRequest -Uri "https://aka.ms/vs/17/release/vs_buildtools.exe" -OutFile "vs_buildtools.exe" + Start-Process -FilePath "vs_buildtools.exe" -ArgumentList "--quiet", "--wait", "--norestart", "--nocache", ` + "--installPath", "C:\BuildTools", ` + "--add", "Microsoft.VisualStudio.Component.VC.Tools.ARM64", ` + "--add", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64", ` + "--add", "Microsoft.VisualStudio.Component.Windows11SDK.22621", ` + "--add", "Microsoft.VisualStudio.Component.VC.ATL", ` + "--add", "Microsoft.VisualStudio.Component.VC.ATLMFC", ` + "--add", "Microsoft.VisualStudio.Component.VC.Llvm.Clang" -Wait + shell: powershell + - name: Add Visual Studio Build Tools to PATH + run: | + $vsPath = "C:\BuildTools\VC\Tools\MSVC" + $latestVersion = (Get-ChildItem $vsPath | Sort-Object {[version]$_.Name} -Descending)[0].Name + Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\MSVC\$latestVersion\bin\Hostx64\arm64" + Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\MSVC\$latestVersion\bin\Hostx64\x64" + Add-Content $env:GITHUB_PATH "C:\Program Files (x86)\Windows Kits\10\bin\10.0.22621.0\arm64" + Add-Content $env:GITHUB_PATH "C:\Program Files (x86)\Windows Kits\10\bin\10.0.22621.0\x64" + Add-Content $env:GITHUB_PATH "C:\BuildTools\VC\Tools\Llvm\x64\bin" + + $env:LIB = "" + Add-Content $env:GITHUB_ENV "LIB=C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\um\arm64;C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\ucrt\arm64" + shell: powershell + - name: Install Rust + run: | + Invoke-WebRequest https://win.rustup.rs/x86_64 -OutFile rustup-init.exe + .\rustup-init.exe -y --default-host aarch64-pc-windows-msvc + shell: powershell + - name: Add Rust to PATH + run: | + Add-Content $env:GITHUB_PATH "$env:USERPROFILE\.cargo\bin" + shell: powershell + + - uses: Swatinem/rust-cache@v2 + with: + workspaces: rust + - name: Install 7-Zip ARM + if: steps.cache-installs.outputs.cache-hit != 'true' + run: | + New-Item -Path 'C:\7zip' -ItemType Directory + Invoke-WebRequest https://7-zip.org/a/7z2408-arm64.exe -OutFile C:\7zip\7z-installer.exe + Start-Process -FilePath C:\7zip\7z-installer.exe -ArgumentList '/S' -Wait + shell: powershell + - name: Add 7-Zip to PATH + run: Add-Content $env:GITHUB_PATH "C:\Program Files\7-Zip" + shell: powershell + - name: Install Protoc v21.12 + if: steps.cache-installs.outputs.cache-hit != 'true' + working-directory: C:\ + run: | + New-Item -Path 'C:\protoc' -ItemType Directory + Set-Location C:\protoc + Invoke-WebRequest https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win64.zip -OutFile C:\protoc\protoc.zip + & 'C:\Program Files\7-Zip\7z.exe' x protoc.zip + shell: powershell + - name: Add Protoc to PATH + run: Add-Content $env:GITHUB_PATH "C:\protoc\bin" + shell: powershell + - name: Run tests + run: | + $env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT + cargo build --target aarch64-pc-windows-msvc + cargo test --target aarch64-pc-windows-msvc diff --git a/ci/build_windows_artifacts.ps1 b/ci/build_windows_artifacts.ps1 index 039d4b95..02f2d207 100644 --- a/ci/build_windows_artifacts.ps1 +++ b/ci/build_windows_artifacts.ps1 @@ -3,6 +3,7 @@ # Targets supported: # - x86_64-pc-windows-msvc # - i686-pc-windows-msvc +# - aarch64-pc-windows-msvc function Prebuild-Rust { param ( @@ -31,7 +32,7 @@ function Build-NodeBinaries { $targets = $args[0] if (-not $targets) { - $targets = "x86_64-pc-windows-msvc" + $targets = "x86_64-pc-windows-msvc", "aarch64-pc-windows-msvc" } Write-Host "Building artifacts for targets: $targets" diff --git a/ci/build_windows_artifacts_nodejs.ps1 b/ci/build_windows_artifacts_nodejs.ps1 index b960f306..5c1ac4fa 100644 --- a/ci/build_windows_artifacts_nodejs.ps1 +++ b/ci/build_windows_artifacts_nodejs.ps1 @@ -3,6 +3,7 @@ # Targets supported: # - x86_64-pc-windows-msvc # - i686-pc-windows-msvc +# - aarch64-pc-windows-msvc function Prebuild-Rust { param ( @@ -31,7 +32,7 @@ function Build-NodeBinaries { $targets = $args[0] if (-not $targets) { - $targets = "x86_64-pc-windows-msvc" + $targets = "x86_64-pc-windows-msvc", "aarch64-pc-windows-msvc" } Write-Host "Building artifacts for targets: $targets" diff --git a/java/core/pom.xml b/java/core/pom.xml index 398bdc6f..557f92b3 100644 --- a/java/core/pom.xml +++ b/java/core/pom.xml @@ -8,7 +8,7 @@ com.lancedb lancedb-parent - 0.12.0-final.0 + 0.13.0-beta.1 ../pom.xml diff --git a/java/pom.xml b/java/pom.xml index acd4b9df..9c4790ce 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -6,7 +6,7 @@ com.lancedb lancedb-parent - 0.12.0-final.0 + 0.13.0-beta.1 pom LanceDB Parent diff --git a/node/package-lock.json b/node/package-lock.json index c81e6662..e160fcf8 100644 --- a/node/package-lock.json +++ b/node/package-lock.json @@ -1,12 +1,12 @@ { "name": "vectordb", - "version": "0.12.0", + "version": "0.13.0-beta.1", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "vectordb", - "version": "0.12.0", + "version": "0.13.0-beta.1", "cpu": [ "x64", "arm64" @@ -52,11 +52,12 @@ "uuid": "^9.0.0" }, "optionalDependencies": { - "@lancedb/vectordb-darwin-arm64": "0.12.0", - "@lancedb/vectordb-darwin-x64": "0.12.0", - "@lancedb/vectordb-linux-arm64-gnu": "0.12.0", - "@lancedb/vectordb-linux-x64-gnu": "0.12.0", - "@lancedb/vectordb-win32-x64-msvc": "0.12.0" + "@lancedb/vectordb-darwin-arm64": "0.13.0-beta.1", + "@lancedb/vectordb-darwin-x64": "0.13.0-beta.1", + "@lancedb/vectordb-linux-arm64-gnu": "0.13.0-beta.1", + "@lancedb/vectordb-linux-x64-gnu": "0.13.0-beta.1", + "@lancedb/vectordb-win32-arm64-msvc": "0.13.0-beta.1", + "@lancedb/vectordb-win32-x64-msvc": "0.13.0-beta.1" }, "peerDependencies": { "@apache-arrow/ts": "^14.0.2", @@ -326,6 +327,66 @@ "@jridgewell/sourcemap-codec": "^1.4.10" } }, + "node_modules/@lancedb/vectordb-darwin-arm64": { + "version": "0.13.0-beta.1", + "resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.13.0-beta.1.tgz", + "integrity": "sha512-beOrf6selCzzhLgDG8Nibma4nO/CSnA1wUKRmlJHEPtGcg7PW18z6MP/nfwQMpMR/FLRfTo8pPTbpzss47MiQQ==", + "cpu": [ + "arm64" + ], + "optional": true, + "os": [ + "darwin" + ] + }, + "node_modules/@lancedb/vectordb-darwin-x64": { + "version": "0.13.0-beta.1", + "resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.13.0-beta.1.tgz", + "integrity": "sha512-YdraGRF/RbJRkKh0v3xT03LUhq47T2GtCvJ5gZp8wKlh4pHa8LuhLU0DIdvmG/DT5vuQA+td8HDkBm/e3EOdNg==", + "cpu": [ + "x64" + ], + "optional": true, + "os": [ + "darwin" + ] + }, + "node_modules/@lancedb/vectordb-linux-arm64-gnu": { + "version": "0.13.0-beta.1", + "resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.13.0-beta.1.tgz", + "integrity": "sha512-Pp0O/uhEqof1oLaWrNbv+Ym+q8kBkiCqaA5+2eAZ6a3e9U+Ozkvb0FQrHuyi9adJ5wKQ4NabyQE9BMf2bYpOnQ==", + "cpu": [ + "arm64" + ], + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@lancedb/vectordb-linux-x64-gnu": { + "version": "0.13.0-beta.1", + "resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.13.0-beta.1.tgz", + "integrity": "sha512-y8nxOye4egfWF5FGED9EfkmZ1O5HnRLU4a61B8m5JSpkivO9v2epTcbYN0yt/7ZFCgtqMfJ8VW4Mi7qQcz3KDA==", + "cpu": [ + "x64" + ], + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@lancedb/vectordb-win32-x64-msvc": { + "version": "0.13.0-beta.1", + "resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.13.0-beta.1.tgz", + "integrity": "sha512-STMDP9dp0TBLkB3ro+16pKcGy6bmbhRuEZZZ1Tp5P75yTPeVh4zIgWkidMdU1qBbEYM7xacnsp9QAwgLnMU/Ow==", + "cpu": [ + "x64" + ], + "optional": true, + "os": [ + "win32" + ] + }, "node_modules/@neon-rs/cli": { "version": "0.0.160", "resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz", diff --git a/node/package.json b/node/package.json index f293b411..e3045277 100644 --- a/node/package.json +++ b/node/package.json @@ -1,6 +1,6 @@ { "name": "vectordb", - "version": "0.12.0", + "version": "0.13.0-beta.1", "description": " Serverless, low-latency vector database for AI applications", "main": "dist/index.js", "types": "dist/index.d.ts", @@ -84,14 +84,16 @@ "aarch64-apple-darwin": "@lancedb/vectordb-darwin-arm64", "x86_64-unknown-linux-gnu": "@lancedb/vectordb-linux-x64-gnu", "aarch64-unknown-linux-gnu": "@lancedb/vectordb-linux-arm64-gnu", - "x86_64-pc-windows-msvc": "@lancedb/vectordb-win32-x64-msvc" + "x86_64-pc-windows-msvc": "@lancedb/vectordb-win32-x64-msvc", + "aarch64-pc-windows-msvc": "@lancedb/vectordb-win32-arm64-msvc" } }, "optionalDependencies": { - "@lancedb/vectordb-darwin-arm64": "0.12.0", - "@lancedb/vectordb-darwin-x64": "0.12.0", - "@lancedb/vectordb-linux-arm64-gnu": "0.12.0", - "@lancedb/vectordb-linux-x64-gnu": "0.12.0", - "@lancedb/vectordb-win32-x64-msvc": "0.12.0" + "@lancedb/vectordb-darwin-arm64": "0.13.0-beta.1", + "@lancedb/vectordb-darwin-x64": "0.13.0-beta.1", + "@lancedb/vectordb-linux-arm64-gnu": "0.13.0-beta.1", + "@lancedb/vectordb-linux-x64-gnu": "0.13.0-beta.1", + "@lancedb/vectordb-win32-x64-msvc": "0.13.0-beta.1", + "@lancedb/vectordb-win32-arm64-msvc": "0.13.0-beta.1" } } diff --git a/node/src/remote/client.ts b/node/src/remote/client.ts index 01a11f07..cf99a182 100644 --- a/node/src/remote/client.ts +++ b/node/src/remote/client.ts @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -import axios, { type AxiosResponse, type ResponseType } from 'axios' +import axios, { type AxiosError, type AxiosResponse, type ResponseType } from 'axios' import { tableFromIPC, type Table as ArrowTable } from 'apache-arrow' @@ -197,7 +197,7 @@ export class HttpLancedbClient { response = await callWithMiddlewares(req, this._middlewares) return response } catch (err: any) { - console.error('error: ', err) + console.error(serializeErrorAsJson(err)) if (err.response === undefined) { throw new Error(`Network Error: ${err.message as string}`) } @@ -247,7 +247,8 @@ export class HttpLancedbClient { // return response } catch (err: any) { - console.error('error: ', err) + console.error(serializeErrorAsJson(err)) + if (err.response === undefined) { throw new Error(`Network Error: ${err.message as string}`) } @@ -287,3 +288,15 @@ export class HttpLancedbClient { return clone } } + +function serializeErrorAsJson(err: AxiosError) { + const error = JSON.parse(JSON.stringify(err, Object.getOwnPropertyNames(err))) + error.response = err.response != null + ? JSON.parse(JSON.stringify( + err.response, + // config contains the request data, too noisy + Object.getOwnPropertyNames(err.response).filter(prop => prop !== 'config') + )) + : null + return JSON.stringify({ error }) +} diff --git a/nodejs/Cargo.toml b/nodejs/Cargo.toml index 81733d0e..ba7af8da 100644 --- a/nodejs/Cargo.toml +++ b/nodejs/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "lancedb-nodejs" edition.workspace = true -version = "0.12.0" +version = "0.13.0-beta.1" license.workspace = true description.workspace = true repository.workspace = true diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index 5bf01dad..33d01858 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -402,6 +402,40 @@ describe("When creating an index", () => { expect(rst.numRows).toBe(1); }); + it("should be able to query unindexed data", async () => { + await tbl.createIndex("vec"); + await tbl.add([ + { + id: 300, + vec: Array(32) + .fill(1) + .map(() => Math.random()), + tags: [], + }, + ]); + + const plan1 = await tbl.query().nearestTo(queryVec).explainPlan(true); + expect(plan1).toMatch("LanceScan"); + + const plan2 = await tbl + .query() + .nearestTo(queryVec) + .fastSearch() + .explainPlan(true); + expect(plan2).not.toMatch("LanceScan"); + }); + + it("should be able to query with row id", async () => { + const results = await tbl + .query() + .nearestTo(queryVec) + .withRowId() + .limit(1) + .toArray(); + expect(results.length).toBe(1); + expect(results[0]).toHaveProperty("_rowid"); + }); + it("should allow parameters to be specified", async () => { await tbl.createIndex("vec", { config: Index.ivfPq({ diff --git a/nodejs/lancedb/query.ts b/nodejs/lancedb/query.ts index 8c0f51cf..32d58e05 100644 --- a/nodejs/lancedb/query.ts +++ b/nodejs/lancedb/query.ts @@ -239,6 +239,29 @@ export class QueryBase return this; } + /** + * Skip searching un-indexed data. This can make search faster, but will miss + * any data that is not yet indexed. + * + * Use {@link lancedb.Table#optimize} to index all un-indexed data. + */ + fastSearch(): this { + this.doCall((inner: NativeQueryType) => inner.fastSearch()); + return this; + } + + /** + * Whether to return the row id in the results. + * + * This column can be used to match results between different queries. For + * example, to match results from a full text search and a vector search in + * order to perform hybrid search. + */ + withRowId(): this { + this.doCall((inner: NativeQueryType) => inner.withRowId()); + return this; + } + protected nativeExecute( options?: Partial, ): Promise { diff --git a/nodejs/npm/darwin-arm64/package.json b/nodejs/npm/darwin-arm64/package.json index b9938915..41b27452 100644 --- a/nodejs/npm/darwin-arm64/package.json +++ b/nodejs/npm/darwin-arm64/package.json @@ -1,6 +1,6 @@ { "name": "@lancedb/lancedb-darwin-arm64", - "version": "0.12.0", + "version": "0.13.0-beta.1", "os": ["darwin"], "cpu": ["arm64"], "main": "lancedb.darwin-arm64.node", diff --git a/nodejs/npm/darwin-x64/package.json b/nodejs/npm/darwin-x64/package.json index 8b3da0f4..83ab1882 100644 --- a/nodejs/npm/darwin-x64/package.json +++ b/nodejs/npm/darwin-x64/package.json @@ -1,6 +1,6 @@ { "name": "@lancedb/lancedb-darwin-x64", - "version": "0.12.0", + "version": "0.13.0-beta.1", "os": ["darwin"], "cpu": ["x64"], "main": "lancedb.darwin-x64.node", diff --git a/nodejs/npm/linux-arm64-gnu/package.json b/nodejs/npm/linux-arm64-gnu/package.json index 55e3c7f2..74d73114 100644 --- a/nodejs/npm/linux-arm64-gnu/package.json +++ b/nodejs/npm/linux-arm64-gnu/package.json @@ -1,6 +1,6 @@ { "name": "@lancedb/lancedb-linux-arm64-gnu", - "version": "0.12.0", + "version": "0.13.0-beta.1", "os": ["linux"], "cpu": ["arm64"], "main": "lancedb.linux-arm64-gnu.node", diff --git a/nodejs/npm/linux-x64-gnu/package.json b/nodejs/npm/linux-x64-gnu/package.json index 37219174..984ad06b 100644 --- a/nodejs/npm/linux-x64-gnu/package.json +++ b/nodejs/npm/linux-x64-gnu/package.json @@ -1,6 +1,6 @@ { "name": "@lancedb/lancedb-linux-x64-gnu", - "version": "0.12.0", + "version": "0.13.0-beta.1", "os": ["linux"], "cpu": ["x64"], "main": "lancedb.linux-x64-gnu.node", diff --git a/nodejs/npm/win32-arm64-msvc/README.md b/nodejs/npm/win32-arm64-msvc/README.md new file mode 100644 index 00000000..939180c6 --- /dev/null +++ b/nodejs/npm/win32-arm64-msvc/README.md @@ -0,0 +1,3 @@ +# `@lancedb/lancedb-win32-arm64-msvc` + +This is the **aarch64-pc-windows-msvc** binary for `@lancedb/lancedb` diff --git a/nodejs/npm/win32-arm64-msvc/package.json b/nodejs/npm/win32-arm64-msvc/package.json new file mode 100644 index 00000000..0478cef7 --- /dev/null +++ b/nodejs/npm/win32-arm64-msvc/package.json @@ -0,0 +1,18 @@ +{ + "name": "@lancedb/lancedb-win32-arm64-msvc", + "version": "0.12.0", + "os": [ + "win32" + ], + "cpu": [ + "arm64" + ], + "main": "lancedb.win32-arm64-msvc.node", + "files": [ + "lancedb.win32-arm64-msvc.node" + ], + "license": "Apache 2.0", + "engines": { + "node": ">= 18" + } +} diff --git a/nodejs/npm/win32-x64-msvc/package.json b/nodejs/npm/win32-x64-msvc/package.json index 4c705e06..33c97c35 100644 --- a/nodejs/npm/win32-x64-msvc/package.json +++ b/nodejs/npm/win32-x64-msvc/package.json @@ -1,6 +1,6 @@ { "name": "@lancedb/lancedb-win32-x64-msvc", - "version": "0.12.0", + "version": "0.13.0-beta.1", "os": ["win32"], "cpu": ["x64"], "main": "lancedb.win32-x64-msvc.node", diff --git a/nodejs/package-lock.json b/nodejs/package-lock.json index a642d63b..f56ad672 100644 --- a/nodejs/package-lock.json +++ b/nodejs/package-lock.json @@ -1,12 +1,12 @@ { "name": "@lancedb/lancedb", - "version": "0.11.1-beta.1", + "version": "0.12.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@lancedb/lancedb", - "version": "0.11.1-beta.1", + "version": "0.12.0", "cpu": [ "x64", "arm64" diff --git a/nodejs/package.json b/nodejs/package.json index fbd76092..943d30fa 100644 --- a/nodejs/package.json +++ b/nodejs/package.json @@ -10,7 +10,7 @@ "vector database", "ann" ], - "version": "0.12.0", + "version": "0.13.0-beta.1", "main": "dist/index.js", "exports": { ".": "./dist/index.js", diff --git a/nodejs/src/lib.rs b/nodejs/src/lib.rs index a18bc75d..d0a02ee4 100644 --- a/nodejs/src/lib.rs +++ b/nodejs/src/lib.rs @@ -82,7 +82,7 @@ pub struct OpenTableOptions { #[napi::module_init] fn init() { let env = Env::new() - .filter_or("LANCEDB_LOG", "trace") + .filter_or("LANCEDB_LOG", "warn") .write_style("LANCEDB_LOG_STYLE"); env_logger::init_from_env(env); } diff --git a/nodejs/src/query.rs b/nodejs/src/query.rs index d0132699..448ca134 100644 --- a/nodejs/src/query.rs +++ b/nodejs/src/query.rs @@ -80,6 +80,16 @@ impl Query { Ok(VectorQuery { inner }) } + #[napi] + pub fn fast_search(&mut self) { + self.inner = self.inner.clone().fast_search(); + } + + #[napi] + pub fn with_row_id(&mut self) { + self.inner = self.inner.clone().with_row_id(); + } + #[napi(catch_unwind)] pub async fn execute( &self, @@ -183,6 +193,16 @@ impl VectorQuery { self.inner = self.inner.clone().offset(offset as usize); } + #[napi] + pub fn fast_search(&mut self) { + self.inner = self.inner.clone().fast_search(); + } + + #[napi] + pub fn with_row_id(&mut self) { + self.inner = self.inner.clone().with_row_id(); + } + #[napi(catch_unwind)] pub async fn execute( &self, diff --git a/python/.bumpversion.toml b/python/.bumpversion.toml index 4b25c2f0..f02351a1 100644 --- a/python/.bumpversion.toml +++ b/python/.bumpversion.toml @@ -1,5 +1,5 @@ [tool.bumpversion] -current_version = "0.15.0" +current_version = "0.16.0-beta.0" parse = """(?x) (?P0|[1-9]\\d*)\\. (?P0|[1-9]\\d*)\\. diff --git a/python/Cargo.toml b/python/Cargo.toml index 31c825cb..74a9e7fe 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lancedb-python" -version = "0.15.0" +version = "0.16.0-beta.0" edition.workspace = true description = "Python bindings for LanceDB" license.workspace = true diff --git a/python/pyproject.toml b/python/pyproject.toml index 10dc7375..86be43a1 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -3,13 +3,11 @@ name = "lancedb" # version in Cargo.toml dependencies = [ "deprecation", - "pylance==0.19.1", - "requests>=2.31.0", + "nest-asyncio~=1.0", + "pylance==0.19.2-beta.3", "tqdm>=4.27.0", "pydantic>=1.10", - "attrs>=21.3.0", "packaging", - "cachetools", "overrides>=0.7", ] description = "lancedb" @@ -61,6 +59,7 @@ dev = ["ruff", "pre-commit"] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] clip = ["torch", "pillow", "open-clip"] embeddings = [ + "requests>=2.31.0", "openai>=1.6.1", "sentence-transformers", "torch", diff --git a/python/python/lancedb/__init__.py b/python/python/lancedb/__init__.py index b394fa6f..2c5e521d 100644 --- a/python/python/lancedb/__init__.py +++ b/python/python/lancedb/__init__.py @@ -19,12 +19,10 @@ from typing import Dict, Optional, Union, Any __version__ = importlib.metadata.version("lancedb") -from lancedb.remote import ClientConfig - from ._lancedb import connect as lancedb_connect from .common import URI, sanitize_uri from .db import AsyncConnection, DBConnection, LanceDBConnection -from .remote.db import RemoteDBConnection +from .remote import ClientConfig from .schema import vector from .table import AsyncTable @@ -37,6 +35,7 @@ def connect( host_override: Optional[str] = None, read_consistency_interval: Optional[timedelta] = None, request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None, + client_config: Union[ClientConfig, Dict[str, Any], None] = None, **kwargs: Any, ) -> DBConnection: """Connect to a LanceDB database. @@ -64,14 +63,10 @@ def connect( the last check, then the table will be checked for updates. Note: this consistency only applies to read operations. Write operations are always consistent. - request_thread_pool: int or ThreadPoolExecutor, optional - The thread pool to use for making batch requests to the LanceDB Cloud API. - If an integer, then a ThreadPoolExecutor will be created with that - number of threads. If None, then a ThreadPoolExecutor will be created - with the default number of threads. If a ThreadPoolExecutor, then that - executor will be used for making requests. This is for LanceDB Cloud - only and is only used when making batch requests (i.e., passing in - multiple queries to the search method at once). + client_config: ClientConfig or dict, optional + Configuration options for the LanceDB Cloud HTTP client. If a dict, then + the keys are the attributes of the ClientConfig class. If None, then the + default configuration is used. Examples -------- @@ -94,6 +89,8 @@ def connect( conn : DBConnection A connection to a LanceDB database. """ + from .remote.db import RemoteDBConnection + if isinstance(uri, str) and uri.startswith("db://"): if api_key is None: api_key = os.environ.get("LANCEDB_API_KEY") @@ -106,7 +103,9 @@ def connect( api_key, region, host_override, + # TODO: remove this (deprecation warning downstream) request_thread_pool=request_thread_pool, + client_config=client_config, **kwargs, ) diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index 33b1f07c..bc4d6617 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -36,6 +36,8 @@ class Connection(object): data_storage_version: Optional[str] = None, enable_v2_manifest_paths: Optional[bool] = None, ) -> Table: ... + async def rename_table(self, old_name: str, new_name: str) -> None: ... + async def drop_table(self, name: str) -> None: ... class Table: def name(self) -> str: ... diff --git a/python/python/lancedb/db.py b/python/python/lancedb/db.py index 6af4cdb8..0a9e27d8 100644 --- a/python/python/lancedb/db.py +++ b/python/python/lancedb/db.py @@ -817,6 +817,18 @@ class AsyncConnection(object): table = await self._inner.open_table(name, storage_options, index_cache_size) return AsyncTable(table) + async def rename_table(self, old_name: str, new_name: str): + """Rename a table in the database. + + Parameters + ---------- + old_name: str + The current name of the table. + new_name: str + The new name of the table. + """ + await self._inner.rename_table(old_name, new_name) + async def drop_table(self, name: str): """Drop a table from the database. diff --git a/python/python/lancedb/embeddings/jinaai.py b/python/python/lancedb/embeddings/jinaai.py index 6619627d..5f89d97c 100644 --- a/python/python/lancedb/embeddings/jinaai.py +++ b/python/python/lancedb/embeddings/jinaai.py @@ -13,7 +13,6 @@ import os import io -import requests import base64 from urllib.parse import urlparse from pathlib import Path @@ -226,6 +225,8 @@ class JinaEmbeddings(EmbeddingFunction): return [result["embedding"] for result in sorted_embeddings] def _init_client(self): + import requests + if JinaEmbeddings._session is None: if self.api_key is None and os.environ.get("JINA_API_KEY") is None: api_key_not_found_help("jina") diff --git a/python/python/lancedb/index.py b/python/python/lancedb/index.py index b7e44b52..a1b06a29 100644 --- a/python/python/lancedb/index.py +++ b/python/python/lancedb/index.py @@ -467,6 +467,8 @@ class IvfPq: The default value is 256. """ + if distance_type is not None: + distance_type = distance_type.lower() self._inner = LanceDbIndex.ivf_pq( distance_type=distance_type, num_partitions=num_partitions, diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 1062289e..09eaa414 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -481,6 +481,7 @@ class LanceQueryBuilder(ABC): >>> plan = table.search(query).explain_plan(True) >>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE ProjectionExec: expr=[vector@0 as vector, _distance@2 as _distance] + GlobalLimitExec: skip=0, fetch=10 FilterExec: _distance@2 IS NOT NULL SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], preserve_partitioning=[false] KNNVectorDistance: metric=l2 @@ -500,7 +501,16 @@ class LanceQueryBuilder(ABC): nearest={ "column": self._vector_column, "q": self._query, + "k": self._limit, + "metric": self._metric, + "nprobes": self._nprobes, + "refine_factor": self._refine_factor, }, + prefilter=self._prefilter, + filter=self._str_query, + limit=self._limit, + with_row_id=self._with_row_id, + offset=self._offset, ).explain_plan(verbose) def vector(self, vector: Union[np.ndarray, list]) -> LanceQueryBuilder: @@ -1315,6 +1325,48 @@ class AsyncQueryBase(object): self._inner.offset(offset) return self + def fast_search(self) -> AsyncQuery: + """ + Skip searching un-indexed data. + + This can make queries faster, but will miss any data that has not been + indexed. + + !!! tip + You can add new data into an existing index by calling + [AsyncTable.optimize][lancedb.table.AsyncTable.optimize]. + """ + self._inner.fast_search() + return self + + def with_row_id(self) -> AsyncQuery: + """ + Include the _rowid column in the results. + """ + self._inner.with_row_id() + return self + + def postfilter(self) -> AsyncQuery: + """ + If this is called then filtering will happen after the search instead of + before. + By default filtering will be performed before the search. This is how + filtering is typically understood to work. This prefilter step does add some + additional latency. Creating a scalar index on the filter column(s) can + often improve this latency. However, sometimes a filter is too complex or + scalar indices cannot be applied to the column. In these cases postfiltering + can be used instead of prefiltering to improve latency. + Post filtering applies the filter to the results of the search. This + means we only run the filter on a much smaller set of data. However, it can + cause the query to return fewer than `limit` results (or even no results) if + none of the nearest results match the filter. + Post filtering happens during the "refine stage" (described in more detail in + @see {@link VectorQuery#refineFactor}). This means that setting a higher refine + factor can often help restore some of the results lost by post filtering. + """ + self._inner.postfilter() + return self + async def to_batches( self, *, max_batch_length: Optional[int] = None ) -> AsyncRecordBatchReader: @@ -1618,30 +1670,6 @@ class AsyncVectorQuery(AsyncQueryBase): self._inner.distance_type(distance_type) return self - def postfilter(self) -> AsyncVectorQuery: - """ - If this is called then filtering will happen after the vector search instead of - before. - - By default filtering will be performed before the vector search. This is how - filtering is typically understood to work. This prefilter step does add some - additional latency. Creating a scalar index on the filter column(s) can - often improve this latency. However, sometimes a filter is too complex or - scalar indices cannot be applied to the column. In these cases postfiltering - can be used instead of prefiltering to improve latency. - - Post filtering applies the filter to the results of the vector search. This - means we only run the filter on a much smaller set of data. However, it can - cause the query to return fewer than `limit` results (or even no results) if - none of the nearest results match the filter. - - Post filtering happens during the "refine stage" (described in more detail in - @see {@link VectorQuery#refineFactor}). This means that setting a higher refine - factor can often help restore some of the results lost by post filtering. - """ - self._inner.postfilter() - return self - def bypass_vector_index(self) -> AsyncVectorQuery: """ If this is called then any vector index is skipped diff --git a/python/python/lancedb/remote/__init__.py b/python/python/lancedb/remote/__init__.py index 98cbd2e5..e834c226 100644 --- a/python/python/lancedb/remote/__init__.py +++ b/python/python/lancedb/remote/__init__.py @@ -11,62 +11,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import abc -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import timedelta from typing import List, Optional -import attrs from lancedb import __version__ -import pyarrow as pa -from pydantic import BaseModel -from lancedb.common import VECTOR_COLUMN_NAME - -__all__ = ["LanceDBClient", "VectorQuery", "VectorQueryResult"] - - -class VectorQuery(BaseModel): - # vector to search for - vector: List[float] - - # sql filter to refine the query with - filter: Optional[str] = None - - # top k results to return - k: int - - # # metrics - _metric: str = "L2" - - # which columns to return in the results - columns: Optional[List[str]] = None - - # optional query parameters for tuning the results, - # e.g. `{"nprobes": "10", "refine_factor": "10"}` - nprobes: int = 10 - - refine_factor: Optional[int] = None - - vector_column: str = VECTOR_COLUMN_NAME - - fast_search: bool = False - - -@attrs.define -class VectorQueryResult: - # for now the response is directly seralized into a pandas dataframe - tbl: pa.Table - - def to_arrow(self) -> pa.Table: - return self.tbl - - -class LanceDBClient(abc.ABC): - @abc.abstractmethod - def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult: - """Query the LanceDB server for the given table and query.""" - pass +__all__ = ["TimeoutConfig", "RetryConfig", "ClientConfig"] @dataclass @@ -165,8 +116,8 @@ class RetryConfig: @dataclass class ClientConfig: user_agent: str = f"LanceDB-Python-Client/{__version__}" - retry_config: Optional[RetryConfig] = None - timeout_config: Optional[TimeoutConfig] = None + retry_config: RetryConfig = field(default_factory=RetryConfig) + timeout_config: Optional[TimeoutConfig] = field(default_factory=TimeoutConfig) def __post_init__(self): if isinstance(self.retry_config, dict): diff --git a/python/python/lancedb/remote/arrow.py b/python/python/lancedb/remote/arrow.py deleted file mode 100644 index ac39e247..00000000 --- a/python/python/lancedb/remote/arrow.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2023 LanceDB Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Iterable, Union -import pyarrow as pa - - -def to_ipc_binary(table: Union[pa.Table, Iterable[pa.RecordBatch]]) -> bytes: - """Serialize a PyArrow Table to IPC binary.""" - sink = pa.BufferOutputStream() - if isinstance(table, Iterable): - table = pa.Table.from_batches(table) - with pa.ipc.new_stream(sink, table.schema) as writer: - writer.write_table(table) - return sink.getvalue().to_pybytes() diff --git a/python/python/lancedb/remote/client.py b/python/python/lancedb/remote/client.py deleted file mode 100644 index d546e92f..00000000 --- a/python/python/lancedb/remote/client.py +++ /dev/null @@ -1,269 +0,0 @@ -# Copyright 2023 LanceDB Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import functools -import logging -import os -from typing import Any, Callable, Dict, List, Optional, Union -from urllib.parse import urljoin - -import attrs -import pyarrow as pa -import requests -from pydantic import BaseModel -from requests.adapters import HTTPAdapter -from urllib3 import Retry - -from lancedb.common import Credential -from lancedb.remote import VectorQuery, VectorQueryResult -from lancedb.remote.connection_timeout import LanceDBClientHTTPAdapterFactory -from lancedb.remote.errors import LanceDBClientError - -ARROW_STREAM_CONTENT_TYPE = "application/vnd.apache.arrow.stream" - - -def _check_not_closed(f): - @functools.wraps(f) - def wrapped(self, *args, **kwargs): - if self.closed: - raise ValueError("Connection is closed") - return f(self, *args, **kwargs) - - return wrapped - - -def _read_ipc(resp: requests.Response) -> pa.Table: - resp_body = resp.content - with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader: - return reader.read_all() - - -@attrs.define(slots=False) -class RestfulLanceDBClient: - db_name: str - region: str - api_key: Credential - host_override: Optional[str] = attrs.field(default=None) - - closed: bool = attrs.field(default=False, init=False) - - connection_timeout: float = attrs.field(default=120.0, kw_only=True) - read_timeout: float = attrs.field(default=300.0, kw_only=True) - - @functools.cached_property - def session(self) -> requests.Session: - sess = requests.Session() - - retry_adapter_instance = retry_adapter(retry_adapter_options()) - sess.mount(urljoin(self.url, "/v1/table/"), retry_adapter_instance) - - adapter_class = LanceDBClientHTTPAdapterFactory() - sess.mount("https://", adapter_class()) - return sess - - @property - def url(self) -> str: - return ( - self.host_override - or f"https://{self.db_name}.{self.region}.api.lancedb.com" - ) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.close() - return False # Do not suppress exceptions - - def close(self): - self.session.close() - self.closed = True - - @functools.cached_property - def headers(self) -> Dict[str, str]: - headers = { - "x-api-key": self.api_key, - } - if self.region == "local": # Local test mode - headers["Host"] = f"{self.db_name}.{self.region}.api.lancedb.com" - if self.host_override: - headers["x-lancedb-database"] = self.db_name - return headers - - @staticmethod - def _check_status(resp: requests.Response): - # Leaving request id empty for now, as we'll be replacing this impl - # with the Rust one shortly. - if resp.status_code == 404: - raise LanceDBClientError( - f"Not found: {resp.text}", request_id="", status_code=404 - ) - elif 400 <= resp.status_code < 500: - raise LanceDBClientError( - f"Bad Request: {resp.status_code}, error: {resp.text}", - request_id="", - status_code=resp.status_code, - ) - elif 500 <= resp.status_code < 600: - raise LanceDBClientError( - f"Internal Server Error: {resp.status_code}, error: {resp.text}", - request_id="", - status_code=resp.status_code, - ) - elif resp.status_code != 200: - raise LanceDBClientError( - f"Unknown Error: {resp.status_code}, error: {resp.text}", - request_id="", - status_code=resp.status_code, - ) - - @_check_not_closed - def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None): - """Send a GET request and returns the deserialized response payload.""" - if isinstance(params, BaseModel): - params: Dict[str, Any] = params.dict(exclude_none=True) - with self.session.get( - urljoin(self.url, uri), - params=params, - headers=self.headers, - timeout=(self.connection_timeout, self.read_timeout), - ) as resp: - self._check_status(resp) - return resp.json() - - @_check_not_closed - def post( - self, - uri: str, - data: Optional[Union[Dict[str, Any], BaseModel, bytes]] = None, - params: Optional[Dict[str, Any]] = None, - content_type: Optional[str] = None, - deserialize: Callable = lambda resp: resp.json(), - request_id: Optional[str] = None, - ) -> Dict[str, Any]: - """Send a POST request and returns the deserialized response payload. - - Parameters - ---------- - uri : str - The uri to send the POST request to. - data: Union[Dict[str, Any], BaseModel] - request_id: Optional[str] - Optional client side request id to be sent in the request headers. - - """ - if isinstance(data, BaseModel): - data: Dict[str, Any] = data.dict(exclude_none=True) - if isinstance(data, bytes): - req_kwargs = {"data": data} - else: - req_kwargs = {"json": data} - - headers = self.headers.copy() - if content_type is not None: - headers["content-type"] = content_type - if request_id is not None: - headers["x-request-id"] = request_id - with self.session.post( - urljoin(self.url, uri), - headers=headers, - params=params, - timeout=(self.connection_timeout, self.read_timeout), - **req_kwargs, - ) as resp: - self._check_status(resp) - return deserialize(resp) - - @_check_not_closed - def list_tables(self, limit: int, page_token: Optional[str] = None) -> List[str]: - """List all tables in the database.""" - if page_token is None: - page_token = "" - json = self.get("/v1/table/", {"limit": limit, "page_token": page_token}) - return json["tables"] - - @_check_not_closed - def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult: - """Query a table.""" - tbl = self.post(f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc) - return VectorQueryResult(tbl) - - def mount_retry_adapter_for_table(self, table_name: str) -> None: - """ - Adds an http adapter to session that will retry retryable requests to the table. - """ - retry_options = retry_adapter_options(methods=["GET", "POST"]) - retry_adapter_instance = retry_adapter(retry_options) - session = self.session - - session.mount( - urljoin(self.url, f"/v1/table/{table_name}/query/"), retry_adapter_instance - ) - session.mount( - urljoin(self.url, f"/v1/table/{table_name}/describe/"), - retry_adapter_instance, - ) - session.mount( - urljoin(self.url, f"/v1/table/{table_name}/index/list/"), - retry_adapter_instance, - ) - - -def retry_adapter_options(methods=["GET"]) -> Dict[str, Any]: - return { - "retries": int(os.environ.get("LANCE_CLIENT_MAX_RETRIES", "3")), - "connect_retries": int(os.environ.get("LANCE_CLIENT_CONNECT_RETRIES", "3")), - "read_retries": int(os.environ.get("LANCE_CLIENT_READ_RETRIES", "3")), - "backoff_factor": float( - os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_FACTOR", "0.25") - ), - "backoff_jitter": float( - os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_JITTER", "0.25") - ), - "statuses": [ - int(i.strip()) - for i in os.environ.get( - "LANCE_CLIENT_RETRY_STATUSES", "429, 500, 502, 503" - ).split(",") - ], - "methods": methods, - } - - -def retry_adapter(options: Dict[str, Any]) -> HTTPAdapter: - total_retries = options["retries"] - connect_retries = options["connect_retries"] - read_retries = options["read_retries"] - backoff_factor = options["backoff_factor"] - backoff_jitter = options["backoff_jitter"] - statuses = options["statuses"] - methods = frozenset(options["methods"]) - logging.debug( - f"Setting up retry adapter with {total_retries} retries," # noqa G003 - + f"connect retries {connect_retries}, read retries {read_retries}," - + f"backoff factor {backoff_factor}, statuses {statuses}, " - + f"methods {methods}" - ) - - return HTTPAdapter( - max_retries=Retry( - total=total_retries, - connect=connect_retries, - read=read_retries, - backoff_factor=backoff_factor, - backoff_jitter=backoff_jitter, - status_forcelist=statuses, - allowed_methods=methods, - ) - ) diff --git a/python/python/lancedb/remote/connection_timeout.py b/python/python/lancedb/remote/connection_timeout.py deleted file mode 100644 index f9d18e56..00000000 --- a/python/python/lancedb/remote/connection_timeout.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright 2024 LanceDB Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This module contains an adapter that will close connections if they have not been -# used before a certain timeout. This is necessary because some load balancers will -# close connections after a certain amount of time, but the request module may not yet -# have received the FIN/ACK and will try to reuse the connection. -# -# TODO some of the code here can be simplified if/when this PR is merged: -# https://github.com/urllib3/urllib3/pull/3275 - -import datetime -import logging -import os - -from requests.adapters import HTTPAdapter -from urllib3.connection import HTTPSConnection -from urllib3.connectionpool import HTTPSConnectionPool -from urllib3.poolmanager import PoolManager - - -def get_client_connection_timeout() -> int: - return int(os.environ.get("LANCE_CLIENT_CONNECTION_TIMEOUT", "300")) - - -class LanceDBHTTPSConnection(HTTPSConnection): - """ - HTTPSConnection that tracks the last time it was used. - """ - - idle_timeout: datetime.timedelta - last_activity: datetime.datetime - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.last_activity = datetime.datetime.now() - - def request(self, *args, **kwargs): - self.last_activity = datetime.datetime.now() - super().request(*args, **kwargs) - - def is_expired(self): - return datetime.datetime.now() - self.last_activity > self.idle_timeout - - -def LanceDBHTTPSConnectionPoolFactory(client_idle_timeout: int): - """ - Creates a connection pool class that can be used to close idle connections. - """ - - class LanceDBHTTPSConnectionPool(HTTPSConnectionPool): - # override the connection class - ConnectionCls = LanceDBHTTPSConnection - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def _get_conn(self, timeout: float | None = None): - logging.debug("Getting https connection") - conn = super()._get_conn(timeout) - if conn.is_expired(): - logging.debug("Closing expired connection") - conn.close() - - return conn - - def _new_conn(self): - conn = super()._new_conn() - conn.idle_timeout = datetime.timedelta(seconds=client_idle_timeout) - return conn - - return LanceDBHTTPSConnectionPool - - -class LanceDBClientPoolManager(PoolManager): - def __init__( - self, client_idle_timeout: int, num_pools: int, maxsize: int, **kwargs - ): - super().__init__(num_pools=num_pools, maxsize=maxsize, **kwargs) - # inject our connection pool impl - connection_pool_class = LanceDBHTTPSConnectionPoolFactory( - client_idle_timeout=client_idle_timeout - ) - self.pool_classes_by_scheme["https"] = connection_pool_class - - -def LanceDBClientHTTPAdapterFactory(): - """ - Creates an HTTPAdapter class that can be used to close idle connections - """ - - # closure over the timeout - client_idle_timeout = get_client_connection_timeout() - - class LanceDBClientRequestHTTPAdapter(HTTPAdapter): - def init_poolmanager(self, connections, maxsize, block=False): - # inject our pool manager impl - self.poolmanager = LanceDBClientPoolManager( - client_idle_timeout=client_idle_timeout, - num_pools=connections, - maxsize=maxsize, - block=block, - ) - - return LanceDBClientRequestHTTPAdapter diff --git a/python/python/lancedb/remote/db.py b/python/python/lancedb/remote/db.py index bb7554a4..51ef389e 100644 --- a/python/python/lancedb/remote/db.py +++ b/python/python/lancedb/remote/db.py @@ -11,13 +11,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio +from datetime import timedelta import logging -import uuid from concurrent.futures import ThreadPoolExecutor -from typing import Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Union from urllib.parse import urlparse +import warnings -from cachetools import TTLCache +from lancedb import connect_async +from lancedb.remote import ClientConfig import pyarrow as pa from overrides import override @@ -25,10 +28,8 @@ from ..common import DATA from ..db import DBConnection from ..embeddings import EmbeddingFunctionConfig from ..pydantic import LanceModel -from ..table import Table, sanitize_create_table +from ..table import Table from ..util import validate_table_name -from .arrow import to_ipc_binary -from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient class RemoteDBConnection(DBConnection): @@ -41,26 +42,70 @@ class RemoteDBConnection(DBConnection): region: str, host_override: Optional[str] = None, request_thread_pool: Optional[ThreadPoolExecutor] = None, - connection_timeout: float = 120.0, - read_timeout: float = 300.0, + client_config: Union[ClientConfig, Dict[str, Any], None] = None, + connection_timeout: Optional[float] = None, + read_timeout: Optional[float] = None, ): """Connect to a remote LanceDB database.""" + + if isinstance(client_config, dict): + client_config = ClientConfig(**client_config) + elif client_config is None: + client_config = ClientConfig() + + # These are legacy options from the old Python-based client. We keep them + # here for backwards compatibility, but will remove them in a future release. + if request_thread_pool is not None: + warnings.warn( + "request_thread_pool is no longer used and will be removed in " + "a future release.", + DeprecationWarning, + ) + + if connection_timeout is not None: + warnings.warn( + "connection_timeout is deprecated and will be removed in a future " + "release. Please use client_config.timeout_config.connect_timeout " + "instead.", + DeprecationWarning, + ) + client_config.timeout_config.connect_timeout = timedelta( + seconds=connection_timeout + ) + + if read_timeout is not None: + warnings.warn( + "read_timeout is deprecated and will be removed in a future release. " + "Please use client_config.timeout_config.read_timeout instead.", + DeprecationWarning, + ) + client_config.timeout_config.read_timeout = timedelta(seconds=read_timeout) + parsed = urlparse(db_url) if parsed.scheme != "db": raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://") - self._uri = str(db_url) self.db_name = parsed.netloc - self.api_key = api_key - self._client = RestfulLanceDBClient( - self.db_name, - region, - api_key, - host_override, - connection_timeout=connection_timeout, - read_timeout=read_timeout, + + import nest_asyncio + + nest_asyncio.apply() + try: + self._loop = asyncio.get_running_loop() + except RuntimeError: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + + self.client_config = client_config + + self._conn = self._loop.run_until_complete( + connect_async( + db_url, + api_key=api_key, + region=region, + host_override=host_override, + client_config=client_config, + ) ) - self._request_thread_pool = request_thread_pool - self._table_cache = TTLCache(maxsize=10000, ttl=300) def __repr__(self) -> str: return f"RemoteConnect(name={self.db_name})" @@ -82,16 +127,9 @@ class RemoteDBConnection(DBConnection): ------- An iterator of table names. """ - while True: - result = self._client.list_tables(limit, page_token) - - if len(result) > 0: - page_token = result[len(result) - 1] - else: - break - for item in result: - self._table_cache[item] = True - yield item + return self._loop.run_until_complete( + self._conn.table_names(start_after=page_token, limit=limit) + ) @override def open_table(self, name: str, *, index_cache_size: Optional[int] = None) -> Table: @@ -108,20 +146,14 @@ class RemoteDBConnection(DBConnection): """ from .table import RemoteTable - self._client.mount_retry_adapter_for_table(name) - if index_cache_size is not None: logging.info( "index_cache_size is ignored in LanceDb Cloud" " (there is no local cache to configure)" ) - # check if table exists - if self._table_cache.get(name) is None: - self._client.post(f"/v1/table/{name}/describe/") - self._table_cache[name] = True - - return RemoteTable(self, name) + table = self._loop.run_until_complete(self._conn.open_table(name)) + return RemoteTable(table, self.db_name, self._loop) @override def create_table( @@ -233,27 +265,20 @@ class RemoteDBConnection(DBConnection): "Please vote https://github.com/lancedb/lancedb/issues/626 " "for this feature." ) - if mode is not None: - logging.warning("mode is not yet supported on LanceDB Cloud.") - - data, schema = sanitize_create_table( - data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value - ) from .table import RemoteTable - data = to_ipc_binary(data) - request_id = uuid.uuid4().hex - - self._client.post( - f"/v1/table/{name}/create/", - data=data, - request_id=request_id, - content_type=ARROW_STREAM_CONTENT_TYPE, + table = self._loop.run_until_complete( + self._conn.create_table( + name, + data, + mode=mode, + schema=schema, + on_bad_vectors=on_bad_vectors, + fill_value=fill_value, + ) ) - - self._table_cache[name] = True - return RemoteTable(self, name) + return RemoteTable(table, self.db_name, self._loop) @override def drop_table(self, name: str): @@ -264,11 +289,7 @@ class RemoteDBConnection(DBConnection): name: str The name of the table. """ - - self._client.post( - f"/v1/table/{name}/drop/", - ) - self._table_cache.pop(name, default=None) + self._loop.run_until_complete(self._conn.drop_table(name)) @override def rename_table(self, cur_name: str, new_name: str): @@ -281,12 +302,7 @@ class RemoteDBConnection(DBConnection): new_name: str The new name of the table. """ - self._client.post( - f"/v1/table/{cur_name}/rename/", - data={"new_table_name": new_name}, - ) - self._table_cache.pop(cur_name, default=None) - self._table_cache[new_name] = True + self._loop.run_until_complete(self._conn.rename_table(cur_name, new_name)) async def close(self): """Close the connection to the database.""" diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index 986fbced..e2d88b98 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -11,53 +11,56 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import logging -import uuid -from concurrent.futures import Future from functools import cached_property from typing import Dict, Iterable, List, Optional, Union, Literal +from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfPq, LabelList import pyarrow as pa -from lance import json_to_schema from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME from lancedb.merge import LanceMergeInsertBuilder from lancedb.embeddings import EmbeddingFunctionRegistry from ..query import LanceVectorQueryBuilder, LanceQueryBuilder -from ..table import Query, Table, _sanitize_data -from ..util import value_to_sql, infer_vector_column_name -from .arrow import to_ipc_binary -from .client import ARROW_STREAM_CONTENT_TYPE -from .db import RemoteDBConnection +from ..table import AsyncTable, Query, Table class RemoteTable(Table): - def __init__(self, conn: RemoteDBConnection, name: str): - self._conn = conn - self.name = name + def __init__( + self, + table: AsyncTable, + db_name: str, + loop: Optional[asyncio.AbstractEventLoop] = None, + ): + self._loop = loop + self._table = table + self.db_name = db_name + + @property + def name(self) -> str: + """The name of the table""" + return self._table.name def __repr__(self) -> str: - return f"RemoteTable({self._conn.db_name}.{self.name})" + return f"RemoteTable({self.db_name}.{self.name})" def __len__(self) -> int: self.count_rows(None) - @cached_property + @property def schema(self) -> pa.Schema: """The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#) of this Table """ - resp = self._conn._client.post(f"/v1/table/{self.name}/describe/") - schema = json_to_schema(resp["schema"]) - return schema + return self._loop.run_until_complete(self._table.schema()) @property def version(self) -> int: """Get the current version of the table""" - resp = self._conn._client.post(f"/v1/table/{self.name}/describe/") - return resp["version"] + return self._loop.run_until_complete(self._table.version()) @cached_property def embedding_functions(self) -> dict: @@ -84,20 +87,18 @@ class RemoteTable(Table): def list_indices(self): """List all the indices on the table""" - resp = self._conn._client.post(f"/v1/table/{self.name}/index/list/") - return resp + return self._loop.run_until_complete(self._table.list_indices()) def index_stats(self, index_uuid: str): """List all the stats of a specified index""" - resp = self._conn._client.post( - f"/v1/table/{self.name}/index/{index_uuid}/stats/" - ) - return resp + return self._loop.run_until_complete(self._table.index_stats(index_uuid)) def create_scalar_index( self, column: str, index_type: Literal["BTREE", "BITMAP", "LABEL_LIST", "scalar"] = "scalar", + *, + replace: bool = False, ): """Creates a scalar index Parameters @@ -107,20 +108,23 @@ class RemoteTable(Table): or string column. index_type : str The index type of the scalar index. Must be "scalar" (BTREE), - "BTREE", "BITMAP", or "LABEL_LIST" + "BTREE", "BITMAP", or "LABEL_LIST", + replace : bool + If True, replace the existing index with the new one. """ + if index_type == "scalar" or index_type == "BTREE": + config = BTree() + elif index_type == "BITMAP": + config = Bitmap() + elif index_type == "LABEL_LIST": + config = LabelList() + else: + raise ValueError(f"Unknown index type: {index_type}") - data = { - "column": column, - "index_type": index_type, - "replace": True, - } - resp = self._conn._client.post( - f"/v1/table/{self.name}/create_scalar_index/", data=data + self._loop.run_until_complete( + self._table.create_index(column, config=config, replace=replace) ) - return resp - def create_fts_index( self, column: str, @@ -128,15 +132,10 @@ class RemoteTable(Table): replace: bool = False, with_position: bool = True, ): - data = { - "column": column, - "index_type": "FTS", - "replace": replace, - } - resp = self._conn._client.post( - f"/v1/table/{self.name}/create_index/", data=data + config = FTS(with_position=with_position) + self._loop.run_until_complete( + self._table.create_index(column, config=config, replace=replace) ) - return resp def create_index( self, @@ -204,17 +203,22 @@ class RemoteTable(Table): "Existing indexes will always be replaced." ) - data = { - "column": vector_column_name, - "index_type": index_type, - "metric_type": metric, - "index_cache_size": index_cache_size, - } - resp = self._conn._client.post( - f"/v1/table/{self.name}/create_index/", data=data - ) + index_type = index_type.upper() + if index_type == "VECTOR" or index_type == "IVF_PQ": + config = IvfPq(distance_type=metric) + elif index_type == "IVF_HNSW_PQ": + config = HnswPq(distance_type=metric) + elif index_type == "IVF_HNSW_SQ": + config = HnswSq(distance_type=metric) + else: + raise ValueError( + f"Unknown vector index type: {index_type}. Valid options are" + " 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'" + ) - return resp + self._loop.run_until_complete( + self._table.create_index(vector_column_name, config=config) + ) def add( self, @@ -246,22 +250,10 @@ class RemoteTable(Table): The value to use when filling vectors. Only used if on_bad_vectors="fill". """ - data, _ = _sanitize_data( - data, - self.schema, - metadata=self.schema.metadata, - on_bad_vectors=on_bad_vectors, - fill_value=fill_value, - ) - payload = to_ipc_binary(data) - - request_id = uuid.uuid4().hex - - self._conn._client.post( - f"/v1/table/{self.name}/insert/", - data=payload, - params={"request_id": request_id, "mode": mode}, - content_type=ARROW_STREAM_CONTENT_TYPE, + self._loop.run_until_complete( + self._table.add( + data, mode=mode, on_bad_vectors=on_bad_vectors, fill_value=fill_value + ) ) def search( @@ -337,12 +329,6 @@ class RemoteTable(Table): # empty query builder is not supported in saas, raise error if query is None and query_type != "hybrid": raise ValueError("Empty query is not supported") - vector_column_name = infer_vector_column_name( - schema=self.schema, - query_type=query_type, - query=query, - vector_column_name=vector_column_name, - ) return LanceQueryBuilder.create( self, @@ -356,37 +342,9 @@ class RemoteTable(Table): def _execute_query( self, query: Query, batch_size: Optional[int] = None ) -> pa.RecordBatchReader: - if ( - query.vector is not None - and len(query.vector) > 0 - and not isinstance(query.vector[0], float) - ): - if self._conn._request_thread_pool is None: - - def submit(name, q): - f = Future() - f.set_result(self._conn._client.query(name, q)) - return f - - else: - - def submit(name, q): - return self._conn._request_thread_pool.submit( - self._conn._client.query, name, q - ) - - results = [] - for v in query.vector: - v = list(v) - q = query.copy() - q.vector = v - results.append(submit(self.name, q)) - return pa.concat_tables( - [add_index(r.result().to_arrow(), i) for i, r in enumerate(results)] - ).to_reader() - else: - result = self._conn._client.query(self.name, query) - return result.to_arrow().to_reader() + return self._loop.run_until_complete( + self._table._execute_query(query, batch_size=batch_size) + ) def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder: """Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder] @@ -403,42 +361,8 @@ class RemoteTable(Table): on_bad_vectors: str, fill_value: float, ): - data, _ = _sanitize_data( - new_data, - self.schema, - metadata=None, - on_bad_vectors=on_bad_vectors, - fill_value=fill_value, - ) - payload = to_ipc_binary(data) - - params = {} - if len(merge._on) != 1: - raise ValueError( - "RemoteTable only supports a single on key in merge_insert" - ) - params["on"] = merge._on[0] - params["when_matched_update_all"] = str(merge._when_matched_update_all).lower() - if merge._when_matched_update_all_condition is not None: - params["when_matched_update_all_filt"] = ( - merge._when_matched_update_all_condition - ) - params["when_not_matched_insert_all"] = str( - merge._when_not_matched_insert_all - ).lower() - params["when_not_matched_by_source_delete"] = str( - merge._when_not_matched_by_source_delete - ).lower() - if merge._when_not_matched_by_source_condition is not None: - params["when_not_matched_by_source_delete_filt"] = ( - merge._when_not_matched_by_source_condition - ) - - self._conn._client.post( - f"/v1/table/{self.name}/merge_insert/", - data=payload, - params=params, - content_type=ARROW_STREAM_CONTENT_TYPE, + self._loop.run_until_complete( + self._table._do_merge(merge, new_data, on_bad_vectors, fill_value) ) def delete(self, predicate: str): @@ -488,8 +412,7 @@ class RemoteTable(Table): x vector _distance # doctest: +SKIP 0 2 [3.0, 4.0] 85.0 # doctest: +SKIP """ - payload = {"predicate": predicate} - self._conn._client.post(f"/v1/table/{self.name}/delete/", data=payload) + self._loop.run_until_complete(self._table.delete(predicate)) def update( self, @@ -539,18 +462,9 @@ class RemoteTable(Table): 2 2 [10.0, 10.0] # doctest: +SKIP """ - if values is not None and values_sql is not None: - raise ValueError("Only one of values or values_sql can be provided") - if values is None and values_sql is None: - raise ValueError("Either values or values_sql must be provided") - - if values is not None: - updates = [[k, value_to_sql(v)] for k, v in values.items()] - else: - updates = [[k, v] for k, v in values_sql.items()] - - payload = {"predicate": where, "updates": updates} - self._conn._client.post(f"/v1/table/{self.name}/update/", data=payload) + self._loop.run_until_complete( + self._table.update(where=where, updates=values, updates_sql=values_sql) + ) def cleanup_old_versions(self, *_): """cleanup_old_versions() is not supported on the LanceDB cloud""" @@ -565,11 +479,7 @@ class RemoteTable(Table): ) def count_rows(self, filter: Optional[str] = None) -> int: - payload = {"predicate": filter} - resp = self._conn._client.post( - f"/v1/table/{self.name}/count_rows/", data=payload - ) - return resp + return self._loop.run_until_complete(self._table.count_rows(filter)) def add_columns(self, transforms: Dict[str, str]): raise NotImplementedError( diff --git a/python/python/lancedb/rerankers/jinaai.py b/python/python/lancedb/rerankers/jinaai.py index 6be646bd..c44355cf 100644 --- a/python/python/lancedb/rerankers/jinaai.py +++ b/python/python/lancedb/rerankers/jinaai.py @@ -12,7 +12,6 @@ # limitations under the License. import os -import requests from functools import cached_property from typing import Union @@ -57,6 +56,8 @@ class JinaReranker(Reranker): @cached_property def _client(self): + import requests + if os.environ.get("JINA_API_KEY") is None and self.api_key is None: raise ValueError( "JINA_API_KEY not set. Either set it in your environment or \ diff --git a/python/python/lancedb/rerankers/linear_combination.py b/python/python/lancedb/rerankers/linear_combination.py index 8bcfb5e3..1aa8d6a1 100644 --- a/python/python/lancedb/rerankers/linear_combination.py +++ b/python/python/lancedb/rerankers/linear_combination.py @@ -11,7 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from numpy import NaN +from numpy import nan import pyarrow as pa from .base import Reranker @@ -71,7 +71,7 @@ class LinearCombinationReranker(Reranker): elif self.score == "all": results = results.append_column( "_distance", - pa.array([NaN] * len(fts_results), type=pa.float32()), + pa.array([nan] * len(fts_results), type=pa.float32()), ) return results @@ -92,7 +92,7 @@ class LinearCombinationReranker(Reranker): elif self.score == "all": results = results.append_column( "_score", - pa.array([NaN] * len(vector_results), type=pa.float32()), + pa.array([nan] * len(vector_results), type=pa.float32()), ) return results diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 59dc4487..18e2c266 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -62,7 +62,7 @@ if TYPE_CHECKING: from lance.dataset import CleanupStats, ReaderLike from ._lancedb import Table as LanceDBTable, OptimizeStats from .db import LanceDBConnection - from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList, FTS + from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList, FTS, HnswPq, HnswSq pd = safe_import_pandas() pl = safe_import_polars() @@ -948,7 +948,9 @@ class Table(ABC): return _table_uri(self._conn.uri, self.name) def _get_fts_index_path(self) -> Tuple[str, pa_fs.FileSystem, bool]: - if get_uri_scheme(self._dataset_uri) != "file": + from .remote.table import RemoteTable + + if isinstance(self, RemoteTable) or get_uri_scheme(self._dataset_uri) != "file": return ("", None, False) path = join_uri(self._dataset_uri, "_indices", "fts") fs, path = fs_from_uri(path) @@ -2382,7 +2384,9 @@ class AsyncTable: column: str, *, replace: Optional[bool] = None, - config: Optional[Union[IvfPq, BTree, Bitmap, LabelList, FTS]] = None, + config: Optional[ + Union[IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS] + ] = None, ): """Create an index to speed up queries @@ -2535,7 +2539,44 @@ class AsyncTable: async def _execute_query( self, query: Query, batch_size: Optional[int] = None ) -> pa.RecordBatchReader: - pass + # The sync remote table calls into this method, so we need to map the + # query to the async version of the query and run that here. This is only + # used for that code path right now. + async_query = self.query().limit(query.k) + if query.offset > 0: + async_query = async_query.offset(query.offset) + if query.columns: + async_query = async_query.select(query.columns) + if query.filter: + async_query = async_query.where(query.filter) + if query.fast_search: + async_query = async_query.fast_search() + if query.with_row_id: + async_query = async_query.with_row_id() + + if query.vector: + async_query = ( + async_query.nearest_to(query.vector) + .distance_type(query.metric) + .nprobes(query.nprobes) + ) + if query.refine_factor: + async_query = async_query.refine_factor(query.refine_factor) + if query.vector_column: + async_query = async_query.column(query.vector_column) + + if not query.prefilter: + async_query = async_query.postfilter() + + if isinstance(query.full_text_query, str): + async_query = async_query.nearest_to_text(query.full_text_query) + elif isinstance(query.full_text_query, dict): + fts_query = query.full_text_query["query"] + fts_columns = query.full_text_query.get("columns", []) or [] + async_query = async_query.nearest_to_text(fts_query, columns=fts_columns) + + table = await async_query.to_arrow() + return table.to_reader() async def _do_merge( self, @@ -2781,7 +2822,7 @@ class AsyncTable: cleanup_older_than = round(cleanup_older_than.total_seconds() * 1000) return await self._inner.optimize(cleanup_older_than, delete_unverified) - async def list_indices(self) -> IndexConfig: + async def list_indices(self) -> Iterable[IndexConfig]: """ List all indices that have been created with Self::create_index """ @@ -2865,3 +2906,8 @@ class IndexStatistics: ] distance_type: Optional[Literal["l2", "cosine", "dot"]] = None num_indices: Optional[int] = None + + # This exists for backwards compatibility with an older API, which returned + # a dictionary instead of a class. + def __getitem__(self, key): + return getattr(self, key) diff --git a/python/python/tests/test_embeddings_slow.py b/python/python/tests/test_embeddings_slow.py index 87b2e249..9e17ca66 100644 --- a/python/python/tests/test_embeddings_slow.py +++ b/python/python/tests/test_embeddings_slow.py @@ -18,7 +18,6 @@ import lancedb import numpy as np import pandas as pd import pytest -import requests from lancedb.embeddings import get_registry from lancedb.pydantic import LanceModel, Vector @@ -108,6 +107,7 @@ def test_basic_text_embeddings(alias, tmp_path): @pytest.mark.slow def test_openclip(tmp_path): + import requests from PIL import Image db = lancedb.connect(tmp_path) diff --git a/python/python/tests/test_fts.py b/python/python/tests/test_fts.py index ce649581..594552a0 100644 --- a/python/python/tests/test_fts.py +++ b/python/python/tests/test_fts.py @@ -235,6 +235,29 @@ async def test_search_fts_async(async_table): results = await async_table.query().nearest_to_text("puppy").limit(5).to_list() assert len(results) == 5 + expected_count = await async_table.count_rows( + "count > 5000 and contains(text, 'puppy')" + ) + expected_count = min(expected_count, 10) + + limited_results_pre_filter = await ( + async_table.query() + .nearest_to_text("puppy") + .where("count > 5000") + .limit(10) + .to_list() + ) + assert len(limited_results_pre_filter) == expected_count + limited_results_post_filter = await ( + async_table.query() + .nearest_to_text("puppy") + .where("count > 5000") + .limit(10) + .postfilter() + .to_list() + ) + assert len(limited_results_post_filter) <= expected_count + @pytest.mark.asyncio async def test_search_fts_specify_column_async(async_table): diff --git a/python/python/tests/test_index.py b/python/python/tests/test_index.py index 1245997e..3268179b 100644 --- a/python/python/tests/test_index.py +++ b/python/python/tests/test_index.py @@ -49,7 +49,7 @@ async def test_create_scalar_index(some_table: AsyncTable): # Can recreate if replace=True await some_table.create_index("id", replace=True) indices = await some_table.list_indices() - assert str(indices) == '[Index(BTree, columns=["id"])]' + assert str(indices) == '[Index(BTree, columns=["id"], name="id_idx")]' assert len(indices) == 1 assert indices[0].index_type == "BTree" assert indices[0].columns == ["id"] @@ -64,7 +64,7 @@ async def test_create_scalar_index(some_table: AsyncTable): async def test_create_bitmap_index(some_table: AsyncTable): await some_table.create_index("id", config=Bitmap()) indices = await some_table.list_indices() - assert str(indices) == '[Index(Bitmap, columns=["id"])]' + assert str(indices) == '[Index(Bitmap, columns=["id"], name="id_idx")]' indices = await some_table.list_indices() assert len(indices) == 1 index_name = indices[0].name @@ -80,7 +80,7 @@ async def test_create_bitmap_index(some_table: AsyncTable): async def test_create_label_list_index(some_table: AsyncTable): await some_table.create_index("tags", config=LabelList()) indices = await some_table.list_indices() - assert str(indices) == '[Index(LabelList, columns=["tags"])]' + assert str(indices) == '[Index(LabelList, columns=["tags"], name="tags_idx")]' @pytest.mark.asyncio diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index 11750e4d..b3f0d26a 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -17,6 +17,7 @@ from typing import Optional import lance import lancedb +from lancedb.index import IvfPq import numpy as np import pandas.testing as tm import pyarrow as pa @@ -330,6 +331,12 @@ async def test_query_async(table_async: AsyncTable): # Also check an empty query await check_query(table_async.query().where("id < 0"), expected_num_rows=0) + # with row id + await check_query( + table_async.query().select(["id", "vector"]).with_row_id(), + expected_columns=["id", "vector", "_rowid"], + ) + @pytest.mark.asyncio async def test_query_to_arrow_async(table_async: AsyncTable): @@ -358,6 +365,25 @@ async def test_query_to_pandas_async(table_async: AsyncTable): assert df.shape == (0, 4) +@pytest.mark.asyncio +async def test_fast_search_async(tmp_path): + db = await lancedb.connect_async(tmp_path) + vectors = pa.FixedShapeTensorArray.from_numpy_ndarray( + np.random.rand(256, 32) + ).storage + table = await db.create_table("test", pa.table({"vector": vectors})) + await table.create_index( + "vector", config=IvfPq(num_partitions=1, num_sub_vectors=1) + ) + await table.add(pa.table({"vector": vectors})) + + q = [1.0] * 32 + plan = await table.query().nearest_to(q).explain_plan(True) + assert "LanceScan" in plan + plan = await table.query().nearest_to(q).fast_search().explain_plan(True) + assert "LanceScan" not in plan + + def test_explain_plan(table): q = LanceVectorQueryBuilder(table, [0, 0], "vector") plan = q.explain_plan(verbose=True) diff --git a/python/python/tests/test_remote_client.py b/python/python/tests/test_remote_client.py deleted file mode 100644 index f5874953..00000000 --- a/python/python/tests/test_remote_client.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright 2023 LanceDB Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import attrs -import numpy as np -import pandas as pd -import pyarrow as pa -import pytest -from aiohttp import web -from lancedb.remote.client import RestfulLanceDBClient, VectorQuery - - -@attrs.define -class MockLanceDBServer: - runner: web.AppRunner = attrs.field(init=False) - site: web.TCPSite = attrs.field(init=False) - - async def query_handler(self, request: web.Request) -> web.Response: - table_name = request.match_info["table_name"] - assert table_name == "test_table" - - await request.json() - # TODO: do some matching - - vecs = pd.Series([np.random.rand(128) for x in range(10)], name="vector") - ids = pd.Series(range(10), name="id") - df = pd.DataFrame([vecs, ids]).T - - batch = pa.RecordBatch.from_pandas( - df, - schema=pa.schema( - [ - pa.field("vector", pa.list_(pa.float32(), 128)), - pa.field("id", pa.int64()), - ] - ), - ) - - sink = pa.BufferOutputStream() - with pa.ipc.new_file(sink, batch.schema) as writer: - writer.write_batch(batch) - - return web.Response(body=sink.getvalue().to_pybytes()) - - async def setup(self): - app = web.Application() - app.add_routes([web.post("/table/{table_name}", self.query_handler)]) - self.runner = web.AppRunner(app) - await self.runner.setup() - self.site = web.TCPSite(self.runner, "localhost", 8111) - - async def start(self): - await self.site.start() - - async def stop(self): - await self.runner.cleanup() - - -@pytest.mark.skip(reason="flaky somehow, fix later") -@pytest.mark.asyncio -async def test_e2e_with_mock_server(): - mock_server = MockLanceDBServer() - await mock_server.setup() - await mock_server.start() - - try: - with RestfulLanceDBClient("lancedb+http://localhost:8111") as client: - df = ( - await client.query( - "test_table", - VectorQuery( - vector=np.random.rand(128).tolist(), - k=10, - _metric="L2", - columns=["id", "vector"], - ), - ) - ).to_pandas() - - assert "vector" in df.columns - assert "id" in df.columns - - assert client.closed - finally: - # make sure we don't leak resources - await mock_server.stop() diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index e03b6636..bc3a2783 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -2,91 +2,19 @@ # SPDX-FileCopyrightText: Copyright The LanceDB Authors import contextlib +from datetime import timedelta import http.server +import json import threading from unittest.mock import MagicMock import uuid import lancedb +from lancedb.conftest import MockTextEmbeddingFunction +from lancedb.remote import ClientConfig from lancedb.remote.errors import HttpError, RetryError -import pyarrow as pa -from lancedb.remote.client import VectorQuery, VectorQueryResult import pytest - - -class FakeLanceDBClient: - def close(self): - pass - - def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult: - assert table_name == "test" - t = pa.schema([]).empty_table() - return VectorQueryResult(t) - - def post(self, path: str): - pass - - def mount_retry_adapter_for_table(self, table_name: str): - pass - - -def test_remote_db(): - conn = lancedb.connect("db://client-will-be-injected", api_key="fake") - setattr(conn, "_client", FakeLanceDBClient()) - - table = conn["test"] - table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))]) - table.search([1.0, 2.0]).to_pandas() - - -def test_create_empty_table(): - client = MagicMock() - conn = lancedb.connect("db://client-will-be-injected", api_key="fake") - - conn._client = client - - schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))]) - - client.post.return_value = {"status": "ok"} - table = conn.create_table("test", schema=schema) - assert table.name == "test" - assert client.post.call_args[0][0] == "/v1/table/test/create/" - - json_schema = { - "fields": [ - { - "name": "vector", - "nullable": True, - "type": { - "type": "fixed_size_list", - "fields": [ - {"name": "item", "nullable": True, "type": {"type": "float"}} - ], - "length": 2, - }, - }, - ] - } - client.post.return_value = {"schema": json_schema} - assert table.schema == schema - assert client.post.call_args[0][0] == "/v1/table/test/describe/" - - client.post.return_value = 0 - assert table.count_rows(None) == 0 - - -def test_create_table_with_recordbatches(): - client = MagicMock() - conn = lancedb.connect("db://client-will-be-injected", api_key="fake") - - conn._client = client - - batch = pa.RecordBatch.from_arrays([pa.array([[1.0, 2.0], [3.0, 4.0]])], ["vector"]) - - client.post.return_value = {"status": "ok"} - table = conn.create_table("test", [batch], schema=batch.schema) - assert table.name == "test" - assert client.post.call_args[0][0] == "/v1/table/test/create/" +import pyarrow as pa def make_mock_http_handler(handler): @@ -100,8 +28,35 @@ def make_mock_http_handler(handler): return MockLanceDBHandler +@contextlib.contextmanager +def mock_lancedb_connection(handler): + with http.server.HTTPServer( + ("localhost", 8080), make_mock_http_handler(handler) + ) as server: + handle = threading.Thread(target=server.serve_forever) + handle.start() + + db = lancedb.connect( + "db://dev", + api_key="fake", + host_override="http://localhost:8080", + client_config={ + "retry_config": {"retries": 2}, + "timeout_config": { + "connect_timeout": 1, + }, + }, + ) + + try: + yield db + finally: + server.shutdown() + handle.join() + + @contextlib.asynccontextmanager -async def mock_lancedb_connection(handler): +async def mock_lancedb_connection_async(handler): with http.server.HTTPServer( ("localhost", 8080), make_mock_http_handler(handler) ) as server: @@ -143,7 +98,7 @@ async def test_async_remote_db(): request.end_headers() request.wfile.write(b'{"tables": []}') - async with mock_lancedb_connection(handler) as db: + async with mock_lancedb_connection_async(handler) as db: table_names = await db.table_names() assert table_names == [] @@ -159,12 +114,12 @@ async def test_http_error(): request.end_headers() request.wfile.write(b"Internal Server Error") - async with mock_lancedb_connection(handler) as db: - with pytest.raises(HttpError, match="Internal Server Error") as exc_info: + async with mock_lancedb_connection_async(handler) as db: + with pytest.raises(HttpError) as exc_info: await db.table_names() assert exc_info.value.request_id == request_id_holder["request_id"] - assert exc_info.value.status_code == 507 + assert "Internal Server Error" in str(exc_info.value) @pytest.mark.asyncio @@ -178,15 +133,225 @@ async def test_retry_error(): request.end_headers() request.wfile.write(b"Try again later") - async with mock_lancedb_connection(handler) as db: - with pytest.raises(RetryError, match="Hit retry limit") as exc_info: + async with mock_lancedb_connection_async(handler) as db: + with pytest.raises(RetryError) as exc_info: await db.table_names() assert exc_info.value.request_id == request_id_holder["request_id"] - assert exc_info.value.status_code == 429 cause = exc_info.value.__cause__ assert isinstance(cause, HttpError) assert "Try again later" in str(cause) assert cause.request_id == request_id_holder["request_id"] assert cause.status_code == 429 + + +@contextlib.contextmanager +def query_test_table(query_handler): + def handler(request): + if request.path == "/v1/table/test/describe/": + request.send_response(200) + request.send_header("Content-Type", "application/json") + request.end_headers() + request.wfile.write(b"{}") + elif request.path == "/v1/table/test/query/": + content_len = int(request.headers.get("Content-Length")) + body = request.rfile.read(content_len) + body = json.loads(body) + + data = query_handler(body) + + request.send_response(200) + request.send_header("Content-Type", "application/vnd.apache.arrow.file") + request.end_headers() + + with pa.ipc.new_file(request.wfile, schema=data.schema) as f: + f.write_table(data) + else: + request.send_response(404) + request.end_headers() + + with mock_lancedb_connection(handler) as db: + assert repr(db) == "RemoteConnect(name=dev)" + table = db.open_table("test") + assert repr(table) == "RemoteTable(dev.test)" + yield table + + +def test_query_sync_minimal(): + def handler(body): + assert body == { + "distance_type": "l2", + "k": 10, + "prefilter": False, + "refine_factor": None, + "vector": [1.0, 2.0, 3.0], + "nprobes": 20, + } + + return pa.table({"id": [1, 2, 3]}) + + with query_test_table(handler) as table: + data = table.search([1, 2, 3]).to_list() + expected = [{"id": 1}, {"id": 2}, {"id": 3}] + assert data == expected + + +def test_query_sync_maximal(): + def handler(body): + assert body == { + "distance_type": "cosine", + "k": 42, + "prefilter": True, + "refine_factor": 10, + "vector": [1.0, 2.0, 3.0], + "nprobes": 5, + "filter": "id > 0", + "columns": ["id", "name"], + "vector_column": "vector2", + "fast_search": True, + "with_row_id": True, + } + + return pa.table({"id": [1, 2, 3], "name": ["a", "b", "c"]}) + + with query_test_table(handler) as table: + ( + table.search([1, 2, 3], vector_column_name="vector2", fast_search=True) + .metric("cosine") + .limit(42) + .refine_factor(10) + .nprobes(5) + .where("id > 0", prefilter=True) + .with_row_id(True) + .select(["id", "name"]) + .to_list() + ) + + +def test_query_sync_fts(): + def handler(body): + assert body == { + "full_text_query": { + "query": "puppy", + "columns": [], + }, + "k": 10, + "vector": [], + } + + return pa.table({"id": [1, 2, 3]}) + + with query_test_table(handler) as table: + (table.search("puppy", query_type="fts").to_list()) + + def handler(body): + assert body == { + "full_text_query": { + "query": "puppy", + "columns": ["name", "description"], + }, + "k": 42, + "vector": [], + "with_row_id": True, + } + + return pa.table({"id": [1, 2, 3]}) + + with query_test_table(handler) as table: + ( + table.search("puppy", query_type="fts", fts_columns=["name", "description"]) + .with_row_id(True) + .limit(42) + .to_list() + ) + + +def test_query_sync_hybrid(): + def handler(body): + if "full_text_query" in body: + # FTS query + assert body == { + "full_text_query": { + "query": "puppy", + "columns": [], + }, + "k": 42, + "vector": [], + "with_row_id": True, + } + return pa.table({"_rowid": [1, 2, 3], "_score": [0.1, 0.2, 0.3]}) + else: + # Vector query + assert body == { + "distance_type": "l2", + "k": 42, + "prefilter": False, + "refine_factor": None, + "vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + "nprobes": 20, + "with_row_id": True, + } + return pa.table({"_rowid": [1, 2, 3], "_distance": [0.1, 0.2, 0.3]}) + + with query_test_table(handler) as table: + embedding_func = MockTextEmbeddingFunction() + embedding_config = MagicMock() + embedding_config.function = embedding_func + + embedding_funcs = MagicMock() + embedding_funcs.get = MagicMock(return_value=embedding_config) + table.embedding_functions = embedding_funcs + + (table.search("puppy", query_type="hybrid").limit(42).to_list()) + + +def test_create_client(): + mandatory_args = { + "uri": "db://dev", + "api_key": "fake-api-key", + "region": "us-east-1", + } + + db = lancedb.connect(**mandatory_args) + assert isinstance(db.client_config, ClientConfig) + + db = lancedb.connect(**mandatory_args, client_config={}) + assert isinstance(db.client_config, ClientConfig) + + db = lancedb.connect( + **mandatory_args, + client_config=ClientConfig(timeout_config={"connect_timeout": 42}), + ) + assert isinstance(db.client_config, ClientConfig) + assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42) + + db = lancedb.connect( + **mandatory_args, + client_config={"timeout_config": {"connect_timeout": timedelta(seconds=42)}}, + ) + assert isinstance(db.client_config, ClientConfig) + assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42) + + db = lancedb.connect( + **mandatory_args, client_config=ClientConfig(retry_config={"retries": 42}) + ) + assert isinstance(db.client_config, ClientConfig) + assert db.client_config.retry_config.retries == 42 + + db = lancedb.connect( + **mandatory_args, client_config={"retry_config": {"retries": 42}} + ) + assert isinstance(db.client_config, ClientConfig) + assert db.client_config.retry_config.retries == 42 + + with pytest.warns(DeprecationWarning): + db = lancedb.connect(**mandatory_args, connection_timeout=42) + assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42) + + with pytest.warns(DeprecationWarning): + db = lancedb.connect(**mandatory_args, read_timeout=42) + assert db.client_config.timeout_config.read_timeout == timedelta(seconds=42) + + with pytest.warns(DeprecationWarning): + lancedb.connect(**mandatory_args, request_thread_pool=10) diff --git a/python/src/connection.rs b/python/src/connection.rs index 200285a4..46e15cfb 100644 --- a/python/src/connection.rs +++ b/python/src/connection.rs @@ -170,6 +170,17 @@ impl Connection { }) } + pub fn rename_table( + self_: PyRef<'_, Self>, + old_name: String, + new_name: String, + ) -> PyResult> { + let inner = self_.get_inner()?.clone(); + future_into_py(self_.py(), async move { + inner.rename_table(old_name, new_name).await.infer_error() + }) + } + pub fn drop_table(self_: PyRef<'_, Self>, name: String) -> PyResult> { let inner = self_.get_inner()?.clone(); future_into_py(self_.py(), async move { diff --git a/python/src/index.rs b/python/src/index.rs index 7510b7fe..4ea4c19f 100644 --- a/python/src/index.rs +++ b/python/src/index.rs @@ -24,8 +24,8 @@ use lancedb::{ DistanceType, }; use pyo3::{ - exceptions::{PyRuntimeError, PyValueError}, - pyclass, pymethods, PyResult, + exceptions::{PyKeyError, PyRuntimeError, PyValueError}, + pyclass, pymethods, IntoPy, PyObject, PyResult, Python, }; use crate::util::parse_distance_type; @@ -236,7 +236,21 @@ pub struct IndexConfig { #[pymethods] impl IndexConfig { pub fn __repr__(&self) -> String { - format!("Index({}, columns={:?})", self.index_type, self.columns) + format!( + "Index({}, columns={:?}, name=\"{}\")", + self.index_type, self.columns, self.name + ) + } + + // For backwards-compatibility with the old sync SDK, we also support getting + // attributes via __getitem__. + pub fn __getitem__(&self, key: String, py: Python<'_>) -> PyResult { + match key.as_str() { + "index_type" => Ok(self.index_type.clone().into_py(py)), + "columns" => Ok(self.columns.clone().into_py(py)), + "name" | "index_name" => Ok(self.name.clone().into_py(py)), + _ => Err(PyKeyError::new_err(format!("Invalid key: {}", key))), + } } } diff --git a/python/src/query.rs b/python/src/query.rs index 42bd4a13..0f93f9ce 100644 --- a/python/src/query.rs +++ b/python/src/query.rs @@ -68,6 +68,18 @@ impl Query { self.inner = self.inner.clone().offset(offset as usize); } + pub fn fast_search(&mut self) { + self.inner = self.inner.clone().fast_search(); + } + + pub fn with_row_id(&mut self) { + self.inner = self.inner.clone().with_row_id(); + } + + pub fn postfilter(&mut self) { + self.inner = self.inner.clone().postfilter(); + } + pub fn nearest_to(&mut self, vector: Bound<'_, PyAny>) -> PyResult { let data: ArrayData = ArrayData::from_pyarrow_bound(&vector)?; let array = make_array(data); @@ -146,6 +158,14 @@ impl VectorQuery { self.inner = self.inner.clone().offset(offset as usize); } + pub fn fast_search(&mut self) { + self.inner = self.inner.clone().fast_search(); + } + + pub fn with_row_id(&mut self) { + self.inner = self.inner.clone().with_row_id(); + } + pub fn column(&mut self, column: String) { self.inner = self.inner.clone().column(&column); } diff --git a/rust/ffi/node/Cargo.toml b/rust/ffi/node/Cargo.toml index 70c5baf0..3d570c58 100644 --- a/rust/ffi/node/Cargo.toml +++ b/rust/ffi/node/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lancedb-node" -version = "0.12.0" +version = "0.13.0-beta.1" description = "Serverless, low-latency vector database for AI applications" license.workspace = true edition.workspace = true diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index 99a245bd..341970ec 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lancedb" -version = "0.12.0" +version = "0.13.0-beta.1" edition.workspace = true description = "LanceDB: A serverless, low-latency vector database for AI applications" license.workspace = true diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index 44a6b443..40329b66 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -39,9 +39,6 @@ use crate::utils::validate_table_name; use crate::Table; pub use lance_encoding::version::LanceFileVersion; -#[cfg(feature = "remote")] -use log::warn; - pub const LANCE_FILE_EXTENSION: &str = "lance"; pub type TableBuilderCallback = Box OpenTableBuilder + Send>; @@ -719,8 +716,7 @@ impl ConnectBuilder { let api_key = self.api_key.ok_or_else(|| Error::InvalidInput { message: "An api_key is required when connecting to LanceDb Cloud".to_string(), })?; - // TODO: remove this warning when the remote client is ready - warn!("The rust implementation of the remote client is not yet ready for use."); + let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new( &self.uri, &api_key, diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index 6118e6b7..135f46a1 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -403,6 +403,26 @@ pub trait QueryBase { /// By default, it is false. fn fast_search(self) -> Self; + /// If this is called then filtering will happen after the vector search instead of + /// before. + /// + /// By default filtering will be performed before the vector search. This is how + /// filtering is typically understood to work. This prefilter step does add some + /// additional latency. Creating a scalar index on the filter column(s) can + /// often improve this latency. However, sometimes a filter is too complex or scalar + /// indices cannot be applied to the column. In these cases postfiltering can be + /// used instead of prefiltering to improve latency. + /// + /// Post filtering applies the filter to the results of the vector search. This means + /// we only run the filter on a much smaller set of data. However, it can cause the + /// query to return fewer than `limit` results (or even no results) if none of the nearest + /// results match the filter. + /// + /// Post filtering happens during the "refine stage" (described in more detail in + /// [`Self::refine_factor`]). This means that setting a higher refine factor can often + /// help restore some of the results lost by post filtering. + fn postfilter(self) -> Self; + /// Return the `_rowid` meta column from the Table. fn with_row_id(self) -> Self; } @@ -442,6 +462,11 @@ impl QueryBase for T { self } + fn postfilter(mut self) -> Self { + self.mut_query().prefilter = false; + self + } + fn with_row_id(mut self) -> Self { self.mut_query().with_row_id = true; self @@ -561,6 +586,9 @@ pub struct Query { /// /// By default, this is false. pub(crate) with_row_id: bool, + + /// If set to false, the filter will be applied after the vector search. + pub(crate) prefilter: bool, } impl Query { @@ -574,6 +602,7 @@ impl Query { select: Select::All, fast_search: false, with_row_id: false, + prefilter: true, } } @@ -678,8 +707,6 @@ pub struct VectorQuery { pub(crate) distance_type: Option, /// Default is true. Set to false to enforce a brute force search. pub(crate) use_index: bool, - /// Apply filter before ANN search/ - pub(crate) prefilter: bool, } impl VectorQuery { @@ -692,7 +719,6 @@ impl VectorQuery { refine_factor: None, distance_type: None, use_index: true, - prefilter: true, } } @@ -782,29 +808,6 @@ impl VectorQuery { self } - /// If this is called then filtering will happen after the vector search instead of - /// before. - /// - /// By default filtering will be performed before the vector search. This is how - /// filtering is typically understood to work. This prefilter step does add some - /// additional latency. Creating a scalar index on the filter column(s) can - /// often improve this latency. However, sometimes a filter is too complex or scalar - /// indices cannot be applied to the column. In these cases postfiltering can be - /// used instead of prefiltering to improve latency. - /// - /// Post filtering applies the filter to the results of the vector search. This means - /// we only run the filter on a much smaller set of data. However, it can cause the - /// query to return fewer than `limit` results (or even no results) if none of the nearest - /// results match the filter. - /// - /// Post filtering happens during the "refine stage" (described in more detail in - /// [`Self::refine_factor`]). This means that setting a higher refine factor can often - /// help restore some of the results lost by post filtering. - pub fn postfilter(mut self) -> Self { - self.prefilter = false; - self - } - /// If this is called then any vector index is skipped /// /// An exhaustive (flat) search will be performed. The query vector will diff --git a/rust/lancedb/src/remote/client.rs b/rust/lancedb/src/remote/client.rs index 50789108..48c8aa1c 100644 --- a/rust/lancedb/src/remote/client.rs +++ b/rust/lancedb/src/remote/client.rs @@ -341,7 +341,22 @@ impl RestfulLanceDbClient { request_id }; - debug!("Sending request_id={}: {:?}", request_id, &request); + if log::log_enabled!(log::Level::Debug) { + let content_type = request + .headers() + .get("content-type") + .map(|v| v.to_str().unwrap()); + if content_type == Some("application/json") { + let body = request.body().as_ref().unwrap().as_bytes().unwrap(); + let body = String::from_utf8_lossy(body); + debug!( + "Sending request_id={}: {:?} with body {}", + request_id, request, body + ); + } else { + debug!("Sending request_id={}: {:?}", request_id, request); + } + } if with_retry { self.send_with_retry_impl(client, request, request_id).await diff --git a/rust/lancedb/src/remote/db.rs b/rust/lancedb/src/remote/db.rs index 8fe415be..05b5dfe2 100644 --- a/rust/lancedb/src/remote/db.rs +++ b/rust/lancedb/src/remote/db.rs @@ -161,7 +161,7 @@ impl ConnectionInternal for RemoteDatabase { if self.table_cache.get(&options.name).is_none() { let req = self .client - .get(&format!("/v1/table/{}/describe/", options.name)); + .post(&format!("/v1/table/{}/describe/", options.name)); let (request_id, resp) = self.client.send(req, true).await?; if resp.status() == StatusCode::NOT_FOUND { return Err(crate::Error::TableNotFound { name: options.name }); @@ -301,7 +301,7 @@ mod tests { #[tokio::test] async fn test_open_table() { let conn = Connection::new_with_handler(|request| { - assert_eq!(request.method(), &reqwest::Method::GET); + assert_eq!(request.method(), &reqwest::Method::POST); assert_eq!(request.url().path(), "/v1/table/table1/describe/"); assert_eq!(request.url().query(), None); diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index f9900b2c..a8754cc3 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -167,6 +167,10 @@ impl RemoteTable { body["fast_search"] = serde_json::Value::Bool(true); } + if params.with_row_id { + body["with_row_id"] = serde_json::Value::Bool(true); + } + if let Some(full_text_search) = ¶ms.full_text_search { if full_text_search.wand_factor.is_some() { return Err(Error::NotSupported { @@ -305,13 +309,13 @@ impl TableInternal for RemoteTable { let mut body = serde_json::Value::Object(Default::default()); Self::apply_query_params(&mut body, &query.base)?; - body["prefilter"] = query.prefilter.into(); + body["prefilter"] = query.base.prefilter.into(); body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default()); body["nprobes"] = query.nprobes.into(); body["refine_factor"] = query.refine_factor.into(); - if let Some(vector) = query.query_vector.as_ref() { - let vector: Vec = match vector.data_type() { + let vector: Vec = if let Some(vector) = query.query_vector.as_ref() { + match vector.data_type() { DataType::Float32 => vector .as_any() .downcast_ref::() @@ -325,9 +329,12 @@ impl TableInternal for RemoteTable { message: "VectorQuery vector must be of type Float32".into(), }) } - }; - body["vector"] = serde_json::json!(vector); - } + } + } else { + // Server takes empty vector, not null or undefined. + Vec::new() + }; + body["vector"] = serde_json::json!(vector); if let Some(vector_column) = query.column.as_ref() { body["vector_column"] = serde_json::Value::String(vector_column.clone()); @@ -358,6 +365,8 @@ impl TableInternal for RemoteTable { let mut body = serde_json::Value::Object(Default::default()); Self::apply_query_params(&mut body, query)?; + // Empty vector can be passed if no vector search is performed. + body["vector"] = serde_json::Value::Array(Vec::new()); let request = request.json(&body); @@ -379,7 +388,7 @@ impl TableInternal for RemoteTable { let request = request.json(&serde_json::json!({ "updates": updates, - "only_if": update.filter, + "predicate": update.filter, })); let (request_id, response) = self.client.send(request, false).await?; @@ -933,7 +942,7 @@ mod tests { assert_eq!(col_name, "b"); assert_eq!(expression, "b - 1"); - let only_if = value.get("only_if").unwrap().as_str().unwrap(); + let only_if = value.get("predicate").unwrap().as_str().unwrap(); assert_eq!(only_if, "b > 10"); } @@ -1167,6 +1176,8 @@ mod tests { "query": "hello world", }, "k": 10, + "vector": [], + "with_row_id": true, }); assert_eq!(body, expected_body); @@ -1189,6 +1200,7 @@ mod tests { FullTextSearchQuery::new("hello world".into()) .columns(Some(vec!["a".into(), "b".into()])), ) + .with_row_id() .limit(10) .execute() .await diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index a94526ca..ee5e5bba 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -1842,7 +1842,7 @@ impl TableInternal for NativeTable { scanner.nprobs(query.nprobes); scanner.use_index(query.use_index); - scanner.prefilter(query.prefilter); + scanner.prefilter(query.base.prefilter); match query.base.select { Select::Columns(ref columns) => { scanner.project(columns.as_slice())?;