mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 05:19:58 +00:00
Compare commits
25 Commits
python-v0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d11819c90c | ||
|
|
9b902272f1 | ||
|
|
8c0622fa2c | ||
|
|
2191f948c3 | ||
|
|
acc3b03004 | ||
|
|
7f091b8c8e | ||
|
|
c19bdd9a24 | ||
|
|
dad0ff5cd2 | ||
|
|
a705621067 | ||
|
|
39614fdb7d | ||
|
|
96d534d4bc | ||
|
|
5051d30d09 | ||
|
|
db853c4041 | ||
|
|
76d1d22bdc | ||
|
|
d8746c61c6 | ||
|
|
1a66df2627 | ||
|
|
44670076c1 | ||
|
|
92f0b16e46 | ||
|
|
1620ba3508 | ||
|
|
3ae90dde80 | ||
|
|
4f07fea6df | ||
|
|
3d7d82cf86 | ||
|
|
edc4e40a7b | ||
|
|
ca3806a02f | ||
|
|
35cff12e31 |
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.19.0-beta.7"
|
||||
current_version = "0.19.0-beta.10"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
13
.github/workflows/docs.yml
vendored
13
.github/workflows/docs.yml
vendored
@@ -18,17 +18,24 @@ concurrency:
|
||||
group: "pages"
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
# This reduces the disk space needed for the build
|
||||
RUSTFLAGS: "-C debuginfo=0"
|
||||
# according to: https://matklad.github.io/2021/09/04/fast-rust-builds.html
|
||||
# CI builds are faster with incremental disabled.
|
||||
CARGO_INCREMENTAL: "0"
|
||||
|
||||
jobs:
|
||||
# Single deploy job since we're just deploying
|
||||
build:
|
||||
environment:
|
||||
name: github-pages
|
||||
url: ${{ steps.deployment.outputs.page_url }}
|
||||
runs-on: buildjet-8vcpu-ubuntu-2204
|
||||
runs-on: ubuntu-24.04
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Install dependecies needed for ubuntu
|
||||
- name: Install dependencies needed for ubuntu
|
||||
run: |
|
||||
sudo apt install -y protobuf-compiler libssl-dev
|
||||
rustup update && rustup default
|
||||
@@ -38,6 +45,7 @@ jobs:
|
||||
python-version: "3.10"
|
||||
cache: "pip"
|
||||
cache-dependency-path: "docs/requirements.txt"
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- name: Build Python
|
||||
working-directory: python
|
||||
run: |
|
||||
@@ -49,7 +57,6 @@ jobs:
|
||||
node-version: 20
|
||||
cache: 'npm'
|
||||
cache-dependency-path: node/package-lock.json
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- name: Install node dependencies
|
||||
working-directory: node
|
||||
run: |
|
||||
|
||||
4
.github/workflows/python.yml
vendored
4
.github/workflows/python.yml
vendored
@@ -136,9 +136,9 @@ jobs:
|
||||
- uses: ./.github/workflows/run_tests
|
||||
with:
|
||||
integration: true
|
||||
- name: Test without pylance
|
||||
- name: Test without pylance or pandas
|
||||
run: |
|
||||
pip uninstall -y pylance
|
||||
pip uninstall -y pylance pandas
|
||||
pytest -vv python/tests/test_table.py
|
||||
# Make sure wheels are not included in the Rust cache
|
||||
- name: Delete wheels
|
||||
|
||||
119
Cargo.lock
generated
119
Cargo.lock
generated
@@ -128,9 +128,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "anyhow"
|
||||
version = "1.0.97"
|
||||
version = "1.0.98"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dcfed56ad506cb2c684a14971b8861fdc3baaaae314b9e5f9bb532cbe3ba7a4f"
|
||||
checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487"
|
||||
|
||||
[[package]]
|
||||
name = "arbitrary"
|
||||
@@ -390,9 +390,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "async-compression"
|
||||
version = "0.4.22"
|
||||
version = "0.4.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59a194f9d963d8099596278594b3107448656ba73831c9d8c783e613ce86da64"
|
||||
checksum = "b37fc50485c4f3f736a4fb14199f6d5f5ba008d7f28fe710306c92780f004c07"
|
||||
dependencies = [
|
||||
"flate2",
|
||||
"futures-core",
|
||||
@@ -564,9 +564,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "aws-lc-sys"
|
||||
version = "0.28.0"
|
||||
version = "0.28.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9f7720b74ed28ca77f90769a71fd8c637a0137f6fae4ae947e1050229cff57f"
|
||||
checksum = "bfa9b6986f250236c27e5a204062434a773a13243d2ffc2955f37bdba4c5c6a1"
|
||||
dependencies = [
|
||||
"bindgen",
|
||||
"cc",
|
||||
@@ -882,7 +882,7 @@ dependencies = [
|
||||
"aws-smithy-async",
|
||||
"aws-smithy-runtime-api",
|
||||
"aws-smithy-types",
|
||||
"h2 0.4.8",
|
||||
"h2 0.4.9",
|
||||
"http 0.2.12",
|
||||
"http 1.3.1",
|
||||
"http-body 0.4.6",
|
||||
@@ -1185,9 +1185,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "blake3"
|
||||
version = "1.8.1"
|
||||
version = "1.8.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "389a099b34312839e16420d499a9cad9650541715937ffbdd40d36f49e77eeb3"
|
||||
checksum = "3888aaa89e4b2a40fca9848e400f6a658a5a3978de7be858e209cafa8be9a4a0"
|
||||
dependencies = [
|
||||
"arrayref",
|
||||
"arrayvec",
|
||||
@@ -2377,7 +2377,16 @@ version = "5.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225"
|
||||
dependencies = [
|
||||
"dirs-sys",
|
||||
"dirs-sys 0.4.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dirs"
|
||||
version = "6.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c3e8aa94d75141228480295a7d0e7feb620b1a5ad9f12bc40be62411e38cce4e"
|
||||
dependencies = [
|
||||
"dirs-sys 0.5.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2388,10 +2397,22 @@ checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"option-ext",
|
||||
"redox_users",
|
||||
"redox_users 0.4.6",
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dirs-sys"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e01a3366d27ee9890022452ee61b2b63a67e6f13f58900b651ff5665f0bb1fab"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"option-ext",
|
||||
"redox_users 0.5.0",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "displaydoc"
|
||||
version = "0.2.5"
|
||||
@@ -2558,9 +2579,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "ethnum"
|
||||
version = "1.5.0"
|
||||
version = "1.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b90ca2580b73ab6a1f724b76ca11ab632df820fd6040c336200d2c1df7b3c82c"
|
||||
checksum = "0939f82868b77ef93ce3c3c3daf2b3c526b456741da5a1a4559e590965b6026b"
|
||||
|
||||
[[package]]
|
||||
name = "event-listener"
|
||||
@@ -3049,9 +3070,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.4.8"
|
||||
version = "0.4.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5017294ff4bb30944501348f6f8e42e6ad28f42c8bbef7a74029aff064a4e3c2"
|
||||
checksum = "75249d144030531f8dee69fe9cea04d3edf809a017ae445e2abdff6629e86633"
|
||||
dependencies = [
|
||||
"atomic-waker",
|
||||
"bytes",
|
||||
@@ -3138,7 +3159,7 @@ version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cc03dcb0b0a83ae3f3363ec811014ae669f083e4e499c66602f447c4828737a1"
|
||||
dependencies = [
|
||||
"dirs",
|
||||
"dirs 5.0.1",
|
||||
"futures",
|
||||
"http 1.3.1",
|
||||
"indicatif",
|
||||
@@ -3286,7 +3307,7 @@ dependencies = [
|
||||
"bytes",
|
||||
"futures-channel",
|
||||
"futures-util",
|
||||
"h2 0.4.8",
|
||||
"h2 0.4.9",
|
||||
"http 1.3.1",
|
||||
"http-body 1.0.1",
|
||||
"httparse",
|
||||
@@ -3645,9 +3666,9 @@ checksum = "9028f49264629065d057f340a86acb84867925865f73bbf8d47b4d149a7e88b8"
|
||||
|
||||
[[package]]
|
||||
name = "jiff"
|
||||
version = "0.2.6"
|
||||
version = "0.2.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1f33145a5cbea837164362c7bd596106eb7c5198f97d1ba6f6ebb3223952e488"
|
||||
checksum = "5a064218214dc6a10fbae5ec5fa888d80c45d611aba169222fc272072bf7aef6"
|
||||
dependencies = [
|
||||
"jiff-static",
|
||||
"log",
|
||||
@@ -3658,9 +3679,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "jiff-static"
|
||||
version = "0.2.6"
|
||||
version = "0.2.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "43ce13c40ec6956157a3635d97a1ee2df323b263f09ea14165131289cb0f5c19"
|
||||
checksum = "199b7932d97e325aff3a7030e141eafe7f2c6268e1d1b24859b753a627f45254"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -3965,7 +3986,7 @@ dependencies = [
|
||||
"datafusion-physical-expr",
|
||||
"datafusion-sql",
|
||||
"deepsize",
|
||||
"dirs",
|
||||
"dirs 5.0.1",
|
||||
"fst",
|
||||
"futures",
|
||||
"half",
|
||||
@@ -4115,7 +4136,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb"
|
||||
version = "0.19.0-beta.7"
|
||||
version = "0.19.0-beta.10"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -4202,7 +4223,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb-node"
|
||||
version = "0.19.0-beta.7"
|
||||
version = "0.19.0-beta.10"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-ipc",
|
||||
@@ -4227,7 +4248,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb-nodejs"
|
||||
version = "0.19.0-beta.7"
|
||||
version = "0.19.0-beta.10"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-ipc",
|
||||
@@ -4245,7 +4266,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb-python"
|
||||
version = "0.22.0-beta.7"
|
||||
version = "0.22.0-beta.10"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"env_logger",
|
||||
@@ -4342,9 +4363,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.171"
|
||||
version = "0.2.172"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c19937216e9d3aa9956d9bb8dfc0b0c8beb6058fc4f7a4dc4d850edf86a237d6"
|
||||
checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa"
|
||||
|
||||
[[package]]
|
||||
name = "libloading"
|
||||
@@ -4368,9 +4389,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "libm"
|
||||
version = "0.2.11"
|
||||
version = "0.2.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa"
|
||||
checksum = "c9627da5196e5d8ed0b0495e61e518847578da83483c37288316d9b2e03a7f72"
|
||||
|
||||
[[package]]
|
||||
name = "libredox"
|
||||
@@ -5637,9 +5658,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.94"
|
||||
version = "1.0.95"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a31971752e70b8b2686d7e46ec17fb38dad4051d94024c88df49b667caea9c84"
|
||||
checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
@@ -5837,13 +5858,13 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "quinn-proto"
|
||||
version = "0.11.10"
|
||||
version = "0.11.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b820744eb4dc9b57a3398183639c511b5a26d2ed702cedd3febaa1393caa22cc"
|
||||
checksum = "bcbafbbdbb0f638fe3f35f3c56739f77a8a1d070cb25603226c83339b391472b"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"getrandom 0.3.2",
|
||||
"rand 0.9.0",
|
||||
"rand 0.9.1",
|
||||
"ring",
|
||||
"rustc-hash 2.1.1",
|
||||
"rustls 0.23.26",
|
||||
@@ -5903,13 +5924,12 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.9.0"
|
||||
version = "0.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94"
|
||||
checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97"
|
||||
dependencies = [
|
||||
"rand_chacha 0.9.0",
|
||||
"rand_core 0.9.3",
|
||||
"zerocopy 0.8.24",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -6084,6 +6104,17 @@ dependencies = [
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redox_users"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dd6f9d3d47bdd2ad6945c5015a226ec6155d0bcdfd8f7cd29f86b71f8de99d2b"
|
||||
dependencies = [
|
||||
"getrandom 0.2.15",
|
||||
"libredox",
|
||||
"thiserror 2.0.12",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.11.1"
|
||||
@@ -6152,7 +6183,7 @@ dependencies = [
|
||||
"encoding_rs",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"h2 0.4.8",
|
||||
"h2 0.4.9",
|
||||
"http 1.3.1",
|
||||
"http-body 1.0.1",
|
||||
"http-body-util",
|
||||
@@ -6701,11 +6732,11 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "shellexpand"
|
||||
version = "3.1.0"
|
||||
version = "3.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "da03fa3b94cc19e3ebfc88c4229c49d8f08cdbd1228870a45f0ffdf84988e14b"
|
||||
checksum = "8b1fdf65dd6331831494dd616b30351c38e96e45921a27745cf98490458b90bb"
|
||||
dependencies = [
|
||||
"dirs",
|
||||
"dirs 6.0.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -6716,9 +6747,9 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
|
||||
|
||||
[[package]]
|
||||
name = "signal-hook-registry"
|
||||
version = "1.4.2"
|
||||
version = "1.4.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1"
|
||||
checksum = "9203b8055f63a2a00e2f593bb0510367fe707d7ff1e5c872de2f537b339e5410"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
@@ -765,7 +765,10 @@ This can be used to update zero to all rows depending on how many rows match the
|
||||
];
|
||||
const tbl = await db.createTable("my_table", data)
|
||||
|
||||
await tbl.update({vector: [10, 10]}, { where: "x = 2"})
|
||||
await tbl.update({
|
||||
values: { vector: [10, 10] },
|
||||
where: "x = 2"
|
||||
});
|
||||
```
|
||||
|
||||
=== "vectordb (deprecated)"
|
||||
@@ -784,7 +787,10 @@ This can be used to update zero to all rows depending on how many rows match the
|
||||
];
|
||||
const tbl = await db.createTable("my_table", data)
|
||||
|
||||
await tbl.update({ where: "x = 2", values: {vector: [10, 10]} })
|
||||
await tbl.update({
|
||||
where: "x = 2",
|
||||
values: { vector: [10, 10] }
|
||||
});
|
||||
```
|
||||
|
||||
#### Updating using a sql query
|
||||
|
||||
@@ -753,3 +753,26 @@ Retrieve the version of the table
|
||||
#### Returns
|
||||
|
||||
`Promise`<`number`>
|
||||
|
||||
***
|
||||
|
||||
### waitForIndex()
|
||||
|
||||
```ts
|
||||
abstract waitForIndex(indexNames, timeoutSeconds): Promise<void>
|
||||
```
|
||||
|
||||
Waits for asynchronous indexing to complete on the table.
|
||||
|
||||
#### Parameters
|
||||
|
||||
* **indexNames**: `string`[]
|
||||
The name of the indices to wait for
|
||||
|
||||
* **timeoutSeconds**: `number`
|
||||
The number of seconds to wait before timing out
|
||||
This will raise an error if the indices are not created and fully indexed within the timeout.
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<`void`>
|
||||
|
||||
@@ -39,3 +39,11 @@ and the same name, then an error will be returned. This is true even if
|
||||
that index is out of date.
|
||||
|
||||
The default is true
|
||||
|
||||
***
|
||||
|
||||
### waitTimeoutSeconds?
|
||||
|
||||
```ts
|
||||
optional waitTimeoutSeconds: number;
|
||||
```
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.19.0-beta.7</version>
|
||||
<version>0.19.0-beta.10</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.19.0-beta.7</version>
|
||||
<version>0.19.0-beta.10</version>
|
||||
<packaging>pom</packaging>
|
||||
|
||||
<name>LanceDB Parent</name>
|
||||
|
||||
44
node/package-lock.json
generated
44
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.19.0-beta.7",
|
||||
"version": "0.19.0-beta.10",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "vectordb",
|
||||
"version": "0.19.0-beta.7",
|
||||
"version": "0.19.0-beta.10",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
@@ -52,11 +52,11 @@
|
||||
"uuid": "^9.0.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.19.0-beta.7",
|
||||
"@lancedb/vectordb-darwin-x64": "0.19.0-beta.7",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.19.0-beta.7",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.19.0-beta.7",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.19.0-beta.7"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.19.0-beta.10",
|
||||
"@lancedb/vectordb-darwin-x64": "0.19.0-beta.10",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.19.0-beta.10",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.19.0-beta.10",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.19.0-beta.10"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@apache-arrow/ts": "^14.0.2",
|
||||
@@ -327,9 +327,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
||||
"version": "0.19.0-beta.7",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.19.0-beta.7.tgz",
|
||||
"integrity": "sha512-HpbVKw4Vs+mPv7uPwaK7ilJlGrGdjOrNlC2mSkMCj0OlEwGRVcEcrSyijI7LXQH7ybEgNnDhSds5TuzBV26SGg==",
|
||||
"version": "0.19.0-beta.10",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.19.0-beta.10.tgz",
|
||||
"integrity": "sha512-4PvsrE+hJ+AqFY33yHIEVET2ayv0pzpiWqCnMQ5OdpakQZvfykmp9ykc5KI80VuWAlniJDYuW+fju3z8/wiUHQ==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -340,9 +340,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-x64": {
|
||||
"version": "0.19.0-beta.7",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.19.0-beta.7.tgz",
|
||||
"integrity": "sha512-x3X7nqIYVZtxaa0uZUk/M99vKvDinZ5G0+8k2NqZ696YXGWKGyRxR6k8ZzKYCoCTSuYXnBftgKoIlwJGtNt8Bw==",
|
||||
"version": "0.19.0-beta.10",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.19.0-beta.10.tgz",
|
||||
"integrity": "sha512-YHXHkD4mmIry+KMoTX7Qts5Ea9fG0DklywIiF8TS7h/9XbXLG74lf+GUy2Eh/s1wKLd4LtRh2SbHpOtZoOH4lA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -353,9 +353,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
||||
"version": "0.19.0-beta.7",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.19.0-beta.7.tgz",
|
||||
"integrity": "sha512-Vwj0HI3+b4NgXKf+5+W/GfLBCGoQMBGM47vA/ts1dpe/PxraOQYPDv67I5kbXkCQKwhal7b0iZx/PbMu0JZPyw==",
|
||||
"version": "0.19.0-beta.10",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.19.0-beta.10.tgz",
|
||||
"integrity": "sha512-c019rw30N25WIXnkhAwJ4QkpUcUJqbkGay3RiR3vTmGQ5YOZWw5V5g/v2y7APcv+ZlZfJ4YgDjFH8wqtiECNJQ==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -366,9 +366,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
||||
"version": "0.19.0-beta.7",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.19.0-beta.7.tgz",
|
||||
"integrity": "sha512-Dx2B6UWQei9D7Rt+MgHWqPTYtEK2w3EgsNb5ENEWUTZxH7lD/CV7Sw0JMK5LDG209fFcpXFerveF6J8ZC8uGBQ==",
|
||||
"version": "0.19.0-beta.10",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.19.0-beta.10.tgz",
|
||||
"integrity": "sha512-xnbC6rqpuJDv2q6xNBKrrocNOTcM4z6+8Zi7wP+Sb+WXvLzkR7hm7ZS0gyeExRknEEd91imhL/ZuAxEq1892YA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -379,9 +379,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
||||
"version": "0.19.0-beta.7",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.19.0-beta.7.tgz",
|
||||
"integrity": "sha512-F5LZGa+gkUH1TgsWZWLLAMejwXFIWdash7+85ip4k2M0ThyqLF/dtlldOvteUEd5+flxihGjHg6TUtnSY8XBFA==",
|
||||
"version": "0.19.0-beta.10",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.19.0-beta.10.tgz",
|
||||
"integrity": "sha512-bUDKvY4tmMEeBpzPfRu2lVo+07nRGzUoUm60WzfvRpa/Y6rwjcCCRuuTOvfTcrnbGYN/kw5yoUu8ZDsZ7mT77Q==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.19.0-beta.7",
|
||||
"version": "0.19.0-beta.10",
|
||||
"description": " Serverless, low-latency vector database for AI applications",
|
||||
"private": false,
|
||||
"main": "dist/index.js",
|
||||
@@ -89,10 +89,10 @@
|
||||
}
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-x64": "0.19.0-beta.7",
|
||||
"@lancedb/vectordb-darwin-arm64": "0.19.0-beta.7",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.19.0-beta.7",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.19.0-beta.7",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.19.0-beta.7"
|
||||
"@lancedb/vectordb-darwin-x64": "0.19.0-beta.10",
|
||||
"@lancedb/vectordb-darwin-arm64": "0.19.0-beta.10",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.19.0-beta.10",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.19.0-beta.10",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.19.0-beta.10"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "lancedb-nodejs"
|
||||
edition.workspace = true
|
||||
version = "0.19.0-beta.7"
|
||||
version = "0.19.0-beta.10"
|
||||
license.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
|
||||
@@ -507,6 +507,15 @@ describe("When creating an index", () => {
|
||||
expect(indices2.length).toBe(0);
|
||||
});
|
||||
|
||||
it("should wait for index readiness", async () => {
|
||||
// Create an index and then wait for it to be ready
|
||||
await tbl.createIndex("vec");
|
||||
const indices = await tbl.listIndices();
|
||||
expect(indices.length).toBeGreaterThan(0);
|
||||
const idxName = indices[0].name;
|
||||
await expect(tbl.waitForIndex([idxName], 5)).resolves.toBeUndefined();
|
||||
});
|
||||
|
||||
it("should search with distance range", async () => {
|
||||
await tbl.createIndex("vec");
|
||||
|
||||
@@ -824,6 +833,7 @@ describe("When creating an index", () => {
|
||||
// Only build index over v1
|
||||
await tbl.createIndex("vec", {
|
||||
config: Index.ivfPq({ numPartitions: 2, numSubVectors: 2 }),
|
||||
waitTimeoutSeconds: 30,
|
||||
});
|
||||
|
||||
const rst = await tbl
|
||||
|
||||
@@ -681,4 +681,6 @@ export interface IndexOptions {
|
||||
* The default is true
|
||||
*/
|
||||
replace?: boolean;
|
||||
|
||||
waitTimeoutSeconds?: number;
|
||||
}
|
||||
|
||||
@@ -246,6 +246,19 @@ export abstract class Table {
|
||||
*/
|
||||
abstract prewarmIndex(name: string): Promise<void>;
|
||||
|
||||
/**
|
||||
* Waits for asynchronous indexing to complete on the table.
|
||||
*
|
||||
* @param indexNames The name of the indices to wait for
|
||||
* @param timeoutSeconds The number of seconds to wait before timing out
|
||||
*
|
||||
* This will raise an error if the indices are not created and fully indexed within the timeout.
|
||||
*/
|
||||
abstract waitForIndex(
|
||||
indexNames: string[],
|
||||
timeoutSeconds: number,
|
||||
): Promise<void>;
|
||||
|
||||
/**
|
||||
* Create a {@link Query} Builder.
|
||||
*
|
||||
@@ -569,7 +582,12 @@ export class LocalTable extends Table {
|
||||
// Bit of a hack to get around the fact that TS has no package-scope.
|
||||
// biome-ignore lint/suspicious/noExplicitAny: skip
|
||||
const nativeIndex = (options?.config as any)?.inner;
|
||||
await this.inner.createIndex(nativeIndex, column, options?.replace);
|
||||
await this.inner.createIndex(
|
||||
nativeIndex,
|
||||
column,
|
||||
options?.replace,
|
||||
options?.waitTimeoutSeconds,
|
||||
);
|
||||
}
|
||||
|
||||
async dropIndex(name: string): Promise<void> {
|
||||
@@ -580,6 +598,13 @@ export class LocalTable extends Table {
|
||||
await this.inner.prewarmIndex(name);
|
||||
}
|
||||
|
||||
async waitForIndex(
|
||||
indexNames: string[],
|
||||
timeoutSeconds: number,
|
||||
): Promise<void> {
|
||||
await this.inner.waitForIndex(indexNames, timeoutSeconds);
|
||||
}
|
||||
|
||||
query(): Query {
|
||||
return new Query(this.inner);
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-arm64",
|
||||
"version": "0.19.0-beta.7",
|
||||
"version": "0.19.0-beta.10",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.darwin-arm64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-x64",
|
||||
"version": "0.19.0-beta.7",
|
||||
"version": "0.19.0-beta.10",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.darwin-x64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||
"version": "0.19.0-beta.7",
|
||||
"version": "0.19.0-beta.10",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||
"version": "0.19.0-beta.7",
|
||||
"version": "0.19.0-beta.10",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||
"version": "0.19.0-beta.7",
|
||||
"version": "0.19.0-beta.10",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||
"version": "0.19.0-beta.7",
|
||||
"version": "0.19.0-beta.10",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||
"version": "0.19.0-beta.7",
|
||||
"version": "0.19.0-beta.10",
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||
"version": "0.19.0-beta.7",
|
||||
"version": "0.19.0-beta.10",
|
||||
"os": ["win32"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.win32-x64-msvc.node",
|
||||
|
||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.19.0-beta.7",
|
||||
"version": "0.19.0-beta.10",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.19.0-beta.7",
|
||||
"version": "0.19.0-beta.10",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
"ann"
|
||||
],
|
||||
"private": false,
|
||||
"version": "0.19.0-beta.7",
|
||||
"version": "0.19.0-beta.10",
|
||||
"main": "dist/index.js",
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
|
||||
@@ -111,6 +111,7 @@ impl Table {
|
||||
index: Option<&Index>,
|
||||
column: String,
|
||||
replace: Option<bool>,
|
||||
wait_timeout_s: Option<i64>,
|
||||
) -> napi::Result<()> {
|
||||
let lancedb_index = if let Some(index) = index {
|
||||
index.consume()?
|
||||
@@ -121,6 +122,10 @@ impl Table {
|
||||
if let Some(replace) = replace {
|
||||
builder = builder.replace(replace);
|
||||
}
|
||||
if let Some(timeout) = wait_timeout_s {
|
||||
builder =
|
||||
builder.wait_timeout(std::time::Duration::from_secs(timeout.try_into().unwrap()));
|
||||
}
|
||||
builder.execute().await.default_error()
|
||||
}
|
||||
|
||||
@@ -140,6 +145,18 @@ impl Table {
|
||||
.default_error()
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn wait_for_index(&self, index_names: Vec<String>, timeout_s: i64) -> Result<()> {
|
||||
let timeout = std::time::Duration::from_secs(timeout_s.try_into().unwrap());
|
||||
let index_names: Vec<&str> = index_names.iter().map(|s| s.as_str()).collect();
|
||||
let slice: &[&str] = &index_names;
|
||||
|
||||
self.inner_ref()?
|
||||
.wait_for_index(slice, timeout)
|
||||
.await
|
||||
.default_error()
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn update(
|
||||
&self,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.22.0-beta.8"
|
||||
current_version = "0.22.0-beta.11"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.22.0-beta.8"
|
||||
version = "0.22.0-beta.11"
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
license.workspace = true
|
||||
|
||||
@@ -77,6 +77,7 @@ embeddings = [
|
||||
"pillow",
|
||||
"open-clip-torch",
|
||||
"cohere",
|
||||
"colpali-engine>=0.3.10",
|
||||
"huggingface_hub",
|
||||
"InstructorEmbedding",
|
||||
"google.generativeai",
|
||||
|
||||
@@ -9,7 +9,7 @@ import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.dataset
|
||||
|
||||
from .dependencies import pandas as pd
|
||||
from .dependencies import _check_for_pandas, pandas as pd
|
||||
|
||||
DATA = Union[List[dict], "pd.DataFrame", pa.Table, Iterable[pa.RecordBatch]]
|
||||
VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray]
|
||||
@@ -63,7 +63,7 @@ def data_to_reader(
|
||||
data: DATA, schema: Optional[pa.Schema] = None
|
||||
) -> pa.RecordBatchReader:
|
||||
"""Convert various types of input into a RecordBatchReader"""
|
||||
if pd is not None and isinstance(data, pd.DataFrame):
|
||||
if _check_for_pandas(data) and isinstance(data, pd.DataFrame):
|
||||
return pa.Table.from_pandas(data, schema=schema).to_reader()
|
||||
elif isinstance(data, pa.Table):
|
||||
return data.to_reader()
|
||||
|
||||
@@ -19,3 +19,4 @@ from .imagebind import ImageBindEmbeddings
|
||||
from .jinaai import JinaEmbeddings
|
||||
from .watsonx import WatsonxEmbeddings
|
||||
from .voyageai import VoyageAIEmbeddingFunction
|
||||
from .colpali import ColPaliEmbeddings
|
||||
|
||||
255
python/python/lancedb/embeddings/colpali.py
Normal file
255
python/python/lancedb/embeddings/colpali.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import List, Union, Optional, Any
|
||||
import numpy as np
|
||||
import io
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import EmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import TEXT, IMAGES, is_flash_attn_2_available
|
||||
|
||||
|
||||
@register("colpali")
|
||||
class ColPaliEmbeddings(EmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the ColPali engine for
|
||||
multimodal multi-vector embeddings.
|
||||
|
||||
This embedding function supports ColQwen2.5 models, producing multivector outputs
|
||||
for both text and image inputs. The output embeddings are lists of vectors, each
|
||||
vector being 128-dimensional by default, represented as List[List[float]].
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name : str
|
||||
The name of the model to use (e.g., "Metric-AI/ColQwen2.5-3b-multilingual-v1.0")
|
||||
device : str
|
||||
The device for inference (default "cuda:0").
|
||||
dtype : str
|
||||
Data type for model weights (default "bfloat16").
|
||||
use_token_pooling : bool
|
||||
Whether to use token pooling to reduce embedding size (default True).
|
||||
pool_factor : int
|
||||
Factor to reduce sequence length if token pooling is enabled (default 2).
|
||||
quantization_config : Optional[BitsAndBytesConfig]
|
||||
Quantization configuration for the model. (default None, bitsandbytes needed)
|
||||
batch_size : int
|
||||
Batch size for processing inputs (default 2).
|
||||
"""
|
||||
|
||||
model_name: str = "Metric-AI/ColQwen2.5-3b-multilingual-v1.0"
|
||||
device: str = "auto"
|
||||
dtype: str = "bfloat16"
|
||||
use_token_pooling: bool = True
|
||||
pool_factor: int = 2
|
||||
quantization_config: Optional[Any] = None
|
||||
batch_size: int = 2
|
||||
|
||||
_model = None
|
||||
_processor = None
|
||||
_token_pooler = None
|
||||
_vector_dim = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
(
|
||||
self._model,
|
||||
self._processor,
|
||||
self._token_pooler,
|
||||
) = self._load_model(
|
||||
self.model_name,
|
||||
self.dtype,
|
||||
self.device,
|
||||
self.use_token_pooling,
|
||||
self.quantization_config,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=1)
|
||||
def _load_model(
|
||||
model_name: str,
|
||||
dtype: str,
|
||||
device: str,
|
||||
use_token_pooling: bool,
|
||||
quantization_config: Optional[Any],
|
||||
):
|
||||
"""
|
||||
Initialize and cache the ColPali model, processor, and token pooler.
|
||||
"""
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
transformers = attempt_import_or_raise("transformers", "transformers")
|
||||
colpali_engine = attempt_import_or_raise("colpali_engine", "colpali_engine")
|
||||
from colpali_engine.compression.token_pooling import HierarchicalTokenPooler
|
||||
|
||||
if quantization_config is not None:
|
||||
if not isinstance(quantization_config, transformers.BitsAndBytesConfig):
|
||||
raise ValueError("quantization_config must be a BitsAndBytesConfig")
|
||||
|
||||
if dtype == "bfloat16":
|
||||
torch_dtype = torch.bfloat16
|
||||
elif dtype == "float16":
|
||||
torch_dtype = torch.float16
|
||||
elif dtype == "float64":
|
||||
torch_dtype = torch.float64
|
||||
else:
|
||||
torch_dtype = torch.float32
|
||||
|
||||
model = colpali_engine.models.ColQwen2_5.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=device,
|
||||
quantization_config=quantization_config
|
||||
if quantization_config is not None
|
||||
else None,
|
||||
attn_implementation="flash_attention_2"
|
||||
if is_flash_attn_2_available()
|
||||
else None,
|
||||
).eval()
|
||||
processor = colpali_engine.models.ColQwen2_5_Processor.from_pretrained(
|
||||
model_name
|
||||
)
|
||||
token_pooler = HierarchicalTokenPooler() if use_token_pooling else None
|
||||
return model, processor, token_pooler
|
||||
|
||||
def ndims(self):
|
||||
"""
|
||||
Return the dimension of a vector in the multivector output (e.g., 128).
|
||||
"""
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
if self._vector_dim is None:
|
||||
dummy_query = "test"
|
||||
batch_queries = self._processor.process_queries([dummy_query]).to(
|
||||
self._model.device
|
||||
)
|
||||
with torch.no_grad():
|
||||
query_embeddings = self._model(**batch_queries)
|
||||
|
||||
if self.use_token_pooling and self._token_pooler is not None:
|
||||
query_embeddings = self._token_pooler.pool_embeddings(
|
||||
query_embeddings,
|
||||
pool_factor=self.pool_factor,
|
||||
padding=True,
|
||||
padding_side=self._processor.tokenizer.padding_side,
|
||||
)
|
||||
|
||||
self._vector_dim = query_embeddings[0].shape[-1]
|
||||
return self._vector_dim
|
||||
|
||||
def _process_embeddings(self, embeddings):
|
||||
"""
|
||||
Format model embeddings into List[List[float]].
|
||||
Use token pooling if enabled.
|
||||
"""
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
if self.use_token_pooling and self._token_pooler is not None:
|
||||
embeddings = self._token_pooler.pool_embeddings(
|
||||
embeddings,
|
||||
pool_factor=self.pool_factor,
|
||||
padding=True,
|
||||
padding_side=self._processor.tokenizer.padding_side,
|
||||
)
|
||||
|
||||
if isinstance(embeddings, torch.Tensor):
|
||||
tensors = embeddings.detach().cpu()
|
||||
if tensors.dtype == torch.bfloat16:
|
||||
tensors = tensors.to(torch.float32)
|
||||
return (
|
||||
tensors.numpy()
|
||||
.astype(np.float64 if self.dtype == "float64" else np.float32)
|
||||
.tolist()
|
||||
)
|
||||
return []
|
||||
|
||||
def generate_text_embeddings(self, text: TEXT) -> List[List[List[float]]]:
|
||||
"""
|
||||
Generate embeddings for text input.
|
||||
"""
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
text = self.sanitize_input(text)
|
||||
all_embeddings = []
|
||||
|
||||
for i in range(0, len(text), self.batch_size):
|
||||
batch_text = text[i : i + self.batch_size]
|
||||
batch_queries = self._processor.process_queries(batch_text).to(
|
||||
self._model.device
|
||||
)
|
||||
with torch.no_grad():
|
||||
query_embeddings = self._model(**batch_queries)
|
||||
all_embeddings.extend(self._process_embeddings(query_embeddings))
|
||||
return all_embeddings
|
||||
|
||||
def _prepare_images(self, images: IMAGES) -> List:
|
||||
"""
|
||||
Convert image inputs to PIL Images.
|
||||
"""
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
requests = attempt_import_or_raise("requests", "requests")
|
||||
images = self.sanitize_input(images)
|
||||
pil_images = []
|
||||
try:
|
||||
for image in images:
|
||||
if isinstance(image, str):
|
||||
if image.startswith(("http://", "https://")):
|
||||
response = requests.get(image, timeout=10)
|
||||
response.raise_for_status()
|
||||
pil_images.append(PIL.Image.open(io.BytesIO(response.content)))
|
||||
else:
|
||||
with PIL.Image.open(image) as im:
|
||||
pil_images.append(im.copy())
|
||||
elif isinstance(image, bytes):
|
||||
pil_images.append(PIL.Image.open(io.BytesIO(image)))
|
||||
else:
|
||||
# Assume it's a PIL Image; will raise if invalid
|
||||
pil_images.append(image)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to process image: {e}")
|
||||
|
||||
return pil_images
|
||||
|
||||
def generate_image_embeddings(self, images: IMAGES) -> List[List[List[float]]]:
|
||||
"""
|
||||
Generate embeddings for a batch of images.
|
||||
"""
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
pil_images = self._prepare_images(images)
|
||||
all_embeddings = []
|
||||
|
||||
for i in range(0, len(pil_images), self.batch_size):
|
||||
batch_images = pil_images[i : i + self.batch_size]
|
||||
batch_images = self._processor.process_images(batch_images).to(
|
||||
self._model.device
|
||||
)
|
||||
with torch.no_grad():
|
||||
image_embeddings = self._model(**batch_images)
|
||||
all_embeddings.extend(self._process_embeddings(image_embeddings))
|
||||
return all_embeddings
|
||||
|
||||
def compute_query_embeddings(
|
||||
self, query: Union[str, IMAGES], *args, **kwargs
|
||||
) -> List[List[List[float]]]:
|
||||
"""
|
||||
Compute embeddings for a single user query (text only).
|
||||
"""
|
||||
if not isinstance(query, str):
|
||||
raise ValueError(
|
||||
"Query must be a string, image to image search is not supported"
|
||||
)
|
||||
return self.generate_text_embeddings([query])
|
||||
|
||||
def compute_source_embeddings(
|
||||
self, images: IMAGES, *args, **kwargs
|
||||
) -> List[List[List[float]]]:
|
||||
"""
|
||||
Compute embeddings for a batch of source images.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
images : Union[str, bytes, List, pa.Array, pa.ChunkedArray, np.ndarray]
|
||||
Batch of images (paths, URLs, bytes, or PIL Images).
|
||||
"""
|
||||
images = self.sanitize_input(images)
|
||||
return self.generate_image_embeddings(images)
|
||||
@@ -18,6 +18,7 @@ import numpy as np
|
||||
import pyarrow as pa
|
||||
|
||||
from ..dependencies import pandas as pd
|
||||
from ..util import attempt_import_or_raise
|
||||
|
||||
|
||||
# ruff: noqa: PERF203
|
||||
@@ -275,3 +276,12 @@ def url_retrieve(url: str):
|
||||
def api_key_not_found_help(provider):
|
||||
logging.error("Could not find API key for %s", provider)
|
||||
raise ValueError(f"Please set the {provider.upper()}_API_KEY environment variable.")
|
||||
|
||||
|
||||
def is_flash_attn_2_available():
|
||||
try:
|
||||
attempt_import_or_raise("flash_attn", "flash_attn")
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
@@ -152,6 +152,104 @@ def Vector(
|
||||
return FixedSizeList
|
||||
|
||||
|
||||
def MultiVector(
|
||||
dim: int, value_type: pa.DataType = pa.float32(), nullable: bool = True
|
||||
) -> Type:
|
||||
"""Pydantic MultiVector Type for multi-vector embeddings.
|
||||
|
||||
This type represents a list of vectors, each with the same dimension.
|
||||
Useful for models that produce multiple embeddings per input, like ColPali.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dim : int
|
||||
The dimension of each vector in the multi-vector.
|
||||
value_type : pyarrow.DataType, optional
|
||||
The value type of the vectors, by default pa.float32()
|
||||
nullable : bool, optional
|
||||
Whether the multi-vector is nullable, by default it is True.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> import pydantic
|
||||
>>> from lancedb.pydantic import MultiVector
|
||||
...
|
||||
>>> class MyModel(pydantic.BaseModel):
|
||||
... id: int
|
||||
... text: str
|
||||
... embeddings: MultiVector(128) # List of 128-dimensional vectors
|
||||
>>> schema = pydantic_to_schema(MyModel)
|
||||
>>> assert schema == pa.schema([
|
||||
... pa.field("id", pa.int64(), False),
|
||||
... pa.field("text", pa.utf8(), False),
|
||||
... pa.field("embeddings", pa.list_(pa.list_(pa.float32(), 128)))
|
||||
... ])
|
||||
"""
|
||||
|
||||
class MultiVectorList(list, FixedSizeListMixin):
|
||||
def __repr__(self):
|
||||
return f"MultiVector(dim={dim})"
|
||||
|
||||
@staticmethod
|
||||
def nullable() -> bool:
|
||||
return nullable
|
||||
|
||||
@staticmethod
|
||||
def dim() -> int:
|
||||
return dim
|
||||
|
||||
@staticmethod
|
||||
def value_arrow_type() -> pa.DataType:
|
||||
return value_type
|
||||
|
||||
@staticmethod
|
||||
def is_multi_vector() -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
return core_schema.no_info_after_validator_function(
|
||||
cls,
|
||||
core_schema.list_schema(
|
||||
items_schema=core_schema.list_schema(
|
||||
min_length=dim,
|
||||
max_length=dim,
|
||||
items_schema=core_schema.float_schema(),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> Generator[Callable, None, None]:
|
||||
yield cls.validate
|
||||
|
||||
# For pydantic v1
|
||||
@classmethod
|
||||
def validate(cls, v):
|
||||
if not isinstance(v, (list, range)):
|
||||
raise TypeError("A list of vectors is needed")
|
||||
for vec in v:
|
||||
if not isinstance(vec, (list, range, np.ndarray)) or len(vec) != dim:
|
||||
raise TypeError(f"Each vector must be a list of {dim} numbers")
|
||||
return cls(v)
|
||||
|
||||
if PYDANTIC_VERSION.major < 2:
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]):
|
||||
field_schema["items"] = {
|
||||
"type": "array",
|
||||
"items": {"type": "number"},
|
||||
"minItems": dim,
|
||||
"maxItems": dim,
|
||||
}
|
||||
|
||||
return MultiVectorList
|
||||
|
||||
|
||||
def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
|
||||
"""Convert a field with native Python type to Arrow data type.
|
||||
|
||||
@@ -206,6 +304,9 @@ def _pydantic_type_to_arrow_type(tp: Any, field: FieldInfo) -> pa.DataType:
|
||||
fields = _pydantic_model_to_fields(tp)
|
||||
return pa.struct(fields)
|
||||
if issubclass(tp, FixedSizeListMixin):
|
||||
if getattr(tp, "is_multi_vector", lambda: False)():
|
||||
return pa.list_(pa.list_(tp.value_arrow_type(), tp.dim()))
|
||||
# For regular Vector
|
||||
return pa.list_(tp.value_arrow_type(), tp.dim())
|
||||
return _py_type_to_arrow_type(tp, field)
|
||||
|
||||
|
||||
@@ -28,6 +28,8 @@ import pyarrow.compute as pc
|
||||
import pyarrow.fs as pa_fs
|
||||
import pydantic
|
||||
|
||||
from lancedb.pydantic import PYDANTIC_VERSION
|
||||
|
||||
from . import __version__
|
||||
from .arrow import AsyncRecordBatchReader
|
||||
from .dependencies import pandas as pd
|
||||
@@ -498,10 +500,14 @@ class Query(pydantic.BaseModel):
|
||||
)
|
||||
return query
|
||||
|
||||
class Config:
|
||||
# This tells pydantic to allow custom types (needed for the `vector` query since
|
||||
# pa.Array wouln't be allowed otherwise)
|
||||
arbitrary_types_allowed = True
|
||||
# This tells pydantic to allow custom types (needed for the `vector` query since
|
||||
# pa.Array wouln't be allowed otherwise)
|
||||
if PYDANTIC_VERSION.major < 2: # Pydantic 1.x compat
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
else:
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
|
||||
class LanceQueryBuilder(ABC):
|
||||
@@ -1586,6 +1592,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._refine_factor = None
|
||||
self._distance_type = None
|
||||
self._phrase_query = None
|
||||
self._lower_bound = None
|
||||
self._upper_bound = None
|
||||
|
||||
def _validate_query(self, query, vector=None, text=None):
|
||||
if query is not None and (vector is not None or text is not None):
|
||||
@@ -1665,6 +1673,10 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._vector_query.ef(self._ef)
|
||||
if self._bypass_vector_index:
|
||||
self._vector_query.bypass_vector_index()
|
||||
if self._lower_bound or self._upper_bound:
|
||||
self._vector_query.distance_range(
|
||||
lower_bound=self._lower_bound, upper_bound=self._upper_bound
|
||||
)
|
||||
|
||||
if self._reranker is None:
|
||||
self._reranker = RRFReranker()
|
||||
|
||||
@@ -104,6 +104,7 @@ class RemoteTable(Table):
|
||||
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST", "scalar"] = "scalar",
|
||||
*,
|
||||
replace: bool = False,
|
||||
wait_timeout: timedelta = None,
|
||||
):
|
||||
"""Creates a scalar index
|
||||
Parameters
|
||||
@@ -126,13 +127,18 @@ class RemoteTable(Table):
|
||||
else:
|
||||
raise ValueError(f"Unknown index type: {index_type}")
|
||||
|
||||
LOOP.run(self._table.create_index(column, config=config, replace=replace))
|
||||
LOOP.run(
|
||||
self._table.create_index(
|
||||
column, config=config, replace=replace, wait_timeout=wait_timeout
|
||||
)
|
||||
)
|
||||
|
||||
def create_fts_index(
|
||||
self,
|
||||
column: str,
|
||||
*,
|
||||
replace: bool = False,
|
||||
wait_timeout: timedelta = None,
|
||||
with_position: bool = True,
|
||||
# tokenizer configs:
|
||||
base_tokenizer: str = "simple",
|
||||
@@ -153,7 +159,11 @@ class RemoteTable(Table):
|
||||
remove_stop_words=remove_stop_words,
|
||||
ascii_folding=ascii_folding,
|
||||
)
|
||||
LOOP.run(self._table.create_index(column, config=config, replace=replace))
|
||||
LOOP.run(
|
||||
self._table.create_index(
|
||||
column, config=config, replace=replace, wait_timeout=wait_timeout
|
||||
)
|
||||
)
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
@@ -165,6 +175,7 @@ class RemoteTable(Table):
|
||||
replace: Optional[bool] = None,
|
||||
accelerator: Optional[str] = None,
|
||||
index_type="vector",
|
||||
wait_timeout: Optional[timedelta] = None,
|
||||
):
|
||||
"""Create an index on the table.
|
||||
Currently, the only parameters that matter are
|
||||
@@ -236,7 +247,11 @@ class RemoteTable(Table):
|
||||
" 'IVF_FLAT', 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
|
||||
)
|
||||
|
||||
LOOP.run(self._table.create_index(vector_column_name, config=config))
|
||||
LOOP.run(
|
||||
self._table.create_index(
|
||||
vector_column_name, config=config, wait_timeout=wait_timeout
|
||||
)
|
||||
)
|
||||
|
||||
def add(
|
||||
self,
|
||||
@@ -554,6 +569,11 @@ class RemoteTable(Table):
|
||||
def drop_index(self, index_name: str):
|
||||
return LOOP.run(self._table.drop_index(index_name))
|
||||
|
||||
def wait_for_index(
|
||||
self, index_names: Iterable[str], timeout: timedelta = timedelta(seconds=300)
|
||||
):
|
||||
return LOOP.run(self._table.wait_for_index(index_names, timeout))
|
||||
|
||||
def uses_v2_manifest_paths(self) -> bool:
|
||||
raise NotImplementedError(
|
||||
"uses_v2_manifest_paths() is not supported on the LanceDB Cloud"
|
||||
|
||||
@@ -631,6 +631,7 @@ class Table(ABC):
|
||||
index_cache_size: Optional[int] = None,
|
||||
*,
|
||||
index_type: VectorIndexType = "IVF_PQ",
|
||||
wait_timeout: Optional[timedelta] = None,
|
||||
num_bits: int = 8,
|
||||
max_iterations: int = 50,
|
||||
sample_rate: int = 256,
|
||||
@@ -666,6 +667,8 @@ class Table(ABC):
|
||||
num_bits: int
|
||||
The number of bits to encode sub-vectors. Only used with the IVF_PQ index.
|
||||
Only 4 and 8 are supported.
|
||||
wait_timeout: timedelta, optional
|
||||
The timeout to wait if indexing is asynchronous.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -689,6 +692,23 @@ class Table(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def wait_for_index(
|
||||
self, index_names: Iterable[str], timeout: timedelta = timedelta(seconds=300)
|
||||
) -> None:
|
||||
"""
|
||||
Wait for indexing to complete for the given index names.
|
||||
This will poll the table until all the indices are fully indexed,
|
||||
or raise a timeout exception if the timeout is reached.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
index_names: str
|
||||
The name of the indices to poll
|
||||
timeout: timedelta
|
||||
Timeout to wait for asynchronous indexing. The default is 5 minutes.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def create_scalar_index(
|
||||
self,
|
||||
@@ -696,6 +716,7 @@ class Table(ABC):
|
||||
*,
|
||||
replace: bool = True,
|
||||
index_type: ScalarIndexType = "BTREE",
|
||||
wait_timeout: Optional[timedelta] = None,
|
||||
):
|
||||
"""Create a scalar index on a column.
|
||||
|
||||
@@ -708,7 +729,8 @@ class Table(ABC):
|
||||
Replace the existing index if it exists.
|
||||
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"], default "BTREE"
|
||||
The type of index to create.
|
||||
|
||||
wait_timeout: timedelta, optional
|
||||
The timeout to wait if indexing is asynchronous.
|
||||
Examples
|
||||
--------
|
||||
|
||||
@@ -767,6 +789,7 @@ class Table(ABC):
|
||||
stem: bool = False,
|
||||
remove_stop_words: bool = False,
|
||||
ascii_folding: bool = False,
|
||||
wait_timeout: Optional[timedelta] = None,
|
||||
):
|
||||
"""Create a full-text search index on the table.
|
||||
|
||||
@@ -822,6 +845,8 @@ class Table(ABC):
|
||||
ascii_folding : bool, default False
|
||||
Whether to fold ASCII characters. This converts accented characters to
|
||||
their ASCII equivalent. For example, "café" would be converted to "cafe".
|
||||
wait_timeout: timedelta, optional
|
||||
The timeout to wait if indexing is asynchronous.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1771,6 +1796,11 @@ class LanceTable(Table):
|
||||
"""
|
||||
return LOOP.run(self._table.prewarm_index(name))
|
||||
|
||||
def wait_for_index(
|
||||
self, index_names: Iterable[str], timeout: timedelta = timedelta(seconds=300)
|
||||
) -> None:
|
||||
return LOOP.run(self._table.wait_for_index(index_names, timeout))
|
||||
|
||||
def create_scalar_index(
|
||||
self,
|
||||
column: str,
|
||||
@@ -2964,6 +2994,7 @@ class AsyncTable:
|
||||
config: Optional[
|
||||
Union[IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS]
|
||||
] = None,
|
||||
wait_timeout: Optional[timedelta] = None,
|
||||
):
|
||||
"""Create an index to speed up queries
|
||||
|
||||
@@ -2988,6 +3019,8 @@ class AsyncTable:
|
||||
For advanced configuration you can specify the type of index you would
|
||||
like to create. You can also specify index-specific parameters when
|
||||
creating an index object.
|
||||
wait_timeout: timedelta, optional
|
||||
The timeout to wait if indexing is asynchronous.
|
||||
"""
|
||||
if config is not None:
|
||||
if not isinstance(
|
||||
@@ -2998,7 +3031,9 @@ class AsyncTable:
|
||||
" Bitmap, LabelList, or FTS"
|
||||
)
|
||||
try:
|
||||
await self._inner.create_index(column, index=config, replace=replace)
|
||||
await self._inner.create_index(
|
||||
column, index=config, replace=replace, wait_timeout=wait_timeout
|
||||
)
|
||||
except ValueError as e:
|
||||
if "not support the requested language" in str(e):
|
||||
supported_langs = ", ".join(lang_mapping.values())
|
||||
@@ -3043,6 +3078,23 @@ class AsyncTable:
|
||||
"""
|
||||
await self._inner.prewarm_index(name)
|
||||
|
||||
async def wait_for_index(
|
||||
self, index_names: Iterable[str], timeout: timedelta = timedelta(seconds=300)
|
||||
) -> None:
|
||||
"""
|
||||
Wait for indexing to complete for the given index names.
|
||||
This will poll the table until all the indices are fully indexed,
|
||||
or raise a timeout exception if the timeout is reached.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
index_names: str
|
||||
The name of the indices to poll
|
||||
timeout: timedelta
|
||||
Timeout to wait for asynchronous indexing. The default is 5 minutes.
|
||||
"""
|
||||
await self._inner.wait_for_index(index_names, timeout)
|
||||
|
||||
async def add(
|
||||
self,
|
||||
data: DATA,
|
||||
|
||||
@@ -11,7 +11,7 @@ import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.pydantic import LanceModel, Vector, MultiVector
|
||||
import requests
|
||||
|
||||
# These are integration tests for embedding functions.
|
||||
@@ -575,3 +575,67 @@ def test_voyageai_multimodal_embedding_text_function():
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
importlib.util.find_spec("colpali_engine") is None,
|
||||
reason="colpali_engine not installed",
|
||||
)
|
||||
def test_colpali(tmp_path):
|
||||
import requests
|
||||
from lancedb.pydantic import LanceModel
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = get_registry()
|
||||
func = registry.get("colpali").create()
|
||||
|
||||
class MediaItems(LanceModel):
|
||||
text: str
|
||||
image_uri: str = func.SourceField()
|
||||
image_bytes: bytes = func.SourceField()
|
||||
image_vectors: MultiVector(func.ndims()) = (
|
||||
func.VectorField()
|
||||
) # Multivector image embeddings
|
||||
|
||||
table = db.create_table("media", schema=MediaItems)
|
||||
|
||||
texts = [
|
||||
"a cute cat playing with yarn",
|
||||
"a puppy in a flower field",
|
||||
"a red sports car on the highway",
|
||||
"a vintage bicycle leaning against a wall",
|
||||
"a plate of delicious pasta",
|
||||
"fresh fruit salad in a bowl",
|
||||
]
|
||||
|
||||
uris = [
|
||||
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
||||
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
||||
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
||||
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
||||
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
||||
]
|
||||
|
||||
# Get images as bytes
|
||||
image_bytes = [requests.get(uri).content for uri in uris]
|
||||
|
||||
table.add(
|
||||
pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes})
|
||||
)
|
||||
|
||||
# Test text-to-image search
|
||||
image_results = (
|
||||
table.search("fluffy companion", vector_column_name="image_vectors")
|
||||
.limit(1)
|
||||
.to_pydantic(MediaItems)[0]
|
||||
)
|
||||
assert "cat" in image_results.text.lower() or "puppy" in image_results.text.lower()
|
||||
|
||||
# Verify multivector dimensions
|
||||
first_row = table.to_arrow().to_pylist()[0]
|
||||
assert len(first_row["image_vectors"]) > 1, "Should have multiple image vectors"
|
||||
assert len(first_row["image_vectors"][0]) == func.ndims(), (
|
||||
"Vector dimension mismatch"
|
||||
)
|
||||
|
||||
@@ -4,13 +4,32 @@
|
||||
import lancedb
|
||||
|
||||
from lancedb.query import LanceHybridQueryBuilder
|
||||
from lancedb.rerankers.rrf import RRFReranker
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from lancedb.index import FTS
|
||||
from lancedb.table import AsyncTable
|
||||
from lancedb.table import AsyncTable, Table
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sync_table(tmpdir_factory) -> Table:
|
||||
tmp_path = str(tmpdir_factory.mktemp("data"))
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pa.table(
|
||||
{
|
||||
"text": pa.array(["a", "b", "cat", "dog"]),
|
||||
"vector": pa.array(
|
||||
[[0.1, 0.1], [2, 2], [-0.1, -0.1], [0.5, -0.5]],
|
||||
type=pa.list_(pa.float32(), list_size=2),
|
||||
),
|
||||
}
|
||||
)
|
||||
table = db.create_table("test", data)
|
||||
table.create_fts_index("text", with_position=False, use_tantivy=False)
|
||||
return table
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@@ -102,6 +121,42 @@ async def test_async_hybrid_query_default_limit(table: AsyncTable):
|
||||
assert texts.count("a") == 1
|
||||
|
||||
|
||||
def test_hybrid_query_distance_range(sync_table: Table):
|
||||
reranker = RRFReranker(return_score="all")
|
||||
result = (
|
||||
sync_table.search(query_type="hybrid")
|
||||
.vector([0.0, 0.4])
|
||||
.text("cat and dog")
|
||||
.distance_range(lower_bound=0.2, upper_bound=0.5)
|
||||
.rerank(reranker)
|
||||
.limit(2)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) == 2
|
||||
print(result)
|
||||
for dist in result["_distance"]:
|
||||
if dist.is_valid:
|
||||
assert 0.2 <= dist.as_py() <= 0.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hybrid_query_distance_range_async(table: AsyncTable):
|
||||
reranker = RRFReranker(return_score="all")
|
||||
result = await (
|
||||
table.query()
|
||||
.nearest_to([0.0, 0.4])
|
||||
.nearest_to_text("cat and dog")
|
||||
.distance_range(lower_bound=0.2, upper_bound=0.5)
|
||||
.rerank(reranker)
|
||||
.limit(2)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) == 2
|
||||
for dist in result["_distance"]:
|
||||
if dist.is_valid:
|
||||
assert 0.2 <= dist.as_py() <= 0.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explain_plan(table: AsyncTable):
|
||||
plan = await (
|
||||
|
||||
@@ -9,7 +9,13 @@ from typing import List, Optional, Tuple
|
||||
import pyarrow as pa
|
||||
import pydantic
|
||||
import pytest
|
||||
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
|
||||
from lancedb.pydantic import (
|
||||
PYDANTIC_VERSION,
|
||||
LanceModel,
|
||||
Vector,
|
||||
pydantic_to_schema,
|
||||
MultiVector,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
@@ -354,3 +360,55 @@ def test_optional_nested_model():
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_multi_vector():
|
||||
class TestModel(pydantic.BaseModel):
|
||||
vec: MultiVector(8)
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
assert schema == pa.schema(
|
||||
[pa.field("vec", pa.list_(pa.list_(pa.float32(), 8)), True)]
|
||||
)
|
||||
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
TestModel(vec=[[1.0] * 7])
|
||||
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
TestModel(vec=[[1.0] * 9])
|
||||
|
||||
TestModel(vec=[[1.0] * 8])
|
||||
TestModel(vec=[[1.0] * 8, [2.0] * 8])
|
||||
|
||||
TestModel(vec=[])
|
||||
|
||||
|
||||
def test_multi_vector_nullable():
|
||||
class NullableModel(pydantic.BaseModel):
|
||||
vec: MultiVector(16, nullable=False)
|
||||
|
||||
schema = pydantic_to_schema(NullableModel)
|
||||
assert schema == pa.schema(
|
||||
[pa.field("vec", pa.list_(pa.list_(pa.float32(), 16)), False)]
|
||||
)
|
||||
|
||||
class DefaultModel(pydantic.BaseModel):
|
||||
vec: MultiVector(16)
|
||||
|
||||
schema = pydantic_to_schema(DefaultModel)
|
||||
assert schema == pa.schema(
|
||||
[pa.field("vec", pa.list_(pa.list_(pa.float32(), 16)), True)]
|
||||
)
|
||||
|
||||
|
||||
def test_multi_vector_in_lance_model():
|
||||
class TestModel(LanceModel):
|
||||
id: int
|
||||
vectors: MultiVector(16) = Field(default=[[0.0] * 16])
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
assert schema == TestModel.to_arrow_schema()
|
||||
assert TestModel.field_names() == ["id", "vectors"]
|
||||
|
||||
t = TestModel(id=1)
|
||||
assert t.vectors == [[0.0] * 16]
|
||||
|
||||
@@ -257,7 +257,9 @@ async def test_distance_range_with_new_rows_async():
|
||||
}
|
||||
)
|
||||
table = await conn.create_table("test", data)
|
||||
table.create_index("vector", config=IvfPq(num_partitions=1, num_sub_vectors=2))
|
||||
await table.create_index(
|
||||
"vector", config=IvfPq(num_partitions=1, num_sub_vectors=2)
|
||||
)
|
||||
|
||||
q = [0, 0]
|
||||
rs = await table.query().nearest_to(q).to_arrow()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import contextlib
|
||||
from datetime import timedelta
|
||||
@@ -235,6 +235,10 @@ def test_table_add_in_threadpool():
|
||||
|
||||
def test_table_create_indices():
|
||||
def handler(request):
|
||||
index_stats = dict(
|
||||
index_type="IVF_PQ", num_indexed_rows=1000, num_unindexed_rows=0
|
||||
)
|
||||
|
||||
if request.path == "/v1/table/test/create_index/":
|
||||
request.send_response(200)
|
||||
request.end_headers()
|
||||
@@ -258,6 +262,47 @@ def test_table_create_indices():
|
||||
)
|
||||
)
|
||||
request.wfile.write(payload.encode())
|
||||
elif request.path == "/v1/table/test/index/list/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
payload = json.dumps(
|
||||
dict(
|
||||
indexes=[
|
||||
{
|
||||
"index_name": "id_idx",
|
||||
"columns": ["id"],
|
||||
},
|
||||
{
|
||||
"index_name": "text_idx",
|
||||
"columns": ["text"],
|
||||
},
|
||||
{
|
||||
"index_name": "vector_idx",
|
||||
"columns": ["vector"],
|
||||
},
|
||||
]
|
||||
)
|
||||
)
|
||||
request.wfile.write(payload.encode())
|
||||
elif request.path == "/v1/table/test/index/id_idx/stats/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
payload = json.dumps(index_stats)
|
||||
request.wfile.write(payload.encode())
|
||||
elif request.path == "/v1/table/test/index/text_idx/stats/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
payload = json.dumps(index_stats)
|
||||
request.wfile.write(payload.encode())
|
||||
elif request.path == "/v1/table/test/index/vector_idx/stats/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
payload = json.dumps(index_stats)
|
||||
request.wfile.write(payload.encode())
|
||||
elif "/drop/" in request.path:
|
||||
request.send_response(200)
|
||||
request.end_headers()
|
||||
@@ -269,14 +314,81 @@ def test_table_create_indices():
|
||||
# Parameters are well-tested through local and async tests.
|
||||
# This is a smoke-test.
|
||||
table = db.create_table("test", [{"id": 1}])
|
||||
table.create_scalar_index("id")
|
||||
table.create_fts_index("text")
|
||||
table.create_scalar_index("vector")
|
||||
table.create_scalar_index("id", wait_timeout=timedelta(seconds=2))
|
||||
table.create_fts_index("text", wait_timeout=timedelta(seconds=2))
|
||||
table.create_index(
|
||||
vector_column_name="vector", wait_timeout=timedelta(seconds=10)
|
||||
)
|
||||
table.wait_for_index(["id_idx"], timedelta(seconds=2))
|
||||
table.wait_for_index(["text_idx", "vector_idx"], timedelta(seconds=2))
|
||||
table.drop_index("vector_idx")
|
||||
table.drop_index("id_idx")
|
||||
table.drop_index("text_idx")
|
||||
|
||||
|
||||
def test_table_wait_for_index_timeout():
|
||||
def handler(request):
|
||||
index_stats = dict(
|
||||
index_type="BTREE", num_indexed_rows=1000, num_unindexed_rows=1
|
||||
)
|
||||
|
||||
if request.path == "/v1/table/test/create/?mode=create":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b"{}")
|
||||
elif request.path == "/v1/table/test/describe/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
payload = json.dumps(
|
||||
dict(
|
||||
version=1,
|
||||
schema=dict(
|
||||
fields=[
|
||||
dict(name="id", type={"type": "int64"}, nullable=False),
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
request.wfile.write(payload.encode())
|
||||
elif request.path == "/v1/table/test/index/list/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
payload = json.dumps(
|
||||
dict(
|
||||
indexes=[
|
||||
{
|
||||
"index_name": "id_idx",
|
||||
"columns": ["id"],
|
||||
},
|
||||
]
|
||||
)
|
||||
)
|
||||
request.wfile.write(payload.encode())
|
||||
elif request.path == "/v1/table/test/index/id_idx/stats/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
payload = json.dumps(index_stats)
|
||||
print(f"{index_stats=}")
|
||||
request.wfile.write(payload.encode())
|
||||
else:
|
||||
request.send_response(404)
|
||||
request.end_headers()
|
||||
|
||||
with mock_lancedb_connection(handler) as db:
|
||||
table = db.create_table("test", [{"id": 1}])
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match=re.escape(
|
||||
'Timeout error: timed out waiting for indices: ["id_idx"] after 1s'
|
||||
),
|
||||
):
|
||||
table.wait_for_index(["id_idx"], timedelta(seconds=1))
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def query_test_table(query_handler, *, server_version=Version("0.1.0")):
|
||||
def handler(request):
|
||||
|
||||
@@ -9,9 +9,9 @@ from typing import List
|
||||
from unittest.mock import patch
|
||||
|
||||
import lancedb
|
||||
from lancedb.dependencies import _PANDAS_AVAILABLE
|
||||
from lancedb.index import HnswPq, HnswSq, IvfPq
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
import pyarrow as pa
|
||||
import pyarrow.dataset
|
||||
@@ -138,13 +138,16 @@ def test_create_table(mem_db: DBConnection):
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
]
|
||||
df = pd.DataFrame(rows)
|
||||
pa_table = pa.Table.from_pandas(df, schema=schema)
|
||||
pa_table = pa.Table.from_pylist(rows, schema=schema)
|
||||
data = [
|
||||
("Rows", rows),
|
||||
("pd_DataFrame", df),
|
||||
("pa_Table", pa_table),
|
||||
]
|
||||
if _PANDAS_AVAILABLE:
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
data.append(("pd_DataFrame", df))
|
||||
|
||||
for name, d in data:
|
||||
tbl = mem_db.create_table(name, data=d, schema=schema).to_arrow()
|
||||
@@ -296,7 +299,7 @@ def test_add_subschema(mem_db: DBConnection):
|
||||
|
||||
data = {"price": 10.0, "item": "foo"}
|
||||
table.add([data])
|
||||
data = pd.DataFrame({"price": [2.0], "vector": [[3.1, 4.1]]})
|
||||
data = pa.Table.from_pydict({"price": [2.0], "vector": [[3.1, 4.1]]})
|
||||
table.add(data)
|
||||
data = {"price": 3.0, "vector": [5.9, 26.5], "item": "bar"}
|
||||
table.add([data])
|
||||
@@ -405,6 +408,7 @@ def test_add_nullability(mem_db: DBConnection):
|
||||
|
||||
|
||||
def test_add_pydantic_model(mem_db: DBConnection):
|
||||
pytest.importorskip("pandas")
|
||||
# https://github.com/lancedb/lancedb/issues/562
|
||||
|
||||
class Metadata(BaseModel):
|
||||
@@ -473,10 +477,10 @@ def test_polars(mem_db: DBConnection):
|
||||
table = mem_db.create_table("test", data=pl.DataFrame(data))
|
||||
assert len(table) == 2
|
||||
|
||||
result = table.to_pandas()
|
||||
assert np.allclose(result["vector"].tolist(), data["vector"])
|
||||
assert result["item"].tolist() == data["item"]
|
||||
assert np.allclose(result["price"].tolist(), data["price"])
|
||||
result = table.to_arrow()
|
||||
assert np.allclose(result["vector"].to_pylist(), data["vector"])
|
||||
assert result["item"].to_pylist() == data["item"]
|
||||
assert np.allclose(result["price"].to_pylist(), data["price"])
|
||||
|
||||
schema = pa.schema(
|
||||
[
|
||||
@@ -688,7 +692,7 @@ def test_delete(mem_db: DBConnection):
|
||||
assert len(table.list_versions()) == 2
|
||||
assert table.version == 2
|
||||
assert len(table) == 1
|
||||
assert table.to_pandas()["id"].tolist() == [1]
|
||||
assert table.to_arrow()["id"].to_pylist() == [1]
|
||||
|
||||
|
||||
def test_update(mem_db: DBConnection):
|
||||
@@ -852,6 +856,7 @@ def test_merge_insert(mem_db: DBConnection):
|
||||
ids=["pa.Table", "pd.DataFrame", "rows"],
|
||||
)
|
||||
def test_merge_insert_subschema(mem_db: DBConnection, data_format):
|
||||
pytest.importorskip("pandas")
|
||||
initial_data = pa.table(
|
||||
{"id": range(3), "a": [1.0, 2.0, 3.0], "c": ["x", "x", "x"]}
|
||||
)
|
||||
@@ -948,7 +953,7 @@ def test_create_with_embedding_function(mem_db: DBConnection):
|
||||
|
||||
func = MockTextEmbeddingFunction.create()
|
||||
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
|
||||
df = pd.DataFrame({"text": texts, "vector": func.compute_source_embeddings(texts)})
|
||||
df = pa.table({"text": texts, "vector": func.compute_source_embeddings(texts)})
|
||||
|
||||
conf = EmbeddingFunctionConfig(
|
||||
source_column="text", vector_column="vector", function=func
|
||||
@@ -973,7 +978,7 @@ def test_create_f16_table(mem_db: DBConnection):
|
||||
text: str
|
||||
vector: Vector(32, value_type=pa.float16())
|
||||
|
||||
df = pd.DataFrame(
|
||||
df = pa.table(
|
||||
{
|
||||
"text": [f"s-{i}" for i in range(512)],
|
||||
"vector": [np.random.randn(32).astype(np.float16) for _ in range(512)],
|
||||
@@ -986,7 +991,7 @@ def test_create_f16_table(mem_db: DBConnection):
|
||||
table.add(df)
|
||||
table.create_index(num_partitions=2, num_sub_vectors=2)
|
||||
|
||||
query = df.vector.iloc[2]
|
||||
query = df["vector"][2].as_py()
|
||||
expected = table.search(query).limit(2).to_arrow()
|
||||
|
||||
assert "s-2" in expected["text"].to_pylist()
|
||||
@@ -1002,7 +1007,7 @@ def test_add_with_embedding_function(mem_db: DBConnection):
|
||||
table = mem_db.create_table("my_table", schema=MyTable)
|
||||
|
||||
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
|
||||
df = pd.DataFrame({"text": texts})
|
||||
df = pa.table({"text": texts})
|
||||
table.add(df)
|
||||
|
||||
texts = ["the quick brown fox", "jumped over the lazy dog"]
|
||||
@@ -1033,14 +1038,14 @@ def test_multiple_vector_columns(mem_db: DBConnection):
|
||||
{"vector1": v1, "vector2": v2, "text": "foo"},
|
||||
{"vector1": v2, "vector2": v1, "text": "bar"},
|
||||
]
|
||||
df = pd.DataFrame(data)
|
||||
df = pa.Table.from_pylist(data)
|
||||
table.add(df)
|
||||
|
||||
q = np.random.randn(10)
|
||||
result1 = table.search(q, vector_column_name="vector1").limit(1).to_pandas()
|
||||
result2 = table.search(q, vector_column_name="vector2").limit(1).to_pandas()
|
||||
result1 = table.search(q, vector_column_name="vector1").limit(1).to_arrow()
|
||||
result2 = table.search(q, vector_column_name="vector2").limit(1).to_arrow()
|
||||
|
||||
assert result1["text"].iloc[0] != result2["text"].iloc[0]
|
||||
assert result1["text"][0] != result2["text"][0]
|
||||
|
||||
|
||||
def test_create_scalar_index(mem_db: DBConnection):
|
||||
@@ -1078,22 +1083,22 @@ def test_empty_query(mem_db: DBConnection):
|
||||
"my_table",
|
||||
data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}],
|
||||
)
|
||||
df = table.search().select(["id"]).where("text='bar'").limit(1).to_pandas()
|
||||
val = df.id.iloc[0]
|
||||
df = table.search().select(["id"]).where("text='bar'").limit(1).to_arrow()
|
||||
val = df["id"][0].as_py()
|
||||
assert val == 1
|
||||
|
||||
table = mem_db.create_table("my_table2", data=[{"id": i} for i in range(100)])
|
||||
df = table.search().select(["id"]).to_pandas()
|
||||
assert len(df) == 100
|
||||
df = table.search().select(["id"]).to_arrow()
|
||||
assert df.num_rows == 100
|
||||
# None is the same as default
|
||||
df = table.search().select(["id"]).limit(None).to_pandas()
|
||||
assert len(df) == 100
|
||||
df = table.search().select(["id"]).limit(None).to_arrow()
|
||||
assert df.num_rows == 100
|
||||
# invalid limist is the same as None, wihch is the same as default
|
||||
df = table.search().select(["id"]).limit(-1).to_pandas()
|
||||
assert len(df) == 100
|
||||
df = table.search().select(["id"]).limit(-1).to_arrow()
|
||||
assert df.num_rows == 100
|
||||
# valid limit should work
|
||||
df = table.search().select(["id"]).limit(42).to_pandas()
|
||||
assert len(df) == 42
|
||||
df = table.search().select(["id"]).limit(42).to_arrow()
|
||||
assert df.num_rows == 42
|
||||
|
||||
|
||||
def test_search_with_schema_inf_single_vector(mem_db: DBConnection):
|
||||
@@ -1112,14 +1117,14 @@ def test_search_with_schema_inf_single_vector(mem_db: DBConnection):
|
||||
{"vector_col": v1, "text": "foo"},
|
||||
{"vector_col": v2, "text": "bar"},
|
||||
]
|
||||
df = pd.DataFrame(data)
|
||||
df = pa.Table.from_pylist(data)
|
||||
table.add(df)
|
||||
|
||||
q = np.random.randn(10)
|
||||
result1 = table.search(q, vector_column_name="vector_col").limit(1).to_pandas()
|
||||
result2 = table.search(q).limit(1).to_pandas()
|
||||
result1 = table.search(q, vector_column_name="vector_col").limit(1).to_arrow()
|
||||
result2 = table.search(q).limit(1).to_arrow()
|
||||
|
||||
assert result1["text"].iloc[0] == result2["text"].iloc[0]
|
||||
assert result1["text"][0].as_py() == result2["text"][0].as_py()
|
||||
|
||||
|
||||
def test_search_with_schema_inf_multiple_vector(mem_db: DBConnection):
|
||||
@@ -1139,12 +1144,12 @@ def test_search_with_schema_inf_multiple_vector(mem_db: DBConnection):
|
||||
{"vector1": v1, "vector2": v2, "text": "foo"},
|
||||
{"vector1": v2, "vector2": v1, "text": "bar"},
|
||||
]
|
||||
df = pd.DataFrame(data)
|
||||
df = pa.Table.from_pylist(data)
|
||||
table.add(df)
|
||||
|
||||
q = np.random.randn(10)
|
||||
with pytest.raises(ValueError):
|
||||
table.search(q).limit(1).to_pandas()
|
||||
table.search(q).limit(1).to_arrow()
|
||||
|
||||
|
||||
def test_compact_cleanup(tmp_db: DBConnection):
|
||||
|
||||
@@ -652,6 +652,11 @@ impl HybridQuery {
|
||||
self.inner_vec.bypass_vector_index();
|
||||
}
|
||||
|
||||
#[pyo3(signature = (lower_bound=None, upper_bound=None))]
|
||||
pub fn distance_range(&mut self, lower_bound: Option<f32>, upper_bound: Option<f32>) {
|
||||
self.inner_vec.distance_range(lower_bound, upper_bound);
|
||||
}
|
||||
|
||||
pub fn to_vector_query(&mut self) -> PyResult<VectorQuery> {
|
||||
Ok(VectorQuery {
|
||||
inner: self.inner_vec.inner.clone(),
|
||||
|
||||
@@ -177,15 +177,19 @@ impl Table {
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (column, index=None, replace=None))]
|
||||
#[pyo3(signature = (column, index=None, replace=None, wait_timeout=None))]
|
||||
pub fn create_index<'a>(
|
||||
self_: PyRef<'a, Self>,
|
||||
column: String,
|
||||
index: Option<Bound<'_, PyAny>>,
|
||||
replace: Option<bool>,
|
||||
wait_timeout: Option<Bound<'_, PyAny>>,
|
||||
) -> PyResult<Bound<'a, PyAny>> {
|
||||
let index = extract_index_params(&index)?;
|
||||
let mut op = self_.inner_ref()?.create_index(&[column], index);
|
||||
let timeout = wait_timeout.map(|t| t.extract::<std::time::Duration>().unwrap());
|
||||
let mut op = self_
|
||||
.inner_ref()?
|
||||
.create_index_with_timeout(&[column], index, timeout);
|
||||
if let Some(replace) = replace {
|
||||
op = op.replace(replace);
|
||||
}
|
||||
@@ -204,6 +208,26 @@ impl Table {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn wait_for_index<'a>(
|
||||
self_: PyRef<'a, Self>,
|
||||
index_names: Vec<String>,
|
||||
timeout: Bound<'_, PyAny>,
|
||||
) -> PyResult<Bound<'a, PyAny>> {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
let timeout = timeout.extract::<std::time::Duration>()?;
|
||||
future_into_py(self_.py(), async move {
|
||||
let index_refs = index_names
|
||||
.iter()
|
||||
.map(String::as_str)
|
||||
.collect::<Vec<&str>>();
|
||||
inner
|
||||
.wait_for_index(&index_refs, timeout)
|
||||
.await
|
||||
.infer_error()?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn prewarm_index(self_: PyRef<'_, Self>, index_name: String) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-node"
|
||||
version = "0.19.0-beta.7"
|
||||
version = "0.19.0-beta.10"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb"
|
||||
version = "0.19.0-beta.7"
|
||||
version = "0.19.0-beta.10"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
|
||||
@@ -35,6 +35,8 @@ pub enum Error {
|
||||
Schema { message: String },
|
||||
#[snafu(display("Runtime error: {message}"))]
|
||||
Runtime { message: String },
|
||||
#[snafu(display("Timeout error: {message}"))]
|
||||
Timeout { message: String },
|
||||
|
||||
// 3rd party / external errors
|
||||
#[snafu(display("object_store error: {source}"))]
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use scalar::FtsIndexBuilder;
|
||||
use serde::Deserialize;
|
||||
use serde_with::skip_serializing_none;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use vector::IvfFlatIndexBuilder;
|
||||
|
||||
use crate::{table::BaseTable, DistanceType, Error, Result};
|
||||
@@ -17,6 +17,7 @@ use self::{
|
||||
|
||||
pub mod scalar;
|
||||
pub mod vector;
|
||||
pub mod waiter;
|
||||
|
||||
/// Supported index types.
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -69,6 +70,7 @@ pub struct IndexBuilder {
|
||||
pub(crate) index: Index,
|
||||
pub(crate) columns: Vec<String>,
|
||||
pub(crate) replace: bool,
|
||||
pub(crate) wait_timeout: Option<Duration>,
|
||||
}
|
||||
|
||||
impl IndexBuilder {
|
||||
@@ -78,6 +80,7 @@ impl IndexBuilder {
|
||||
index,
|
||||
columns,
|
||||
replace: true,
|
||||
wait_timeout: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,6 +94,15 @@ impl IndexBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
/// Duration of time to wait for asynchronous indexing to complete. If not set,
|
||||
/// `create_index()` will not wait.
|
||||
///
|
||||
/// This is not supported for `NativeTable` since indexing is synchronous.
|
||||
pub fn wait_timeout(mut self, d: Duration) -> Self {
|
||||
self.wait_timeout = Some(d);
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn execute(self) -> Result<()> {
|
||||
self.parent.clone().create_index(self).await
|
||||
}
|
||||
|
||||
89
rust/lancedb/src/index/waiter.rs
Normal file
89
rust/lancedb/src/index/waiter.rs
Normal file
@@ -0,0 +1,89 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::table::BaseTable;
|
||||
use crate::Error;
|
||||
use log::debug;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::time::sleep;
|
||||
|
||||
const DEFAULT_SLEEP_MS: u64 = 1000;
|
||||
const MAX_WAIT: Duration = Duration::from_secs(2 * 60 * 60);
|
||||
|
||||
/// Poll the table using list_indices() and index_stats() until all of the indices have 0 un-indexed rows.
|
||||
/// Will return Error::Timeout if the columns are not fully indexed within the timeout.
|
||||
pub async fn wait_for_index(
|
||||
table: &dyn BaseTable,
|
||||
index_names: &[&str],
|
||||
timeout: Duration,
|
||||
) -> Result<()> {
|
||||
if timeout > MAX_WAIT {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!("timeout must be less than {:?}", MAX_WAIT),
|
||||
});
|
||||
}
|
||||
let start = Instant::now();
|
||||
let mut remaining = index_names.to_vec();
|
||||
|
||||
// poll via list_indices() and index_stats() until all indices are created and fully indexed
|
||||
while start.elapsed() < timeout {
|
||||
let mut completed = vec![];
|
||||
let indices = table.list_indices().await?;
|
||||
|
||||
for &idx in &remaining {
|
||||
if !indices.iter().any(|i| i.name == *idx) {
|
||||
debug!("still waiting for new index '{}'", idx);
|
||||
continue;
|
||||
}
|
||||
|
||||
let stats = table.index_stats(idx.as_ref()).await?;
|
||||
match stats {
|
||||
None => {
|
||||
debug!("still waiting for new index '{}'", idx);
|
||||
continue;
|
||||
}
|
||||
Some(s) => {
|
||||
if s.num_unindexed_rows == 0 {
|
||||
// note: this may never stabilize under constant writes.
|
||||
// we should later replace this with a status/job model
|
||||
completed.push(idx);
|
||||
debug!(
|
||||
"fully indexed '{}'. indexed rows: {}",
|
||||
idx, s.num_indexed_rows
|
||||
);
|
||||
} else {
|
||||
debug!(
|
||||
"still waiting for index '{}'. unindexed rows: {}",
|
||||
idx, s.num_unindexed_rows
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
remaining.retain(|idx| !completed.contains(idx));
|
||||
if remaining.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
sleep(Duration::from_millis(DEFAULT_SLEEP_MS)).await;
|
||||
}
|
||||
|
||||
// debug log index diagnostics
|
||||
for &r in &remaining {
|
||||
let stats = table.index_stats(r.as_ref()).await?;
|
||||
match stats {
|
||||
Some(s) => debug!(
|
||||
"index '{}' not fully indexed after {:?}. stats: {:?}",
|
||||
r, timeout, s
|
||||
),
|
||||
None => debug!("index '{}' not found after {:?}", r, timeout),
|
||||
}
|
||||
}
|
||||
|
||||
Err(Error::Timeout {
|
||||
message: format!(
|
||||
"timed out waiting for indices: {:?} after {:?}",
|
||||
remaining, timeout
|
||||
),
|
||||
})
|
||||
}
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
pub(crate) mod client;
|
||||
pub(crate) mod db;
|
||||
mod retry;
|
||||
pub(crate) mod table;
|
||||
pub(crate) mod util;
|
||||
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::{collections::HashMap, future::Future, str::FromStr, time::Duration};
|
||||
|
||||
use http::HeaderName;
|
||||
use log::debug;
|
||||
use reqwest::{
|
||||
header::{HeaderMap, HeaderValue},
|
||||
Request, RequestBuilder, Response,
|
||||
Body, Request, RequestBuilder, Response,
|
||||
};
|
||||
use std::{collections::HashMap, future::Future, str::FromStr, time::Duration};
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::remote::db::RemoteOptions;
|
||||
use crate::remote::retry::{ResolvedRetryConfig, RetryCounter};
|
||||
|
||||
const REQUEST_ID_HEADER: HeaderName = HeaderName::from_static("x-request-id");
|
||||
|
||||
@@ -118,41 +118,14 @@ pub struct RetryConfig {
|
||||
/// You can also set the `LANCE_CLIENT_RETRY_STATUSES` environment variable
|
||||
/// to set this value. Use a comma-separated list of integer values.
|
||||
///
|
||||
/// The default is 429, 500, 502, 503.
|
||||
/// Note that write operations will never be retried on 5xx errors as this may
|
||||
/// result in duplicated writes.
|
||||
///
|
||||
/// The default is 409, 429, 500, 502, 503, 504.
|
||||
pub statuses: Option<Vec<u16>>,
|
||||
// TODO: should we allow customizing methods?
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct ResolvedRetryConfig {
|
||||
retries: u8,
|
||||
connect_retries: u8,
|
||||
read_retries: u8,
|
||||
backoff_factor: f32,
|
||||
backoff_jitter: f32,
|
||||
statuses: Vec<reqwest::StatusCode>,
|
||||
}
|
||||
|
||||
impl TryFrom<RetryConfig> for ResolvedRetryConfig {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(retry_config: RetryConfig) -> Result<Self> {
|
||||
Ok(Self {
|
||||
retries: retry_config.retries.unwrap_or(3),
|
||||
connect_retries: retry_config.connect_retries.unwrap_or(3),
|
||||
read_retries: retry_config.read_retries.unwrap_or(3),
|
||||
backoff_factor: retry_config.backoff_factor.unwrap_or(0.25),
|
||||
backoff_jitter: retry_config.backoff_jitter.unwrap_or(0.25),
|
||||
statuses: retry_config
|
||||
.statuses
|
||||
.unwrap_or_else(|| vec![429, 500, 502, 503])
|
||||
.into_iter()
|
||||
.map(|status| reqwest::StatusCode::from_u16(status).unwrap())
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// We use the `HttpSend` trait to abstract over the `reqwest::Client` so that
|
||||
// we can mock responses in tests. Based on the patterns from this blog post:
|
||||
// https://write.as/balrogboogie/testing-reqwest-based-clients
|
||||
@@ -160,8 +133,8 @@ impl TryFrom<RetryConfig> for ResolvedRetryConfig {
|
||||
pub struct RestfulLanceDbClient<S: HttpSend = Sender> {
|
||||
client: reqwest::Client,
|
||||
host: String,
|
||||
retry_config: ResolvedRetryConfig,
|
||||
sender: S,
|
||||
pub(crate) retry_config: ResolvedRetryConfig,
|
||||
pub(crate) sender: S,
|
||||
}
|
||||
|
||||
pub trait HttpSend: Clone + Send + Sync + std::fmt::Debug + 'static {
|
||||
@@ -375,74 +348,69 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
self.client.post(full_uri)
|
||||
}
|
||||
|
||||
pub async fn send(&self, req: RequestBuilder, with_retry: bool) -> Result<(String, Response)> {
|
||||
pub async fn send(&self, req: RequestBuilder) -> Result<(String, Response)> {
|
||||
let (client, request) = req.build_split();
|
||||
let mut request = request.unwrap();
|
||||
let request_id = self.extract_request_id(&mut request);
|
||||
self.log_request(&request, &request_id);
|
||||
|
||||
// Set a request id.
|
||||
// TODO: allow the user to supply this, through middleware?
|
||||
let request_id = if let Some(request_id) = request.headers().get(REQUEST_ID_HEADER) {
|
||||
request_id.to_str().unwrap().to_string()
|
||||
} else {
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
let header = HeaderValue::from_str(&request_id).unwrap();
|
||||
request.headers_mut().insert(REQUEST_ID_HEADER, header);
|
||||
request_id
|
||||
};
|
||||
|
||||
if log::log_enabled!(log::Level::Debug) {
|
||||
let content_type = request
|
||||
.headers()
|
||||
.get("content-type")
|
||||
.map(|v| v.to_str().unwrap());
|
||||
if content_type == Some("application/json") {
|
||||
let body = request.body().as_ref().unwrap().as_bytes().unwrap();
|
||||
let body = String::from_utf8_lossy(body);
|
||||
debug!(
|
||||
"Sending request_id={}: {:?} with body {}",
|
||||
request_id, request, body
|
||||
);
|
||||
} else {
|
||||
debug!("Sending request_id={}: {:?}", request_id, request);
|
||||
}
|
||||
}
|
||||
|
||||
if with_retry {
|
||||
self.send_with_retry_impl(client, request, request_id).await
|
||||
} else {
|
||||
let response = self
|
||||
.sender
|
||||
.send(&client, request)
|
||||
.await
|
||||
.err_to_http(request_id.clone())?;
|
||||
debug!(
|
||||
"Received response for request_id={}: {:?}",
|
||||
request_id, &response
|
||||
);
|
||||
Ok((request_id, response))
|
||||
}
|
||||
let response = self
|
||||
.sender
|
||||
.send(&client, request)
|
||||
.await
|
||||
.err_to_http(request_id.clone())?;
|
||||
debug!(
|
||||
"Received response for request_id={}: {:?}",
|
||||
request_id, &response
|
||||
);
|
||||
Ok((request_id, response))
|
||||
}
|
||||
|
||||
async fn send_with_retry_impl(
|
||||
/// Send the request using retries configured in the RetryConfig.
|
||||
/// If retry_5xx is false, 5xx requests will not be retried regardless of the statuses configured
|
||||
/// in the RetryConfig.
|
||||
/// Since this requires arrow serialization, this is implemented here instead of in RestfulLanceDbClient
|
||||
pub async fn send_with_retry(
|
||||
&self,
|
||||
client: reqwest::Client,
|
||||
req: Request,
|
||||
request_id: String,
|
||||
req_builder: RequestBuilder,
|
||||
mut make_body: Option<Box<dyn FnMut() -> Result<Body> + Send + 'static>>,
|
||||
retry_5xx: bool,
|
||||
) -> Result<(String, Response)> {
|
||||
let mut retry_counter = RetryCounter::new(&self.retry_config, request_id);
|
||||
let retry_config = &self.retry_config;
|
||||
let non_5xx_statuses = retry_config
|
||||
.statuses
|
||||
.iter()
|
||||
.filter(|s| !s.is_server_error())
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// clone and build the request to extract the request id
|
||||
let tmp_req = req_builder.try_clone().ok_or_else(|| Error::Runtime {
|
||||
message: "Attempted to retry a request that cannot be cloned".to_string(),
|
||||
})?;
|
||||
let (_, r) = tmp_req.build_split();
|
||||
let mut r = r.unwrap();
|
||||
let request_id = self.extract_request_id(&mut r);
|
||||
let mut retry_counter = RetryCounter::new(retry_config, request_id.clone());
|
||||
|
||||
loop {
|
||||
// This only works if the request body is not a stream. If it is
|
||||
// a stream, we can't use the retry path. We would need to implement
|
||||
// an outer retry.
|
||||
let request = req.try_clone().ok_or_else(|| Error::Runtime {
|
||||
let mut req_builder = req_builder.try_clone().ok_or_else(|| Error::Runtime {
|
||||
message: "Attempted to retry a request that cannot be cloned".to_string(),
|
||||
})?;
|
||||
let response = self
|
||||
.sender
|
||||
.send(&client, request)
|
||||
.await
|
||||
.map(|r| (r.status(), r));
|
||||
|
||||
// set the streaming body on the request builder after clone
|
||||
if let Some(body_gen) = make_body.as_mut() {
|
||||
let body = body_gen()?;
|
||||
req_builder = req_builder.body(body);
|
||||
}
|
||||
|
||||
let (c, request) = req_builder.build_split();
|
||||
let mut request = request.unwrap();
|
||||
self.set_request_id(&mut request, &request_id.clone());
|
||||
self.log_request(&request, &request_id);
|
||||
|
||||
let response = self.sender.send(&c, request).await.map(|r| (r.status(), r));
|
||||
|
||||
match response {
|
||||
Ok((status, response)) if status.is_success() => {
|
||||
debug!(
|
||||
@@ -451,7 +419,10 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
);
|
||||
return Ok((retry_counter.request_id, response));
|
||||
}
|
||||
Ok((status, response)) if self.retry_config.statuses.contains(&status) => {
|
||||
Ok((status, response))
|
||||
if (retry_5xx && retry_config.statuses.contains(&status))
|
||||
|| non_5xx_statuses.contains(&status) =>
|
||||
{
|
||||
let source = self
|
||||
.check_response(&retry_counter.request_id, response)
|
||||
.await
|
||||
@@ -480,6 +451,47 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
}
|
||||
}
|
||||
|
||||
fn log_request(&self, request: &Request, request_id: &String) {
|
||||
if log::log_enabled!(log::Level::Debug) {
|
||||
let content_type = request
|
||||
.headers()
|
||||
.get("content-type")
|
||||
.map(|v| v.to_str().unwrap());
|
||||
if content_type == Some("application/json") {
|
||||
let body = request.body().as_ref().unwrap().as_bytes().unwrap();
|
||||
let body = String::from_utf8_lossy(body);
|
||||
debug!(
|
||||
"Sending request_id={}: {:?} with body {}",
|
||||
request_id, request, body
|
||||
);
|
||||
} else {
|
||||
debug!("Sending request_id={}: {:?}", request_id, request);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the request ID from the request headers.
|
||||
/// If the request ID header is not set, this will generate a new one and set
|
||||
/// it on the request headers
|
||||
pub fn extract_request_id(&self, request: &mut Request) -> String {
|
||||
// Set a request id.
|
||||
// TODO: allow the user to supply this, through middleware?
|
||||
let request_id = if let Some(request_id) = request.headers().get(REQUEST_ID_HEADER) {
|
||||
request_id.to_str().unwrap().to_string()
|
||||
} else {
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
self.set_request_id(request, &request_id);
|
||||
request_id
|
||||
};
|
||||
request_id
|
||||
}
|
||||
|
||||
/// Set the request ID header
|
||||
pub fn set_request_id(&self, request: &mut Request, request_id: &str) {
|
||||
let header = HeaderValue::from_str(request_id).unwrap();
|
||||
request.headers_mut().insert(REQUEST_ID_HEADER, header);
|
||||
}
|
||||
|
||||
pub async fn check_response(&self, request_id: &str, response: Response) -> Result<Response> {
|
||||
// Try to get the response text, but if that fails, just return the status code
|
||||
let status = response.status();
|
||||
@@ -501,91 +513,6 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
}
|
||||
}
|
||||
|
||||
struct RetryCounter<'a> {
|
||||
request_failures: u8,
|
||||
connect_failures: u8,
|
||||
read_failures: u8,
|
||||
config: &'a ResolvedRetryConfig,
|
||||
request_id: String,
|
||||
}
|
||||
|
||||
impl<'a> RetryCounter<'a> {
|
||||
fn new(config: &'a ResolvedRetryConfig, request_id: String) -> Self {
|
||||
Self {
|
||||
request_failures: 0,
|
||||
connect_failures: 0,
|
||||
read_failures: 0,
|
||||
config,
|
||||
request_id,
|
||||
}
|
||||
}
|
||||
|
||||
fn check_out_of_retries(
|
||||
&self,
|
||||
source: Box<dyn std::error::Error + Send + Sync>,
|
||||
status_code: Option<reqwest::StatusCode>,
|
||||
) -> Result<()> {
|
||||
if self.request_failures >= self.config.retries
|
||||
|| self.connect_failures >= self.config.connect_retries
|
||||
|| self.read_failures >= self.config.read_retries
|
||||
{
|
||||
Err(Error::Retry {
|
||||
request_id: self.request_id.clone(),
|
||||
request_failures: self.request_failures,
|
||||
max_request_failures: self.config.retries,
|
||||
connect_failures: self.connect_failures,
|
||||
max_connect_failures: self.config.connect_retries,
|
||||
read_failures: self.read_failures,
|
||||
max_read_failures: self.config.read_retries,
|
||||
source,
|
||||
status_code,
|
||||
})
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn increment_request_failures(&mut self, source: crate::Error) -> Result<()> {
|
||||
self.request_failures += 1;
|
||||
let status_code = if let crate::Error::Http { status_code, .. } = &source {
|
||||
*status_code
|
||||
} else {
|
||||
None
|
||||
};
|
||||
self.check_out_of_retries(Box::new(source), status_code)
|
||||
}
|
||||
|
||||
fn increment_connect_failures(&mut self, source: reqwest::Error) -> Result<()> {
|
||||
self.connect_failures += 1;
|
||||
let status_code = source.status();
|
||||
self.check_out_of_retries(Box::new(source), status_code)
|
||||
}
|
||||
|
||||
fn increment_read_failures(&mut self, source: reqwest::Error) -> Result<()> {
|
||||
self.read_failures += 1;
|
||||
let status_code = source.status();
|
||||
self.check_out_of_retries(Box::new(source), status_code)
|
||||
}
|
||||
|
||||
fn next_sleep_time(&self) -> Duration {
|
||||
let backoff = self.config.backoff_factor * (2.0f32.powi(self.request_failures as i32));
|
||||
let jitter = rand::random::<f32>() * self.config.backoff_jitter;
|
||||
let sleep_time = Duration::from_secs_f32(backoff + jitter);
|
||||
debug!(
|
||||
"Retrying request {:?} ({}/{} connect, {}/{} read, {}/{} read) in {:?}",
|
||||
self.request_id,
|
||||
self.connect_failures,
|
||||
self.config.connect_retries,
|
||||
self.request_failures,
|
||||
self.config.retries,
|
||||
self.read_failures,
|
||||
self.config.read_retries,
|
||||
sleep_time
|
||||
);
|
||||
sleep_time
|
||||
}
|
||||
}
|
||||
|
||||
pub trait RequestResultExt {
|
||||
type Output;
|
||||
fn err_to_http(self, request_id: String) -> Result<Self::Output>;
|
||||
|
||||
@@ -255,7 +255,7 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
if let Some(start_after) = request.start_after {
|
||||
req = req.query(&[("page_token", start_after)]);
|
||||
}
|
||||
let (request_id, rsp) = self.client.send(req, true).await?;
|
||||
let (request_id, rsp) = self.client.send_with_retry(req, None, true).await?;
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let version = parse_server_version(&request_id, &rsp)?;
|
||||
let tables = rsp
|
||||
@@ -302,7 +302,7 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
.body(data_buffer)
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE);
|
||||
|
||||
let (request_id, rsp) = self.client.send(req, false).await?;
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
|
||||
if rsp.status() == StatusCode::BAD_REQUEST {
|
||||
let body = rsp.text().await.err_to_http(request_id.clone())?;
|
||||
@@ -362,7 +362,7 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
let req = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/describe/", request.name));
|
||||
let (request_id, rsp) = self.client.send(req, true).await?;
|
||||
let (request_id, rsp) = self.client.send_with_retry(req, None, true).await?;
|
||||
if rsp.status() == StatusCode::NOT_FOUND {
|
||||
return Err(crate::Error::TableNotFound { name: request.name });
|
||||
}
|
||||
@@ -383,7 +383,7 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/rename/", current_name));
|
||||
let req = req.json(&serde_json::json!({ "new_table_name": new_name }));
|
||||
let (request_id, resp) = self.client.send(req, false).await?;
|
||||
let (request_id, resp) = self.client.send(req).await?;
|
||||
self.client.check_response(&request_id, resp).await?;
|
||||
let table = self.table_cache.remove(current_name).await;
|
||||
if let Some(table) = table {
|
||||
@@ -394,7 +394,7 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
|
||||
async fn drop_table(&self, name: &str) -> Result<()> {
|
||||
let req = self.client.post(&format!("/v1/table/{}/drop/", name));
|
||||
let (request_id, resp) = self.client.send(req, true).await?;
|
||||
let (request_id, resp) = self.client.send(req).await?;
|
||||
self.client.check_response(&request_id, resp).await?;
|
||||
self.table_cache.remove(name).await;
|
||||
Ok(())
|
||||
|
||||
122
rust/lancedb/src/remote/retry.rs
Normal file
122
rust/lancedb/src/remote/retry.rs
Normal file
@@ -0,0 +1,122 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use crate::remote::RetryConfig;
|
||||
use crate::Error;
|
||||
use log::debug;
|
||||
use std::time::Duration;
|
||||
|
||||
pub struct RetryCounter<'a> {
|
||||
pub request_failures: u8,
|
||||
pub connect_failures: u8,
|
||||
pub read_failures: u8,
|
||||
pub config: &'a ResolvedRetryConfig,
|
||||
pub request_id: String,
|
||||
}
|
||||
|
||||
impl<'a> RetryCounter<'a> {
|
||||
pub(crate) fn new(config: &'a ResolvedRetryConfig, request_id: String) -> Self {
|
||||
Self {
|
||||
request_failures: 0,
|
||||
connect_failures: 0,
|
||||
read_failures: 0,
|
||||
config,
|
||||
request_id,
|
||||
}
|
||||
}
|
||||
|
||||
fn check_out_of_retries(
|
||||
&self,
|
||||
source: Box<dyn std::error::Error + Send + Sync>,
|
||||
status_code: Option<reqwest::StatusCode>,
|
||||
) -> crate::Result<()> {
|
||||
if self.request_failures >= self.config.retries
|
||||
|| self.connect_failures >= self.config.connect_retries
|
||||
|| self.read_failures >= self.config.read_retries
|
||||
{
|
||||
Err(Error::Retry {
|
||||
request_id: self.request_id.clone(),
|
||||
request_failures: self.request_failures,
|
||||
max_request_failures: self.config.retries,
|
||||
connect_failures: self.connect_failures,
|
||||
max_connect_failures: self.config.connect_retries,
|
||||
read_failures: self.read_failures,
|
||||
max_read_failures: self.config.read_retries,
|
||||
source,
|
||||
status_code,
|
||||
})
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn increment_request_failures(&mut self, source: crate::Error) -> crate::Result<()> {
|
||||
self.request_failures += 1;
|
||||
let status_code = if let crate::Error::Http { status_code, .. } = &source {
|
||||
*status_code
|
||||
} else {
|
||||
None
|
||||
};
|
||||
self.check_out_of_retries(Box::new(source), status_code)
|
||||
}
|
||||
|
||||
pub fn increment_connect_failures(&mut self, source: reqwest::Error) -> crate::Result<()> {
|
||||
self.connect_failures += 1;
|
||||
let status_code = source.status();
|
||||
self.check_out_of_retries(Box::new(source), status_code)
|
||||
}
|
||||
|
||||
pub fn increment_read_failures(&mut self, source: reqwest::Error) -> crate::Result<()> {
|
||||
self.read_failures += 1;
|
||||
let status_code = source.status();
|
||||
self.check_out_of_retries(Box::new(source), status_code)
|
||||
}
|
||||
|
||||
pub fn next_sleep_time(&self) -> Duration {
|
||||
let backoff = self.config.backoff_factor * (2.0f32.powi(self.request_failures as i32));
|
||||
let jitter = rand::random::<f32>() * self.config.backoff_jitter;
|
||||
let sleep_time = Duration::from_secs_f32(backoff + jitter);
|
||||
debug!(
|
||||
"Retrying request {:?} ({}/{} connect, {}/{} read, {}/{} read) in {:?}",
|
||||
self.request_id,
|
||||
self.connect_failures,
|
||||
self.config.connect_retries,
|
||||
self.request_failures,
|
||||
self.config.retries,
|
||||
self.read_failures,
|
||||
self.config.read_retries,
|
||||
sleep_time
|
||||
);
|
||||
sleep_time
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ResolvedRetryConfig {
|
||||
pub retries: u8,
|
||||
pub connect_retries: u8,
|
||||
pub read_retries: u8,
|
||||
pub backoff_factor: f32,
|
||||
pub backoff_jitter: f32,
|
||||
pub statuses: Vec<reqwest::StatusCode>,
|
||||
}
|
||||
|
||||
impl TryFrom<RetryConfig> for ResolvedRetryConfig {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(retry_config: RetryConfig) -> crate::Result<Self> {
|
||||
Ok(Self {
|
||||
retries: retry_config.retries.unwrap_or(3),
|
||||
connect_retries: retry_config.connect_retries.unwrap_or(3),
|
||||
read_retries: retry_config.read_retries.unwrap_or(3),
|
||||
backoff_factor: retry_config.backoff_factor.unwrap_or(0.25),
|
||||
backoff_jitter: retry_config.backoff_jitter.unwrap_or(0.25),
|
||||
statuses: retry_config
|
||||
.statuses
|
||||
.unwrap_or_else(|| vec![409, 429, 500, 502, 503, 504])
|
||||
.into_iter()
|
||||
.map(|status| reqwest::StatusCode::from_u16(status).unwrap())
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,17 +1,13 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::io::Cursor;
|
||||
use std::pin::Pin;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::index::Index;
|
||||
use crate::index::IndexStatistics;
|
||||
use crate::query::{QueryFilter, QueryRequest, Select, VectorQueryRequest};
|
||||
use crate::table::{AddDataMode, AnyQuery, Filter};
|
||||
use crate::utils::{supported_btree_data_type, supported_vector_data_type};
|
||||
use crate::{DistanceType, Error, Table};
|
||||
use arrow_array::RecordBatchReader;
|
||||
use arrow_array::{RecordBatch, RecordBatchIterator, RecordBatchReader};
|
||||
use arrow_ipc::reader::FileReader;
|
||||
use arrow_schema::{DataType, SchemaRef};
|
||||
use async_trait::async_trait;
|
||||
@@ -25,9 +21,19 @@ use lance::arrow::json::{JsonDataType, JsonSchema};
|
||||
use lance::dataset::scanner::DatasetRecordBatchStream;
|
||||
use lance::dataset::{ColumnAlteration, NewColumnTransform, Version};
|
||||
use lance_datafusion::exec::{execute_plan, OneShotExec};
|
||||
use reqwest::{RequestBuilder, Response};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::io::Cursor;
|
||||
use std::pin::Pin;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use super::client::RequestResultExt;
|
||||
use super::client::{HttpSend, RestfulLanceDbClient, Sender};
|
||||
use super::db::ServerVersion;
|
||||
use super::ARROW_STREAM_CONTENT_TYPE;
|
||||
use crate::index::waiter::wait_for_index;
|
||||
use crate::{
|
||||
connection::NoData,
|
||||
error::Result,
|
||||
@@ -39,11 +45,6 @@ use crate::{
|
||||
},
|
||||
};
|
||||
|
||||
use super::client::RequestResultExt;
|
||||
use super::client::{HttpSend, RestfulLanceDbClient, Sender};
|
||||
use super::db::ServerVersion;
|
||||
use super::ARROW_STREAM_CONTENT_TYPE;
|
||||
|
||||
const REQUEST_TIMEOUT_HEADER: HeaderName = HeaderName::from_static("x-request-timeout-ms");
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -83,7 +84,7 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
let body = serde_json::json!({ "version": version });
|
||||
request = request.json(&body);
|
||||
|
||||
let (request_id, response) = self.client.send(request, true).await?;
|
||||
let (request_id, response) = self.send(request, true).await?;
|
||||
|
||||
let response = self.check_table_response(&request_id, response).await?;
|
||||
|
||||
@@ -127,6 +128,61 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
Ok(reqwest::Body::wrap_stream(body_stream))
|
||||
}
|
||||
|
||||
/// Buffer the reader into memory
|
||||
async fn buffer_reader<R: RecordBatchReader + ?Sized>(
|
||||
reader: &mut R,
|
||||
) -> Result<(SchemaRef, Vec<RecordBatch>)> {
|
||||
let schema = reader.schema();
|
||||
let mut batches = Vec::new();
|
||||
for batch in reader {
|
||||
batches.push(batch?);
|
||||
}
|
||||
Ok((schema, batches))
|
||||
}
|
||||
|
||||
/// Create a new RecordBatchReader from buffered data
|
||||
fn make_reader(schema: SchemaRef, batches: Vec<RecordBatch>) -> impl RecordBatchReader {
|
||||
let iter = batches.into_iter().map(Ok);
|
||||
RecordBatchIterator::new(iter, schema)
|
||||
}
|
||||
|
||||
async fn send(&self, req: RequestBuilder, with_retry: bool) -> Result<(String, Response)> {
|
||||
let res = if with_retry {
|
||||
self.client.send_with_retry(req, None, true).await?
|
||||
} else {
|
||||
self.client.send(req).await?
|
||||
};
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Send the request with streaming body.
|
||||
/// This will use retries if with_retry is set and the number of configured retries is > 0.
|
||||
/// If retries are enabled, the stream will be buffered into memory.
|
||||
async fn send_streaming(
|
||||
&self,
|
||||
req: RequestBuilder,
|
||||
mut data: Box<dyn RecordBatchReader + Send>,
|
||||
with_retry: bool,
|
||||
) -> Result<(String, Response)> {
|
||||
if !with_retry || self.client.retry_config.retries == 0 {
|
||||
let body = Self::reader_as_body(data)?;
|
||||
return self.client.send(req.body(body)).await;
|
||||
}
|
||||
|
||||
// to support retries, buffer into memory and clone the batches on each retry
|
||||
let (schema, batches) = Self::buffer_reader(&mut *data).await?;
|
||||
let make_body = Box::new(move || {
|
||||
let reader = Self::make_reader(schema.clone(), batches.clone());
|
||||
Self::reader_as_body(Box::new(reader))
|
||||
});
|
||||
let res = self
|
||||
.client
|
||||
.send_with_retry(req, Some(make_body), false)
|
||||
.await?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
async fn check_table_response(
|
||||
&self,
|
||||
request_id: &str,
|
||||
@@ -168,7 +224,8 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
}
|
||||
|
||||
// Server requires k.
|
||||
let limit = params.limit.unwrap_or(usize::MAX);
|
||||
// use isize::MAX as usize to avoid overflow: https://github.com/lancedb/lancedb/issues/2211
|
||||
let limit = params.limit.unwrap_or(isize::MAX as usize);
|
||||
body["k"] = serde_json::Value::Number(serde_json::Number::from(limit));
|
||||
|
||||
if let Some(filter) = ¶ms.filter {
|
||||
@@ -339,8 +396,6 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
let mut request = self.client.post(&format!("/v1/table/{}/query/", self.name));
|
||||
|
||||
if let Some(timeout) = options.timeout {
|
||||
// Client side timeout
|
||||
request = request.timeout(timeout);
|
||||
// Also send to server, so it can abort the query if it takes too long.
|
||||
// (If it doesn't fit into u64, it's not worth sending anyways.)
|
||||
if let Ok(timeout_ms) = u64::try_from(timeout.as_millis()) {
|
||||
@@ -355,11 +410,29 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
.collect();
|
||||
|
||||
let futures = requests.into_iter().map(|req| async move {
|
||||
let (request_id, response) = self.client.send(req, true).await?;
|
||||
let (request_id, response) = self.send(req, true).await?;
|
||||
self.read_arrow_stream(&request_id, response).await
|
||||
});
|
||||
let streams = futures::future::try_join_all(futures).await?;
|
||||
Ok(streams)
|
||||
let streams = futures::future::try_join_all(futures);
|
||||
|
||||
if let Some(timeout) = options.timeout {
|
||||
let timeout_future = tokio::time::sleep(timeout);
|
||||
tokio::pin!(timeout_future);
|
||||
tokio::pin!(streams);
|
||||
tokio::select! {
|
||||
_ = &mut timeout_future => {
|
||||
Err(Error::Other {
|
||||
message: format!("Query timeout after {} ms", timeout.as_millis()),
|
||||
source: None,
|
||||
})
|
||||
}
|
||||
result = &mut streams => {
|
||||
Ok(result?)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Ok(streams.await?)
|
||||
}
|
||||
}
|
||||
|
||||
async fn prepare_query_bodies(&self, query: &AnyQuery) -> Result<Vec<serde_json::Value>> {
|
||||
@@ -455,7 +528,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
let body = serde_json::json!({ "version": version });
|
||||
request = request.json(&body);
|
||||
|
||||
let (request_id, response) = self.client.send(request, true).await?;
|
||||
let (request_id, response) = self.send(request, true).await?;
|
||||
self.check_table_response(&request_id, response).await?;
|
||||
self.checkout_latest().await?;
|
||||
Ok(())
|
||||
@@ -465,7 +538,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/version/list/", self.name));
|
||||
let (request_id, response) = self.client.send(request, true).await?;
|
||||
let (request_id, response) = self.send(request, true).await?;
|
||||
let response = self.check_table_response(&request_id, response).await?;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -511,7 +584,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
request = request.json(&body);
|
||||
}
|
||||
|
||||
let (request_id, response) = self.client.send(request, true).await?;
|
||||
let (request_id, response) = self.send(request, true).await?;
|
||||
|
||||
let response = self.check_table_response(&request_id, response).await?;
|
||||
|
||||
@@ -529,12 +602,10 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
data: Box<dyn RecordBatchReader + Send>,
|
||||
) -> Result<()> {
|
||||
self.check_mutable().await?;
|
||||
let body = Self::reader_as_body(data)?;
|
||||
let mut request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/insert/", self.name))
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
|
||||
.body(body);
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE);
|
||||
|
||||
match add.mode {
|
||||
AddDataMode::Append => {}
|
||||
@@ -543,8 +614,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
}
|
||||
}
|
||||
|
||||
let (request_id, response) = self.client.send(request, false).await?;
|
||||
|
||||
let (request_id, response) = self.send_streaming(request, data, true).await?;
|
||||
self.check_table_response(&request_id, response).await?;
|
||||
|
||||
Ok(())
|
||||
@@ -612,7 +682,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let futures = requests.into_iter().map(|req| async move {
|
||||
let (request_id, response) = self.client.send(req, true).await?;
|
||||
let (request_id, response) = self.send(req, true).await?;
|
||||
let response = self.check_table_response(&request_id, response).await?;
|
||||
let body = response.text().await.err_to_http(request_id.clone())?;
|
||||
|
||||
@@ -654,7 +724,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
.collect();
|
||||
|
||||
let futures = requests.into_iter().map(|req| async move {
|
||||
let (request_id, response) = self.client.send(req, true).await?;
|
||||
let (request_id, response) = self.send(req, true).await?;
|
||||
let response = self.check_table_response(&request_id, response).await?;
|
||||
let body = response.text().await.err_to_http(request_id.clone())?;
|
||||
|
||||
@@ -696,7 +766,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
"predicate": update.filter,
|
||||
}));
|
||||
|
||||
let (request_id, response) = self.client.send(request, false).await?;
|
||||
let (request_id, response) = self.send(request, true).await?;
|
||||
|
||||
self.check_table_response(&request_id, response).await?;
|
||||
|
||||
@@ -710,7 +780,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/delete/", self.name))
|
||||
.json(&body);
|
||||
let (request_id, response) = self.client.send(request, false).await?;
|
||||
let (request_id, response) = self.send(request, true).await?;
|
||||
self.check_table_response(&request_id, response).await?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -796,34 +866,45 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
|
||||
let request = request.json(&body);
|
||||
|
||||
let (request_id, response) = self.client.send(request, false).await?;
|
||||
let (request_id, response) = self.send(request, true).await?;
|
||||
|
||||
self.check_table_response(&request_id, response).await?;
|
||||
|
||||
if let Some(wait_timeout) = index.wait_timeout {
|
||||
let name = format!("{}_idx", column);
|
||||
self.wait_for_index(&[&name], wait_timeout).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Poll until the columns are fully indexed. Will return Error::Timeout if the columns
|
||||
/// are not fully indexed within the timeout.
|
||||
async fn wait_for_index(&self, index_names: &[&str], timeout: Duration) -> Result<()> {
|
||||
wait_for_index(self, index_names, timeout).await
|
||||
}
|
||||
|
||||
async fn merge_insert(
|
||||
&self,
|
||||
params: MergeInsertBuilder,
|
||||
new_data: Box<dyn RecordBatchReader + Send>,
|
||||
) -> Result<()> {
|
||||
self.check_mutable().await?;
|
||||
|
||||
let query = MergeInsertRequest::try_from(params)?;
|
||||
let body = Self::reader_as_body(new_data)?;
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/merge_insert/", self.name))
|
||||
.query(&query)
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
|
||||
.body(body);
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE);
|
||||
|
||||
let (request_id, response) = self.client.send(request, false).await?;
|
||||
let (request_id, response) = self.send_streaming(request, new_data, true).await?;
|
||||
|
||||
self.check_table_response(&request_id, response).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn optimize(&self, _action: OptimizeAction) -> Result<OptimizeStats> {
|
||||
self.check_mutable().await?;
|
||||
Err(Error::NotSupported {
|
||||
@@ -852,7 +933,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/add_columns/", self.name))
|
||||
.json(&body);
|
||||
let (request_id, response) = self.client.send(request, false).await?;
|
||||
let (request_id, response) = self.send(request, true).await?; // todo:
|
||||
self.check_table_response(&request_id, response).await?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -891,7 +972,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/alter_columns/", self.name))
|
||||
.json(&body);
|
||||
let (request_id, response) = self.client.send(request, false).await?;
|
||||
let (request_id, response) = self.send(request, true).await?;
|
||||
self.check_table_response(&request_id, response).await?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -903,7 +984,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/drop_columns/", self.name))
|
||||
.json(&body);
|
||||
let (request_id, response) = self.client.send(request, false).await?;
|
||||
let (request_id, response) = self.send(request, true).await?;
|
||||
self.check_table_response(&request_id, response).await?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -917,7 +998,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
let body = serde_json::json!({ "version": version });
|
||||
request = request.json(&body);
|
||||
|
||||
let (request_id, response) = self.client.send(request, true).await?;
|
||||
let (request_id, response) = self.send(request, true).await?;
|
||||
let response = self.check_table_response(&request_id, response).await?;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -974,7 +1055,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
let body = serde_json::json!({ "version": version });
|
||||
request = request.json(&body);
|
||||
|
||||
let (request_id, response) = self.client.send(request, true).await?;
|
||||
let (request_id, response) = self.send(request, true).await?;
|
||||
|
||||
if response.status() == StatusCode::NOT_FOUND {
|
||||
return Ok(None);
|
||||
@@ -998,7 +1079,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
"/v1/table/{}/index/{}/drop/",
|
||||
self.name, index_name
|
||||
));
|
||||
let (request_id, response) = self.client.send(request, true).await?;
|
||||
let (request_id, response) = self.send(request, true).await?;
|
||||
self.check_table_response(&request_id, response).await?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -1459,6 +1540,42 @@ mod tests {
|
||||
assert_eq!(&body, &expected_body);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_insert_retries_on_409() {
|
||||
let batch = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap();
|
||||
let data = Box::new(RecordBatchIterator::new(
|
||||
[Ok(batch.clone())],
|
||||
batch.schema(),
|
||||
));
|
||||
|
||||
// Default parameters
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/merge_insert/");
|
||||
|
||||
let params = request.url().query_pairs().collect::<HashMap<_, _>>();
|
||||
assert_eq!(params["on"], "some_col");
|
||||
assert_eq!(params["when_matched_update_all"], "false");
|
||||
assert_eq!(params["when_not_matched_insert_all"], "false");
|
||||
assert_eq!(params["when_not_matched_by_source_delete"], "false");
|
||||
assert!(!params.contains_key("when_matched_update_all_filt"));
|
||||
assert!(!params.contains_key("when_not_matched_by_source_delete_filt"));
|
||||
|
||||
http::Response::builder().status(409).body("").unwrap()
|
||||
});
|
||||
|
||||
let e = table
|
||||
.merge_insert(&["some_col"])
|
||||
.execute(data)
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(e.to_string().contains("Hit retry limit"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
@@ -1500,7 +1617,7 @@ mod tests {
|
||||
let body = request.body().unwrap().as_bytes().unwrap();
|
||||
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
|
||||
let expected_body = serde_json::json!({
|
||||
"k": usize::MAX,
|
||||
"k": isize::MAX as usize,
|
||||
"prefilter": true,
|
||||
"vector": [], // Empty vector means no vector query.
|
||||
"version": null,
|
||||
@@ -2416,4 +2533,88 @@ mod tests {
|
||||
});
|
||||
table.drop_index("my_index").await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_wait_for_index() {
|
||||
let table = _make_table_with_indices(0);
|
||||
table
|
||||
.wait_for_index(&["vector_idx", "my_idx"], Duration::from_secs(1))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_wait_for_index_timeout() {
|
||||
let table = _make_table_with_indices(100);
|
||||
let e = table
|
||||
.wait_for_index(&["vector_idx", "my_idx"], Duration::from_secs(1))
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_eq!(
|
||||
e.to_string(),
|
||||
"Timeout error: timed out waiting for indices: [\"vector_idx\", \"my_idx\"] after 1s"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_wait_for_index_timeout_never_created() {
|
||||
let table = _make_table_with_indices(0);
|
||||
let e = table
|
||||
.wait_for_index(&["doesnt_exist_idx"], Duration::from_secs(1))
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_eq!(
|
||||
e.to_string(),
|
||||
"Timeout error: timed out waiting for indices: [\"doesnt_exist_idx\"] after 1s"
|
||||
);
|
||||
}
|
||||
|
||||
fn _make_table_with_indices(unindexed_rows: usize) -> Table {
|
||||
let table = Table::new_with_handler("my_table", move |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
|
||||
let response_body = match request.url().path() {
|
||||
"/v1/table/my_table/index/list/" => {
|
||||
serde_json::json!({
|
||||
"indexes": [
|
||||
{
|
||||
"index_name": "vector_idx",
|
||||
"index_uuid": "3fa85f64-5717-4562-b3fc-2c963f66afa6",
|
||||
"columns": ["vector"],
|
||||
"index_status": "done",
|
||||
},
|
||||
{
|
||||
"index_name": "my_idx",
|
||||
"index_uuid": "34255f64-5717-4562-b3fc-2c963f66afa6",
|
||||
"columns": ["my_column"],
|
||||
"index_status": "done",
|
||||
},
|
||||
]
|
||||
})
|
||||
}
|
||||
"/v1/table/my_table/index/vector_idx/stats/" => {
|
||||
serde_json::json!({
|
||||
"num_indexed_rows": 100000,
|
||||
"num_unindexed_rows": unindexed_rows,
|
||||
"index_type": "IVF_PQ",
|
||||
"distance_type": "l2"
|
||||
})
|
||||
}
|
||||
"/v1/table/my_table/index/my_idx/stats/" => {
|
||||
serde_json::json!({
|
||||
"num_indexed_rows": 100000,
|
||||
"num_unindexed_rows": unindexed_rows,
|
||||
"index_type": "LABEL_LIST"
|
||||
})
|
||||
}
|
||||
_path => {
|
||||
serde_json::json!(None::<String>)
|
||||
}
|
||||
};
|
||||
let body = serde_json::to_string(&response_body).unwrap();
|
||||
let status = if body == "null" { 404 } else { 200 };
|
||||
http::Response::builder().status(status).body(body).unwrap()
|
||||
});
|
||||
table
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,10 +3,6 @@
|
||||
|
||||
//! LanceDB Table APIs
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow::array::{AsArray, FixedSizeListBuilder, Float32Builder};
|
||||
use arrow::datatypes::{Float32Type, UInt8Type};
|
||||
use arrow_array::{RecordBatchIterator, RecordBatchReader};
|
||||
@@ -45,6 +41,10 @@ use lance_table::format::Manifest;
|
||||
use lance_table::io::commit::ManifestNamingScheme;
|
||||
use log::info;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::format;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::arrow::IntoArrow;
|
||||
use crate::connection::NoData;
|
||||
@@ -78,6 +78,7 @@ pub mod datafusion;
|
||||
pub(crate) mod dataset;
|
||||
pub mod merge;
|
||||
|
||||
use crate::index::waiter::wait_for_index;
|
||||
pub use chrono::Duration;
|
||||
pub use lance::dataset::optimize::CompactionOptions;
|
||||
pub use lance::dataset::scanner::DatasetRecordBatchStream;
|
||||
@@ -491,6 +492,13 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
|
||||
async fn table_definition(&self) -> Result<TableDefinition>;
|
||||
/// Get the table URI
|
||||
fn dataset_uri(&self) -> &str;
|
||||
/// Poll until the columns are fully indexed. Will return Error::Timeout if the columns
|
||||
/// are not fully indexed within the timeout.
|
||||
async fn wait_for_index(
|
||||
&self,
|
||||
index_names: &[&str],
|
||||
timeout: std::time::Duration,
|
||||
) -> Result<()>;
|
||||
}
|
||||
|
||||
/// A Table is a collection of strong typed Rows.
|
||||
@@ -769,6 +777,28 @@ impl Table {
|
||||
)
|
||||
}
|
||||
|
||||
/// See [Table::create_index]
|
||||
/// For remote tables, this allows an optional wait_timeout to poll until asynchronous indexing is complete
|
||||
pub fn create_index_with_timeout(
|
||||
&self,
|
||||
columns: &[impl AsRef<str>],
|
||||
index: Index,
|
||||
wait_timeout: Option<std::time::Duration>,
|
||||
) -> IndexBuilder {
|
||||
let mut builder = IndexBuilder::new(
|
||||
self.inner.clone(),
|
||||
columns
|
||||
.iter()
|
||||
.map(|val| val.as_ref().to_string())
|
||||
.collect::<Vec<_>>(),
|
||||
index,
|
||||
);
|
||||
if let Some(timeout) = wait_timeout {
|
||||
builder = builder.wait_timeout(timeout);
|
||||
}
|
||||
builder
|
||||
}
|
||||
|
||||
/// Create a builder for a merge insert operation
|
||||
///
|
||||
/// This operation can add rows, update rows, and remove rows all in a single
|
||||
@@ -1104,6 +1134,16 @@ impl Table {
|
||||
self.inner.prewarm_index(name).await
|
||||
}
|
||||
|
||||
/// Poll until the columns are fully indexed. Will return Error::Timeout if the columns
|
||||
/// are not fully indexed within the timeout.
|
||||
pub async fn wait_for_index(
|
||||
&self,
|
||||
index_names: &[&str],
|
||||
timeout: std::time::Duration,
|
||||
) -> Result<()> {
|
||||
self.inner.wait_for_index(index_names, timeout).await
|
||||
}
|
||||
|
||||
// Take many execution plans and map them into a single plan that adds
|
||||
// a query_index column and unions them.
|
||||
pub(crate) fn multi_vector_plan(
|
||||
@@ -2430,6 +2470,16 @@ impl BaseTable for NativeTable {
|
||||
loss,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Poll until the columns are fully indexed. Will return Error::Timeout if the columns
|
||||
/// are not fully indexed within the timeout.
|
||||
async fn wait_for_index(
|
||||
&self,
|
||||
index_names: &[&str],
|
||||
timeout: std::time::Duration,
|
||||
) -> Result<()> {
|
||||
wait_for_index(self, index_names, timeout).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -3213,7 +3263,10 @@ mod tests {
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
table
|
||||
.wait_for_index(&["embeddings_idx"], Duration::from_millis(10))
|
||||
.await
|
||||
.unwrap();
|
||||
let index_configs = table.list_indices().await.unwrap();
|
||||
assert_eq!(index_configs.len(), 1);
|
||||
let index = index_configs.into_iter().next().unwrap();
|
||||
@@ -3281,7 +3334,10 @@ mod tests {
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
table
|
||||
.wait_for_index(&["i_idx"], Duration::from_millis(10))
|
||||
.await
|
||||
.unwrap();
|
||||
let index_configs = table.list_indices().await.unwrap();
|
||||
assert_eq!(index_configs.len(), 1);
|
||||
let index = index_configs.into_iter().next().unwrap();
|
||||
|
||||
Reference in New Issue
Block a user