mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 05:19:58 +00:00
Compare commits
39 Commits
python-v0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
51437bc228 | ||
|
|
fa53cfcfd2 | ||
|
|
374fe0ad95 | ||
|
|
35e5b84ba9 | ||
|
|
7c12d497b0 | ||
|
|
dfe4ba8dad | ||
|
|
fa1b9ad5bd | ||
|
|
8877eb020d | ||
|
|
01e4291d21 | ||
|
|
ab3ea76ad1 | ||
|
|
728ef8657d | ||
|
|
0b13901a16 | ||
|
|
84b110e0ef | ||
|
|
e1836e54e3 | ||
|
|
4ba5326880 | ||
|
|
b036a69300 | ||
|
|
5b12a47119 | ||
|
|
769d483e50 | ||
|
|
9ecb11fe5a | ||
|
|
22bd8329f3 | ||
|
|
a736fad149 | ||
|
|
072adc41aa | ||
|
|
c6f25ef1f0 | ||
|
|
2f0c5baea2 | ||
|
|
a63dd66d41 | ||
|
|
d6b3ccb37b | ||
|
|
c4f99e82e5 | ||
|
|
979a2d3d9d | ||
|
|
7ac5f74c80 | ||
|
|
ecdee4d2b1 | ||
|
|
f391ed828a | ||
|
|
a99a450f2b | ||
|
|
6fa1f37506 | ||
|
|
544382df5e | ||
|
|
784f00ef6d | ||
|
|
96d7446f70 | ||
|
|
99ea78fb55 | ||
|
|
8eef4cdc28 | ||
|
|
0f102f02c3 |
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.16.1-beta.2"
|
||||
current_version = "0.18.0-beta.0"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
14
.github/workflows/python.yml
vendored
14
.github/workflows/python.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
||||
doctest:
|
||||
name: "Doctest"
|
||||
timeout-minutes: 30
|
||||
runs-on: "ubuntu-22.04"
|
||||
runs-on: "ubuntu-24.04"
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -54,7 +54,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
python-version: "3.12"
|
||||
cache: "pip"
|
||||
- name: Install protobuf
|
||||
run: |
|
||||
@@ -75,8 +75,8 @@ jobs:
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
matrix:
|
||||
python-minor-version: ["9", "11"]
|
||||
runs-on: "ubuntu-22.04"
|
||||
python-minor-version: ["9", "12"]
|
||||
runs-on: "ubuntu-24.04"
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -127,7 +127,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
python-version: "3.12"
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
with:
|
||||
workspaces: python
|
||||
@@ -157,7 +157,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
python-version: "3.12"
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
with:
|
||||
workspaces: python
|
||||
@@ -168,7 +168,7 @@ jobs:
|
||||
run: rm -rf target/wheels
|
||||
pydantic1x:
|
||||
timeout-minutes: 30
|
||||
runs-on: "ubuntu-22.04"
|
||||
runs-on: "ubuntu-24.04"
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
23
.github/workflows/rust.yml
vendored
23
.github/workflows/rust.yml
vendored
@@ -61,7 +61,12 @@ jobs:
|
||||
CXX: clang++
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
# Remote cargo.lock to force a fresh build
|
||||
# Building without a lock file often requires the latest Rust version since downstream
|
||||
# dependencies may have updated their minimum Rust version.
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@v1
|
||||
with:
|
||||
toolchain: "stable"
|
||||
# Remove cargo.lock to force a fresh build
|
||||
- name: Remove Cargo.lock
|
||||
run: rm -f Cargo.lock
|
||||
- uses: rui314/setup-mold@v1
|
||||
@@ -179,15 +184,17 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Install dependencies
|
||||
- name: Install dependencies (part 1)
|
||||
run: |
|
||||
set -e
|
||||
apk add protobuf-dev curl clang lld llvm19 grep npm bash msitools sed
|
||||
|
||||
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y
|
||||
source $HOME/.cargo/env
|
||||
rustup target add aarch64-pc-windows-msvc
|
||||
|
||||
- name: Install rust
|
||||
uses: actions-rust-lang/setup-rust-toolchain@v1
|
||||
with:
|
||||
target: aarch64-pc-windows-msvc
|
||||
- name: Install dependencies (part 2)
|
||||
run: |
|
||||
set -e
|
||||
mkdir -p sysroot
|
||||
cd sysroot
|
||||
sh ../ci/sysroot-aarch64-pc-windows-msvc.sh
|
||||
@@ -259,7 +266,7 @@ jobs:
|
||||
- name: Install Rust
|
||||
run: |
|
||||
Invoke-WebRequest https://win.rustup.rs/x86_64 -OutFile rustup-init.exe
|
||||
.\rustup-init.exe -y --default-host aarch64-pc-windows-msvc
|
||||
.\rustup-init.exe -y --default-host aarch64-pc-windows-msvc --default-toolchain 1.83.0
|
||||
shell: powershell
|
||||
- name: Add Rust to PATH
|
||||
run: |
|
||||
|
||||
242
Cargo.lock
generated
242
Cargo.lock
generated
@@ -128,9 +128,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "anyhow"
|
||||
version = "1.0.95"
|
||||
version = "1.0.96"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "34ac096ce696dc2fcabef30516bb13c0a68a11d30131d3df6f04711467681b04"
|
||||
checksum = "6b964d184e89d9b6b67dd2715bc8e74cf3107fb2b529990c90cf517326150bf4"
|
||||
|
||||
[[package]]
|
||||
name = "arbitrary"
|
||||
@@ -520,9 +520,9 @@ checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
|
||||
|
||||
[[package]]
|
||||
name = "aws-config"
|
||||
version = "1.5.16"
|
||||
version = "1.5.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "50236e4d60fe8458de90a71c0922c761e41755adf091b1b03de1cef537179915"
|
||||
checksum = "490aa7465ee685b2ced076bb87ef654a47724a7844e2c7d3af4e749ce5b875dd"
|
||||
dependencies = [
|
||||
"aws-credential-types",
|
||||
"aws-runtime",
|
||||
@@ -588,9 +588,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "aws-sdk-bedrockruntime"
|
||||
version = "1.74.0"
|
||||
version = "1.75.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6938541d1948a543bca23303fec4cff9c36bf0e63b8fa3ae1b337bcb9d5b81af"
|
||||
checksum = "2ddf7475b6f50a1a5be8edb1bcdf6e4ae00feed5b890d14a3f1f0e14d76f5a16"
|
||||
dependencies = [
|
||||
"aws-credential-types",
|
||||
"aws-runtime",
|
||||
@@ -612,9 +612,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "aws-sdk-dynamodb"
|
||||
version = "1.65.0"
|
||||
version = "1.66.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "873144cfb097fc75555f2b2728fa4d5f705a17a4613a0f017baff2f7cfea2b09"
|
||||
checksum = "5296daf754d333f51798bff599876c3849394ec3dabe8d1d61cbacb961fdde37"
|
||||
dependencies = [
|
||||
"aws-credential-types",
|
||||
"aws-runtime",
|
||||
@@ -635,9 +635,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "aws-sdk-kms"
|
||||
version = "1.60.0"
|
||||
version = "1.61.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "adc36035f7393a24719069c9a2f52e20972f7ee8472bd788e863968736acc449"
|
||||
checksum = "72054067b7b84e963ee29c3b7fdfc61f76bcfc697e38b8dc1095a3ad2e7e764a"
|
||||
dependencies = [
|
||||
"aws-credential-types",
|
||||
"aws-runtime",
|
||||
@@ -657,9 +657,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "aws-sdk-s3"
|
||||
version = "1.76.0"
|
||||
version = "1.77.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "66e83401ad7287ad15244d557e35502c2a94105ca5b41d656c391f1a4fc04ca2"
|
||||
checksum = "34e87342432a3de0e94e82c99a7cbd9042f99de029ae1f4e368160f9e9929264"
|
||||
dependencies = [
|
||||
"aws-credential-types",
|
||||
"aws-runtime",
|
||||
@@ -691,9 +691,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "aws-sdk-sso"
|
||||
version = "1.59.0"
|
||||
version = "1.60.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "00a35fc7e74f5be45839eb753568535c074a592185dd0a2d406685018d581c43"
|
||||
checksum = "60186fab60b24376d3e33b9ff0a43485f99efd470e3b75a9160c849741d63d56"
|
||||
dependencies = [
|
||||
"aws-credential-types",
|
||||
"aws-runtime",
|
||||
@@ -713,9 +713,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "aws-sdk-ssooidc"
|
||||
version = "1.60.0"
|
||||
version = "1.61.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f8fa655b4f313124ce272cbc38c5fef13793c832279cec750103e5e6b71a54b8"
|
||||
checksum = "7033130ce1ee13e6018905b7b976c915963755aef299c1521897679d6cd4f8ef"
|
||||
dependencies = [
|
||||
"aws-credential-types",
|
||||
"aws-runtime",
|
||||
@@ -735,9 +735,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "aws-sdk-sts"
|
||||
version = "1.60.0"
|
||||
version = "1.61.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dc1cfe5e16b90421ea031f4c6348b534ef442e76f6bf4a1b2b592c12cc2c6af9"
|
||||
checksum = "c5c1cac7677179d622b4448b0d31bcb359185295dc6fca891920cfb17e2b5156"
|
||||
dependencies = [
|
||||
"aws-credential-types",
|
||||
"aws-runtime",
|
||||
@@ -798,9 +798,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "aws-smithy-checksums"
|
||||
version = "0.62.0"
|
||||
version = "0.63.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f2f45a1c384d7a393026bc5f5c177105aa9fa68e4749653b985707ac27d77295"
|
||||
checksum = "db2dc8d842d872529355c72632de49ef8c5a2949a4472f10e802f28cf925770c"
|
||||
dependencies = [
|
||||
"aws-smithy-http",
|
||||
"aws-smithy-types",
|
||||
@@ -1153,9 +1153,9 @@ checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf"
|
||||
|
||||
[[package]]
|
||||
name = "bytemuck"
|
||||
version = "1.18.0"
|
||||
version = "1.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "94bbb0ad554ad961ddc5da507a12a29b14e4ae5bda06b19f575a3e6079d2e2ae"
|
||||
checksum = "ef657dfab802224e671f5818e9a4935f9b1957ed18e58292690cc39e7a4092a3"
|
||||
dependencies = [
|
||||
"bytemuck_derive",
|
||||
]
|
||||
@@ -1250,9 +1250,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.2.14"
|
||||
version = "1.2.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0c3d1b2e905a3a7b00a6141adb0e4c0bb941d11caf55349d863942a1cc44e3c9"
|
||||
checksum = "c736e259eea577f443d5c86c304f9f4ae0295c43f3ba05c21f1d66b5f06001af"
|
||||
dependencies = [
|
||||
"jobserver",
|
||||
"libc",
|
||||
@@ -2279,9 +2279,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.13.0"
|
||||
version = "1.14.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0"
|
||||
checksum = "b7914353092ddf589ad78f25c5c1c21b7f80b0ff8621e7c814c3485b5306da9d"
|
||||
|
||||
[[package]]
|
||||
name = "elliptic-curve"
|
||||
@@ -2508,9 +2508,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "flate2"
|
||||
version = "1.0.35"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c"
|
||||
checksum = "11faaf5a5236997af9848be0bef4db95824b1d534ebc64d0f0c6cf3e67bd38dc"
|
||||
dependencies = [
|
||||
"crc32fast",
|
||||
"miniz_oxide",
|
||||
@@ -2570,8 +2570,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "fsst"
|
||||
version = "0.23.1"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.1-beta.4#6b58bc16230faeb5387c5478c485254a52e9787f"
|
||||
version = "0.24.0"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.24.0-beta.1#33ae43b2944c12e0dbd139e8aa098cffa74edef5"
|
||||
dependencies = [
|
||||
"rand",
|
||||
]
|
||||
@@ -3532,8 +3532,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance"
|
||||
version = "0.23.1"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.1-beta.4#6b58bc16230faeb5387c5478c485254a52e9787f"
|
||||
version = "0.24.0"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.24.0-beta.1#33ae43b2944c12e0dbd139e8aa098cffa74edef5"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-arith",
|
||||
@@ -3592,8 +3592,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-arrow"
|
||||
version = "0.23.1"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.1-beta.4#6b58bc16230faeb5387c5478c485254a52e9787f"
|
||||
version = "0.24.0"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.24.0-beta.1#33ae43b2944c12e0dbd139e8aa098cffa74edef5"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -3610,8 +3610,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-core"
|
||||
version = "0.23.1"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.1-beta.4#6b58bc16230faeb5387c5478c485254a52e9787f"
|
||||
version = "0.24.0"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.24.0-beta.1#33ae43b2944c12e0dbd139e8aa098cffa74edef5"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -3647,8 +3647,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-datafusion"
|
||||
version = "0.23.1"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.1-beta.4#6b58bc16230faeb5387c5478c485254a52e9787f"
|
||||
version = "0.24.0"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.24.0-beta.1#33ae43b2944c12e0dbd139e8aa098cffa74edef5"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -3673,8 +3673,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-encoding"
|
||||
version = "0.23.1"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.1-beta.4#6b58bc16230faeb5387c5478c485254a52e9787f"
|
||||
version = "0.24.0"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.24.0-beta.1#33ae43b2944c12e0dbd139e8aa098cffa74edef5"
|
||||
dependencies = [
|
||||
"arrayref",
|
||||
"arrow",
|
||||
@@ -3712,8 +3712,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-file"
|
||||
version = "0.23.1"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.1-beta.4#6b58bc16230faeb5387c5478c485254a52e9787f"
|
||||
version = "0.24.0"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.24.0-beta.1#33ae43b2944c12e0dbd139e8aa098cffa74edef5"
|
||||
dependencies = [
|
||||
"arrow-arith",
|
||||
"arrow-array",
|
||||
@@ -3747,8 +3747,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-index"
|
||||
version = "0.23.1"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.1-beta.4#6b58bc16230faeb5387c5478c485254a52e9787f"
|
||||
version = "0.24.0"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.24.0-beta.1#33ae43b2944c12e0dbd139e8aa098cffa74edef5"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -3800,8 +3800,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-io"
|
||||
version = "0.23.1"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.1-beta.4#6b58bc16230faeb5387c5478c485254a52e9787f"
|
||||
version = "0.24.0"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.24.0-beta.1#33ae43b2944c12e0dbd139e8aa098cffa74edef5"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-arith",
|
||||
@@ -3839,8 +3839,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-linalg"
|
||||
version = "0.23.1"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.1-beta.4#6b58bc16230faeb5387c5478c485254a52e9787f"
|
||||
version = "0.24.0"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.24.0-beta.1#33ae43b2944c12e0dbd139e8aa098cffa74edef5"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-ord",
|
||||
@@ -3863,8 +3863,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-table"
|
||||
version = "0.23.1"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.1-beta.4#6b58bc16230faeb5387c5478c485254a52e9787f"
|
||||
version = "0.24.0"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.24.0-beta.1#33ae43b2944c12e0dbd139e8aa098cffa74edef5"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -3903,8 +3903,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-testing"
|
||||
version = "0.23.1"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.1-beta.4#6b58bc16230faeb5387c5478c485254a52e9787f"
|
||||
version = "0.24.0"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.24.0-beta.1#33ae43b2944c12e0dbd139e8aa098cffa74edef5"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-schema",
|
||||
@@ -3915,7 +3915,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb"
|
||||
version = "0.16.1-beta.2"
|
||||
version = "0.18.0-beta.0"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -3969,6 +3969,8 @@ dependencies = [
|
||||
"random_word",
|
||||
"regex",
|
||||
"reqwest",
|
||||
"rstest",
|
||||
"semver 1.0.25",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_with",
|
||||
@@ -3999,7 +4001,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb-node"
|
||||
version = "0.16.1-beta.2"
|
||||
version = "0.18.0-beta.0"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-ipc",
|
||||
@@ -4024,7 +4026,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb-nodejs"
|
||||
version = "0.16.1-beta.2"
|
||||
version = "0.18.0-beta.0"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-ipc",
|
||||
@@ -4042,7 +4044,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb-python"
|
||||
version = "0.19.1-beta.2"
|
||||
version = "0.21.0-beta.0"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"env_logger",
|
||||
@@ -4133,9 +4135,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.169"
|
||||
version = "0.2.170"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a"
|
||||
checksum = "875b3680cb2f8f71bdcf9a30f38d48282f5d3c95cbf9b3fa57269bb5d5c06828"
|
||||
|
||||
[[package]]
|
||||
name = "libloading"
|
||||
@@ -4181,9 +4183,9 @@ checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab"
|
||||
|
||||
[[package]]
|
||||
name = "litemap"
|
||||
version = "0.7.4"
|
||||
version = "0.7.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104"
|
||||
checksum = "23fb14cb19457329c82206317a5663005a4d404783dc74f4252769b0d5f42856"
|
||||
|
||||
[[package]]
|
||||
name = "lock_api"
|
||||
@@ -4197,9 +4199,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "log"
|
||||
version = "0.4.25"
|
||||
version = "0.4.26"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f"
|
||||
checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e"
|
||||
|
||||
[[package]]
|
||||
name = "loom"
|
||||
@@ -4365,9 +4367,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
|
||||
|
||||
[[package]]
|
||||
name = "miniz_oxide"
|
||||
version = "0.8.4"
|
||||
version = "0.8.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b3b1c9bd4fe1f0f8b387f6eb9eb3b4a1aa26185e5750efb9140301703f62cd1b"
|
||||
checksum = "8e3e04debbb59698c15bacbb6d93584a8c0ca9cc3213cb423d31f760d8843ce5"
|
||||
dependencies = [
|
||||
"adler2",
|
||||
]
|
||||
@@ -4485,9 +4487,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "napi-build"
|
||||
version = "2.1.4"
|
||||
version = "2.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "db836caddef23662b94e16bf1f26c40eceb09d6aee5d5b06a7ac199320b69b19"
|
||||
checksum = "40685973218af4aa4b42486652692c294c44b5a67e4b2202df721c9063f2e51c"
|
||||
|
||||
[[package]]
|
||||
name = "napi-derive"
|
||||
@@ -4529,9 +4531,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "native-tls"
|
||||
version = "0.2.13"
|
||||
version = "0.2.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0dab59f8e050d5df8e4dd87d9206fb6f65a483e20ac9fda365ade4fab353196c"
|
||||
checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"log",
|
||||
@@ -4790,9 +4792,9 @@ checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e"
|
||||
|
||||
[[package]]
|
||||
name = "oneshot"
|
||||
version = "0.1.10"
|
||||
version = "0.1.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "79d72a7c0f743d2ebb0a2ad1d219db75fdc799092ed3a884c9144c42a31225bd"
|
||||
checksum = "b4ce411919553d3f9fa53a0880544cda985a112117a0444d5ff1e870a893d6ea"
|
||||
|
||||
[[package]]
|
||||
name = "onig"
|
||||
@@ -5438,9 +5440,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "portable-atomic"
|
||||
version = "1.10.0"
|
||||
version = "1.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6"
|
||||
checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e"
|
||||
|
||||
[[package]]
|
||||
name = "powerfmt"
|
||||
@@ -5870,9 +5872,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.5.8"
|
||||
version = "0.5.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834"
|
||||
checksum = "82b568323e98e49e2a0899dcee453dd679fae22d69adf9b11dd508d1549b7e2f"
|
||||
dependencies = [
|
||||
"bitflags 2.8.0",
|
||||
]
|
||||
@@ -5938,6 +5940,12 @@ version = "0.8.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
|
||||
|
||||
[[package]]
|
||||
name = "relative-path"
|
||||
version = "1.9.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2"
|
||||
|
||||
[[package]]
|
||||
name = "reqwest"
|
||||
version = "0.12.12"
|
||||
@@ -6021,9 +6029,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "ring"
|
||||
version = "0.17.9"
|
||||
version = "0.17.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e75ec5e92c4d8aede845126adc388046234541629e76029599ed35a003c7ed24"
|
||||
checksum = "da5349ae27d3887ca812fb375b45a4fbb36d8d12d2df394968cd86e35683fe73"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"cfg-if",
|
||||
@@ -6035,14 +6043,44 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "roaring"
|
||||
version = "0.10.9"
|
||||
version = "0.10.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "41589aba99537475bf697f2118357cad1c31590c5a1b9f6d9fc4ad6d07503661"
|
||||
checksum = "a652edd001c53df0b3f96a36a8dc93fce6866988efc16808235653c6bcac8bf2"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"byteorder",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rstest"
|
||||
version = "0.23.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0a2c585be59b6b5dd66a9d2084aa1d8bd52fbdb806eafdeffb52791147862035"
|
||||
dependencies = [
|
||||
"futures",
|
||||
"futures-timer",
|
||||
"rstest_macros",
|
||||
"rustc_version",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rstest_macros"
|
||||
version = "0.23.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "825ea780781b15345a146be27eaefb05085e337e869bff01b4306a4fd4a9ad5a"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"glob",
|
||||
"proc-macro-crate",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"regex",
|
||||
"relative-path",
|
||||
"rustc_version",
|
||||
"syn 2.0.98",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rust-stemmers"
|
||||
version = "1.2.0"
|
||||
@@ -6343,18 +6381,18 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4"
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.217"
|
||||
version = "1.0.218"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70"
|
||||
checksum = "e8dfc9d19bdbf6d17e22319da49161d5d0108e4188e8b680aef6299eed22df60"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.217"
|
||||
version = "1.0.218"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0"
|
||||
checksum = "f09503e191f4e797cb8aac08e9a4a4695c5edf6a2e70e376d961ddd5c969f82b"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -6363,9 +6401,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.138"
|
||||
version = "1.0.139"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949"
|
||||
checksum = "44f86c3acccc9c65b153fe1b85a3be07fe5515274ec9f0653b4a0875731c72a6"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"memchr",
|
||||
@@ -6627,9 +6665,9 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
|
||||
|
||||
[[package]]
|
||||
name = "stacker"
|
||||
version = "0.1.18"
|
||||
version = "0.1.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d08feb8f695b465baed819b03c128dc23f57a694510ab1f06c77f763975685e"
|
||||
checksum = "d9156ebd5870ef293bfb43f91c7a74528d363ec0d424afe24160ed5a4343d08a"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"cfg-if",
|
||||
@@ -7347,9 +7385,9 @@ checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-ident"
|
||||
version = "1.0.16"
|
||||
version = "1.0.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034"
|
||||
checksum = "00e2473a93778eb0bad35909dff6a10d28e63f792f16ed15e404fca9d5eeedbe"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-normalization-alignments"
|
||||
@@ -7461,9 +7499,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
|
||||
|
||||
[[package]]
|
||||
name = "uuid"
|
||||
version = "1.13.2"
|
||||
version = "1.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8c1f41ffb7cf259f1ecc2876861a17e7142e63ead296f671f81f6ae85903e0d6"
|
||||
checksum = "bd8dcafa1ca14750d8d7a05aa05988c17aab20886e1f3ae33a40223c58d92ef7"
|
||||
dependencies = [
|
||||
"getrandom 0.3.1",
|
||||
"serde",
|
||||
@@ -7981,9 +8019,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
|
||||
|
||||
[[package]]
|
||||
name = "winnow"
|
||||
version = "0.7.2"
|
||||
version = "0.7.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59690dea168f2198d1a3b0cac23b8063efcd11012f10ae4698f284808c8ef603"
|
||||
checksum = "0e7f4ea97f6f78012141bcdb6a216b2609f0979ada50b20ca5b52dde2eac2bb1"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
@@ -8077,18 +8115,18 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zerofrom"
|
||||
version = "0.1.5"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cff3ee08c995dee1859d998dea82f7374f2826091dd9cd47def953cae446cd2e"
|
||||
checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5"
|
||||
dependencies = [
|
||||
"zerofrom-derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zerofrom-derive"
|
||||
version = "0.1.5"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808"
|
||||
checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -8141,27 +8179,27 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zstd"
|
||||
version = "0.13.2"
|
||||
version = "0.13.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9"
|
||||
checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a"
|
||||
dependencies = [
|
||||
"zstd-safe",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zstd-safe"
|
||||
version = "7.2.1"
|
||||
version = "7.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059"
|
||||
checksum = "f3051792fbdc2e1e143244dc28c60f73d8470e93f3f9cbd0ead44da5ed802722"
|
||||
dependencies = [
|
||||
"zstd-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zstd-sys"
|
||||
version = "2.0.13+zstd.1.5.6"
|
||||
version = "2.0.14+zstd.1.5.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa"
|
||||
checksum = "8fb060d4926e4ac3a3ad15d864e99ceb5f343c6b34f5bd6d81ae6ed417311be5"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"pkg-config",
|
||||
|
||||
26
Cargo.toml
26
Cargo.toml
@@ -21,16 +21,16 @@ categories = ["database-implementations"]
|
||||
rust-version = "1.78.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.23.1", "features" = [
|
||||
lance = { "version" = "=0.24.0", "features" = [
|
||||
"dynamodb",
|
||||
], git = "https://github.com/lancedb/lance.git", tag = "v0.23.1-beta.4"}
|
||||
lance-io = {version = "=0.23.1", tag="v0.23.1-beta.4", git = "https://github.com/lancedb/lance.git"}
|
||||
lance-index = {version = "=0.23.1", tag="v0.23.1-beta.4", git = "https://github.com/lancedb/lance.git"}
|
||||
lance-linalg = {version = "=0.23.1", tag="v0.23.1-beta.4", git = "https://github.com/lancedb/lance.git"}
|
||||
lance-table = {version = "=0.23.1", tag="v0.23.1-beta.4", git = "https://github.com/lancedb/lance.git"}
|
||||
lance-testing = {version = "=0.23.1", tag="v0.23.1-beta.4", git = "https://github.com/lancedb/lance.git"}
|
||||
lance-datafusion = {version = "=0.23.1", tag="v0.23.1-beta.4", git = "https://github.com/lancedb/lance.git"}
|
||||
lance-encoding = {version = "=0.23.1", tag="v0.23.1-beta.4", git = "https://github.com/lancedb/lance.git"}
|
||||
], git = "https://github.com/lancedb/lance.git", tag = "v0.24.0-beta.1" }
|
||||
lance-io = { version = "=0.24.0", tag = "v0.24.0-beta.1", git = "https://github.com/lancedb/lance.git" }
|
||||
lance-index = { version = "=0.24.0", tag = "v0.24.0-beta.1", git = "https://github.com/lancedb/lance.git" }
|
||||
lance-linalg = { version = "=0.24.0", tag = "v0.24.0-beta.1", git = "https://github.com/lancedb/lance.git" }
|
||||
lance-table = { version = "=0.24.0", tag = "v0.24.0-beta.1", git = "https://github.com/lancedb/lance.git" }
|
||||
lance-testing = { version = "=0.24.0", tag = "v0.24.0-beta.1", git = "https://github.com/lancedb/lance.git" }
|
||||
lance-datafusion = { version = "=0.24.0", tag = "v0.24.0-beta.1", git = "https://github.com/lancedb/lance.git" }
|
||||
lance-encoding = { version = "=0.24.0", tag = "v0.24.0-beta.1", git = "https://github.com/lancedb/lance.git" }
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "53.2", optional = false }
|
||||
arrow-array = "53.2"
|
||||
@@ -41,7 +41,6 @@ arrow-schema = "53.2"
|
||||
arrow-arith = "53.2"
|
||||
arrow-cast = "53.2"
|
||||
async-trait = "0"
|
||||
chrono = "0.4.35"
|
||||
datafusion = { version = "44.0", default-features = false }
|
||||
datafusion-catalog = "44.0"
|
||||
datafusion-common = { version = "44.0", default-features = false }
|
||||
@@ -63,6 +62,13 @@ num-traits = "0.2"
|
||||
rand = "0.8"
|
||||
regex = "1.10"
|
||||
lazy_static = "1"
|
||||
semver = "1.0.25"
|
||||
|
||||
# Temporary pins to work around downstream issues
|
||||
# https://github.com/apache/arrow-rs/commit/2fddf85afcd20110ce783ed5b4cdeb82293da30b
|
||||
chrono = "=0.4.39"
|
||||
# https://github.com/RustCrypto/formats/issues/1684
|
||||
base64ct = "=1.6.0"
|
||||
|
||||
# Workaround for: https://github.com/eira-fransham/crunchy/issues/13
|
||||
crunchy = "=0.2.2"
|
||||
|
||||
@@ -4,6 +4,9 @@ repo_url: https://github.com/lancedb/lancedb
|
||||
edit_uri: https://github.com/lancedb/lancedb/tree/main/docs/src
|
||||
repo_name: lancedb/lancedb
|
||||
docs_dir: src
|
||||
watch:
|
||||
- src
|
||||
- ../python/python
|
||||
|
||||
theme:
|
||||
name: "material"
|
||||
@@ -63,6 +66,7 @@ plugins:
|
||||
- https://arrow.apache.org/docs/objects.inv
|
||||
- https://pandas.pydata.org/docs/objects.inv
|
||||
- https://lancedb.github.io/lance/objects.inv
|
||||
- https://docs.pydantic.dev/latest/objects.inv
|
||||
- mkdocs-jupyter
|
||||
- render_swagger:
|
||||
allow_arbitrary_locations: true
|
||||
@@ -105,8 +109,8 @@ nav:
|
||||
- 📚 Concepts:
|
||||
- Vector search: concepts/vector_search.md
|
||||
- Indexing:
|
||||
- IVFPQ: concepts/index_ivfpq.md
|
||||
- HNSW: concepts/index_hnsw.md
|
||||
- IVFPQ: concepts/index_ivfpq.md
|
||||
- HNSW: concepts/index_hnsw.md
|
||||
- Storage: concepts/storage.md
|
||||
- Data management: concepts/data_management.md
|
||||
- 🔨 Guides:
|
||||
@@ -130,8 +134,8 @@ nav:
|
||||
- Adaptive RAG: rag/adaptive_rag.md
|
||||
- SFR RAG: rag/sfr_rag.md
|
||||
- Advanced Techniques:
|
||||
- HyDE: rag/advanced_techniques/hyde.md
|
||||
- FLARE: rag/advanced_techniques/flare.md
|
||||
- HyDE: rag/advanced_techniques/hyde.md
|
||||
- FLARE: rag/advanced_techniques/flare.md
|
||||
- Reranking:
|
||||
- Quickstart: reranking/index.md
|
||||
- Cohere Reranker: reranking/cohere.md
|
||||
@@ -146,7 +150,7 @@ nav:
|
||||
- Building Custom Rerankers: reranking/custom_reranker.md
|
||||
- Example: notebooks/lancedb_reranking.ipynb
|
||||
- Filtering: sql.md
|
||||
- Versioning & Reproducibility:
|
||||
- Versioning & Reproducibility:
|
||||
- sync API: notebooks/reproducibility.ipynb
|
||||
- async API: notebooks/reproducibility_async.ipynb
|
||||
- Configuring Storage: guides/storage.md
|
||||
@@ -178,6 +182,7 @@ nav:
|
||||
- Imagebind embeddings: embeddings/available_embedding_models/multimodal_embedding_functions/imagebind_embedding.md
|
||||
- Jina Embeddings: embeddings/available_embedding_models/multimodal_embedding_functions/jina_multimodal_embedding.md
|
||||
- User-defined embedding functions: embeddings/custom_embedding_function.md
|
||||
- Variables and secrets: embeddings/variables_and_secrets.md
|
||||
- "Example: Multi-lingual semantic search": notebooks/multi_lingual_example.ipynb
|
||||
- "Example: MultiModal CLIP Embeddings": notebooks/DisappearingEmbeddingFunction.ipynb
|
||||
- 🔌 Integrations:
|
||||
@@ -240,8 +245,8 @@ nav:
|
||||
- Concepts:
|
||||
- Vector search: concepts/vector_search.md
|
||||
- Indexing:
|
||||
- IVFPQ: concepts/index_ivfpq.md
|
||||
- HNSW: concepts/index_hnsw.md
|
||||
- IVFPQ: concepts/index_ivfpq.md
|
||||
- HNSW: concepts/index_hnsw.md
|
||||
- Storage: concepts/storage.md
|
||||
- Data management: concepts/data_management.md
|
||||
- Guides:
|
||||
@@ -265,8 +270,8 @@ nav:
|
||||
- Adaptive RAG: rag/adaptive_rag.md
|
||||
- SFR RAG: rag/sfr_rag.md
|
||||
- Advanced Techniques:
|
||||
- HyDE: rag/advanced_techniques/hyde.md
|
||||
- FLARE: rag/advanced_techniques/flare.md
|
||||
- HyDE: rag/advanced_techniques/hyde.md
|
||||
- FLARE: rag/advanced_techniques/flare.md
|
||||
- Reranking:
|
||||
- Quickstart: reranking/index.md
|
||||
- Cohere Reranker: reranking/cohere.md
|
||||
@@ -280,7 +285,7 @@ nav:
|
||||
- Building Custom Rerankers: reranking/custom_reranker.md
|
||||
- Example: notebooks/lancedb_reranking.ipynb
|
||||
- Filtering: sql.md
|
||||
- Versioning & Reproducibility:
|
||||
- Versioning & Reproducibility:
|
||||
- sync API: notebooks/reproducibility.ipynb
|
||||
- async API: notebooks/reproducibility_async.ipynb
|
||||
- Configuring Storage: guides/storage.md
|
||||
@@ -311,6 +316,7 @@ nav:
|
||||
- Imagebind embeddings: embeddings/available_embedding_models/multimodal_embedding_functions/imagebind_embedding.md
|
||||
- Jina Embeddings: embeddings/available_embedding_models/multimodal_embedding_functions/jina_multimodal_embedding.md
|
||||
- User-defined embedding functions: embeddings/custom_embedding_function.md
|
||||
- Variables and secrets: embeddings/variables_and_secrets.md
|
||||
- "Example: Multi-lingual semantic search": notebooks/multi_lingual_example.ipynb
|
||||
- "Example: MultiModal CLIP Embeddings": notebooks/DisappearingEmbeddingFunction.ipynb
|
||||
- Integrations:
|
||||
@@ -349,8 +355,8 @@ nav:
|
||||
- 🦀 Rust:
|
||||
- Overview: examples/examples_rust.md
|
||||
- Studies:
|
||||
- studies/overview.md
|
||||
- ↗Improve retrievers with hybrid search and reranking: https://blog.lancedb.com/hybrid-search-and-reranking-report/
|
||||
- studies/overview.md
|
||||
- ↗Improve retrievers with hybrid search and reranking: https://blog.lancedb.com/hybrid-search-and-reranking-report/
|
||||
- API reference:
|
||||
- Overview: api_reference.md
|
||||
- Python: python/python.md
|
||||
@@ -371,6 +377,7 @@ extra_css:
|
||||
|
||||
extra_javascript:
|
||||
- "extra_js/init_ask_ai_widget.js"
|
||||
- "extra_js/reo.js"
|
||||
|
||||
extra:
|
||||
analytics:
|
||||
|
||||
@@ -55,6 +55,14 @@ Let's implement `SentenceTransformerEmbeddings` class. All you need to do is imp
|
||||
|
||||
This is a stripped down version of our implementation of `SentenceTransformerEmbeddings` that removes certain optimizations and default settings.
|
||||
|
||||
!!! danger "Use sensitive keys to prevent leaking secrets"
|
||||
To prevent leaking secrets, such as API keys, you should add any sensitive
|
||||
parameters of an embedding function to the output of the
|
||||
[sensitive_keys()][lancedb.embeddings.base.EmbeddingFunction.sensitive_keys] /
|
||||
[getSensitiveKeys()](../../js/namespaces/embedding/classes/EmbeddingFunction/#getsensitivekeys)
|
||||
method. This prevents users from accidentally instantiating the embedding
|
||||
function with hard-coded secrets.
|
||||
|
||||
Now you can use this embedding function to create your table schema and that's it! you can then ingest data and run queries without manually vectorizing the inputs.
|
||||
|
||||
=== "Python"
|
||||
|
||||
53
docs/src/embeddings/variables_and_secrets.md
Normal file
53
docs/src/embeddings/variables_and_secrets.md
Normal file
@@ -0,0 +1,53 @@
|
||||
# Variable and Secrets
|
||||
|
||||
Most embedding configuration options are saved in the table's metadata. However,
|
||||
this isn't always appropriate. For example, API keys should never be stored in the
|
||||
metadata. Additionally, other configuration options might be best set at runtime,
|
||||
such as the `device` configuration that controls whether to use GPU or CPU for
|
||||
inference. If you hardcoded this to GPU, you wouldn't be able to run the code on
|
||||
a server without one.
|
||||
|
||||
To handle these cases, you can set variables on the embedding registry and
|
||||
reference them in the embedding configuration. These variables will be available
|
||||
during the runtime of your program, but not saved in the table's metadata. When
|
||||
the table is loaded from a different process, the variables must be set again.
|
||||
|
||||
To set a variable, use the `set_var()` / `setVar()` method on the embedding registry.
|
||||
To reference a variable, use the syntax `$env:VARIABLE_NAME`. If there is a default
|
||||
value, you can use the syntax `$env:VARIABLE_NAME:DEFAULT_VALUE`.
|
||||
|
||||
## Using variables to set secrets
|
||||
|
||||
Sensitive configuration, such as API keys, must either be set as environment
|
||||
variables or using variables on the embedding registry. If you pass in a hardcoded
|
||||
value, LanceDB will raise an error. Instead, if you want to set an API key via
|
||||
configuration, use a variable:
|
||||
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
--8<-- "python/python/tests/docs/test_embeddings_optional.py:register_secret"
|
||||
```
|
||||
|
||||
=== "Typescript"
|
||||
|
||||
```typescript
|
||||
--8<-- "nodejs/examples/embedding.test.ts:register_secret"
|
||||
```
|
||||
|
||||
## Using variables to set the device parameter
|
||||
|
||||
Many embedding functions that run locally have a `device` parameter that controls
|
||||
whether to use GPU or CPU for inference. Because not all computers have a GPU,
|
||||
it's helpful to be able to set the `device` parameter at runtime, rather than
|
||||
have it hard coded in the embedding configuration. To make it work even if the
|
||||
variable isn't set, you could provide a default value of `cpu` in the embedding
|
||||
configuration.
|
||||
|
||||
Some embedding libraries even have a method to detect which devices are available,
|
||||
which could be used to dynamically set the device at runtime. For example, in Python
|
||||
you can check if a CUDA GPU is available using `torch.cuda.is_available()`.
|
||||
|
||||
```python
|
||||
--8<-- "python/python/tests/docs/test_embeddings_optional.py:register_device"
|
||||
```
|
||||
1
docs/src/extra_js/reo.js
Normal file
1
docs/src/extra_js/reo.js
Normal file
@@ -0,0 +1 @@
|
||||
!function(){var e,t,n;e="9627b71b382d201",t=function(){Reo.init({clientID:"9627b71b382d201"})},(n=document.createElement("script")).src="https://static.reo.dev/"+e+"/reo.js",n.defer=!0,n.onload=t,document.head.appendChild(n)}();
|
||||
@@ -8,6 +8,23 @@
|
||||
|
||||
An embedding function that automatically creates vector representation for a given column.
|
||||
|
||||
It's important subclasses pass the **original** options to the super constructor
|
||||
and then pass those options to `resolveVariables` to resolve any variables before
|
||||
using them.
|
||||
|
||||
## Example
|
||||
|
||||
```ts
|
||||
class MyEmbeddingFunction extends EmbeddingFunction {
|
||||
constructor(options: {model: string, timeout: number}) {
|
||||
super(optionsRaw);
|
||||
const options = this.resolveVariables(optionsRaw);
|
||||
this.model = options.model;
|
||||
this.timeout = options.timeout;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Extended by
|
||||
|
||||
- [`TextEmbeddingFunction`](TextEmbeddingFunction.md)
|
||||
@@ -82,12 +99,33 @@ The datatype of the embeddings
|
||||
|
||||
***
|
||||
|
||||
### getSensitiveKeys()
|
||||
|
||||
```ts
|
||||
protected getSensitiveKeys(): string[]
|
||||
```
|
||||
|
||||
Provide a list of keys in the function options that should be treated as
|
||||
sensitive. If users pass raw values for these keys, they will be rejected.
|
||||
|
||||
#### Returns
|
||||
|
||||
`string`[]
|
||||
|
||||
***
|
||||
|
||||
### init()?
|
||||
|
||||
```ts
|
||||
optional init(): Promise<void>
|
||||
```
|
||||
|
||||
Optionally load any resources needed for the embedding function.
|
||||
|
||||
This method is called after the embedding function has been initialized
|
||||
but before any embeddings are computed. It is useful for loading local models
|
||||
or other resources that are needed for the embedding function to work.
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<`void`>
|
||||
@@ -108,6 +146,24 @@ The number of dimensions of the embeddings
|
||||
|
||||
***
|
||||
|
||||
### resolveVariables()
|
||||
|
||||
```ts
|
||||
protected resolveVariables(config): Partial<M>
|
||||
```
|
||||
|
||||
Apply variables to the config.
|
||||
|
||||
#### Parameters
|
||||
|
||||
* **config**: `Partial`<`M`>
|
||||
|
||||
#### Returns
|
||||
|
||||
`Partial`<`M`>
|
||||
|
||||
***
|
||||
|
||||
### sourceField()
|
||||
|
||||
```ts
|
||||
@@ -134,37 +190,15 @@ sourceField is used in combination with `LanceSchema` to provide a declarative d
|
||||
### toJSON()
|
||||
|
||||
```ts
|
||||
abstract toJSON(): Partial<M>
|
||||
toJSON(): Record<string, any>
|
||||
```
|
||||
|
||||
Convert the embedding function to a JSON object
|
||||
It is used to serialize the embedding function to the schema
|
||||
It's important that any object returned by this method contains all the necessary
|
||||
information to recreate the embedding function
|
||||
|
||||
It should return the same object that was passed to the constructor
|
||||
If it does not, the embedding function will not be able to be recreated, or could be recreated incorrectly
|
||||
Get the original arguments to the constructor, to serialize them so they
|
||||
can be used to recreate the embedding function later.
|
||||
|
||||
#### Returns
|
||||
|
||||
`Partial`<`M`>
|
||||
|
||||
#### Example
|
||||
|
||||
```ts
|
||||
class MyEmbeddingFunction extends EmbeddingFunction {
|
||||
constructor(options: {model: string, timeout: number}) {
|
||||
super();
|
||||
this.model = options.model;
|
||||
this.timeout = options.timeout;
|
||||
}
|
||||
toJSON() {
|
||||
return {
|
||||
model: this.model,
|
||||
timeout: this.timeout,
|
||||
};
|
||||
}
|
||||
```
|
||||
`Record`<`string`, `any`>
|
||||
|
||||
***
|
||||
|
||||
|
||||
@@ -80,6 +80,28 @@ getTableMetadata(functions): Map<string, string>
|
||||
|
||||
***
|
||||
|
||||
### getVar()
|
||||
|
||||
```ts
|
||||
getVar(name): undefined | string
|
||||
```
|
||||
|
||||
Get a variable.
|
||||
|
||||
#### Parameters
|
||||
|
||||
* **name**: `string`
|
||||
|
||||
#### Returns
|
||||
|
||||
`undefined` \| `string`
|
||||
|
||||
#### See
|
||||
|
||||
[setVar](EmbeddingFunctionRegistry.md#setvar)
|
||||
|
||||
***
|
||||
|
||||
### length()
|
||||
|
||||
```ts
|
||||
@@ -145,3 +167,31 @@ reset the registry to the initial state
|
||||
#### Returns
|
||||
|
||||
`void`
|
||||
|
||||
***
|
||||
|
||||
### setVar()
|
||||
|
||||
```ts
|
||||
setVar(name, value): void
|
||||
```
|
||||
|
||||
Set a variable. These can be accessed in the embedding function
|
||||
configuration using the syntax `$var:variable_name`. If they are not
|
||||
set, an error will be thrown letting you know which key is unset. If you
|
||||
want to supply a default value, you can add an additional part in the
|
||||
configuration like so: `$var:variable_name:default_value`. Default values
|
||||
can be used for runtime configurations that are not sensitive, such as
|
||||
whether to use a GPU for inference.
|
||||
|
||||
The name must not contain colons. The default value can contain colons.
|
||||
|
||||
#### Parameters
|
||||
|
||||
* **name**: `string`
|
||||
|
||||
* **value**: `string`
|
||||
|
||||
#### Returns
|
||||
|
||||
`void`
|
||||
|
||||
@@ -114,12 +114,37 @@ abstract generateEmbeddings(texts, ...args): Promise<number[][] | Float32Array[]
|
||||
|
||||
***
|
||||
|
||||
### getSensitiveKeys()
|
||||
|
||||
```ts
|
||||
protected getSensitiveKeys(): string[]
|
||||
```
|
||||
|
||||
Provide a list of keys in the function options that should be treated as
|
||||
sensitive. If users pass raw values for these keys, they will be rejected.
|
||||
|
||||
#### Returns
|
||||
|
||||
`string`[]
|
||||
|
||||
#### Inherited from
|
||||
|
||||
[`EmbeddingFunction`](EmbeddingFunction.md).[`getSensitiveKeys`](EmbeddingFunction.md#getsensitivekeys)
|
||||
|
||||
***
|
||||
|
||||
### init()?
|
||||
|
||||
```ts
|
||||
optional init(): Promise<void>
|
||||
```
|
||||
|
||||
Optionally load any resources needed for the embedding function.
|
||||
|
||||
This method is called after the embedding function has been initialized
|
||||
but before any embeddings are computed. It is useful for loading local models
|
||||
or other resources that are needed for the embedding function to work.
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<`void`>
|
||||
@@ -148,6 +173,28 @@ The number of dimensions of the embeddings
|
||||
|
||||
***
|
||||
|
||||
### resolveVariables()
|
||||
|
||||
```ts
|
||||
protected resolveVariables(config): Partial<M>
|
||||
```
|
||||
|
||||
Apply variables to the config.
|
||||
|
||||
#### Parameters
|
||||
|
||||
* **config**: `Partial`<`M`>
|
||||
|
||||
#### Returns
|
||||
|
||||
`Partial`<`M`>
|
||||
|
||||
#### Inherited from
|
||||
|
||||
[`EmbeddingFunction`](EmbeddingFunction.md).[`resolveVariables`](EmbeddingFunction.md#resolvevariables)
|
||||
|
||||
***
|
||||
|
||||
### sourceField()
|
||||
|
||||
```ts
|
||||
@@ -173,37 +220,15 @@ sourceField is used in combination with `LanceSchema` to provide a declarative d
|
||||
### toJSON()
|
||||
|
||||
```ts
|
||||
abstract toJSON(): Partial<M>
|
||||
toJSON(): Record<string, any>
|
||||
```
|
||||
|
||||
Convert the embedding function to a JSON object
|
||||
It is used to serialize the embedding function to the schema
|
||||
It's important that any object returned by this method contains all the necessary
|
||||
information to recreate the embedding function
|
||||
|
||||
It should return the same object that was passed to the constructor
|
||||
If it does not, the embedding function will not be able to be recreated, or could be recreated incorrectly
|
||||
Get the original arguments to the constructor, to serialize them so they
|
||||
can be used to recreate the embedding function later.
|
||||
|
||||
#### Returns
|
||||
|
||||
`Partial`<`M`>
|
||||
|
||||
#### Example
|
||||
|
||||
```ts
|
||||
class MyEmbeddingFunction extends EmbeddingFunction {
|
||||
constructor(options: {model: string, timeout: number}) {
|
||||
super();
|
||||
this.model = options.model;
|
||||
this.timeout = options.timeout;
|
||||
}
|
||||
toJSON() {
|
||||
return {
|
||||
model: this.model,
|
||||
timeout: this.timeout,
|
||||
};
|
||||
}
|
||||
```
|
||||
`Record`<`string`, `any`>
|
||||
|
||||
#### Inherited from
|
||||
|
||||
|
||||
@@ -9,23 +9,50 @@ LanceDB supports [Polars](https://github.com/pola-rs/polars), a blazingly fast D
|
||||
|
||||
First, we connect to a LanceDB database.
|
||||
|
||||
=== "Sync API"
|
||||
|
||||
```py
|
||||
--8<-- "python/python/tests/docs/test_python.py:import-lancedb"
|
||||
--8<-- "python/python/tests/docs/test_python.py:connect_to_lancedb"
|
||||
```
|
||||
|
||||
=== "Async API"
|
||||
|
||||
```py
|
||||
--8<-- "python/python/tests/docs/test_python.py:import-lancedb"
|
||||
--8<-- "python/python/tests/docs/test_python.py:connect_to_lancedb_async"
|
||||
```
|
||||
|
||||
```py
|
||||
--8<-- "python/python/tests/docs/test_python.py:import-lancedb"
|
||||
--8<-- "python/python/tests/docs/test_python.py:connect_to_lancedb"
|
||||
```
|
||||
|
||||
We can load a Polars `DataFrame` to LanceDB directly.
|
||||
|
||||
```py
|
||||
--8<-- "python/python/tests/docs/test_python.py:import-polars"
|
||||
--8<-- "python/python/tests/docs/test_python.py:create_table_polars"
|
||||
```
|
||||
=== "Sync API"
|
||||
|
||||
```py
|
||||
--8<-- "python/python/tests/docs/test_python.py:import-polars"
|
||||
--8<-- "python/python/tests/docs/test_python.py:create_table_polars"
|
||||
```
|
||||
|
||||
=== "Async API"
|
||||
|
||||
```py
|
||||
--8<-- "python/python/tests/docs/test_python.py:import-polars"
|
||||
--8<-- "python/python/tests/docs/test_python.py:create_table_polars_async"
|
||||
```
|
||||
|
||||
We can now perform similarity search via the LanceDB Python API.
|
||||
|
||||
```py
|
||||
--8<-- "python/python/tests/docs/test_python.py:vector_search_polars"
|
||||
```
|
||||
=== "Sync API"
|
||||
|
||||
```py
|
||||
--8<-- "python/python/tests/docs/test_python.py:vector_search_polars"
|
||||
```
|
||||
|
||||
=== "Async API"
|
||||
|
||||
```py
|
||||
--8<-- "python/python/tests/docs/test_python.py:vector_search_polars_async"
|
||||
```
|
||||
|
||||
In addition to the selected columns, LanceDB also returns a vector
|
||||
and also the `_distance` column which is the distance between the query
|
||||
@@ -112,4 +139,3 @@ The reason it's beneficial to not convert the LanceDB Table
|
||||
to a DataFrame is because the table can potentially be way larger
|
||||
than memory, and Polars LazyFrames allow us to work with such
|
||||
larger-than-memory datasets by not loading it into memory all at once.
|
||||
|
||||
|
||||
@@ -2,14 +2,19 @@
|
||||
|
||||
[Pydantic](https://docs.pydantic.dev/latest/) is a data validation library in Python.
|
||||
LanceDB integrates with Pydantic for schema inference, data ingestion, and query result casting.
|
||||
Using [LanceModel][lancedb.pydantic.LanceModel], users can seamlessly
|
||||
integrate Pydantic with the rest of the LanceDB APIs.
|
||||
|
||||
## Schema
|
||||
```python
|
||||
|
||||
LanceDB supports to create Apache Arrow Schema from a
|
||||
[Pydantic BaseModel](https://docs.pydantic.dev/latest/api/main/#pydantic.main.BaseModel)
|
||||
via [pydantic_to_schema()](python.md#lancedb.pydantic.pydantic_to_schema) method.
|
||||
--8<-- "python/python/tests/docs/test_pydantic_integration.py:imports"
|
||||
|
||||
--8<-- "python/python/tests/docs/test_pydantic_integration.py:base_model"
|
||||
|
||||
--8<-- "python/python/tests/docs/test_pydantic_integration.py:set_url"
|
||||
--8<-- "python/python/tests/docs/test_pydantic_integration.py:base_example"
|
||||
```
|
||||
|
||||
::: lancedb.pydantic.pydantic_to_schema
|
||||
|
||||
## Vector Field
|
||||
|
||||
@@ -34,3 +39,9 @@ Current supported type conversions:
|
||||
| `list` | `pyarrow.List` |
|
||||
| `BaseModel` | `pyarrow.Struct` |
|
||||
| `Vector(n)` | `pyarrow.FixedSizeList(float32, n)` |
|
||||
|
||||
LanceDB supports to create Apache Arrow Schema from a
|
||||
[Pydantic BaseModel][pydantic.BaseModel]
|
||||
via [pydantic_to_schema()](python.md#lancedb.pydantic.pydantic_to_schema) method.
|
||||
|
||||
::: lancedb.pydantic.pydantic_to_schema
|
||||
|
||||
@@ -15,6 +15,7 @@ excluded_globs = [
|
||||
"../src/python/duckdb.md",
|
||||
"../src/python/pandas_and_pyarrow.md",
|
||||
"../src/python/polars_arrow.md",
|
||||
"../src/python/pydantic.md",
|
||||
"../src/embeddings/*.md",
|
||||
"../src/concepts/*.md",
|
||||
"../src/ann_indexes.md",
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.16.1-beta.2</version>
|
||||
<version>0.18.0-beta.0</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.16.1-beta.2</version>
|
||||
<version>0.18.0-beta.0</version>
|
||||
<packaging>pom</packaging>
|
||||
|
||||
<name>LanceDB Parent</name>
|
||||
|
||||
124
node/package-lock.json
generated
124
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.16.1-beta.2",
|
||||
"version": "0.18.0-beta.0",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "vectordb",
|
||||
"version": "0.16.1-beta.2",
|
||||
"version": "0.18.0-beta.0",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
@@ -52,14 +52,14 @@
|
||||
"uuid": "^9.0.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.16.1-beta.2",
|
||||
"@lancedb/vectordb-darwin-x64": "0.16.1-beta.2",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.16.1-beta.2",
|
||||
"@lancedb/vectordb-linux-arm64-musl": "0.16.1-beta.2",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.16.1-beta.2",
|
||||
"@lancedb/vectordb-linux-x64-musl": "0.16.1-beta.2",
|
||||
"@lancedb/vectordb-win32-arm64-msvc": "0.16.1-beta.2",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.16.1-beta.2"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.18.0-beta.0",
|
||||
"@lancedb/vectordb-darwin-x64": "0.18.0-beta.0",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.18.0-beta.0",
|
||||
"@lancedb/vectordb-linux-arm64-musl": "0.18.0-beta.0",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.18.0-beta.0",
|
||||
"@lancedb/vectordb-linux-x64-musl": "0.18.0-beta.0",
|
||||
"@lancedb/vectordb-win32-arm64-msvc": "0.18.0-beta.0",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.18.0-beta.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@apache-arrow/ts": "^14.0.2",
|
||||
@@ -329,6 +329,110 @@
|
||||
"@jridgewell/sourcemap-codec": "^1.4.10"
|
||||
}
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
||||
"version": "0.18.0-beta.0",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.18.0-beta.0.tgz",
|
||||
"integrity": "sha512-dLLgMPllYJOiRfPqkqkmoQu48RIa7K4dOF/qFP8Aex3zqeHE/0sFm3DYjtSFc6SR/6yT8u6Y9iFo2cQp5rCFJA==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-x64": {
|
||||
"version": "0.18.0-beta.0",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.18.0-beta.0.tgz",
|
||||
"integrity": "sha512-la0eauU0rzHO5eeVjBt8o/5UW4VzRYAuRA7nqUFLX5T6SWP5+UWjqusVVbWGz3ski+8uEX6VhlaFZP5uIJKGIg==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
||||
"version": "0.18.0-beta.0",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.18.0-beta.0.tgz",
|
||||
"integrity": "sha512-AkXI/lB3yu1Di2G1lhilf89V6qPTppb13aAt+/6gU5/PSfA94y9VXD67D4WyvRbuQghJjDvAavMlWMrJc2NuMw==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-arm64-musl": {
|
||||
"version": "0.18.0-beta.0",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-musl/-/vectordb-linux-arm64-musl-0.18.0-beta.0.tgz",
|
||||
"integrity": "sha512-kTVcJ4LA8w/7egY4m0EXOt8c1DeFUquVtyvexO+VzIFeeHfBkkrMI0DkE0CpHmk+gctkG7EY39jzjgLnPvppnw==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
||||
"version": "0.18.0-beta.0",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.18.0-beta.0.tgz",
|
||||
"integrity": "sha512-KbtIy5DkaWTsKENm5Q27hjovrR7FRuoHhl0wDJtO/2CUZYlrskjEIfcfkfA2CrEQesBug4s5jgsvNM4Wcp6zoA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-x64-musl": {
|
||||
"version": "0.18.0-beta.0",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-musl/-/vectordb-linux-x64-musl-0.18.0-beta.0.tgz",
|
||||
"integrity": "sha512-SF07gmoGVExcF5v+IE6kBbCbXJSDyTgC7QCt+MDS1NsgoQ9OH7IyH7r6HJu16tKflUOUKlUHnP0hQOPpv1fWpg==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-win32-arm64-msvc": {
|
||||
"version": "0.18.0-beta.0",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-arm64-msvc/-/vectordb-win32-arm64-msvc-0.18.0-beta.0.tgz",
|
||||
"integrity": "sha512-YYBuSBGDlxJgSI5gHjDmQo9sl05lAXfzil6QiKfgmUMsBtb2sT+GoUCgG6qzsfe99sWiTf+pMeWDsQgfrj9vNw==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"win32"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
||||
"version": "0.18.0-beta.0",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.18.0-beta.0.tgz",
|
||||
"integrity": "sha512-t9TXeUnMU7YbP+/nUJpStm75aWwUydZj2AK+G2XwDtQrQo4Xg7/NETEbBeogmIOHuidNQYia8jEeQCUon5/+Dw==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"win32"
|
||||
]
|
||||
},
|
||||
"node_modules/@neon-rs/cli": {
|
||||
"version": "0.0.160",
|
||||
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.16.1-beta.2",
|
||||
"version": "0.18.0-beta.0",
|
||||
"description": " Serverless, low-latency vector database for AI applications",
|
||||
"private": false,
|
||||
"main": "dist/index.js",
|
||||
@@ -92,13 +92,13 @@
|
||||
}
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-x64": "0.16.1-beta.2",
|
||||
"@lancedb/vectordb-darwin-arm64": "0.16.1-beta.2",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.16.1-beta.2",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.16.1-beta.2",
|
||||
"@lancedb/vectordb-linux-x64-musl": "0.16.1-beta.2",
|
||||
"@lancedb/vectordb-linux-arm64-musl": "0.16.1-beta.2",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.16.1-beta.2",
|
||||
"@lancedb/vectordb-win32-arm64-msvc": "0.16.1-beta.2"
|
||||
"@lancedb/vectordb-darwin-x64": "0.18.0-beta.0",
|
||||
"@lancedb/vectordb-darwin-arm64": "0.18.0-beta.0",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.18.0-beta.0",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.18.0-beta.0",
|
||||
"@lancedb/vectordb-linux-x64-musl": "0.18.0-beta.0",
|
||||
"@lancedb/vectordb-linux-arm64-musl": "0.18.0-beta.0",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.18.0-beta.0",
|
||||
"@lancedb/vectordb-win32-arm64-msvc": "0.18.0-beta.0"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "lancedb-nodejs"
|
||||
edition.workspace = true
|
||||
version = "0.16.1-beta.2"
|
||||
version = "0.18.0-beta.0"
|
||||
license.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
|
||||
@@ -17,6 +17,8 @@ import {
|
||||
import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding";
|
||||
import { getRegistry, register } from "../lancedb/embedding/registry";
|
||||
|
||||
const testOpenAIInteg = process.env.OPENAI_API_KEY == null ? test.skip : test;
|
||||
|
||||
describe("embedding functions", () => {
|
||||
let tmpDir: tmp.DirResult;
|
||||
beforeEach(() => {
|
||||
@@ -29,9 +31,6 @@ describe("embedding functions", () => {
|
||||
|
||||
it("should be able to create a table with an embedding function", async () => {
|
||||
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||
toJSON(): object {
|
||||
return {};
|
||||
}
|
||||
ndims() {
|
||||
return 3;
|
||||
}
|
||||
@@ -75,9 +74,6 @@ describe("embedding functions", () => {
|
||||
it("should be able to append and upsert using embedding function", async () => {
|
||||
@register()
|
||||
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||
toJSON(): object {
|
||||
return {};
|
||||
}
|
||||
ndims() {
|
||||
return 3;
|
||||
}
|
||||
@@ -143,9 +139,6 @@ describe("embedding functions", () => {
|
||||
it("should be able to create an empty table with an embedding function", async () => {
|
||||
@register()
|
||||
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||
toJSON(): object {
|
||||
return {};
|
||||
}
|
||||
ndims() {
|
||||
return 3;
|
||||
}
|
||||
@@ -194,9 +187,6 @@ describe("embedding functions", () => {
|
||||
it("should error when appending to a table with an unregistered embedding function", async () => {
|
||||
@register("mock")
|
||||
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||
toJSON(): object {
|
||||
return {};
|
||||
}
|
||||
ndims() {
|
||||
return 3;
|
||||
}
|
||||
@@ -241,13 +231,35 @@ describe("embedding functions", () => {
|
||||
`Function "mock" not found in registry`,
|
||||
);
|
||||
});
|
||||
|
||||
testOpenAIInteg("propagates variables through all methods", async () => {
|
||||
delete process.env.OPENAI_API_KEY;
|
||||
const registry = getRegistry();
|
||||
registry.setVar("openai_api_key", "sk-...");
|
||||
const func = registry.get("openai")?.create({
|
||||
model: "text-embedding-ada-002",
|
||||
apiKey: "$var:openai_api_key",
|
||||
}) as EmbeddingFunction;
|
||||
|
||||
const db = await connect("memory://");
|
||||
const wordsSchema = LanceSchema({
|
||||
text: func.sourceField(new Utf8()),
|
||||
vector: func.vectorField(),
|
||||
});
|
||||
const tbl = await db.createEmptyTable("words", wordsSchema, {
|
||||
mode: "overwrite",
|
||||
});
|
||||
await tbl.add([{ text: "hello world" }, { text: "goodbye world" }]);
|
||||
|
||||
const query = "greetings";
|
||||
const actual = (await tbl.search(query).limit(1).toArray())[0];
|
||||
expect(actual).toHaveProperty("text");
|
||||
});
|
||||
|
||||
test.each([new Float16(), new Float32(), new Float64()])(
|
||||
"should be able to provide manual embeddings with multiple float datatype",
|
||||
async (floatType) => {
|
||||
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||
toJSON(): object {
|
||||
return {};
|
||||
}
|
||||
ndims() {
|
||||
return 3;
|
||||
}
|
||||
@@ -292,10 +304,6 @@ describe("embedding functions", () => {
|
||||
async (floatType) => {
|
||||
@register("test1")
|
||||
class MockEmbeddingFunctionWithoutNDims extends EmbeddingFunction<string> {
|
||||
toJSON(): object {
|
||||
return {};
|
||||
}
|
||||
|
||||
embeddingDataType(): Float {
|
||||
return floatType;
|
||||
}
|
||||
@@ -310,9 +318,6 @@ describe("embedding functions", () => {
|
||||
}
|
||||
@register("test")
|
||||
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||
toJSON(): object {
|
||||
return {};
|
||||
}
|
||||
ndims() {
|
||||
return 3;
|
||||
}
|
||||
|
||||
@@ -11,7 +11,11 @@ import * as arrow18 from "apache-arrow-18";
|
||||
import * as tmp from "tmp";
|
||||
|
||||
import { connect } from "../lancedb";
|
||||
import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding";
|
||||
import {
|
||||
EmbeddingFunction,
|
||||
FunctionOptions,
|
||||
LanceSchema,
|
||||
} from "../lancedb/embedding";
|
||||
import { getRegistry, register } from "../lancedb/embedding/registry";
|
||||
|
||||
describe.each([arrow15, arrow16, arrow17, arrow18])("LanceSchema", (arrow) => {
|
||||
@@ -39,11 +43,6 @@ describe.each([arrow15, arrow16, arrow17, arrow18])("Registry", (arrow) => {
|
||||
it("should register a new item to the registry", async () => {
|
||||
@register("mock-embedding")
|
||||
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||
toJSON(): object {
|
||||
return {
|
||||
someText: "hello",
|
||||
};
|
||||
}
|
||||
constructor() {
|
||||
super();
|
||||
}
|
||||
@@ -89,11 +88,6 @@ describe.each([arrow15, arrow16, arrow17, arrow18])("Registry", (arrow) => {
|
||||
});
|
||||
test("should error if registering with the same name", async () => {
|
||||
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||
toJSON(): object {
|
||||
return {
|
||||
someText: "hello",
|
||||
};
|
||||
}
|
||||
constructor() {
|
||||
super();
|
||||
}
|
||||
@@ -114,13 +108,9 @@ describe.each([arrow15, arrow16, arrow17, arrow18])("Registry", (arrow) => {
|
||||
});
|
||||
test("schema should contain correct metadata", async () => {
|
||||
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||
toJSON(): object {
|
||||
return {
|
||||
someText: "hello",
|
||||
};
|
||||
}
|
||||
constructor() {
|
||||
constructor(args: FunctionOptions = {}) {
|
||||
super();
|
||||
this.resolveVariables(args);
|
||||
}
|
||||
ndims() {
|
||||
return 3;
|
||||
@@ -132,7 +122,7 @@ describe.each([arrow15, arrow16, arrow17, arrow18])("Registry", (arrow) => {
|
||||
return data.map(() => [1, 2, 3]);
|
||||
}
|
||||
}
|
||||
const func = new MockEmbeddingFunction();
|
||||
const func = new MockEmbeddingFunction({ someText: "hello" });
|
||||
|
||||
const schema = LanceSchema({
|
||||
id: new arrow.Int32(),
|
||||
@@ -155,3 +145,79 @@ describe.each([arrow15, arrow16, arrow17, arrow18])("Registry", (arrow) => {
|
||||
expect(schema.metadata).toEqual(expectedMetadata);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Registry.setVar", () => {
|
||||
const registry = getRegistry();
|
||||
|
||||
beforeEach(() => {
|
||||
@register("mock-embedding")
|
||||
// biome-ignore lint/correctness/noUnusedVariables :
|
||||
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||
constructor(optionsRaw: FunctionOptions = {}) {
|
||||
super();
|
||||
const options = this.resolveVariables(optionsRaw);
|
||||
|
||||
expect(optionsRaw["someKey"].startsWith("$var:someName")).toBe(true);
|
||||
expect(options["someKey"]).toBe("someValue");
|
||||
|
||||
if (options["secretKey"]) {
|
||||
expect(optionsRaw["secretKey"]).toBe("$var:secretKey");
|
||||
expect(options["secretKey"]).toBe("mySecret");
|
||||
}
|
||||
}
|
||||
async computeSourceEmbeddings(data: string[]) {
|
||||
return data.map(() => [1, 2, 3]);
|
||||
}
|
||||
embeddingDataType() {
|
||||
return new arrow18.Float32() as apiArrow.Float;
|
||||
}
|
||||
protected getSensitiveKeys() {
|
||||
return ["secretKey"];
|
||||
}
|
||||
}
|
||||
});
|
||||
afterEach(() => {
|
||||
registry.reset();
|
||||
});
|
||||
|
||||
it("Should error if the variable is not set", () => {
|
||||
console.log(registry.get("mock-embedding"));
|
||||
expect(() =>
|
||||
registry.get("mock-embedding")!.create({ someKey: "$var:someName" }),
|
||||
).toThrow('Variable "someName" not found');
|
||||
});
|
||||
|
||||
it("should use default values if not set", () => {
|
||||
registry
|
||||
.get("mock-embedding")!
|
||||
.create({ someKey: "$var:someName:someValue" });
|
||||
});
|
||||
|
||||
it("should set a variable that the embedding function understand", () => {
|
||||
registry.setVar("someName", "someValue");
|
||||
registry.get("mock-embedding")!.create({ someKey: "$var:someName" });
|
||||
});
|
||||
|
||||
it("should reject secrets that aren't passed as variables", () => {
|
||||
registry.setVar("someName", "someValue");
|
||||
expect(() =>
|
||||
registry
|
||||
.get("mock-embedding")!
|
||||
.create({ secretKey: "someValue", someKey: "$var:someName" }),
|
||||
).toThrow(
|
||||
'The key "secretKey" is sensitive and cannot be set directly. Please use the $var: syntax to set it.',
|
||||
);
|
||||
});
|
||||
|
||||
it("should not serialize secrets", () => {
|
||||
registry.setVar("someName", "someValue");
|
||||
registry.setVar("secretKey", "mySecret");
|
||||
const func = registry
|
||||
.get("mock-embedding")!
|
||||
.create({ secretKey: "$var:secretKey", someKey: "$var:someName" });
|
||||
expect(func.toJSON()).toEqual({
|
||||
secretKey: "$var:secretKey",
|
||||
someKey: "$var:someName",
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -666,11 +666,11 @@ describe("When creating an index", () => {
|
||||
expect(fs.readdirSync(indexDir)).toHaveLength(1);
|
||||
|
||||
for await (const r of tbl.query().where("id > 1").select(["id"])) {
|
||||
expect(r.numRows).toBe(10);
|
||||
expect(r.numRows).toBe(298);
|
||||
}
|
||||
// should also work with 'filter' alias
|
||||
for await (const r of tbl.query().filter("id > 1").select(["id"])) {
|
||||
expect(r.numRows).toBe(10);
|
||||
expect(r.numRows).toBe(298);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -1038,9 +1038,6 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
||||
test("can search using a string", async () => {
|
||||
@register()
|
||||
class MockEmbeddingFunction extends EmbeddingFunction<string> {
|
||||
toJSON(): object {
|
||||
return {};
|
||||
}
|
||||
ndims() {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -43,12 +43,17 @@ test("custom embedding function", async () => {
|
||||
|
||||
@register("my_embedding")
|
||||
class MyEmbeddingFunction extends EmbeddingFunction<string> {
|
||||
toJSON(): object {
|
||||
return {};
|
||||
constructor(optionsRaw = {}) {
|
||||
super();
|
||||
const options = this.resolveVariables(optionsRaw);
|
||||
// Initialize using options
|
||||
}
|
||||
ndims() {
|
||||
return 3;
|
||||
}
|
||||
protected getSensitiveKeys(): string[] {
|
||||
return [];
|
||||
}
|
||||
embeddingDataType(): Float {
|
||||
return new Float32();
|
||||
}
|
||||
@@ -94,3 +99,14 @@ test("custom embedding function", async () => {
|
||||
expect(await table2.countRows()).toBe(2);
|
||||
});
|
||||
});
|
||||
|
||||
test("embedding function api_key", async () => {
|
||||
// --8<-- [start:register_secret]
|
||||
const registry = getRegistry();
|
||||
registry.setVar("api_key", "sk-...");
|
||||
|
||||
const func = registry.get("openai")!.create({
|
||||
apiKey: "$var:api_key",
|
||||
});
|
||||
// --8<-- [end:register_secret]
|
||||
});
|
||||
|
||||
@@ -15,6 +15,7 @@ import {
|
||||
newVectorType,
|
||||
} from "../arrow";
|
||||
import { sanitizeType } from "../sanitize";
|
||||
import { getRegistry } from "./registry";
|
||||
|
||||
/**
|
||||
* Options for a given embedding function
|
||||
@@ -32,6 +33,22 @@ export interface EmbeddingFunctionConstructor<
|
||||
|
||||
/**
|
||||
* An embedding function that automatically creates vector representation for a given column.
|
||||
*
|
||||
* It's important subclasses pass the **original** options to the super constructor
|
||||
* and then pass those options to `resolveVariables` to resolve any variables before
|
||||
* using them.
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* class MyEmbeddingFunction extends EmbeddingFunction {
|
||||
* constructor(options: {model: string, timeout: number}) {
|
||||
* super(optionsRaw);
|
||||
* const options = this.resolveVariables(optionsRaw);
|
||||
* this.model = options.model;
|
||||
* this.timeout = options.timeout;
|
||||
* }
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
export abstract class EmbeddingFunction<
|
||||
// biome-ignore lint/suspicious/noExplicitAny: we don't know what the implementor will do
|
||||
@@ -44,33 +61,74 @@ export abstract class EmbeddingFunction<
|
||||
*/
|
||||
// biome-ignore lint/style/useNamingConvention: we want to keep the name as it is
|
||||
readonly TOptions!: M;
|
||||
/**
|
||||
* Convert the embedding function to a JSON object
|
||||
* It is used to serialize the embedding function to the schema
|
||||
* It's important that any object returned by this method contains all the necessary
|
||||
* information to recreate the embedding function
|
||||
*
|
||||
* It should return the same object that was passed to the constructor
|
||||
* If it does not, the embedding function will not be able to be recreated, or could be recreated incorrectly
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* class MyEmbeddingFunction extends EmbeddingFunction {
|
||||
* constructor(options: {model: string, timeout: number}) {
|
||||
* super();
|
||||
* this.model = options.model;
|
||||
* this.timeout = options.timeout;
|
||||
* }
|
||||
* toJSON() {
|
||||
* return {
|
||||
* model: this.model,
|
||||
* timeout: this.timeout,
|
||||
* };
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
abstract toJSON(): Partial<M>;
|
||||
|
||||
#config: Partial<M>;
|
||||
|
||||
/**
|
||||
* Get the original arguments to the constructor, to serialize them so they
|
||||
* can be used to recreate the embedding function later.
|
||||
*/
|
||||
// biome-ignore lint/suspicious/noExplicitAny :
|
||||
toJSON(): Record<string, any> {
|
||||
return JSON.parse(JSON.stringify(this.#config));
|
||||
}
|
||||
|
||||
constructor() {
|
||||
this.#config = {};
|
||||
}
|
||||
|
||||
/**
|
||||
* Provide a list of keys in the function options that should be treated as
|
||||
* sensitive. If users pass raw values for these keys, they will be rejected.
|
||||
*/
|
||||
protected getSensitiveKeys(): string[] {
|
||||
return [];
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply variables to the config.
|
||||
*/
|
||||
protected resolveVariables(config: Partial<M>): Partial<M> {
|
||||
this.#config = config;
|
||||
const registry = getRegistry();
|
||||
const newConfig = { ...config };
|
||||
for (const [key_, value] of Object.entries(newConfig)) {
|
||||
if (
|
||||
this.getSensitiveKeys().includes(key_) &&
|
||||
!value.startsWith("$var:")
|
||||
) {
|
||||
throw new Error(
|
||||
`The key "${key_}" is sensitive and cannot be set directly. Please use the $var: syntax to set it.`,
|
||||
);
|
||||
}
|
||||
// Makes TS happy (https://stackoverflow.com/a/78391854)
|
||||
const key = key_ as keyof M;
|
||||
if (typeof value === "string" && value.startsWith("$var:")) {
|
||||
const [name, defaultValue] = value.slice(5).split(":", 2);
|
||||
const variableValue = registry.getVar(name);
|
||||
if (!variableValue) {
|
||||
if (defaultValue) {
|
||||
// biome-ignore lint/suspicious/noExplicitAny:
|
||||
newConfig[key] = defaultValue as any;
|
||||
} else {
|
||||
throw new Error(`Variable "${name}" not found`);
|
||||
}
|
||||
} else {
|
||||
// biome-ignore lint/suspicious/noExplicitAny:
|
||||
newConfig[key] = variableValue as any;
|
||||
}
|
||||
}
|
||||
}
|
||||
return newConfig;
|
||||
}
|
||||
|
||||
/**
|
||||
* Optionally load any resources needed for the embedding function.
|
||||
*
|
||||
* This method is called after the embedding function has been initialized
|
||||
* but before any embeddings are computed. It is useful for loading local models
|
||||
* or other resources that are needed for the embedding function to work.
|
||||
*/
|
||||
async init?(): Promise<void>;
|
||||
|
||||
/**
|
||||
|
||||
@@ -21,11 +21,13 @@ export class OpenAIEmbeddingFunction extends EmbeddingFunction<
|
||||
#modelName: OpenAIOptions["model"];
|
||||
|
||||
constructor(
|
||||
options: Partial<OpenAIOptions> = {
|
||||
optionsRaw: Partial<OpenAIOptions> = {
|
||||
model: "text-embedding-ada-002",
|
||||
},
|
||||
) {
|
||||
super();
|
||||
const options = this.resolveVariables(optionsRaw);
|
||||
|
||||
const openAIKey = options?.apiKey ?? process.env.OPENAI_API_KEY;
|
||||
if (!openAIKey) {
|
||||
throw new Error("OpenAI API key is required");
|
||||
@@ -52,10 +54,8 @@ export class OpenAIEmbeddingFunction extends EmbeddingFunction<
|
||||
this.#modelName = modelName;
|
||||
}
|
||||
|
||||
toJSON() {
|
||||
return {
|
||||
model: this.#modelName,
|
||||
};
|
||||
protected getSensitiveKeys(): string[] {
|
||||
return ["apiKey"];
|
||||
}
|
||||
|
||||
ndims(): number {
|
||||
|
||||
@@ -23,6 +23,7 @@ export interface EmbeddingFunctionCreate<T extends EmbeddingFunction> {
|
||||
*/
|
||||
export class EmbeddingFunctionRegistry {
|
||||
#functions = new Map<string, EmbeddingFunctionConstructor>();
|
||||
#variables = new Map<string, string>();
|
||||
|
||||
/**
|
||||
* Get the number of registered functions
|
||||
@@ -82,10 +83,7 @@ export class EmbeddingFunctionRegistry {
|
||||
};
|
||||
} else {
|
||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||
create = function (options?: any) {
|
||||
const instance = new factory(options);
|
||||
return instance;
|
||||
};
|
||||
create = (options?: any) => new factory(options);
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -164,6 +162,37 @@ export class EmbeddingFunctionRegistry {
|
||||
|
||||
return metadata;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set a variable. These can be accessed in the embedding function
|
||||
* configuration using the syntax `$var:variable_name`. If they are not
|
||||
* set, an error will be thrown letting you know which key is unset. If you
|
||||
* want to supply a default value, you can add an additional part in the
|
||||
* configuration like so: `$var:variable_name:default_value`. Default values
|
||||
* can be used for runtime configurations that are not sensitive, such as
|
||||
* whether to use a GPU for inference.
|
||||
*
|
||||
* The name must not contain colons. The default value can contain colons.
|
||||
*
|
||||
* @param name
|
||||
* @param value
|
||||
*/
|
||||
setVar(name: string, value: string): void {
|
||||
if (name.includes(":")) {
|
||||
throw new Error("Variable names cannot contain colons");
|
||||
}
|
||||
this.#variables.set(name, value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a variable.
|
||||
* @param name
|
||||
* @returns
|
||||
* @see {@link setVar}
|
||||
*/
|
||||
getVar(name: string): string | undefined {
|
||||
return this.#variables.get(name);
|
||||
}
|
||||
}
|
||||
|
||||
const _REGISTRY = new EmbeddingFunctionRegistry();
|
||||
|
||||
@@ -44,11 +44,12 @@ export class TransformersEmbeddingFunction extends EmbeddingFunction<
|
||||
#ndims?: number;
|
||||
|
||||
constructor(
|
||||
options: Partial<XenovaTransformerOptions> = {
|
||||
optionsRaw: Partial<XenovaTransformerOptions> = {
|
||||
model: "Xenova/all-MiniLM-L6-v2",
|
||||
},
|
||||
) {
|
||||
super();
|
||||
const options = this.resolveVariables(optionsRaw);
|
||||
|
||||
const modelName = options?.model ?? "Xenova/all-MiniLM-L6-v2";
|
||||
this.#tokenizerOptions = {
|
||||
@@ -59,22 +60,6 @@ export class TransformersEmbeddingFunction extends EmbeddingFunction<
|
||||
this.#ndims = options.ndims;
|
||||
this.#modelName = modelName;
|
||||
}
|
||||
toJSON() {
|
||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||
const obj: Record<string, any> = {
|
||||
model: this.#modelName,
|
||||
};
|
||||
if (this.#ndims) {
|
||||
obj["ndims"] = this.#ndims;
|
||||
}
|
||||
if (this.#tokenizerOptions) {
|
||||
obj["tokenizerOptions"] = this.#tokenizerOptions;
|
||||
}
|
||||
if (this.#tokenizer) {
|
||||
obj["tokenizer"] = this.#tokenizer.name;
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
|
||||
async init() {
|
||||
let transformers;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-arm64",
|
||||
"version": "0.16.1-beta.2",
|
||||
"version": "0.18.0-beta.0",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.darwin-arm64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-x64",
|
||||
"version": "0.16.1-beta.2",
|
||||
"version": "0.18.0-beta.0",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.darwin-x64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||
"version": "0.16.1-beta.2",
|
||||
"version": "0.18.0-beta.0",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||
"version": "0.16.1-beta.2",
|
||||
"version": "0.18.0-beta.0",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||
"version": "0.16.1-beta.2",
|
||||
"version": "0.18.0-beta.0",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||
"version": "0.16.1-beta.2",
|
||||
"version": "0.18.0-beta.0",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||
"version": "0.16.1-beta.2",
|
||||
"version": "0.18.0-beta.0",
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||
"version": "0.16.1-beta.2",
|
||||
"version": "0.18.0-beta.0",
|
||||
"os": ["win32"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.win32-x64-msvc.node",
|
||||
|
||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.16.1-beta.2",
|
||||
"version": "0.18.0-beta.0",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.16.1-beta.2",
|
||||
"version": "0.18.0-beta.0",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
"ann"
|
||||
],
|
||||
"private": false,
|
||||
"version": "0.16.1-beta.2",
|
||||
"version": "0.18.0-beta.0",
|
||||
"main": "dist/index.js",
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.19.1-beta.3"
|
||||
current_version = "0.21.0-beta.1"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.19.1-beta.3"
|
||||
version = "0.21.0-beta.1"
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
license.workspace = true
|
||||
|
||||
@@ -4,8 +4,8 @@ name = "lancedb"
|
||||
dynamic = ["version"]
|
||||
dependencies = [
|
||||
"deprecation",
|
||||
"pylance==0.23.0",
|
||||
"tqdm>=4.27.0",
|
||||
"pyarrow>=14",
|
||||
"pydantic>=1.10",
|
||||
"packaging",
|
||||
"overrides>=0.7",
|
||||
@@ -54,8 +54,14 @@ tests = [
|
||||
"polars>=0.19, <=1.3.0",
|
||||
"tantivy",
|
||||
"pyarrow-stubs",
|
||||
"pylance~=0.23.2",
|
||||
]
|
||||
dev = [
|
||||
"ruff",
|
||||
"pre-commit",
|
||||
"pyright",
|
||||
'typing-extensions>=4.0.0; python_version < "3.11"',
|
||||
]
|
||||
dev = ["ruff", "pre-commit", "pyright", 'typing-extensions>=4.0.0; python_version < "3.11"']
|
||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||
clip = ["torch", "pillow", "open-clip"]
|
||||
embeddings = [
|
||||
|
||||
@@ -142,6 +142,10 @@ class CompactionStats:
|
||||
files_removed: int
|
||||
files_added: int
|
||||
|
||||
class CleanupStats:
|
||||
bytes_removed: int
|
||||
old_versions: int
|
||||
|
||||
class RemovalStats:
|
||||
bytes_removed: int
|
||||
old_versions_removed: int
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import copy
|
||||
from typing import List, Union
|
||||
|
||||
from lancedb.util import add_note
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
@@ -28,13 +30,67 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
7 # Setting 0 disables retires. Maybe this should not be enabled by default,
|
||||
)
|
||||
_ndims: int = PrivateAttr()
|
||||
_original_args: dict = PrivateAttr()
|
||||
|
||||
@classmethod
|
||||
def create(cls, **kwargs):
|
||||
"""
|
||||
Create an instance of the embedding function
|
||||
"""
|
||||
return cls(**kwargs)
|
||||
resolved_kwargs = cls.__resolveVariables(kwargs)
|
||||
instance = cls(**resolved_kwargs)
|
||||
instance._original_args = kwargs
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def __resolveVariables(cls, args: dict) -> dict:
|
||||
"""
|
||||
Resolve variables in the args
|
||||
"""
|
||||
from .registry import EmbeddingFunctionRegistry
|
||||
|
||||
new_args = copy.deepcopy(args)
|
||||
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
sensitive_keys = cls.sensitive_keys()
|
||||
for k, v in new_args.items():
|
||||
if isinstance(v, str) and not v.startswith("$var:") and k in sensitive_keys:
|
||||
exc = ValueError(
|
||||
f"Sensitive key '{k}' cannot be set to a hardcoded value"
|
||||
)
|
||||
add_note(exc, "Help: Use $var: to set sensitive keys to variables")
|
||||
raise exc
|
||||
|
||||
if isinstance(v, str) and v.startswith("$var:"):
|
||||
parts = v[5:].split(":", maxsplit=1)
|
||||
if len(parts) == 1:
|
||||
try:
|
||||
new_args[k] = registry.get_var(parts[0])
|
||||
except KeyError:
|
||||
exc = ValueError(
|
||||
"Variable '{}' not found in registry".format(parts[0])
|
||||
)
|
||||
add_note(
|
||||
exc,
|
||||
"Help: Variables are reset in new Python sessions. "
|
||||
"Use `registry.set_var` to set variables.",
|
||||
)
|
||||
raise exc
|
||||
else:
|
||||
name, default = parts
|
||||
try:
|
||||
new_args[k] = registry.get_var(name)
|
||||
except KeyError:
|
||||
new_args[k] = default
|
||||
return new_args
|
||||
|
||||
@staticmethod
|
||||
def sensitive_keys() -> List[str]:
|
||||
"""
|
||||
Return a list of keys that are sensitive and should not be allowed
|
||||
to be set to hardcoded values in the config. For example, API keys.
|
||||
"""
|
||||
return []
|
||||
|
||||
@abstractmethod
|
||||
def compute_query_embeddings(self, *args, **kwargs) -> list[Union[np.array, None]]:
|
||||
@@ -103,17 +159,11 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
return texts
|
||||
|
||||
def safe_model_dump(self):
|
||||
from ..pydantic import PYDANTIC_VERSION
|
||||
|
||||
if PYDANTIC_VERSION.major < 2:
|
||||
return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
|
||||
return self.model_dump(
|
||||
exclude={
|
||||
field_name
|
||||
for field_name in self.model_fields
|
||||
if field_name.startswith("_")
|
||||
}
|
||||
)
|
||||
if not hasattr(self, "_original_args"):
|
||||
raise ValueError(
|
||||
"EmbeddingFunction was not created with EmbeddingFunction.create()"
|
||||
)
|
||||
return self._original_args
|
||||
|
||||
@abstractmethod
|
||||
def ndims(self) -> int:
|
||||
|
||||
@@ -57,6 +57,10 @@ class JinaEmbeddings(EmbeddingFunction):
|
||||
# TODO: fix hardcoding
|
||||
return 768
|
||||
|
||||
@staticmethod
|
||||
def sensitive_keys() -> List[str]:
|
||||
return ["api_key"]
|
||||
|
||||
def sanitize_input(
|
||||
self, inputs: Union[TEXT, IMAGES]
|
||||
) -> Union[List[Any], np.ndarray]:
|
||||
|
||||
@@ -54,6 +54,10 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||
def ndims(self):
|
||||
return self._ndims
|
||||
|
||||
@staticmethod
|
||||
def sensitive_keys():
|
||||
return ["api_key"]
|
||||
|
||||
@staticmethod
|
||||
def model_names():
|
||||
return [
|
||||
|
||||
@@ -41,6 +41,7 @@ class EmbeddingFunctionRegistry:
|
||||
|
||||
def __init__(self):
|
||||
self._functions = {}
|
||||
self._variables = {}
|
||||
|
||||
def register(self, alias: str = None):
|
||||
"""
|
||||
@@ -156,6 +157,28 @@ class EmbeddingFunctionRegistry:
|
||||
metadata = json.dumps(json_data, indent=2).encode("utf-8")
|
||||
return {"embedding_functions": metadata}
|
||||
|
||||
def set_var(self, name: str, value: str) -> None:
|
||||
"""
|
||||
Set a variable. These can be accessed in embedding configuration using
|
||||
the syntax `$var:variable_name`. If they are not set, an error will be
|
||||
thrown letting you know which variable is missing. If you want to supply
|
||||
a default value, you can add an additional part in the configuration
|
||||
like so: `$var:variable_name:default_value`. Default values can be
|
||||
used for runtime configurations that are not sensitive, such as
|
||||
whether to use a GPU for inference.
|
||||
|
||||
The name must not contain a colon. Default values can contain colons.
|
||||
"""
|
||||
if ":" in name:
|
||||
raise ValueError("Variable names cannot contain colons")
|
||||
self._variables[name] = value
|
||||
|
||||
def get_var(self, name: str) -> str:
|
||||
"""
|
||||
Get a variable.
|
||||
"""
|
||||
return self._variables[name]
|
||||
|
||||
|
||||
# Global instance
|
||||
__REGISTRY__ = EmbeddingFunctionRegistry()
|
||||
|
||||
@@ -40,6 +40,10 @@ class WatsonxEmbeddings(TextEmbeddingFunction):
|
||||
url: Optional[str] = None
|
||||
params: Optional[Dict] = None
|
||||
|
||||
@staticmethod
|
||||
def sensitive_keys():
|
||||
return ["api_key"]
|
||||
|
||||
@staticmethod
|
||||
def model_names():
|
||||
return [
|
||||
|
||||
@@ -259,7 +259,8 @@ def _pydantic_to_field(name: str, field: FieldInfo) -> pa.Field:
|
||||
|
||||
|
||||
def pydantic_to_schema(model: Type[pydantic.BaseModel]) -> pa.Schema:
|
||||
"""Convert a Pydantic model to a PyArrow Schema.
|
||||
"""Convert a [Pydantic Model][pydantic.BaseModel] to a
|
||||
[PyArrow Schema][pyarrow.Schema].
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -269,24 +270,25 @@ def pydantic_to_schema(model: Type[pydantic.BaseModel]) -> pa.Schema:
|
||||
Returns
|
||||
-------
|
||||
pyarrow.Schema
|
||||
The Arrow Schema
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> from typing import List, Optional
|
||||
>>> import pydantic
|
||||
>>> from lancedb.pydantic import pydantic_to_schema
|
||||
>>> from lancedb.pydantic import pydantic_to_schema, Vector
|
||||
>>> class FooModel(pydantic.BaseModel):
|
||||
... id: int
|
||||
... s: str
|
||||
... vec: List[float]
|
||||
... vec: Vector(1536) # fixed_size_list<item: float32>[1536]
|
||||
... li: List[int]
|
||||
...
|
||||
>>> schema = pydantic_to_schema(FooModel)
|
||||
>>> assert schema == pa.schema([
|
||||
... pa.field("id", pa.int64(), False),
|
||||
... pa.field("s", pa.utf8(), False),
|
||||
... pa.field("vec", pa.list_(pa.float64()), False),
|
||||
... pa.field("vec", pa.list_(pa.float32(), 1536)),
|
||||
... pa.field("li", pa.list_(pa.int64()), False),
|
||||
... ])
|
||||
"""
|
||||
@@ -308,7 +310,7 @@ class LanceModel(pydantic.BaseModel):
|
||||
... vector: Vector(2)
|
||||
...
|
||||
>>> db = lancedb.connect("./example")
|
||||
>>> table = db.create_table("test", schema=TestModel.to_arrow_schema())
|
||||
>>> table = db.create_table("test", schema=TestModel)
|
||||
>>> table.add([
|
||||
... TestModel(name="test", vector=[1.0, 2.0])
|
||||
... ])
|
||||
|
||||
@@ -110,7 +110,7 @@ class Query(pydantic.BaseModel):
|
||||
full_text_query: Optional[Union[str, dict]] = None
|
||||
|
||||
# top k results to return
|
||||
k: int
|
||||
k: Optional[int] = None
|
||||
|
||||
# # metrics
|
||||
metric: str = "L2"
|
||||
@@ -257,7 +257,7 @@ class LanceQueryBuilder(ABC):
|
||||
|
||||
def __init__(self, table: "Table"):
|
||||
self._table = table
|
||||
self._limit = 10
|
||||
self._limit = None
|
||||
self._offset = 0
|
||||
self._columns = None
|
||||
self._where = None
|
||||
@@ -370,8 +370,7 @@ class LanceQueryBuilder(ABC):
|
||||
The maximum number of results to return.
|
||||
The default query limit is 10 results.
|
||||
For ANN/KNN queries, you must specify a limit.
|
||||
Entering 0, a negative number, or None will reset
|
||||
the limit to the default value of 10.
|
||||
For plain searches, all records are returned if limit not set.
|
||||
*WARNING* if you have a large dataset, setting
|
||||
the limit to a large number, e.g. the table size,
|
||||
can potentially result in reading a
|
||||
@@ -595,6 +594,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
fast_search: bool = False,
|
||||
):
|
||||
super().__init__(table)
|
||||
if self._limit is None:
|
||||
self._limit = 10
|
||||
self._query = query
|
||||
self._distance_type = "L2"
|
||||
self._nprobes = 20
|
||||
@@ -888,6 +889,8 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
fts_columns: Union[str, List[str]] = [],
|
||||
):
|
||||
super().__init__(table)
|
||||
if self._limit is None:
|
||||
self._limit = 10
|
||||
self._query = query
|
||||
self._phrase_query = False
|
||||
self.ordering_field_name = ordering_field_name
|
||||
@@ -1055,7 +1058,7 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
||||
query = Query(
|
||||
columns=self._columns,
|
||||
filter=self._where,
|
||||
k=self._limit or 10,
|
||||
k=self._limit,
|
||||
with_row_id=self._with_row_id,
|
||||
vector=[],
|
||||
# not actually respected in remote query
|
||||
|
||||
@@ -3,7 +3,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import deprecation
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
@@ -23,15 +25,15 @@ from typing import (
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import lance
|
||||
from . import __version__
|
||||
from lancedb.arrow import peek_reader
|
||||
from lancedb.background_loop import LOOP
|
||||
from .dependencies import _check_for_pandas
|
||||
from .dependencies import _check_for_hugging_face, _check_for_pandas
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow.fs as pa_fs
|
||||
import numpy as np
|
||||
from lance import LanceDataset
|
||||
from lance.dependencies import _check_for_hugging_face
|
||||
|
||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
@@ -39,6 +41,8 @@ from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
|
||||
from .merge import LanceMergeInsertBuilder
|
||||
from .pydantic import LanceModel, model_to_dict
|
||||
from .query import (
|
||||
AsyncFTSQuery,
|
||||
AsyncHybridQuery,
|
||||
AsyncQuery,
|
||||
AsyncVectorQuery,
|
||||
LanceEmptyQueryBuilder,
|
||||
@@ -62,10 +66,14 @@ from .index import lang_mapping
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._lancedb import Table as LanceDBTable, OptimizeStats, CompactionStats
|
||||
from ._lancedb import (
|
||||
Table as LanceDBTable,
|
||||
OptimizeStats,
|
||||
CleanupStats,
|
||||
CompactionStats,
|
||||
)
|
||||
from .db import LanceDBConnection
|
||||
from .index import IndexConfig
|
||||
from lance.dataset import CleanupStats, ReaderLike
|
||||
import pandas
|
||||
import PIL
|
||||
|
||||
@@ -76,10 +84,9 @@ QueryType = Literal["vector", "fts", "hybrid", "auto"]
|
||||
|
||||
|
||||
def _into_pyarrow_reader(data) -> pa.RecordBatchReader:
|
||||
if _check_for_hugging_face(data):
|
||||
# Huggingface datasets
|
||||
from lance.dependencies import datasets
|
||||
from lancedb.dependencies import datasets
|
||||
|
||||
if _check_for_hugging_face(data):
|
||||
if isinstance(data, datasets.Dataset):
|
||||
schema = data.features.arrow_schema
|
||||
return pa.RecordBatchReader.from_batches(schema, data.data.to_batches())
|
||||
@@ -1070,7 +1077,7 @@ class Table(ABC):
|
||||
older_than: Optional[timedelta] = None,
|
||||
*,
|
||||
delete_unverified: bool = False,
|
||||
) -> CleanupStats:
|
||||
) -> "CleanupStats":
|
||||
"""
|
||||
Clean up old versions of the table, freeing disk space.
|
||||
|
||||
@@ -1381,6 +1388,14 @@ class LanceTable(Table):
|
||||
|
||||
def to_lance(self, **kwargs) -> LanceDataset:
|
||||
"""Return the LanceDataset backing this table."""
|
||||
try:
|
||||
import lance
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The lance library is required to use this function. "
|
||||
"Please install with `pip install pylance`."
|
||||
)
|
||||
|
||||
return lance.dataset(
|
||||
self._dataset_path,
|
||||
version=self.version,
|
||||
@@ -1840,7 +1855,7 @@ class LanceTable(Table):
|
||||
|
||||
def merge(
|
||||
self,
|
||||
other_table: Union[LanceTable, ReaderLike],
|
||||
other_table: Union[LanceTable, DATA],
|
||||
left_on: str,
|
||||
right_on: Optional[str] = None,
|
||||
schema: Optional[Union[pa.Schema, LanceModel]] = None,
|
||||
@@ -1890,12 +1905,13 @@ class LanceTable(Table):
|
||||
1 2 b e
|
||||
2 3 c f
|
||||
"""
|
||||
if isinstance(schema, LanceModel):
|
||||
schema = schema.to_arrow_schema()
|
||||
if isinstance(other_table, LanceTable):
|
||||
other_table = other_table.to_lance()
|
||||
if isinstance(other_table, LanceDataset):
|
||||
other_table = other_table.to_table()
|
||||
else:
|
||||
other_table = _sanitize_data(
|
||||
other_table,
|
||||
schema,
|
||||
)
|
||||
self.to_lance().merge(
|
||||
other_table, left_on=left_on, right_on=right_on, schema=schema
|
||||
)
|
||||
@@ -2218,12 +2234,17 @@ class LanceTable(Table):
|
||||
):
|
||||
LOOP.run(self._table._do_merge(merge, new_data, on_bad_vectors, fill_value))
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.21.0",
|
||||
current_version=__version__,
|
||||
details="Use `Table.optimize` instead.",
|
||||
)
|
||||
def cleanup_old_versions(
|
||||
self,
|
||||
older_than: Optional[timedelta] = None,
|
||||
*,
|
||||
delete_unverified: bool = False,
|
||||
) -> CleanupStats:
|
||||
) -> "CleanupStats":
|
||||
"""
|
||||
Clean up old versions of the table, freeing disk space.
|
||||
|
||||
@@ -2248,6 +2269,11 @@ class LanceTable(Table):
|
||||
older_than, delete_unverified=delete_unverified
|
||||
)
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.21.0",
|
||||
current_version=__version__,
|
||||
details="Use `Table.optimize` instead.",
|
||||
)
|
||||
def compact_files(self, *args, **kwargs) -> CompactionStats:
|
||||
"""
|
||||
Run the compaction process on the table.
|
||||
@@ -2379,6 +2405,19 @@ class LanceTable(Table):
|
||||
"""
|
||||
LOOP.run(self._table.migrate_v2_manifest_paths())
|
||||
|
||||
def replace_field_metadata(self, field_name: str, new_metadata: Dict[str, str]):
|
||||
"""
|
||||
Replace the metadata of a field in the schema
|
||||
|
||||
Parameters
|
||||
----------
|
||||
field_name: str
|
||||
The name of the field to replace the metadata for
|
||||
new_metadata: dict
|
||||
The new metadata to set
|
||||
"""
|
||||
LOOP.run(self._table.replace_field_metadata(field_name, new_metadata))
|
||||
|
||||
|
||||
def _handle_bad_vectors(
|
||||
reader: pa.RecordBatchReader,
|
||||
@@ -2679,7 +2718,7 @@ class AsyncTable:
|
||||
self.close()
|
||||
|
||||
def is_open(self) -> bool:
|
||||
"""Return True if the table is closed."""
|
||||
"""Return True if the table is open."""
|
||||
return self._inner.is_open()
|
||||
|
||||
def close(self):
|
||||
@@ -2702,6 +2741,19 @@ class AsyncTable:
|
||||
"""
|
||||
return await self._inner.schema()
|
||||
|
||||
async def embedding_functions(self) -> Dict[str, EmbeddingFunctionConfig]:
|
||||
"""
|
||||
Get the embedding functions for the table
|
||||
|
||||
Returns
|
||||
-------
|
||||
funcs: Dict[str, EmbeddingFunctionConfig]
|
||||
A mapping of the vector column to the embedding function
|
||||
or empty dict if not configured.
|
||||
"""
|
||||
schema = await self.schema()
|
||||
return EmbeddingFunctionRegistry.get_instance().parse_functions(schema.metadata)
|
||||
|
||||
async def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
"""
|
||||
Count the number of rows in the table.
|
||||
@@ -2931,6 +2983,234 @@ class AsyncTable:
|
||||
|
||||
return LanceMergeInsertBuilder(self, on)
|
||||
|
||||
@overload
|
||||
async def search(
|
||||
self,
|
||||
query: Optional[Union[str]] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: Literal["auto"] = ...,
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
) -> Union[AsyncHybridQuery | AsyncFTSQuery | AsyncVectorQuery]: ...
|
||||
|
||||
@overload
|
||||
async def search(
|
||||
self,
|
||||
query: Optional[Union[str]] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: Literal["hybrid"] = ...,
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
) -> AsyncHybridQuery: ...
|
||||
|
||||
@overload
|
||||
async def search(
|
||||
self,
|
||||
query: Optional[Union[VEC, "PIL.Image.Image", Tuple]] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: Literal["auto"] = ...,
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
) -> AsyncVectorQuery: ...
|
||||
|
||||
@overload
|
||||
async def search(
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: Literal["fts"] = ...,
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
) -> AsyncFTSQuery: ...
|
||||
|
||||
@overload
|
||||
async def search(
|
||||
self,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: Literal["vector"] = ...,
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
) -> AsyncVectorQuery: ...
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: QueryType = "auto",
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
) -> AsyncQuery:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector. We currently support [vector search][search]
|
||||
and [full-text search][experimental-full-text-search].
|
||||
|
||||
All query options are defined in [AsyncQuery][lancedb.query.AsyncQuery].
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: list/np.ndarray/str/PIL.Image.Image, default None
|
||||
The targetted vector to search for.
|
||||
|
||||
- *default None*.
|
||||
Acceptable types are: list, np.ndarray, PIL.Image.Image
|
||||
|
||||
- If None then the select/where/limit clauses are applied to filter
|
||||
the table
|
||||
vector_column_name: str, optional
|
||||
The name of the vector column to search.
|
||||
|
||||
The vector column needs to be a pyarrow fixed size list type
|
||||
|
||||
- If not specified then the vector column is inferred from
|
||||
the table schema
|
||||
|
||||
- If the table has multiple vector columns then the *vector_column_name*
|
||||
needs to be specified. Otherwise, an error is raised.
|
||||
query_type: str
|
||||
*default "auto"*.
|
||||
Acceptable types are: "vector", "fts", "hybrid", or "auto"
|
||||
|
||||
- If "auto" then the query type is inferred from the query;
|
||||
|
||||
- If `query` is a list/np.ndarray then the query type is
|
||||
"vector";
|
||||
|
||||
- If `query` is a PIL.Image.Image then either do vector search,
|
||||
or raise an error if no corresponding embedding function is found.
|
||||
|
||||
- If `query` is a string, then the query type is "vector" if the
|
||||
table has embedding functions else the query type is "fts"
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
A query builder object representing the query.
|
||||
"""
|
||||
|
||||
def is_embedding(query):
|
||||
return isinstance(query, (list, np.ndarray, pa.Array, pa.ChunkedArray))
|
||||
|
||||
async def get_embedding_func(
|
||||
vector_column_name: Optional[str],
|
||||
query_type: QueryType,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]],
|
||||
) -> Tuple[str, EmbeddingFunctionConfig]:
|
||||
schema = await self.schema()
|
||||
vector_column_name = infer_vector_column_name(
|
||||
schema=schema,
|
||||
query_type=query_type,
|
||||
query=query,
|
||||
vector_column_name=vector_column_name,
|
||||
)
|
||||
funcs = EmbeddingFunctionRegistry.get_instance().parse_functions(
|
||||
schema.metadata
|
||||
)
|
||||
func = funcs.get(vector_column_name)
|
||||
if func is None:
|
||||
error = ValueError(
|
||||
f"Column '{vector_column_name}' has no registered "
|
||||
"embedding function."
|
||||
)
|
||||
if len(funcs) > 0:
|
||||
add_note(
|
||||
error,
|
||||
"Embedding functions are registered for columns: "
|
||||
f"{list(funcs.keys())}",
|
||||
)
|
||||
else:
|
||||
add_note(
|
||||
error, "No embedding functions are registered for any columns."
|
||||
)
|
||||
raise error
|
||||
return vector_column_name, func
|
||||
|
||||
async def make_embedding(embedding, query):
|
||||
if embedding is not None:
|
||||
loop = asyncio.get_running_loop()
|
||||
# This function is likely to block, since it either calls an expensive
|
||||
# function or makes an HTTP request to an embeddings REST API.
|
||||
return (
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
embedding.function.compute_query_embeddings_with_retry,
|
||||
query,
|
||||
)
|
||||
)[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
if query_type == "auto":
|
||||
# Infer the query type.
|
||||
if is_embedding(query):
|
||||
vector_query = query
|
||||
query_type = "vector"
|
||||
elif isinstance(query, str):
|
||||
try:
|
||||
(
|
||||
indices,
|
||||
(vector_column_name, embedding_conf),
|
||||
) = await asyncio.gather(
|
||||
self.list_indices(),
|
||||
get_embedding_func(vector_column_name, "auto", query),
|
||||
)
|
||||
except ValueError as e:
|
||||
if "Column" in str(
|
||||
e
|
||||
) and "has no registered embedding function" in str(e):
|
||||
# If the column has no registered embedding function,
|
||||
# then it's an FTS query.
|
||||
query_type = "fts"
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
if embedding_conf is not None:
|
||||
vector_query = await make_embedding(embedding_conf, query)
|
||||
if any(
|
||||
i.columns[0] == embedding_conf.source_column
|
||||
and i.index_type == "FTS"
|
||||
for i in indices
|
||||
):
|
||||
query_type = "hybrid"
|
||||
else:
|
||||
query_type = "vector"
|
||||
else:
|
||||
query_type = "fts"
|
||||
else:
|
||||
# it's an image or something else embeddable.
|
||||
query_type = "vector"
|
||||
elif query_type == "vector":
|
||||
if is_embedding(query):
|
||||
vector_query = query
|
||||
else:
|
||||
vector_column_name, embedding_conf = await get_embedding_func(
|
||||
vector_column_name, query_type, query
|
||||
)
|
||||
vector_query = await make_embedding(embedding_conf, query)
|
||||
elif query_type == "hybrid":
|
||||
if is_embedding(query):
|
||||
raise ValueError("Hybrid search requires a text query")
|
||||
else:
|
||||
vector_column_name, embedding_conf = await get_embedding_func(
|
||||
vector_column_name, query_type, query
|
||||
)
|
||||
vector_query = await make_embedding(embedding_conf, query)
|
||||
|
||||
if query_type == "vector":
|
||||
builder = self.query().nearest_to(vector_query)
|
||||
if vector_column_name:
|
||||
builder = builder.column(vector_column_name)
|
||||
return builder
|
||||
elif query_type == "fts":
|
||||
return self.query().nearest_to_text(query, columns=fts_columns or [])
|
||||
elif query_type == "hybrid":
|
||||
builder = self.query().nearest_to(vector_query)
|
||||
if vector_column_name:
|
||||
builder = builder.column(vector_column_name)
|
||||
return builder.nearest_to_text(query, columns=fts_columns or [])
|
||||
else:
|
||||
raise ValueError(f"Unknown query type: '{query_type}'")
|
||||
|
||||
def vector_search(
|
||||
self,
|
||||
query_vector: Union[VEC, Tuple],
|
||||
@@ -2950,7 +3230,9 @@ class AsyncTable:
|
||||
# The sync remote table calls into this method, so we need to map the
|
||||
# query to the async version of the query and run that here. This is only
|
||||
# used for that code path right now.
|
||||
async_query = self.query().limit(query.k)
|
||||
async_query = self.query()
|
||||
if query.k is not None:
|
||||
async_query = async_query.limit(query.k)
|
||||
if query.offset > 0:
|
||||
async_query = async_query.offset(query.offset)
|
||||
if query.columns:
|
||||
@@ -3366,6 +3648,21 @@ class AsyncTable:
|
||||
"""
|
||||
await self._inner.migrate_manifest_paths_v2()
|
||||
|
||||
async def replace_field_metadata(
|
||||
self, field_name: str, new_metadata: dict[str, str]
|
||||
):
|
||||
"""
|
||||
Replace the metadata of a field in the schema
|
||||
|
||||
Parameters
|
||||
----------
|
||||
field_name: str
|
||||
The name of the field to replace the metadata for
|
||||
new_metadata: dict
|
||||
The new metadata to set
|
||||
"""
|
||||
await self._inner.replace_field_metadata(field_name, new_metadata)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexStatistics:
|
||||
|
||||
@@ -75,6 +75,6 @@ async def test_binary_vector_async():
|
||||
|
||||
query = np.random.randint(0, 2, size=256)
|
||||
packed_query = np.packbits(query)
|
||||
await tbl.query().nearest_to(packed_query).distance_type("hamming").to_arrow()
|
||||
await (await tbl.search(packed_query)).distance_type("hamming").to_arrow()
|
||||
# --8<-- [end:async_binary_vector]
|
||||
await db.drop_table("my_binary_vectors")
|
||||
|
||||
@@ -53,13 +53,13 @@ async def test_binary_vector_async():
|
||||
query = np.random.random(256)
|
||||
|
||||
# Search for the vectors within the range of [0.1, 0.5)
|
||||
await tbl.query().nearest_to(query).distance_range(0.1, 0.5).to_arrow()
|
||||
await (await tbl.search(query)).distance_range(0.1, 0.5).to_arrow()
|
||||
|
||||
# Search for the vectors with the distance less than 0.5
|
||||
await tbl.query().nearest_to(query).distance_range(upper_bound=0.5).to_arrow()
|
||||
await (await tbl.search(query)).distance_range(upper_bound=0.5).to_arrow()
|
||||
|
||||
# Search for the vectors with the distance greater or equal to 0.1
|
||||
await tbl.query().nearest_to(query).distance_range(lower_bound=0.1).to_arrow()
|
||||
await (await tbl.search(query)).distance_range(lower_bound=0.1).to_arrow()
|
||||
|
||||
# --8<-- [end:async_distance_range]
|
||||
await db.drop_table("my_table")
|
||||
|
||||
@@ -28,3 +28,49 @@ def test_embeddings_openai():
|
||||
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||
print(actual.text)
|
||||
# --8<-- [end:openai_embeddings]
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_embeddings_openai_async():
|
||||
uri = "memory://"
|
||||
# --8<-- [start:async_openai_embeddings]
|
||||
db = await lancedb.connect_async(uri)
|
||||
func = get_registry().get("openai").create(name="text-embedding-ada-002")
|
||||
|
||||
class Words(LanceModel):
|
||||
text: str = func.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
table = await db.create_table("words", schema=Words, mode="overwrite")
|
||||
await table.add([{"text": "hello world"}, {"text": "goodbye world"}])
|
||||
|
||||
query = "greetings"
|
||||
actual = await (await table.search(query)).limit(1).to_pydantic(Words)[0]
|
||||
print(actual.text)
|
||||
# --8<-- [end:async_openai_embeddings]
|
||||
|
||||
|
||||
def test_embeddings_secret():
|
||||
# --8<-- [start:register_secret]
|
||||
registry = get_registry()
|
||||
registry.set_var("api_key", "sk-...")
|
||||
|
||||
func = registry.get("openai").create(api_key="$var:api_key")
|
||||
# --8<-- [end:register_secret]
|
||||
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
pytest.skip("torch not installed")
|
||||
|
||||
# --8<-- [start:register_device]
|
||||
import torch
|
||||
|
||||
registry = get_registry()
|
||||
if torch.cuda.is_available():
|
||||
registry.set_var("device", "cuda")
|
||||
|
||||
func = registry.get("huggingface").create(device="$var:device:cpu")
|
||||
# --8<-- [end:register_device]
|
||||
assert func.device == "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
@@ -72,8 +72,7 @@ async def test_ann_index_async():
|
||||
# --8<-- [end:create_ann_index_async]
|
||||
# --8<-- [start:vector_search_async]
|
||||
await (
|
||||
async_tbl.query()
|
||||
.nearest_to(np.random.random((32)))
|
||||
(await async_tbl.search(np.random.random((32))))
|
||||
.limit(2)
|
||||
.nprobes(20)
|
||||
.refine_factor(10)
|
||||
@@ -82,18 +81,14 @@ async def test_ann_index_async():
|
||||
# --8<-- [end:vector_search_async]
|
||||
# --8<-- [start:vector_search_async_with_filter]
|
||||
await (
|
||||
async_tbl.query()
|
||||
.nearest_to(np.random.random((32)))
|
||||
(await async_tbl.search(np.random.random((32))))
|
||||
.where("item != 'item 1141'")
|
||||
.to_pandas()
|
||||
)
|
||||
# --8<-- [end:vector_search_async_with_filter]
|
||||
# --8<-- [start:vector_search_async_with_select]
|
||||
await (
|
||||
async_tbl.query()
|
||||
.nearest_to(np.random.random((32)))
|
||||
.select(["vector"])
|
||||
.to_pandas()
|
||||
(await async_tbl.search(np.random.random((32)))).select(["vector"]).to_pandas()
|
||||
)
|
||||
# --8<-- [end:vector_search_async_with_select]
|
||||
|
||||
@@ -164,7 +159,7 @@ async def test_scalar_index_async():
|
||||
{"book_id": 3, "vector": [5.0, 6]},
|
||||
]
|
||||
async_tbl = await async_db.create_table("book_with_embeddings_async", data)
|
||||
(await async_tbl.query().where("book_id != 3").nearest_to([1, 2]).to_pandas())
|
||||
(await (await async_tbl.search([1, 2])).where("book_id != 3").to_pandas())
|
||||
# --8<-- [end:vector_search_with_scalar_index_async]
|
||||
# --8<-- [start:update_scalar_index_async]
|
||||
await async_tbl.add([{"vector": [7, 8], "book_id": 4}])
|
||||
|
||||
36
python/python/tests/docs/test_pydantic_integration.py
Normal file
36
python/python/tests/docs/test_pydantic_integration.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
# --8<-- [start:imports]
|
||||
import lancedb
|
||||
from lancedb.pydantic import Vector, LanceModel
|
||||
# --8<-- [end:imports]
|
||||
|
||||
|
||||
def test_pydantic_model(tmp_path):
|
||||
# --8<-- [start:base_model]
|
||||
class PersonModel(LanceModel):
|
||||
name: str
|
||||
age: int
|
||||
vector: Vector(2)
|
||||
|
||||
# --8<-- [end:base_model]
|
||||
|
||||
# --8<-- [start:set_url]
|
||||
url = "./example"
|
||||
# --8<-- [end:set_url]
|
||||
url = tmp_path
|
||||
|
||||
# --8<-- [start:base_example]
|
||||
db = lancedb.connect(url)
|
||||
table = db.create_table("person", schema=PersonModel)
|
||||
table.add(
|
||||
[
|
||||
PersonModel(name="bob", age=1, vector=[1.0, 2.0]),
|
||||
PersonModel(name="alice", age=2, vector=[3.0, 4.0]),
|
||||
]
|
||||
)
|
||||
assert table.count_rows() == 2
|
||||
person = table.search([0.0, 0.0]).limit(1).to_pydantic(PersonModel)
|
||||
assert person[0].name == "bob"
|
||||
# --8<-- [end:base_example]
|
||||
@@ -126,19 +126,17 @@ async def test_pandas_and_pyarrow_async():
|
||||
|
||||
query_vector = [100, 100]
|
||||
# Pandas DataFrame
|
||||
df = await async_tbl.query().nearest_to(query_vector).limit(1).to_pandas()
|
||||
df = await (await async_tbl.search(query_vector)).limit(1).to_pandas()
|
||||
print(df)
|
||||
# --8<-- [end:vector_search_async]
|
||||
# --8<-- [start:vector_search_with_filter_async]
|
||||
# Apply the filter via LanceDB
|
||||
results = (
|
||||
await async_tbl.query().nearest_to([100, 100]).where("price < 15").to_pandas()
|
||||
)
|
||||
results = await (await async_tbl.search([100, 100])).where("price < 15").to_pandas()
|
||||
assert len(results) == 1
|
||||
assert results["item"].iloc[0] == "foo"
|
||||
|
||||
# Apply the filter via Pandas
|
||||
df = results = await async_tbl.query().nearest_to([100, 100]).to_pandas()
|
||||
df = results = await (await async_tbl.search([100, 100])).to_pandas()
|
||||
results = df[df.price < 15]
|
||||
assert len(results) == 1
|
||||
assert results["item"].iloc[0] == "foo"
|
||||
@@ -188,3 +186,26 @@ def test_polars():
|
||||
# --8<-- [start:print_table_lazyform]
|
||||
print(ldf.first().collect())
|
||||
# --8<-- [end:print_table_lazyform]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_polars_async():
|
||||
uri = "data/sample-lancedb"
|
||||
db = await lancedb.connect_async(uri)
|
||||
|
||||
# --8<-- [start:create_table_polars_async]
|
||||
data = pl.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
"item": ["foo", "bar"],
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
table = await db.create_table("pl_table_async", data=data)
|
||||
# --8<-- [end:create_table_polars_async]
|
||||
# --8<-- [start:vector_search_polars_async]
|
||||
query = [3.0, 4.0]
|
||||
result = await (await table.search(query)).limit(1).to_polars()
|
||||
print(result)
|
||||
print(type(result))
|
||||
# --8<-- [end:vector_search_polars_async]
|
||||
|
||||
@@ -117,12 +117,11 @@ async def test_vector_search_async():
|
||||
for i, row in enumerate(np.random.random((10_000, 1536)).astype("float32"))
|
||||
]
|
||||
async_tbl = await async_db.create_table("vector_search_async", data=data)
|
||||
(await async_tbl.query().nearest_to(np.random.random((1536))).limit(10).to_list())
|
||||
(await (await async_tbl.search(np.random.random((1536)))).limit(10).to_list())
|
||||
# --8<-- [end:exhaustive_search_async]
|
||||
# --8<-- [start:exhaustive_search_async_cosine]
|
||||
(
|
||||
await async_tbl.query()
|
||||
.nearest_to(np.random.random((1536)))
|
||||
await (await async_tbl.search(np.random.random((1536))))
|
||||
.distance_type("cosine")
|
||||
.limit(10)
|
||||
.to_list()
|
||||
@@ -145,13 +144,13 @@ async def test_vector_search_async():
|
||||
async_tbl = await async_db.create_table("documents_async", data=data)
|
||||
# --8<-- [end:create_table_async_with_nested_schema]
|
||||
# --8<-- [start:search_result_async_as_pyarrow]
|
||||
await async_tbl.query().nearest_to(np.random.randn(1536)).to_arrow()
|
||||
await (await async_tbl.search(np.random.randn(1536))).to_arrow()
|
||||
# --8<-- [end:search_result_async_as_pyarrow]
|
||||
# --8<-- [start:search_result_async_as_pandas]
|
||||
await async_tbl.query().nearest_to(np.random.randn(1536)).to_pandas()
|
||||
await (await async_tbl.search(np.random.randn(1536))).to_pandas()
|
||||
# --8<-- [end:search_result_async_as_pandas]
|
||||
# --8<-- [start:search_result_async_as_list]
|
||||
await async_tbl.query().nearest_to(np.random.randn(1536)).to_list()
|
||||
await (await async_tbl.search(np.random.randn(1536))).to_list()
|
||||
# --8<-- [end:search_result_async_as_list]
|
||||
|
||||
|
||||
@@ -219,9 +218,7 @@ async def test_fts_native_async():
|
||||
|
||||
# async API uses our native FTS algorithm
|
||||
await async_tbl.create_index("text", config=FTS())
|
||||
await (
|
||||
async_tbl.query().nearest_to_text("puppy").select(["text"]).limit(10).to_list()
|
||||
)
|
||||
await (await async_tbl.search("puppy")).select(["text"]).limit(10).to_list()
|
||||
# [{'text': 'Frodo was a happy puppy', '_score': 0.6931471824645996}]
|
||||
# ...
|
||||
# --8<-- [end:basic_fts_async]
|
||||
@@ -235,18 +232,11 @@ async def test_fts_native_async():
|
||||
)
|
||||
# --8<-- [end:fts_config_folding_async]
|
||||
# --8<-- [start:fts_prefiltering_async]
|
||||
await (
|
||||
async_tbl.query()
|
||||
.nearest_to_text("puppy")
|
||||
.limit(10)
|
||||
.where("text='foo'")
|
||||
.to_list()
|
||||
)
|
||||
await (await async_tbl.search("puppy")).limit(10).where("text='foo'").to_list()
|
||||
# --8<-- [end:fts_prefiltering_async]
|
||||
# --8<-- [start:fts_postfiltering_async]
|
||||
await (
|
||||
async_tbl.query()
|
||||
.nearest_to_text("puppy")
|
||||
(await async_tbl.search("puppy"))
|
||||
.limit(10)
|
||||
.where("text='foo'")
|
||||
.postfilter()
|
||||
@@ -347,14 +337,8 @@ async def test_hybrid_search_async():
|
||||
# Create a fts index before the hybrid search
|
||||
await async_tbl.create_index("text", config=FTS())
|
||||
text_query = "flower moon"
|
||||
vector_query = embeddings.compute_query_embeddings(text_query)[0]
|
||||
# hybrid search with default re-ranker
|
||||
await (
|
||||
async_tbl.query()
|
||||
.nearest_to(vector_query)
|
||||
.nearest_to_text(text_query)
|
||||
.to_pandas()
|
||||
)
|
||||
await (await async_tbl.search("flower moon", query_type="hybrid")).to_pandas()
|
||||
# --8<-- [end:basic_hybrid_search_async]
|
||||
# --8<-- [start:hybrid_search_pass_vector_text_async]
|
||||
vector_query = [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from typing import List, Union
|
||||
import os
|
||||
from typing import List, Optional, Union
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import lance
|
||||
@@ -56,7 +57,7 @@ def test_embedding_function(tmp_path):
|
||||
conf = EmbeddingFunctionConfig(
|
||||
source_column="text",
|
||||
vector_column="vector",
|
||||
function=MockTextEmbeddingFunction(),
|
||||
function=MockTextEmbeddingFunction.create(),
|
||||
)
|
||||
metadata = registry.get_table_metadata([conf])
|
||||
table = table.replace_schema_metadata(metadata)
|
||||
@@ -80,6 +81,57 @@ def test_embedding_function(tmp_path):
|
||||
assert np.allclose(actual, expected)
|
||||
|
||||
|
||||
def test_embedding_function_variables():
|
||||
@register("variable-testing")
|
||||
class VariableTestingFunction(TextEmbeddingFunction):
|
||||
key1: str
|
||||
secret_key: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def sensitive_keys():
|
||||
return ["secret_key"]
|
||||
|
||||
def ndims():
|
||||
pass
|
||||
|
||||
def generate_embeddings(self, _texts):
|
||||
pass
|
||||
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
|
||||
# Should error if variable is not set
|
||||
with pytest.raises(ValueError, match="Variable 'test' not found"):
|
||||
registry.get("variable-testing").create(
|
||||
key1="$var:test",
|
||||
)
|
||||
|
||||
# Should use default values if not set
|
||||
func = registry.get("variable-testing").create(key1="$var:test:some_value")
|
||||
assert func.key1 == "some_value"
|
||||
|
||||
# Should set a variable that the embedding function understands
|
||||
registry.set_var("test", "some_value")
|
||||
func = registry.get("variable-testing").create(key1="$var:test")
|
||||
assert func.key1 == "some_value"
|
||||
|
||||
# Should reject secrets that aren't passed in as variables
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Sensitive key 'secret_key' cannot be set to a hardcoded value",
|
||||
):
|
||||
registry.get("variable-testing").create(
|
||||
key1="whatever", secret_key="some_value"
|
||||
)
|
||||
|
||||
# Should not serialize secrets.
|
||||
registry.set_var("secret", "secret_value")
|
||||
func = registry.get("variable-testing").create(
|
||||
key1="whatever", secret_key="$var:secret"
|
||||
)
|
||||
assert func.secret_key == "secret_value"
|
||||
assert func.safe_model_dump()["secret_key"] == "$var:secret"
|
||||
|
||||
|
||||
def test_embedding_with_bad_results(tmp_path):
|
||||
@register("null-embedding")
|
||||
class NullEmbeddingFunction(TextEmbeddingFunction):
|
||||
@@ -91,9 +143,11 @@ def test_embedding_with_bad_results(tmp_path):
|
||||
) -> list[Union[np.array, None]]:
|
||||
# Return None, which is bad if field is non-nullable
|
||||
a = [
|
||||
np.full(self.ndims(), np.nan)
|
||||
if i % 2 == 0
|
||||
else np.random.randn(self.ndims())
|
||||
(
|
||||
np.full(self.ndims(), np.nan)
|
||||
if i % 2 == 0
|
||||
else np.random.randn(self.ndims())
|
||||
)
|
||||
for i in range(len(texts))
|
||||
]
|
||||
return a
|
||||
@@ -359,7 +413,7 @@ def test_embedding_function_safe_model_dump(embedding_type):
|
||||
|
||||
# Note: Some embedding types might require specific parameters
|
||||
try:
|
||||
model = registry.get(embedding_type).create()
|
||||
model = registry.get(embedding_type).create({"max_retries": 1})
|
||||
except Exception as e:
|
||||
pytest.skip(f"Skipping {embedding_type} due to error: {str(e)}")
|
||||
|
||||
@@ -392,3 +446,33 @@ def test_retry(mock_sleep):
|
||||
result = test_function()
|
||||
assert mock_sleep.call_count == 9
|
||||
assert result == "result"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("OPENAI_API_KEY") is None, reason="OpenAI API key not set"
|
||||
)
|
||||
def test_openai_propagates_api_key(monkeypatch):
|
||||
# Make sure that if we set it as a variable, the API key is propagated
|
||||
api_key = os.environ["OPENAI_API_KEY"]
|
||||
monkeypatch.delenv("OPENAI_API_KEY")
|
||||
|
||||
uri = "memory://"
|
||||
registry = get_registry()
|
||||
registry.set_var("open_api_key", api_key)
|
||||
func = registry.get("openai").create(
|
||||
name="text-embedding-ada-002",
|
||||
max_retries=0,
|
||||
api_key="$var:open_api_key",
|
||||
)
|
||||
|
||||
class Words(LanceModel):
|
||||
text: str = func.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
db = lancedb.connect(uri)
|
||||
table = db.create_table("words", schema=Words, mode="overwrite")
|
||||
table.add([{"text": "hello world"}, {"text": "goodbye world"}])
|
||||
|
||||
query = "greetings"
|
||||
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||
assert len(actual.text) > 0
|
||||
|
||||
@@ -174,6 +174,10 @@ def test_search_fts(table, use_tantivy):
|
||||
assert len(results) == 5
|
||||
assert len(results[0]) == 3 # id, text, _score
|
||||
|
||||
# Default limit of 10
|
||||
results = table.search("puppy").select(["id", "text"]).to_list()
|
||||
assert len(results) == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fts_select_async(async_table):
|
||||
|
||||
@@ -1,25 +1,35 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from typing import List, Union
|
||||
import unittest.mock as mock
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import lancedb
|
||||
from lancedb.index import IvfPq, FTS
|
||||
from lancedb.rerankers.cross_encoder import CrossEncoderReranker
|
||||
from lancedb.db import AsyncConnection
|
||||
from lancedb.embeddings.base import TextEmbeddingFunction
|
||||
from lancedb.embeddings.registry import get_registry, register
|
||||
from lancedb.index import FTS, IvfPq
|
||||
import lancedb.pydantic
|
||||
import numpy as np
|
||||
import pandas.testing as tm
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.query import (
|
||||
AsyncFTSQuery,
|
||||
AsyncHybridQuery,
|
||||
AsyncQueryBase,
|
||||
AsyncVectorQuery,
|
||||
LanceVectorQueryBuilder,
|
||||
Query,
|
||||
)
|
||||
from lancedb.rerankers.cross_encoder import CrossEncoderReranker
|
||||
from lancedb.table import AsyncTable, LanceTable
|
||||
from utils import exception_output
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -716,3 +726,101 @@ async def test_query_with_f16(tmp_path: Path):
|
||||
tbl = await db.create_table("test", df)
|
||||
results = await tbl.vector_search([np.float16(1), np.float16(2)]).to_pandas()
|
||||
assert len(results) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_search_auto(mem_db_async: AsyncConnection):
|
||||
nrows = 1000
|
||||
data = pa.table(
|
||||
{
|
||||
"text": [str(i) for i in range(nrows)],
|
||||
}
|
||||
)
|
||||
|
||||
@register("test2")
|
||||
class TestEmbedding(TextEmbeddingFunction):
|
||||
def ndims(self):
|
||||
return 4
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
vec = np.array([float(text) / 1000] * self.ndims())
|
||||
embeddings.append(vec)
|
||||
return embeddings
|
||||
|
||||
registry = get_registry()
|
||||
func = registry.get("test2").create()
|
||||
|
||||
class TestModel(LanceModel):
|
||||
text: str = func.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
tbl = await mem_db_async.create_table("test", data, schema=TestModel)
|
||||
|
||||
funcs = await tbl.embedding_functions()
|
||||
assert len(funcs) == 1
|
||||
|
||||
# No FTS or vector index
|
||||
# Search for vector -> vector query
|
||||
q = [0.1] * 4
|
||||
query = await tbl.search(q)
|
||||
assert isinstance(query, AsyncVectorQuery)
|
||||
|
||||
# Search for string -> vector query
|
||||
query = await tbl.search("0.1")
|
||||
assert isinstance(query, AsyncVectorQuery)
|
||||
|
||||
await tbl.create_index("text", config=FTS())
|
||||
|
||||
query = await tbl.search("0.1")
|
||||
assert isinstance(query, AsyncHybridQuery)
|
||||
|
||||
data_with_vecs = await tbl.to_arrow()
|
||||
data_with_vecs = data_with_vecs.replace_schema_metadata(None)
|
||||
tbl2 = await mem_db_async.create_table("test2", data_with_vecs)
|
||||
with pytest.raises(
|
||||
Exception,
|
||||
match=(
|
||||
"Cannot perform full text search unless an INVERTED index has "
|
||||
"been created"
|
||||
),
|
||||
):
|
||||
query = await (await tbl2.search("0.1")).to_arrow()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_search_specified(mem_db_async: AsyncConnection):
|
||||
nrows, ndims = 1000, 16
|
||||
data = pa.table(
|
||||
{
|
||||
"text": [str(i) for i in range(nrows)],
|
||||
"vector": pa.FixedSizeListArray.from_arrays(
|
||||
pc.random(nrows * ndims).cast(pa.float32()), ndims
|
||||
),
|
||||
}
|
||||
)
|
||||
table = await mem_db_async.create_table("test", data)
|
||||
await table.create_index("text", config=FTS())
|
||||
|
||||
# Validate that specifying fts, vector or hybrid gets the right query.
|
||||
q = [0.1] * ndims
|
||||
query = await table.search(q, query_type="vector")
|
||||
assert isinstance(query, AsyncVectorQuery)
|
||||
|
||||
query = await table.search("0.1", query_type="fts")
|
||||
assert isinstance(query, AsyncFTSQuery)
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown query type: 'foo'"):
|
||||
await table.search("0.1", query_type="foo")
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Column 'vector' has no registered embedding function"
|
||||
) as e:
|
||||
await table.search("0.1", query_type="vector")
|
||||
|
||||
assert "No embedding functions are registered for any columns" in exception_output(
|
||||
e
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ import json
|
||||
import threading
|
||||
from unittest.mock import MagicMock
|
||||
import uuid
|
||||
from packaging.version import Version
|
||||
|
||||
import lancedb
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
@@ -277,11 +278,12 @@ def test_table_create_indices():
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def query_test_table(query_handler):
|
||||
def query_test_table(query_handler, *, server_version=Version("0.1.0")):
|
||||
def handler(request):
|
||||
if request.path == "/v1/table/test/describe/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.send_header("phalanx-version", str(server_version))
|
||||
request.end_headers()
|
||||
request.wfile.write(b"{}")
|
||||
elif request.path == "/v1/table/test/query/":
|
||||
@@ -338,6 +340,7 @@ def test_query_sync_empty_query():
|
||||
"filter": "true",
|
||||
"vector": [],
|
||||
"columns": ["id"],
|
||||
"prefilter": False,
|
||||
"version": None,
|
||||
}
|
||||
|
||||
@@ -387,11 +390,25 @@ def test_query_sync_maximal():
|
||||
)
|
||||
|
||||
|
||||
def test_query_sync_multiple_vectors():
|
||||
def handler(_body):
|
||||
return pa.table({"id": [1]})
|
||||
@pytest.mark.parametrize("server_version", [Version("0.1.0"), Version("0.2.0")])
|
||||
def test_query_sync_batch_queries(server_version):
|
||||
def handler(body):
|
||||
# TODO: we will add the ability to get the server version,
|
||||
# so that we can decide how to perform batch quires.
|
||||
vectors = body["vector"]
|
||||
if server_version >= Version(
|
||||
"0.2.0"
|
||||
): # we can handle batch queries in single request since 0.2.0
|
||||
assert len(vectors) == 2
|
||||
res = []
|
||||
for i, vector in enumerate(vectors):
|
||||
res.append({"id": 1, "query_index": i})
|
||||
return pa.Table.from_pylist(res)
|
||||
else:
|
||||
assert len(vectors) == 3 # matching dim
|
||||
return pa.table({"id": [1]})
|
||||
|
||||
with query_test_table(handler) as table:
|
||||
with query_test_table(handler, server_version=server_version) as table:
|
||||
results = table.search([[1, 2, 3], [4, 5, 6]]).limit(1).to_list()
|
||||
assert len(results) == 2
|
||||
results.sort(key=lambda x: x["query_index"])
|
||||
@@ -406,6 +423,7 @@ def test_query_sync_fts():
|
||||
"columns": [],
|
||||
},
|
||||
"k": 10,
|
||||
"prefilter": True,
|
||||
"vector": [],
|
||||
"version": None,
|
||||
}
|
||||
@@ -423,6 +441,7 @@ def test_query_sync_fts():
|
||||
},
|
||||
"k": 42,
|
||||
"vector": [],
|
||||
"prefilter": True,
|
||||
"with_row_id": True,
|
||||
"version": None,
|
||||
}
|
||||
@@ -449,6 +468,7 @@ def test_query_sync_hybrid():
|
||||
},
|
||||
"k": 42,
|
||||
"vector": [],
|
||||
"prefilter": True,
|
||||
"with_row_id": True,
|
||||
"version": None,
|
||||
}
|
||||
|
||||
@@ -32,8 +32,8 @@ pytest.importorskip("lancedb.fts")
|
||||
def get_test_table(tmp_path, use_tantivy):
|
||||
db = lancedb.connect(tmp_path)
|
||||
# Create a LanceDB table schema with a vector and a text column
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||
meta_emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test").create()
|
||||
meta_emb = EmbeddingFunctionRegistry.get_instance().get("test").create()
|
||||
|
||||
class MyTable(LanceModel):
|
||||
text: str = emb.SourceField()
|
||||
@@ -405,7 +405,9 @@ def test_answerdotai_reranker(tmp_path, use_tantivy):
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
|
||||
os.environ.get("OPENAI_API_KEY") is None
|
||||
or os.environ.get("OPENAI_BASE_URL") is not None,
|
||||
reason="OPENAI_API_KEY not set",
|
||||
)
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_openai_reranker(tmp_path, use_tantivy):
|
||||
|
||||
@@ -887,7 +887,7 @@ def test_create_with_embedding_function(mem_db: DBConnection):
|
||||
text: str
|
||||
vector: Vector(10)
|
||||
|
||||
func = MockTextEmbeddingFunction()
|
||||
func = MockTextEmbeddingFunction.create()
|
||||
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
|
||||
df = pd.DataFrame({"text": texts, "vector": func.compute_source_embeddings(texts)})
|
||||
|
||||
@@ -934,7 +934,7 @@ def test_create_f16_table(mem_db: DBConnection):
|
||||
|
||||
|
||||
def test_add_with_embedding_function(mem_db: DBConnection):
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test").create()
|
||||
|
||||
class MyTable(LanceModel):
|
||||
text: str = emb.SourceField()
|
||||
@@ -1025,13 +1025,13 @@ def test_empty_query(mem_db: DBConnection):
|
||||
|
||||
table = mem_db.create_table("my_table2", data=[{"id": i} for i in range(100)])
|
||||
df = table.search().select(["id"]).to_pandas()
|
||||
assert len(df) == 10
|
||||
assert len(df) == 100
|
||||
# None is the same as default
|
||||
df = table.search().select(["id"]).limit(None).to_pandas()
|
||||
assert len(df) == 10
|
||||
assert len(df) == 100
|
||||
# invalid limist is the same as None, wihch is the same as default
|
||||
df = table.search().select(["id"]).limit(-1).to_pandas()
|
||||
assert len(df) == 10
|
||||
assert len(df) == 100
|
||||
# valid limit should work
|
||||
df = table.search().select(["id"]).limit(42).to_pandas()
|
||||
assert len(df) == 42
|
||||
@@ -1128,7 +1128,7 @@ def test_count_rows(mem_db: DBConnection):
|
||||
|
||||
def setup_hybrid_search_table(db: DBConnection, embedding_func):
|
||||
# Create a LanceDB table schema with a vector and a text column
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get(embedding_func)()
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get(embedding_func).create()
|
||||
|
||||
class MyTable(LanceModel):
|
||||
text: str = emb.SourceField()
|
||||
@@ -1481,3 +1481,12 @@ async def test_optimize_delete_unverified(tmp_db_async: AsyncConnection, tmp_pat
|
||||
cleanup_older_than=timedelta(seconds=0), delete_unverified=True
|
||||
)
|
||||
assert stats.prune.old_versions_removed == 2
|
||||
|
||||
|
||||
def test_replace_field_metadata(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
table = db.create_table("my_table", data=[{"x": 0}])
|
||||
table.replace_field_metadata("x", {"foo": "bar"})
|
||||
schema = table.schema
|
||||
field = schema[0].metadata
|
||||
assert field == {b"foo": b"bar"}
|
||||
|
||||
@@ -127,7 +127,7 @@ def test_append_vector_columns():
|
||||
conf = EmbeddingFunctionConfig(
|
||||
source_column="text",
|
||||
vector_column="vector",
|
||||
function=MockTextEmbeddingFunction(),
|
||||
function=MockTextEmbeddingFunction.create(),
|
||||
)
|
||||
metadata = registry.get_table_metadata([conf])
|
||||
|
||||
@@ -434,7 +434,7 @@ def test_sanitize_data(
|
||||
conf = EmbeddingFunctionConfig(
|
||||
source_column="text",
|
||||
vector_column="vector",
|
||||
function=MockTextEmbeddingFunction(),
|
||||
function=MockTextEmbeddingFunction.create(),
|
||||
)
|
||||
metadata = registry.get_table_metadata([conf])
|
||||
else:
|
||||
|
||||
@@ -10,12 +10,13 @@ use lancedb::table::{
|
||||
Table as LanceDbTable,
|
||||
};
|
||||
use pyo3::{
|
||||
exceptions::{PyRuntimeError, PyValueError},
|
||||
exceptions::{PyKeyError, PyRuntimeError, PyValueError},
|
||||
pyclass, pymethods,
|
||||
types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods},
|
||||
Bound, FromPyObject, PyAny, PyRef, PyResult, Python, ToPyObject,
|
||||
};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
error::PythonErrorExt,
|
||||
@@ -486,6 +487,37 @@ impl Table {
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn replace_field_metadata<'a>(
|
||||
self_: PyRef<'a, Self>,
|
||||
field_name: String,
|
||||
metadata: &Bound<'_, PyDict>,
|
||||
) -> PyResult<Bound<'a, PyAny>> {
|
||||
let mut new_metadata = HashMap::<String, String>::new();
|
||||
for (column_name, value) in metadata.into_iter() {
|
||||
let key: String = column_name.extract()?;
|
||||
let value: String = value.extract()?;
|
||||
new_metadata.insert(key, value);
|
||||
}
|
||||
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let native_tbl = inner
|
||||
.as_native()
|
||||
.ok_or_else(|| PyValueError::new_err("This cannot be run on a remote table"))?;
|
||||
let schema = native_tbl.manifest().await.infer_error()?.schema;
|
||||
let field = schema
|
||||
.field(&field_name)
|
||||
.ok_or_else(|| PyKeyError::new_err(format!("Field {} not found", field_name)))?;
|
||||
|
||||
native_tbl
|
||||
.replace_field_metadata(vec![(field.id as u32, new_metadata)])
|
||||
.await
|
||||
.infer_error()?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(FromPyObject)]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-node"
|
||||
version = "0.16.1-beta.2"
|
||||
version = "0.18.0-beta.0"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb"
|
||||
version = "0.16.1-beta.2"
|
||||
version = "0.18.0-beta.0"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
@@ -70,6 +70,7 @@ candle-core = { version = "0.6.0", optional = true }
|
||||
candle-transformers = { version = "0.6.0", optional = true }
|
||||
candle-nn = { version = "0.6.0", optional = true }
|
||||
tokenizers = { version = "0.19.1", optional = true }
|
||||
semver = { workspace = true }
|
||||
|
||||
# For a workaround, see workspace Cargo.toml
|
||||
crunchy.workspace = true
|
||||
@@ -87,6 +88,7 @@ aws-config = { version = "1.0" }
|
||||
aws-smithy-runtime = { version = "1.3" }
|
||||
datafusion.workspace = true
|
||||
http-body = "1" # Matching reqwest
|
||||
rstest = "0.23.0"
|
||||
|
||||
|
||||
[features]
|
||||
|
||||
82
rust/lancedb/src/catalog.rs
Normal file
82
rust/lancedb/src/catalog.rs
Normal file
@@ -0,0 +1,82 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! Catalog implementation for managing databases
|
||||
|
||||
pub mod listing;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::database::Database;
|
||||
use crate::error::Result;
|
||||
use async_trait::async_trait;
|
||||
|
||||
/// Request parameters for listing databases
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct DatabaseNamesRequest {
|
||||
/// Start listing after this name (exclusive)
|
||||
pub start_after: Option<String>,
|
||||
/// Maximum number of names to return
|
||||
pub limit: Option<u32>,
|
||||
}
|
||||
|
||||
/// Request to open an existing database
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct OpenDatabaseRequest {
|
||||
/// The name of the database to open
|
||||
pub name: String,
|
||||
/// A map of database-specific options
|
||||
///
|
||||
/// Consult the catalog / database implementation to determine which options are available
|
||||
pub database_options: HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// Database creation mode
|
||||
///
|
||||
/// The default behavior is Create
|
||||
pub enum CreateDatabaseMode {
|
||||
/// Create new database, error if exists
|
||||
Create,
|
||||
/// Open existing database if present
|
||||
ExistOk,
|
||||
/// Overwrite existing database
|
||||
Overwrite,
|
||||
}
|
||||
|
||||
impl Default for CreateDatabaseMode {
|
||||
fn default() -> Self {
|
||||
Self::Create
|
||||
}
|
||||
}
|
||||
|
||||
/// Request to create a new database
|
||||
pub struct CreateDatabaseRequest {
|
||||
/// The name of the database to create
|
||||
pub name: String,
|
||||
/// The creation mode
|
||||
pub mode: CreateDatabaseMode,
|
||||
/// A map of catalog-specific options, consult your catalog implementation to determine what's available
|
||||
pub options: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Catalog: Send + Sync + std::fmt::Debug + 'static {
|
||||
/// List database names with pagination
|
||||
async fn database_names(&self, request: DatabaseNamesRequest) -> Result<Vec<String>>;
|
||||
|
||||
/// Create a new database
|
||||
async fn create_database(&self, request: CreateDatabaseRequest) -> Result<Arc<dyn Database>>;
|
||||
|
||||
/// Open existing database
|
||||
async fn open_database(&self, request: OpenDatabaseRequest) -> Result<Arc<dyn Database>>;
|
||||
|
||||
/// Rename database
|
||||
async fn rename_database(&self, old_name: &str, new_name: &str) -> Result<()>;
|
||||
|
||||
/// Delete database
|
||||
async fn drop_database(&self, name: &str) -> Result<()>;
|
||||
|
||||
/// Delete all databases
|
||||
async fn drop_all_databases(&self) -> Result<()>;
|
||||
}
|
||||
569
rust/lancedb/src/catalog/listing.rs
Normal file
569
rust/lancedb/src/catalog/listing.rs
Normal file
@@ -0,0 +1,569 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! Catalog implementation based on a local file system.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::fs::create_dir_all;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::{
|
||||
Catalog, CreateDatabaseMode, CreateDatabaseRequest, DatabaseNamesRequest, OpenDatabaseRequest,
|
||||
};
|
||||
use crate::connection::ConnectRequest;
|
||||
use crate::database::listing::ListingDatabase;
|
||||
use crate::database::Database;
|
||||
use crate::error::{CreateDirSnafu, Error, Result};
|
||||
use async_trait::async_trait;
|
||||
use lance::io::{ObjectStore, ObjectStoreParams, ObjectStoreRegistry};
|
||||
use lance_io::local::to_local_path;
|
||||
use object_store::path::Path as ObjectStorePath;
|
||||
use snafu::ResultExt;
|
||||
|
||||
/// A catalog implementation that works by listing subfolders in a directory
|
||||
///
|
||||
/// The listing catalog will be created with a base folder specified by the URI. Every subfolder
|
||||
/// in this base folder will be considered a database. These will be opened as a
|
||||
/// [`crate::database::listing::ListingDatabase`]
|
||||
#[derive(Debug)]
|
||||
pub struct ListingCatalog {
|
||||
object_store: ObjectStore,
|
||||
|
||||
uri: String,
|
||||
|
||||
base_path: ObjectStorePath,
|
||||
|
||||
storage_options: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl ListingCatalog {
|
||||
/// Try to create a local directory to store the lancedb dataset
|
||||
pub fn try_create_dir(path: &str) -> core::result::Result<(), std::io::Error> {
|
||||
let path = Path::new(path);
|
||||
if !path.try_exists()? {
|
||||
create_dir_all(path)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn uri(&self) -> &str {
|
||||
&self.uri
|
||||
}
|
||||
|
||||
async fn open_path(path: &str) -> Result<Self> {
|
||||
let (object_store, base_path) = ObjectStore::from_path(path).unwrap();
|
||||
if object_store.is_local() {
|
||||
Self::try_create_dir(path).context(CreateDirSnafu { path })?;
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
uri: path.to_string(),
|
||||
base_path,
|
||||
object_store,
|
||||
storage_options: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn connect(request: &ConnectRequest) -> Result<Self> {
|
||||
let uri = &request.uri;
|
||||
let parse_res = url::Url::parse(uri);
|
||||
|
||||
match parse_res {
|
||||
Ok(url) if url.scheme().len() == 1 && cfg!(windows) => Self::open_path(uri).await,
|
||||
Ok(url) => {
|
||||
let plain_uri = url.to_string();
|
||||
|
||||
let registry = Arc::new(ObjectStoreRegistry::default());
|
||||
let storage_options = request.storage_options.clone();
|
||||
let os_params = ObjectStoreParams {
|
||||
storage_options: Some(storage_options.clone()),
|
||||
..Default::default()
|
||||
};
|
||||
let (object_store, base_path) =
|
||||
ObjectStore::from_uri_and_params(registry, &plain_uri, &os_params).await?;
|
||||
if object_store.is_local() {
|
||||
Self::try_create_dir(&plain_uri).context(CreateDirSnafu { path: plain_uri })?;
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
uri: String::from(url.clone()),
|
||||
base_path,
|
||||
object_store,
|
||||
storage_options,
|
||||
})
|
||||
}
|
||||
Err(_) => Self::open_path(uri).await,
|
||||
}
|
||||
}
|
||||
|
||||
fn database_path(&self, name: &str) -> ObjectStorePath {
|
||||
self.base_path.child(name.replace('\\', "/"))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Catalog for ListingCatalog {
|
||||
async fn database_names(&self, request: DatabaseNamesRequest) -> Result<Vec<String>> {
|
||||
let mut f = self
|
||||
.object_store
|
||||
.read_dir(self.base_path.clone())
|
||||
.await?
|
||||
.iter()
|
||||
.map(Path::new)
|
||||
.filter_map(|p| p.file_name().and_then(|s| s.to_str().map(String::from)))
|
||||
.collect::<Vec<String>>();
|
||||
f.sort();
|
||||
|
||||
if let Some(start_after) = request.start_after {
|
||||
let index = f
|
||||
.iter()
|
||||
.position(|name| name.as_str() > start_after.as_str())
|
||||
.unwrap_or(f.len());
|
||||
f.drain(0..index);
|
||||
}
|
||||
if let Some(limit) = request.limit {
|
||||
f.truncate(limit as usize);
|
||||
}
|
||||
Ok(f)
|
||||
}
|
||||
|
||||
async fn create_database(&self, request: CreateDatabaseRequest) -> Result<Arc<dyn Database>> {
|
||||
let db_path = self.database_path(&request.name);
|
||||
let db_path_str = to_local_path(&db_path);
|
||||
let exists = Path::new(&db_path_str).exists();
|
||||
|
||||
match request.mode {
|
||||
CreateDatabaseMode::Create if exists => {
|
||||
return Err(Error::DatabaseAlreadyExists { name: request.name })
|
||||
}
|
||||
CreateDatabaseMode::Create => {
|
||||
create_dir_all(db_path.to_string()).unwrap();
|
||||
}
|
||||
CreateDatabaseMode::ExistOk => {
|
||||
if !exists {
|
||||
create_dir_all(db_path.to_string()).unwrap();
|
||||
}
|
||||
}
|
||||
CreateDatabaseMode::Overwrite => {
|
||||
if exists {
|
||||
self.drop_database(&request.name).await?;
|
||||
}
|
||||
create_dir_all(db_path.to_string()).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
let db_uri = format!("/{}/{}", self.base_path, request.name);
|
||||
|
||||
let connect_request = ConnectRequest {
|
||||
uri: db_uri,
|
||||
api_key: None,
|
||||
region: None,
|
||||
host_override: None,
|
||||
#[cfg(feature = "remote")]
|
||||
client_config: Default::default(),
|
||||
read_consistency_interval: None,
|
||||
storage_options: self.storage_options.clone(),
|
||||
};
|
||||
|
||||
Ok(Arc::new(
|
||||
ListingDatabase::connect_with_options(&connect_request).await?,
|
||||
))
|
||||
}
|
||||
|
||||
async fn open_database(&self, request: OpenDatabaseRequest) -> Result<Arc<dyn Database>> {
|
||||
let db_path = self.database_path(&request.name);
|
||||
|
||||
let db_path_str = to_local_path(&db_path);
|
||||
let exists = Path::new(&db_path_str).exists();
|
||||
if !exists {
|
||||
return Err(Error::DatabaseNotFound { name: request.name });
|
||||
}
|
||||
|
||||
let connect_request = ConnectRequest {
|
||||
uri: db_path.to_string(),
|
||||
api_key: None,
|
||||
region: None,
|
||||
host_override: None,
|
||||
#[cfg(feature = "remote")]
|
||||
client_config: Default::default(),
|
||||
read_consistency_interval: None,
|
||||
storage_options: self.storage_options.clone(),
|
||||
};
|
||||
|
||||
Ok(Arc::new(
|
||||
ListingDatabase::connect_with_options(&connect_request).await?,
|
||||
))
|
||||
}
|
||||
|
||||
async fn rename_database(&self, _old_name: &str, _new_name: &str) -> Result<()> {
|
||||
Err(Error::NotSupported {
|
||||
message: "rename_database is not supported in LanceDB OSS yet".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn drop_database(&self, name: &str) -> Result<()> {
|
||||
let db_path = self.database_path(name);
|
||||
self.object_store
|
||||
.remove_dir_all(db_path.clone())
|
||||
.await
|
||||
.map_err(|err| match err {
|
||||
lance::Error::NotFound { .. } => Error::DatabaseNotFound {
|
||||
name: name.to_owned(),
|
||||
},
|
||||
_ => Error::from(err),
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn drop_all_databases(&self) -> Result<()> {
|
||||
self.object_store
|
||||
.remove_dir_all(self.base_path.clone())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, not(windows)))]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
/// file:/// URIs with drive letters do not work correctly on Windows
|
||||
#[cfg(windows)]
|
||||
fn path_to_uri(path: PathBuf) -> String {
|
||||
path.to_str().unwrap().to_string()
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
fn path_to_uri(path: PathBuf) -> String {
|
||||
Url::from_file_path(path).unwrap().to_string()
|
||||
}
|
||||
|
||||
async fn setup_catalog() -> (TempDir, ListingCatalog) {
|
||||
let tempdir = tempfile::tempdir().unwrap();
|
||||
let catalog_path = tempdir.path().join("catalog");
|
||||
std::fs::create_dir_all(&catalog_path).unwrap();
|
||||
|
||||
let uri = path_to_uri(catalog_path);
|
||||
|
||||
let request = ConnectRequest {
|
||||
uri: uri.clone(),
|
||||
api_key: None,
|
||||
region: None,
|
||||
host_override: None,
|
||||
#[cfg(feature = "remote")]
|
||||
client_config: Default::default(),
|
||||
storage_options: HashMap::new(),
|
||||
read_consistency_interval: None,
|
||||
};
|
||||
|
||||
let catalog = ListingCatalog::connect(&request).await.unwrap();
|
||||
|
||||
(tempdir, catalog)
|
||||
}
|
||||
|
||||
use crate::database::{CreateTableData, CreateTableRequest, TableNamesRequest};
|
||||
use crate::table::TableDefinition;
|
||||
use arrow_schema::Field;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tempfile::{tempdir, TempDir};
|
||||
use url::Url;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_database_names() {
|
||||
let (_tempdir, catalog) = setup_catalog().await;
|
||||
|
||||
let names = catalog
|
||||
.database_names(DatabaseNamesRequest::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(names.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_database() {
|
||||
let (_tempdir, catalog) = setup_catalog().await;
|
||||
|
||||
catalog
|
||||
.create_database(CreateDatabaseRequest {
|
||||
name: "db1".into(),
|
||||
mode: CreateDatabaseMode::Create,
|
||||
options: HashMap::new(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let names = catalog
|
||||
.database_names(DatabaseNamesRequest::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(names, vec!["db1"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_database_exist_ok() {
|
||||
let (_tempdir, catalog) = setup_catalog().await;
|
||||
|
||||
let db1 = catalog
|
||||
.create_database(CreateDatabaseRequest {
|
||||
name: "db_exist_ok".into(),
|
||||
mode: CreateDatabaseMode::ExistOk,
|
||||
options: HashMap::new(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let dummy_schema = Arc::new(arrow_schema::Schema::new(Vec::<Field>::default()));
|
||||
db1.create_table(CreateTableRequest {
|
||||
name: "test_table".parse().unwrap(),
|
||||
data: CreateTableData::Empty(TableDefinition::new_from_schema(dummy_schema)),
|
||||
mode: Default::default(),
|
||||
write_options: Default::default(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let db2 = catalog
|
||||
.create_database(CreateDatabaseRequest {
|
||||
name: "db_exist_ok".into(),
|
||||
mode: CreateDatabaseMode::ExistOk,
|
||||
options: HashMap::new(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let tables = db2.table_names(TableNamesRequest::default()).await.unwrap();
|
||||
assert_eq!(tables, vec!["test_table".to_string()]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_database_overwrite() {
|
||||
let (_tempdir, catalog) = setup_catalog().await;
|
||||
|
||||
let db = catalog
|
||||
.create_database(CreateDatabaseRequest {
|
||||
name: "db_overwrite".into(),
|
||||
mode: CreateDatabaseMode::Create,
|
||||
options: HashMap::new(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let dummy_schema = Arc::new(arrow_schema::Schema::new(Vec::<Field>::default()));
|
||||
db.create_table(CreateTableRequest {
|
||||
name: "old_table".parse().unwrap(),
|
||||
data: CreateTableData::Empty(TableDefinition::new_from_schema(dummy_schema)),
|
||||
mode: Default::default(),
|
||||
write_options: Default::default(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let tables = db.table_names(TableNamesRequest::default()).await.unwrap();
|
||||
assert!(!tables.is_empty());
|
||||
|
||||
let new_db = catalog
|
||||
.create_database(CreateDatabaseRequest {
|
||||
name: "db_overwrite".into(),
|
||||
mode: CreateDatabaseMode::Overwrite,
|
||||
options: HashMap::new(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let tables = new_db
|
||||
.table_names(TableNamesRequest::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(tables.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_database_overwrite_non_existing() {
|
||||
let (_tempdir, catalog) = setup_catalog().await;
|
||||
|
||||
catalog
|
||||
.create_database(CreateDatabaseRequest {
|
||||
name: "new_db".into(),
|
||||
mode: CreateDatabaseMode::Overwrite,
|
||||
options: HashMap::new(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let names = catalog
|
||||
.database_names(DatabaseNamesRequest::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(names.contains(&"new_db".to_string()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_open_database() {
|
||||
let (_tempdir, catalog) = setup_catalog().await;
|
||||
|
||||
// Test open non-existent
|
||||
let result = catalog
|
||||
.open_database(OpenDatabaseRequest {
|
||||
name: "missing".into(),
|
||||
database_options: HashMap::new(),
|
||||
})
|
||||
.await;
|
||||
assert!(matches!(
|
||||
result.unwrap_err(),
|
||||
Error::DatabaseNotFound { name } if name == "missing"
|
||||
));
|
||||
|
||||
// Create and open
|
||||
catalog
|
||||
.create_database(CreateDatabaseRequest {
|
||||
name: "valid_db".into(),
|
||||
mode: CreateDatabaseMode::Create,
|
||||
options: HashMap::new(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let db = catalog
|
||||
.open_database(OpenDatabaseRequest {
|
||||
name: "valid_db".into(),
|
||||
database_options: HashMap::new(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
db.table_names(TableNamesRequest::default()).await.unwrap(),
|
||||
Vec::<String>::new()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_drop_database() {
|
||||
let (_tempdir, catalog) = setup_catalog().await;
|
||||
|
||||
// Create test database
|
||||
catalog
|
||||
.create_database(CreateDatabaseRequest {
|
||||
name: "to_drop".into(),
|
||||
mode: CreateDatabaseMode::Create,
|
||||
options: HashMap::new(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let names = catalog
|
||||
.database_names(DatabaseNamesRequest::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!names.is_empty());
|
||||
|
||||
// Drop database
|
||||
catalog.drop_database("to_drop").await.unwrap();
|
||||
|
||||
let names = catalog
|
||||
.database_names(DatabaseNamesRequest::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(names.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_drop_all_databases() {
|
||||
let (_tempdir, catalog) = setup_catalog().await;
|
||||
|
||||
catalog
|
||||
.create_database(CreateDatabaseRequest {
|
||||
name: "db1".into(),
|
||||
mode: CreateDatabaseMode::Create,
|
||||
options: HashMap::new(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
catalog
|
||||
.create_database(CreateDatabaseRequest {
|
||||
name: "db2".into(),
|
||||
mode: CreateDatabaseMode::Create,
|
||||
options: HashMap::new(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
catalog.drop_all_databases().await.unwrap();
|
||||
|
||||
let names = catalog
|
||||
.database_names(DatabaseNamesRequest::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(names.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rename_database_unsupported() {
|
||||
let (_tempdir, catalog) = setup_catalog().await;
|
||||
let result = catalog.rename_database("old", "new").await;
|
||||
assert!(matches!(
|
||||
result.unwrap_err(),
|
||||
Error::NotSupported { message } if message.contains("rename_database")
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_connect_local_path() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let path = tmp_dir.path().to_str().unwrap();
|
||||
|
||||
let request = ConnectRequest {
|
||||
uri: path.to_string(),
|
||||
api_key: None,
|
||||
region: None,
|
||||
host_override: None,
|
||||
#[cfg(feature = "remote")]
|
||||
client_config: Default::default(),
|
||||
storage_options: HashMap::new(),
|
||||
read_consistency_interval: None,
|
||||
};
|
||||
|
||||
let catalog = ListingCatalog::connect(&request).await.unwrap();
|
||||
assert!(catalog.object_store.is_local());
|
||||
assert_eq!(catalog.uri, path);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_connect_file_scheme() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let path = tmp_dir.path();
|
||||
let uri = path_to_uri(path.to_path_buf());
|
||||
|
||||
let request = ConnectRequest {
|
||||
uri: uri.clone(),
|
||||
api_key: None,
|
||||
region: None,
|
||||
host_override: None,
|
||||
#[cfg(feature = "remote")]
|
||||
client_config: Default::default(),
|
||||
storage_options: HashMap::new(),
|
||||
read_consistency_interval: None,
|
||||
};
|
||||
|
||||
let catalog = ListingCatalog::connect(&request).await.unwrap();
|
||||
assert!(catalog.object_store.is_local());
|
||||
assert_eq!(catalog.uri, uri);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_connect_invalid_uri_fallback() {
|
||||
let invalid_uri = "invalid:///path";
|
||||
let request = ConnectRequest {
|
||||
uri: invalid_uri.to_string(),
|
||||
api_key: None,
|
||||
region: None,
|
||||
host_override: None,
|
||||
#[cfg(feature = "remote")]
|
||||
client_config: Default::default(),
|
||||
storage_options: HashMap::new(),
|
||||
read_consistency_interval: None,
|
||||
};
|
||||
|
||||
let result = ListingCatalog::connect(&request).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,10 @@ pub enum Error {
|
||||
InvalidInput { message: String },
|
||||
#[snafu(display("Table '{name}' was not found"))]
|
||||
TableNotFound { name: String },
|
||||
#[snafu(display("Database '{name}' was not found"))]
|
||||
DatabaseNotFound { name: String },
|
||||
#[snafu(display("Database '{name}' already exists."))]
|
||||
DatabaseAlreadyExists { name: String },
|
||||
#[snafu(display("Index '{name}' was not found"))]
|
||||
IndexNotFound { name: String },
|
||||
#[snafu(display("Embedding function '{name}' was not found. : {reason}"))]
|
||||
|
||||
@@ -191,6 +191,7 @@
|
||||
//! ```
|
||||
|
||||
pub mod arrow;
|
||||
pub mod catalog;
|
||||
pub mod connection;
|
||||
pub mod data;
|
||||
pub mod database;
|
||||
|
||||
@@ -7,6 +7,7 @@ use std::sync::Arc;
|
||||
use arrow::compute::concat_batches;
|
||||
use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array};
|
||||
use arrow_schema::DataType;
|
||||
use datafusion_expr::Expr;
|
||||
use datafusion_physical_plan::ExecutionPlan;
|
||||
use futures::{stream, try_join, FutureExt, TryStreamExt};
|
||||
use half::f16;
|
||||
@@ -464,11 +465,14 @@ impl<T: HasQuery> QueryBase for T {
|
||||
}
|
||||
|
||||
fn only_if(mut self, filter: impl AsRef<str>) -> Self {
|
||||
self.mut_query().filter = Some(filter.as_ref().to_string());
|
||||
self.mut_query().filter = Some(QueryFilter::Sql(filter.as_ref().to_string()));
|
||||
self
|
||||
}
|
||||
|
||||
fn full_text_search(mut self, query: FullTextSearchQuery) -> Self {
|
||||
if self.mut_query().limit.is_none() {
|
||||
self.mut_query().limit = Some(DEFAULT_TOP_K);
|
||||
}
|
||||
self.mut_query().full_text_search = Some(query);
|
||||
self
|
||||
}
|
||||
@@ -577,6 +581,17 @@ pub trait ExecutableQuery {
|
||||
fn explain_plan(&self, verbose: bool) -> impl Future<Output = Result<String>> + Send;
|
||||
}
|
||||
|
||||
/// A query filter that can be applied to a query
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum QueryFilter {
|
||||
/// The filter is an SQL string
|
||||
Sql(String),
|
||||
/// The filter is a Substrait ExtendedExpression message with a single expression
|
||||
Substrait(Arc<[u8]>),
|
||||
/// The filter is a Datafusion expression
|
||||
Datafusion(Expr),
|
||||
}
|
||||
|
||||
/// A basic query into a table without any kind of search
|
||||
///
|
||||
/// This will result in a (potentially filtered) scan if executed
|
||||
@@ -589,7 +604,7 @@ pub struct QueryRequest {
|
||||
pub offset: Option<usize>,
|
||||
|
||||
/// Apply filter to the returned rows.
|
||||
pub filter: Option<String>,
|
||||
pub filter: Option<QueryFilter>,
|
||||
|
||||
/// Perform a full text search on the table.
|
||||
pub full_text_search: Option<FullTextSearchQuery>,
|
||||
@@ -622,7 +637,7 @@ pub struct QueryRequest {
|
||||
impl Default for QueryRequest {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
limit: Some(DEFAULT_TOP_K),
|
||||
limit: None,
|
||||
offset: None,
|
||||
filter: None,
|
||||
full_text_search: None,
|
||||
@@ -707,6 +722,11 @@ impl Query {
|
||||
let mut vector_query = self.into_vector();
|
||||
let query_vector = vector.to_query_vector(&DataType::Float32, "default")?;
|
||||
vector_query.request.query_vector.push(query_vector);
|
||||
|
||||
if vector_query.request.base.limit.is_none() {
|
||||
vector_query.request.base.limit = Some(DEFAULT_TOP_K);
|
||||
}
|
||||
|
||||
Ok(vector_query)
|
||||
}
|
||||
|
||||
|
||||
@@ -19,12 +19,41 @@ use crate::database::{
|
||||
};
|
||||
use crate::error::Result;
|
||||
use crate::table::BaseTable;
|
||||
use crate::Error;
|
||||
|
||||
use super::client::{ClientConfig, HttpSend, RequestResultExt, RestfulLanceDbClient, Sender};
|
||||
use super::table::RemoteTable;
|
||||
use super::util::batches_to_ipc_bytes;
|
||||
use super::util::{batches_to_ipc_bytes, parse_server_version};
|
||||
use super::ARROW_STREAM_CONTENT_TYPE;
|
||||
|
||||
// the versions of the server that we support
|
||||
// for any new feature that we need to change the SDK behavior, we should bump the server version,
|
||||
// and add a feature flag as method of `ServerVersion` here.
|
||||
pub const DEFAULT_SERVER_VERSION: semver::Version = semver::Version::new(0, 1, 0);
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ServerVersion(pub semver::Version);
|
||||
|
||||
impl Default for ServerVersion {
|
||||
fn default() -> Self {
|
||||
Self(DEFAULT_SERVER_VERSION.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerVersion {
|
||||
pub fn parse(version: &str) -> Result<Self> {
|
||||
let version = Self(
|
||||
semver::Version::parse(version).map_err(|e| Error::InvalidInput {
|
||||
message: e.to_string(),
|
||||
})?,
|
||||
);
|
||||
Ok(version)
|
||||
}
|
||||
|
||||
pub fn support_multivector(&self) -> bool {
|
||||
self.0 >= semver::Version::new(0, 2, 0)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ListTablesResponse {
|
||||
tables: Vec<String>,
|
||||
@@ -33,7 +62,7 @@ struct ListTablesResponse {
|
||||
#[derive(Debug)]
|
||||
pub struct RemoteDatabase<S: HttpSend = Sender> {
|
||||
client: RestfulLanceDbClient<S>,
|
||||
table_cache: Cache<String, ()>,
|
||||
table_cache: Cache<String, Arc<RemoteTable<S>>>,
|
||||
}
|
||||
|
||||
impl RemoteDatabase {
|
||||
@@ -115,13 +144,19 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
}
|
||||
let (request_id, rsp) = self.client.send(req, true).await?;
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let version = parse_server_version(&request_id, &rsp)?;
|
||||
let tables = rsp
|
||||
.json::<ListTablesResponse>()
|
||||
.await
|
||||
.err_to_http(request_id)?
|
||||
.tables;
|
||||
for table in &tables {
|
||||
self.table_cache.insert(table.clone(), ()).await;
|
||||
let remote_table = Arc::new(RemoteTable::new(
|
||||
self.client.clone(),
|
||||
table.clone(),
|
||||
version.clone(),
|
||||
));
|
||||
self.table_cache.insert(table.clone(), remote_table).await;
|
||||
}
|
||||
Ok(tables)
|
||||
}
|
||||
@@ -187,34 +222,42 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
return Err(crate::Error::InvalidInput { message: body });
|
||||
}
|
||||
}
|
||||
|
||||
self.client.check_response(&request_id, rsp).await?;
|
||||
|
||||
self.table_cache.insert(request.name.clone(), ()).await;
|
||||
|
||||
Ok(Arc::new(RemoteTable::new(
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let version = parse_server_version(&request_id, &rsp)?;
|
||||
let table = Arc::new(RemoteTable::new(
|
||||
self.client.clone(),
|
||||
request.name,
|
||||
)))
|
||||
request.name.clone(),
|
||||
version,
|
||||
));
|
||||
self.table_cache
|
||||
.insert(request.name.clone(), table.clone())
|
||||
.await;
|
||||
|
||||
Ok(table)
|
||||
}
|
||||
|
||||
async fn open_table(&self, request: OpenTableRequest) -> Result<Arc<dyn BaseTable>> {
|
||||
// We describe the table to confirm it exists before moving on.
|
||||
if self.table_cache.get(&request.name).await.is_none() {
|
||||
if let Some(table) = self.table_cache.get(&request.name).await {
|
||||
Ok(table.clone())
|
||||
} else {
|
||||
let req = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/describe/", request.name));
|
||||
let (request_id, resp) = self.client.send(req, true).await?;
|
||||
if resp.status() == StatusCode::NOT_FOUND {
|
||||
let (request_id, rsp) = self.client.send(req, true).await?;
|
||||
if rsp.status() == StatusCode::NOT_FOUND {
|
||||
return Err(crate::Error::TableNotFound { name: request.name });
|
||||
}
|
||||
self.client.check_response(&request_id, resp).await?;
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let version = parse_server_version(&request_id, &rsp)?;
|
||||
let table = Arc::new(RemoteTable::new(
|
||||
self.client.clone(),
|
||||
request.name.clone(),
|
||||
version,
|
||||
));
|
||||
self.table_cache.insert(request.name, table.clone()).await;
|
||||
Ok(table)
|
||||
}
|
||||
|
||||
Ok(Arc::new(RemoteTable::new(
|
||||
self.client.clone(),
|
||||
request.name,
|
||||
)))
|
||||
}
|
||||
|
||||
async fn rename_table(&self, current_name: &str, new_name: &str) -> Result<()> {
|
||||
@@ -224,8 +267,10 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
let req = req.json(&serde_json::json!({ "new_table_name": new_name }));
|
||||
let (request_id, resp) = self.client.send(req, false).await?;
|
||||
self.client.check_response(&request_id, resp).await?;
|
||||
self.table_cache.remove(current_name).await;
|
||||
self.table_cache.insert(new_name.into(), ()).await;
|
||||
let table = self.table_cache.remove(current_name).await;
|
||||
if let Some(table) = table {
|
||||
self.table_cache.insert(new_name.into(), table).await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -7,10 +7,10 @@ use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::index::Index;
|
||||
use crate::index::IndexStatistics;
|
||||
use crate::query::{QueryRequest, Select, VectorQueryRequest};
|
||||
use crate::query::{QueryFilter, QueryRequest, Select, VectorQueryRequest};
|
||||
use crate::table::{AddDataMode, AnyQuery, Filter};
|
||||
use crate::utils::{supported_btree_data_type, supported_vector_data_type};
|
||||
use crate::{DistanceType, Error};
|
||||
use crate::{DistanceType, Error, Table};
|
||||
use arrow_array::RecordBatchReader;
|
||||
use arrow_ipc::reader::FileReader;
|
||||
use arrow_schema::{DataType, SchemaRef};
|
||||
@@ -24,7 +24,7 @@ use http::StatusCode;
|
||||
use lance::arrow::json::{JsonDataType, JsonSchema};
|
||||
use lance::dataset::scanner::DatasetRecordBatchStream;
|
||||
use lance::dataset::{ColumnAlteration, NewColumnTransform, Version};
|
||||
use lance_datafusion::exec::OneShotExec;
|
||||
use lance_datafusion::exec::{execute_plan, OneShotExec};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
@@ -41,6 +41,7 @@ use crate::{
|
||||
|
||||
use super::client::RequestResultExt;
|
||||
use super::client::{HttpSend, RestfulLanceDbClient, Sender};
|
||||
use super::db::ServerVersion;
|
||||
use super::ARROW_STREAM_CONTENT_TYPE;
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -48,15 +49,21 @@ pub struct RemoteTable<S: HttpSend = Sender> {
|
||||
#[allow(dead_code)]
|
||||
client: RestfulLanceDbClient<S>,
|
||||
name: String,
|
||||
server_version: ServerVersion,
|
||||
|
||||
version: RwLock<Option<u64>>,
|
||||
}
|
||||
|
||||
impl<S: HttpSend> RemoteTable<S> {
|
||||
pub fn new(client: RestfulLanceDbClient<S>, name: String) -> Self {
|
||||
pub fn new(
|
||||
client: RestfulLanceDbClient<S>,
|
||||
name: String,
|
||||
server_version: ServerVersion,
|
||||
) -> Self {
|
||||
Self {
|
||||
client,
|
||||
name,
|
||||
server_version,
|
||||
version: RwLock::new(None),
|
||||
}
|
||||
}
|
||||
@@ -149,16 +156,23 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
}
|
||||
|
||||
fn apply_query_params(body: &mut serde_json::Value, params: &QueryRequest) -> Result<()> {
|
||||
body["prefilter"] = params.prefilter.into();
|
||||
if let Some(offset) = params.offset {
|
||||
body["offset"] = serde_json::Value::Number(serde_json::Number::from(offset));
|
||||
}
|
||||
|
||||
if let Some(limit) = params.limit {
|
||||
body["k"] = serde_json::Value::Number(serde_json::Number::from(limit));
|
||||
}
|
||||
// Server requires k.
|
||||
let limit = params.limit.unwrap_or(usize::MAX);
|
||||
body["k"] = serde_json::Value::Number(serde_json::Number::from(limit));
|
||||
|
||||
if let Some(filter) = ¶ms.filter {
|
||||
body["filter"] = serde_json::Value::String(filter.clone());
|
||||
if let QueryFilter::Sql(filter) = filter {
|
||||
body["filter"] = serde_json::Value::String(filter.clone());
|
||||
} else {
|
||||
return Err(Error::NotSupported {
|
||||
message: "querying a remote table with a non-sql filter".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
match ¶ms.select {
|
||||
@@ -205,13 +219,13 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
}
|
||||
|
||||
fn apply_vector_query_params(
|
||||
body: &mut serde_json::Value,
|
||||
&self,
|
||||
mut body: serde_json::Value,
|
||||
query: &VectorQueryRequest,
|
||||
) -> Result<()> {
|
||||
Self::apply_query_params(body, &query.base)?;
|
||||
) -> Result<Vec<serde_json::Value>> {
|
||||
Self::apply_query_params(&mut body, &query.base)?;
|
||||
|
||||
// Apply general parameters, before we dispatch based on number of query vectors.
|
||||
body["prefilter"] = query.base.prefilter.into();
|
||||
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
|
||||
body["nprobes"] = query.nprobes.into();
|
||||
body["lower_bound"] = query.lower_bound.into();
|
||||
@@ -250,25 +264,40 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
}
|
||||
}
|
||||
|
||||
match query.query_vector.len() {
|
||||
let bodies = match query.query_vector.len() {
|
||||
0 => {
|
||||
// Server takes empty vector, not null or undefined.
|
||||
body["vector"] = serde_json::Value::Array(Vec::new());
|
||||
vec![body]
|
||||
}
|
||||
1 => {
|
||||
body["vector"] = vector_to_json(&query.query_vector[0])?;
|
||||
vec![body]
|
||||
}
|
||||
_ => {
|
||||
let vectors = query
|
||||
.query_vector
|
||||
.iter()
|
||||
.map(vector_to_json)
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
body["vector"] = serde_json::Value::Array(vectors);
|
||||
if self.server_version.support_multivector() {
|
||||
let vectors = query
|
||||
.query_vector
|
||||
.iter()
|
||||
.map(vector_to_json)
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
body["vector"] = serde_json::Value::Array(vectors);
|
||||
vec![body]
|
||||
} else {
|
||||
// Server does not support multiple vectors in a single query.
|
||||
// We need to send multiple requests.
|
||||
let mut bodies = Vec::with_capacity(query.query_vector.len());
|
||||
for vector in &query.query_vector {
|
||||
let mut body = body.clone();
|
||||
body["vector"] = vector_to_json(vector)?;
|
||||
bodies.push(body);
|
||||
}
|
||||
bodies
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(())
|
||||
Ok(bodies)
|
||||
}
|
||||
|
||||
async fn check_mutable(&self) -> Result<()> {
|
||||
@@ -293,27 +322,34 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
&self,
|
||||
query: &AnyQuery,
|
||||
_options: QueryExecutionOptions,
|
||||
) -> Result<Pin<Box<dyn RecordBatchStream + Send>>> {
|
||||
) -> Result<Vec<Pin<Box<dyn RecordBatchStream + Send>>>> {
|
||||
let request = self.client.post(&format!("/v1/table/{}/query/", self.name));
|
||||
|
||||
let version = self.current_version().await;
|
||||
let mut body = serde_json::json!({ "version": version });
|
||||
|
||||
match query {
|
||||
let requests = match query {
|
||||
AnyQuery::Query(query) => {
|
||||
Self::apply_query_params(&mut body, query)?;
|
||||
// Empty vector can be passed if no vector search is performed.
|
||||
body["vector"] = serde_json::Value::Array(Vec::new());
|
||||
vec![request.json(&body)]
|
||||
}
|
||||
AnyQuery::VectorQuery(query) => {
|
||||
Self::apply_vector_query_params(&mut body, query)?;
|
||||
let bodies = self.apply_vector_query_params(body, query)?;
|
||||
bodies
|
||||
.into_iter()
|
||||
.map(|body| request.try_clone().unwrap().json(&body))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let request = request.json(&body);
|
||||
let (request_id, response) = self.client.send(request, true).await?;
|
||||
let stream = self.read_arrow_stream(&request_id, response).await?;
|
||||
Ok(stream)
|
||||
let futures = requests.into_iter().map(|req| async move {
|
||||
let (request_id, response) = self.client.send(req, true).await?;
|
||||
self.read_arrow_stream(&request_id, response).await
|
||||
});
|
||||
let streams = futures::future::try_join_all(futures).await?;
|
||||
Ok(streams)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -336,7 +372,7 @@ mod test_utils {
|
||||
use crate::remote::client::test_utils::MockSender;
|
||||
|
||||
impl RemoteTable<MockSender> {
|
||||
pub fn new_mock<F, T>(name: String, handler: F) -> Self
|
||||
pub fn new_mock<F, T>(name: String, handler: F, version: Option<semver::Version>) -> Self
|
||||
where
|
||||
F: Fn(reqwest::Request) -> http::Response<T> + Send + Sync + 'static,
|
||||
T: Into<reqwest::Body>,
|
||||
@@ -345,6 +381,7 @@ mod test_utils {
|
||||
Self {
|
||||
client,
|
||||
name,
|
||||
server_version: version.map(ServerVersion).unwrap_or_default(),
|
||||
version: RwLock::new(None),
|
||||
}
|
||||
}
|
||||
@@ -485,8 +522,17 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
query: &AnyQuery,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
let stream = self.execute_query(query, options).await?;
|
||||
Ok(Arc::new(OneShotExec::new(stream)))
|
||||
let streams = self.execute_query(query, options).await?;
|
||||
if streams.len() == 1 {
|
||||
let stream = streams.into_iter().next().unwrap();
|
||||
Ok(Arc::new(OneShotExec::new(stream)))
|
||||
} else {
|
||||
let stream_execs = streams
|
||||
.into_iter()
|
||||
.map(|stream| Arc::new(OneShotExec::new(stream)) as Arc<dyn ExecutionPlan>)
|
||||
.collect();
|
||||
Table::multi_vector_plan(stream_execs)
|
||||
}
|
||||
}
|
||||
|
||||
async fn query(
|
||||
@@ -494,8 +540,24 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
query: &AnyQuery,
|
||||
_options: QueryExecutionOptions,
|
||||
) -> Result<DatasetRecordBatchStream> {
|
||||
let stream = self.execute_query(query, _options).await?;
|
||||
Ok(DatasetRecordBatchStream::new(stream))
|
||||
let streams = self.execute_query(query, _options).await?;
|
||||
|
||||
if streams.len() == 1 {
|
||||
Ok(DatasetRecordBatchStream::new(
|
||||
streams.into_iter().next().unwrap(),
|
||||
))
|
||||
} else {
|
||||
let stream_execs = streams
|
||||
.into_iter()
|
||||
.map(|stream| Arc::new(OneShotExec::new(stream)) as Arc<dyn ExecutionPlan>)
|
||||
.collect();
|
||||
let plan = Table::multi_vector_plan(stream_execs)?;
|
||||
|
||||
Ok(DatasetRecordBatchStream::new(execute_plan(
|
||||
plan,
|
||||
Default::default(),
|
||||
)?))
|
||||
}
|
||||
}
|
||||
async fn update(&self, update: UpdateBuilder) -> Result<u64> {
|
||||
self.check_mutable().await?;
|
||||
@@ -878,8 +940,10 @@ mod tests {
|
||||
use futures::{future::BoxFuture, StreamExt, TryFutureExt};
|
||||
use lance_index::scalar::FullTextSearchQuery;
|
||||
use reqwest::Body;
|
||||
use rstest::rstest;
|
||||
|
||||
use crate::index::vector::IvfFlatIndexBuilder;
|
||||
use crate::remote::db::DEFAULT_SERVER_VERSION;
|
||||
use crate::remote::JSON_CONTENT_TYPE;
|
||||
use crate::{
|
||||
index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType},
|
||||
@@ -1287,6 +1351,52 @@ mod tests {
|
||||
table.delete("id in (1, 2, 3)").await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_query_plain() {
|
||||
let expected_data = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap();
|
||||
let expected_data_ref = expected_data.clone();
|
||||
|
||||
let table = Table::new_with_handler("my_table", move |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/query/");
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
JSON_CONTENT_TYPE
|
||||
);
|
||||
|
||||
let body = request.body().unwrap().as_bytes().unwrap();
|
||||
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
|
||||
let expected_body = serde_json::json!({
|
||||
"k": usize::MAX,
|
||||
"prefilter": true,
|
||||
"vector": [], // Empty vector means no vector query.
|
||||
"version": null,
|
||||
});
|
||||
assert_eq!(body, expected_body);
|
||||
|
||||
let response_body = write_ipc_file(&expected_data_ref);
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
|
||||
.body(response_body)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let data = table
|
||||
.query()
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
.collect::<Vec<_>>()
|
||||
.await;
|
||||
assert_eq!(data.len(), 1);
|
||||
assert_eq!(data[0].as_ref().unwrap(), &expected_data);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_query_vector_default_values() {
|
||||
let expected_data = RecordBatch::try_new(
|
||||
@@ -1340,6 +1450,55 @@ mod tests {
|
||||
assert_eq!(data[0].as_ref().unwrap(), &expected_data);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_query_fts_default_values() {
|
||||
let expected_data = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap();
|
||||
let expected_data_ref = expected_data.clone();
|
||||
|
||||
let table = Table::new_with_handler("my_table", move |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/query/");
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
JSON_CONTENT_TYPE
|
||||
);
|
||||
|
||||
let body = request.body().unwrap().as_bytes().unwrap();
|
||||
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
|
||||
let expected_body = serde_json::json!({
|
||||
"full_text_query": {
|
||||
"columns": [],
|
||||
"query": "test",
|
||||
},
|
||||
"prefilter": true,
|
||||
"version": null,
|
||||
"k": 10,
|
||||
"vector": [],
|
||||
});
|
||||
assert_eq!(body, expected_body);
|
||||
|
||||
let response_body = write_ipc_file(&expected_data_ref);
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
|
||||
.body(response_body)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let data = table
|
||||
.query()
|
||||
.full_text_search(FullTextSearchQuery::new("test".to_owned()))
|
||||
.execute()
|
||||
.await;
|
||||
let data = data.unwrap().collect::<Vec<_>>().await;
|
||||
assert_eq!(data.len(), 1);
|
||||
assert_eq!(data[0].as_ref().unwrap(), &expected_data);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_query_vector_all_params() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
@@ -1422,6 +1581,7 @@ mod tests {
|
||||
"k": 10,
|
||||
"vector": [],
|
||||
"with_row_id": true,
|
||||
"prefilter": true,
|
||||
"version": null
|
||||
});
|
||||
assert_eq!(body, expected_body);
|
||||
@@ -1452,9 +1612,12 @@ mod tests {
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case(DEFAULT_SERVER_VERSION.clone())]
|
||||
#[case(semver::Version::new(0, 2, 0))]
|
||||
#[tokio::test]
|
||||
async fn test_query_multiple_vectors() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
async fn test_batch_queries(#[case] version: semver::Version) {
|
||||
let table = Table::new_with_handler_version("my_table", version.clone(), move |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/query/");
|
||||
assert_eq!(
|
||||
@@ -1464,20 +1627,32 @@ mod tests {
|
||||
let body: serde_json::Value =
|
||||
serde_json::from_slice(request.body().unwrap().as_bytes().unwrap()).unwrap();
|
||||
let query_vectors = body["vector"].as_array().unwrap();
|
||||
assert_eq!(query_vectors.len(), 2);
|
||||
assert_eq!(query_vectors[0].as_array().unwrap().len(), 3);
|
||||
assert_eq!(query_vectors[1].as_array().unwrap().len(), 3);
|
||||
let data = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![
|
||||
Field::new("a", DataType::Int32, false),
|
||||
Field::new("query_index", DataType::Int32, false),
|
||||
])),
|
||||
vec![
|
||||
Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6])),
|
||||
Arc::new(Int32Array::from(vec![0, 0, 0, 1, 1, 1])),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let version = ServerVersion(version.clone());
|
||||
let data = if version.support_multivector() {
|
||||
assert_eq!(query_vectors.len(), 2);
|
||||
assert_eq!(query_vectors[0].as_array().unwrap().len(), 3);
|
||||
assert_eq!(query_vectors[1].as_array().unwrap().len(), 3);
|
||||
RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![
|
||||
Field::new("a", DataType::Int32, false),
|
||||
Field::new("query_index", DataType::Int32, false),
|
||||
])),
|
||||
vec![
|
||||
Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6])),
|
||||
Arc::new(Int32Array::from(vec![0, 0, 0, 1, 1, 1])),
|
||||
],
|
||||
)
|
||||
.unwrap()
|
||||
} else {
|
||||
// it's single flat vector, so here the length is dim
|
||||
assert_eq!(query_vectors.len(), 3);
|
||||
RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
let response_body = write_ipc_file(&data);
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
|
||||
@@ -4,9 +4,12 @@
|
||||
use std::io::Cursor;
|
||||
|
||||
use arrow_array::RecordBatchReader;
|
||||
use reqwest::Response;
|
||||
|
||||
use crate::Result;
|
||||
|
||||
use super::db::ServerVersion;
|
||||
|
||||
pub fn batches_to_ipc_bytes(batches: impl RecordBatchReader) -> Result<Vec<u8>> {
|
||||
const WRITE_BUF_SIZE: usize = 4096;
|
||||
let buf = Vec::with_capacity(WRITE_BUF_SIZE);
|
||||
@@ -22,3 +25,24 @@ pub fn batches_to_ipc_bytes(batches: impl RecordBatchReader) -> Result<Vec<u8>>
|
||||
}
|
||||
Ok(buf.into_inner())
|
||||
}
|
||||
|
||||
pub fn parse_server_version(req_id: &str, rsp: &Response) -> Result<ServerVersion> {
|
||||
let version = rsp
|
||||
.headers()
|
||||
.get("phalanx-version")
|
||||
.map(|v| {
|
||||
let v = v.to_str().map_err(|e| crate::Error::Http {
|
||||
source: e.into(),
|
||||
request_id: req_id.to_string(),
|
||||
status_code: Some(rsp.status()),
|
||||
})?;
|
||||
ServerVersion::parse(v).map_err(|e| crate::Error::Http {
|
||||
source: e.into(),
|
||||
request_id: req_id.to_string(),
|
||||
status_code: Some(rsp.status()),
|
||||
})
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or_default();
|
||||
Ok(version)
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ use crate::index::{
|
||||
};
|
||||
use crate::index::{IndexConfig, IndexStatisticsImpl};
|
||||
use crate::query::{
|
||||
IntoQueryVector, Query, QueryExecutionOptions, QueryRequest, Select, VectorQuery,
|
||||
IntoQueryVector, Query, QueryExecutionOptions, QueryFilter, QueryRequest, Select, VectorQuery,
|
||||
VectorQueryRequest, DEFAULT_TOP_K,
|
||||
};
|
||||
use crate::utils::{
|
||||
@@ -509,6 +509,27 @@ mod test_utils {
|
||||
let inner = Arc::new(crate::remote::table::RemoteTable::new_mock(
|
||||
name.into(),
|
||||
handler,
|
||||
None,
|
||||
));
|
||||
Self {
|
||||
inner,
|
||||
// Registry is unused.
|
||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_with_handler_version<T>(
|
||||
name: impl Into<String>,
|
||||
version: semver::Version,
|
||||
handler: impl Fn(reqwest::Request) -> http::Response<T> + Clone + Send + Sync + 'static,
|
||||
) -> Self
|
||||
where
|
||||
T: Into<reqwest::Body>,
|
||||
{
|
||||
let inner = Arc::new(crate::remote::table::RemoteTable::new_mock(
|
||||
name.into(),
|
||||
handler,
|
||||
Some(version),
|
||||
));
|
||||
Self {
|
||||
inner,
|
||||
@@ -2125,7 +2146,17 @@ impl BaseTable for NativeTable {
|
||||
}
|
||||
|
||||
if let Some(filter) = &query.base.filter {
|
||||
scanner.filter(filter)?;
|
||||
match filter {
|
||||
QueryFilter::Sql(sql) => {
|
||||
scanner.filter(sql)?;
|
||||
}
|
||||
QueryFilter::Substrait(substrait) => {
|
||||
scanner.filter_substrait(substrait)?;
|
||||
}
|
||||
QueryFilter::Datafusion(expr) => {
|
||||
scanner.filter_expr(expr.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(fts) = &query.base.full_text_search {
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
//! This module contains adapters to allow LanceDB tables to be used as DataFusion table providers.
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use arrow_array::RecordBatch;
|
||||
use arrow_schema::Schema as ArrowSchema;
|
||||
use async_trait::async_trait;
|
||||
use datafusion_catalog::{Session, TableProvider};
|
||||
@@ -17,7 +18,7 @@ use futures::{TryFutureExt, TryStreamExt};
|
||||
|
||||
use super::{AnyQuery, BaseTable};
|
||||
use crate::{
|
||||
query::{QueryExecutionOptions, QueryRequest, Select},
|
||||
query::{QueryExecutionOptions, QueryFilter, QueryRequest, Select},
|
||||
Result,
|
||||
};
|
||||
|
||||
@@ -104,7 +105,9 @@ impl ExecutionPlan for MetadataEraserExec {
|
||||
) -> DataFusionResult<SendableRecordBatchStream> {
|
||||
let stream = self.input.execute(partition, context)?;
|
||||
let schema = self.schema.clone();
|
||||
let stream = stream.map_ok(move |batch| batch.with_schema(schema.clone()).unwrap());
|
||||
let stream = stream.map_ok(move |batch| {
|
||||
RecordBatch::try_new(schema.clone(), batch.columns().to_vec()).unwrap()
|
||||
});
|
||||
Ok(
|
||||
Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream))
|
||||
as SendableRecordBatchStream,
|
||||
@@ -161,7 +164,13 @@ impl TableProvider for BaseTableAdapter {
|
||||
.collect();
|
||||
query.select = Select::Columns(field_names);
|
||||
}
|
||||
assert!(filters.is_empty());
|
||||
if !filters.is_empty() {
|
||||
let first = filters.first().unwrap().clone();
|
||||
let filter = filters[1..]
|
||||
.iter()
|
||||
.fold(first, |acc, expr| acc.and(expr.clone()));
|
||||
query.filter = Some(QueryFilter::Datafusion(filter));
|
||||
}
|
||||
if let Some(limit) = limit {
|
||||
query.limit = Some(limit);
|
||||
} else {
|
||||
@@ -180,11 +189,7 @@ impl TableProvider for BaseTableAdapter {
|
||||
&self,
|
||||
filters: &[&Expr],
|
||||
) -> DataFusionResult<Vec<TableProviderFilterPushDown>> {
|
||||
// TODO: Pushdown unsupported until we can support datafusion filters in BaseTable::create_plan
|
||||
Ok(vec![
|
||||
TableProviderFilterPushDown::Unsupported;
|
||||
filters.len()
|
||||
])
|
||||
Ok(vec![TableProviderFilterPushDown::Exact; filters.len()])
|
||||
}
|
||||
|
||||
fn statistics(&self) -> Option<Statistics> {
|
||||
@@ -197,67 +202,257 @@ impl TableProvider for BaseTableAdapter {
|
||||
pub mod tests {
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, RecordBatchReader};
|
||||
use arrow::array::AsArray;
|
||||
use arrow_array::{
|
||||
BinaryArray, Float64Array, Int32Array, Int64Array, RecordBatch, RecordBatchIterator,
|
||||
RecordBatchReader, StringArray, UInt32Array,
|
||||
};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use datafusion::{datasource::provider_as_source, prelude::SessionContext};
|
||||
use datafusion_catalog::TableProvider;
|
||||
use datafusion_expr::LogicalPlanBuilder;
|
||||
use datafusion_execution::SendableRecordBatchStream;
|
||||
use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder};
|
||||
use futures::TryStreamExt;
|
||||
use tempfile::tempdir;
|
||||
|
||||
use crate::{connect, table::datafusion::BaseTableAdapter};
|
||||
use crate::{
|
||||
connect,
|
||||
index::{scalar::BTreeIndexBuilder, Index},
|
||||
table::datafusion::BaseTableAdapter,
|
||||
};
|
||||
|
||||
fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static {
|
||||
let metadata = HashMap::from_iter(vec![("foo".to_string(), "bar".to_string())]);
|
||||
let schema = Arc::new(
|
||||
Schema::new(vec![Field::new("i", DataType::Int32, false)]).with_metadata(metadata),
|
||||
Schema::new(vec![
|
||||
Field::new("i", DataType::Int32, false),
|
||||
Field::new("indexed", DataType::UInt32, false),
|
||||
])
|
||||
.with_metadata(metadata),
|
||||
);
|
||||
RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(Int32Array::from_iter_values(0..10))],
|
||||
vec![
|
||||
Arc::new(Int32Array::from_iter_values(0..10)),
|
||||
Arc::new(UInt32Array::from_iter_values(0..10)),
|
||||
],
|
||||
)],
|
||||
schema,
|
||||
)
|
||||
}
|
||||
|
||||
fn make_tbl_two_test_batches() -> impl RecordBatchReader + Send + Sync + 'static {
|
||||
let metadata = HashMap::from_iter(vec![("foo".to_string(), "bar".to_string())]);
|
||||
let schema = Arc::new(
|
||||
Schema::new(vec![
|
||||
Field::new("ints", DataType::Int64, true),
|
||||
Field::new("strings", DataType::Utf8, true),
|
||||
Field::new("floats", DataType::Float64, true),
|
||||
Field::new("jsons", DataType::Utf8, true),
|
||||
Field::new("bins", DataType::Binary, true),
|
||||
Field::new("nodates", DataType::Utf8, true),
|
||||
])
|
||||
.with_metadata(metadata),
|
||||
);
|
||||
RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(Int64Array::from_iter_values(0..1000)),
|
||||
Arc::new(StringArray::from_iter_values(
|
||||
(0..1000).map(|i| i.to_string()),
|
||||
)),
|
||||
Arc::new(Float64Array::from_iter_values((0..1000).map(|i| i as f64))),
|
||||
Arc::new(StringArray::from_iter_values(
|
||||
(0..1000).map(|i| format!("{{\"i\":{}}}", i)),
|
||||
)),
|
||||
Arc::new(BinaryArray::from_iter_values(
|
||||
(0..1000).map(|i| (i as u32).to_be_bytes().to_vec()),
|
||||
)),
|
||||
Arc::new(StringArray::from_iter_values(
|
||||
(0..1000).map(|i| i.to_string()),
|
||||
)),
|
||||
],
|
||||
)],
|
||||
schema,
|
||||
)
|
||||
}
|
||||
|
||||
struct TestFixture {
|
||||
_tmp_dir: tempfile::TempDir,
|
||||
// An adapter for a table with make_test_batches batches
|
||||
adapter: Arc<BaseTableAdapter>,
|
||||
// an adapter for a table with make_tbl_two_test_batches batches
|
||||
adapter2: Arc<BaseTableAdapter>,
|
||||
}
|
||||
|
||||
impl TestFixture {
|
||||
async fn new() -> Self {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let dataset_path = tmp_dir.path().join("test.lance");
|
||||
let uri = dataset_path.to_str().unwrap();
|
||||
|
||||
let db = connect(uri).execute().await.unwrap();
|
||||
|
||||
let tbl = db
|
||||
.create_table("foo", make_test_batches())
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
tbl.create_index(&["indexed"], Index::BTree(BTreeIndexBuilder::default()))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let tbl2 = db
|
||||
.create_table("tbl2", make_tbl_two_test_batches())
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let adapter = Arc::new(
|
||||
BaseTableAdapter::try_new(tbl.base_table().clone())
|
||||
.await
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let adapter2 = Arc::new(
|
||||
BaseTableAdapter::try_new(tbl2.base_table().clone())
|
||||
.await
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
Self {
|
||||
_tmp_dir: tmp_dir,
|
||||
adapter,
|
||||
adapter2,
|
||||
}
|
||||
}
|
||||
|
||||
async fn plan_to_stream(plan: LogicalPlan) -> SendableRecordBatchStream {
|
||||
SessionContext::new()
|
||||
.execute_logical_plan(plan)
|
||||
.await
|
||||
.unwrap()
|
||||
.execute_stream()
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
async fn plan_to_explain(plan: LogicalPlan) -> String {
|
||||
let mut explain_stream = SessionContext::new()
|
||||
.execute_logical_plan(plan)
|
||||
.await
|
||||
.unwrap()
|
||||
.explain(true, false)
|
||||
.unwrap()
|
||||
.execute_stream()
|
||||
.await
|
||||
.unwrap();
|
||||
let batch = explain_stream.try_next().await.unwrap().unwrap();
|
||||
assert!(explain_stream.try_next().await.unwrap().is_none());
|
||||
|
||||
let plan_descs = batch.columns()[0].as_string::<i32>();
|
||||
let plans = batch.columns()[1].as_string::<i32>();
|
||||
|
||||
for (desc, plan) in plan_descs.iter().zip(plans.iter()) {
|
||||
if desc.unwrap() == "physical_plan" {
|
||||
return plan.unwrap().to_string();
|
||||
}
|
||||
}
|
||||
panic!("No physical plan found in explain output");
|
||||
}
|
||||
|
||||
async fn check_plan(plan: LogicalPlan, expected: &str) {
|
||||
let physical_plan = Self::plan_to_explain(plan).await;
|
||||
let mut lines_checked = 0;
|
||||
for (actual_line, expected_line) in physical_plan.lines().zip(expected.lines()) {
|
||||
lines_checked += 1;
|
||||
let actual_trimmed = actual_line.trim();
|
||||
let expected_trimmed = if let Some(ellipsis_pos) = expected_line.find("...") {
|
||||
expected_line[0..ellipsis_pos].trim()
|
||||
} else {
|
||||
expected_line.trim()
|
||||
};
|
||||
assert_eq!(&actual_trimmed[..expected_trimmed.len()], expected_trimmed);
|
||||
}
|
||||
assert_eq!(lines_checked, expected.lines().count());
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_metadata_erased() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let dataset_path = tmp_dir.path().join("test.lance");
|
||||
let uri = dataset_path.to_str().unwrap();
|
||||
let fixture = TestFixture::new().await;
|
||||
|
||||
let db = connect(uri).execute().await.unwrap();
|
||||
assert!(fixture.adapter.schema().metadata().is_empty());
|
||||
|
||||
let tbl = db
|
||||
.create_table("foo", make_test_batches())
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let provider = Arc::new(
|
||||
BaseTableAdapter::try_new(tbl.base_table().clone())
|
||||
.await
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
assert!(provider.schema().metadata().is_empty());
|
||||
|
||||
let plan = LogicalPlanBuilder::scan("foo", provider_as_source(provider), None)
|
||||
let plan = LogicalPlanBuilder::scan("foo", provider_as_source(fixture.adapter), None)
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let mut stream = SessionContext::new()
|
||||
.execute_logical_plan(plan)
|
||||
.await
|
||||
.unwrap()
|
||||
.execute_stream()
|
||||
.await
|
||||
.unwrap();
|
||||
let mut stream = TestFixture::plan_to_stream(plan).await;
|
||||
|
||||
while let Some(batch) = stream.try_next().await.unwrap() {
|
||||
assert!(batch.schema().metadata().is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_metadata_erased_with_filter() {
|
||||
// This is a regression test where the metadata eraser was not properly erasing metadata
|
||||
let fixture = TestFixture::new().await;
|
||||
|
||||
assert!(fixture.adapter.schema().metadata().is_empty());
|
||||
|
||||
let plan = LogicalPlanBuilder::scan("foo", provider_as_source(fixture.adapter2), None)
|
||||
.unwrap()
|
||||
.filter(col("ints").lt(lit(10)))
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let mut stream = TestFixture::plan_to_stream(plan).await;
|
||||
|
||||
while let Some(batch) = stream.try_next().await.unwrap() {
|
||||
assert!(batch.schema().metadata().is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_filter_pushdown() {
|
||||
let fixture = TestFixture::new().await;
|
||||
|
||||
// Basic filter, not much different pushed down than run from DF
|
||||
let plan =
|
||||
LogicalPlanBuilder::scan("foo", provider_as_source(fixture.adapter.clone()), None)
|
||||
.unwrap()
|
||||
.filter(col("i").gt_eq(lit(5)))
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
TestFixture::check_plan(
|
||||
plan,
|
||||
"MetadataEraserExec
|
||||
RepartitionExec:...
|
||||
CoalesceBatchesExec:...
|
||||
FilterExec: i@0 >= 5
|
||||
ProjectionExec:...
|
||||
LanceScan:...",
|
||||
)
|
||||
.await;
|
||||
|
||||
// Filter utilizing scalar index, make sure it gets pushed down
|
||||
let plan = LogicalPlanBuilder::scan("foo", provider_as_source(fixture.adapter), None)
|
||||
.unwrap()
|
||||
.filter(col("indexed").eq(lit(5)))
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
TestFixture::check_plan(plan, "").await;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user