Compare commits

...

25 Commits

Author SHA1 Message Date
Lance Release
d11819c90c Bump version: 0.22.0-beta.10 → 0.22.0-beta.11 2025-04-25 05:01:57 +00:00
BubbleCal
9b902272f1 fix: sync hybrid search ignores the distance range params (#2356)
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Added support for distance range filtering in hybrid vector queries,
allowing users to specify lower and upper bounds for search results.

- **Tests**
- Introduced new tests to validate distance range filtering and
reranking in both synchronous and asynchronous hybrid query scenarios.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2025-04-25 13:01:22 +08:00
Will Jones
8c0622fa2c fix: remote limit to avoid "Limit must be non-negative" (#2354)
To workaround this issue: https://github.com/lancedb/lancedb/issues/2211

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **Bug Fixes**
- Improved handling of large query parameters to prevent potential
overflow issues when using the "k" parameter in queries.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-04-24 15:04:06 -07:00
Philip Meier
2191f948c3 fix: add missing pydantic model config compat (#2316)
Fixes #2315.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **Refactor**
- Enhanced query processing to maintain smooth functionality across
different dependency versions, ensuring improved stability and
performance.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-04-22 14:46:10 -07:00
Will Jones
acc3b03004 ci: fix docs deploy (#2351)
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **Chores**
- Improved CI workflow for documentation builds by optimizing Rust build
settings and updating the runner environment.
  - Fixed a typo in a workflow step name.
- Streamlined caching steps to reduce redundancy and improve efficiency.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-04-22 13:55:34 -07:00
Lance Release
7f091b8c8e Updating package-lock.json 2025-04-22 19:16:43 +00:00
Lance Release
c19bdd9a24 Updating package-lock.json 2025-04-22 18:24:16 +00:00
Lance Release
dad0ff5cd2 Updating package-lock.json 2025-04-22 18:23:59 +00:00
Lance Release
a705621067 Bump version: 0.19.0-beta.9 → 0.19.0-beta.10 2025-04-22 18:23:39 +00:00
Lance Release
39614fdb7d Bump version: 0.22.0-beta.9 → 0.22.0-beta.10 2025-04-22 18:23:17 +00:00
Ryan Green
96d534d4bc feat: add retries to remote client for requests with stream bodies (#2349)
Closes https://github.com/lancedb/lancedb/issues/2307
* Adds retries to remote operations with stream bodies (add,
merge_insert)
* Change default retryable status codes to 409, 429, 500, 502, 503, 504
* Don't retry add or merge_insert operations on 5xx responses

Notes:
* Supporting retries on stream bodies means we have to buffer the body
into memory so it can be cloned on retry. This will impact memory use
patterns for the remote client. This buffering can be disabled by
disabling retries (i.e. setting retries to 0 in RetryConfig)
* It does not seem that retry config can be specified by env vars as the
documentation suggests. I added a follow-up issue
[here](https://github.com/lancedb/lancedb/issues/2350)



<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Summary by CodeRabbit

- **New Features**
- Enhanced retry support for remote requests with configurable limits
and exponential backoff with jitter.
- Added robust retry logic for streaming data uploads, enabling retries
with buffered data to ensure reliability.

- **Bug Fixes**
- Improved error handling and retry behavior for HTTP status codes 409
and 504.

- **Refactor**
- Centralized and modularized HTTP request sending and retry logic
across remote database and table operations.
  - Streamlined request ID management for improved traceability.
- Simplified error message construction in index waiting functionality.

- **Tests**
  - Added a test verifying merge-insert retries on HTTP 409 responses.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-04-22 15:40:44 -02:30
Lance Release
5051d30d09 Updating package-lock.json 2025-04-21 23:55:43 +00:00
Lance Release
db853c4041 Updating package-lock.json 2025-04-21 22:50:56 +00:00
Lance Release
76d1d22bdc Updating package-lock.json 2025-04-21 22:50:40 +00:00
Lance Release
d8746c61c6 Bump version: 0.19.0-beta.8 → 0.19.0-beta.9 2025-04-21 22:50:20 +00:00
Lance Release
1a66df2627 Bump version: 0.22.0-beta.8 → 0.22.0-beta.9 2025-04-21 22:49:59 +00:00
Will Jones
44670076c1 fix: move timeout to avoid retries (#2347)
I added a timeout to query execution options in
https://github.com/lancedb/lancedb/pull/2288. However, this was send to
the request timeout, but the retry implementation is unaware of this
timeout. So once the query timed out, a retry would be triggered.
Instead, this PR changes it so the timeout happens outside the retry
loop.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **Bug Fixes**
- Improved query timeout handling to provide clearer error messages and
more reliable cancellation if a query takes too long to complete.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-04-21 14:27:04 -07:00
Will Jones
92f0b16e46 fix(python): make sure pandas is optional (#2346)
Fixes #2344


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **Tests**
- Updated tests to use PyArrow Tables instead of pandas DataFrames where
possible, reducing reliance on pandas.
- Tests that require pandas are now automatically skipped if pandas is
not installed.
- **Chores**
- Improved workflow to uninstall both pylance and pandas in a specific
test step.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-04-21 13:42:13 -07:00
Eileen Noonan
1620ba3508 docs: make table.update() nodejs guide consistent with API documentation (#2334)
The docs in the Guide here do not match the [API reference]
(https://lancedb.github.io/lancedb/js/classes/Table/#updateopts) for the
nodejs client.

I am writing an Elixir wrapper over the typescript library (Rust
forthcoming!) and confirmed in testing that the API reference is correct
vs the Guide.

Following the Guide docs, the error I got was:

"lance error: Invalid user input: Schema error: No field named bar.
Valid fields are foo. For a query of:

await table.update({foo: "buzz"}, { where: "foo = 'bar'"});
Over a table with a schema of just {foo: Utf8}.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **Documentation**
- Reformatted a code snippet in the guide to enhance readability by
splitting it into multiple lines for improved clarity.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-04-21 08:38:16 -07:00
Ryan Green
3ae90dde80 feat: add new table API to wait for async indexing (#2338)
* Add new wait_for_index() table operation that polls until indices are
created/fully indexed
* Add an optional wait timeout parameter to all create_index operations
* Python and NodeJS interfaces

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Summary by CodeRabbit

- **New Features**
- Added optional waiting for index creation completion with configurable
timeout.
- Introduced methods to poll and wait for indices to be fully built
across sync and async tables.
  - Extended index creation APIs to accept a wait timeout parameter.
- **Bug Fixes**
- Added a new timeout error variant for improved error reporting on
index operations.
- **Tests**
- Added tests covering successful index readiness waiting, timeout
scenarios, and missing index cases.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-04-21 08:41:21 -02:30
Magnus
4f07fea6df feat: add ColPali embedding support with MultiVector type (#2170)
This PR adds ColPali support with ColPaliEmbeddings class (tagged
"colpali") using ColQwen2.5 for multi-vector text/image embeddings. Also
added MultiVector Pydantic type to handle the vector lists.

I've added some integration test for the embedding model and some unit
test for the new Pydantic type. Could be a template for other ColPali
variants as well. or until transformers🤗 starts supporting it.


Still `TODO`:

- [ ] Documentation
- [ ] Add an example

_Could also allow Image as query, but didn't work well when testing it._

[ColPali-Engine](https://github.com/illuin-tech/colpali) version:
0.3.9.dev17+g3faee24

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced support for ColPali-based multimodal multi-vector
embeddings for both text and images.
- Added a new embedding class for generating multi-vector embeddings,
configurable for various model and processing options.
- Added a new Pydantic type for multi-vector embeddings, supporting
validation and schema generation for lists of fixed-dimension vectors.

- **Bug Fixes**
- Ensured proper asynchronous index creation in query tests for improved
reliability.

- **Tests**
- Added integration tests for ColPali embeddings, including
text-to-image search and validation of multi-vector fields.
- Added comprehensive tests for the new multi-vector Pydantic type,
covering schema, validation, and default value behavior.

- **Chores**
  - Updated optional dependencies to include the ColPali engine.
  - Added utility to check for availability of flash attention support.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-04-21 11:47:37 +08:00
Lance Release
3d7d82cf86 Updating package-lock.json 2025-04-17 23:13:37 +00:00
Lance Release
edc4e40a7b Updating package-lock.json 2025-04-17 22:16:36 +00:00
Lance Release
ca3806a02f Updating package-lock.json 2025-04-17 22:16:20 +00:00
Lance Release
35cff12e31 Bump version: 0.19.0-beta.7 → 0.19.0-beta.8 2025-04-17 22:16:02 +00:00
56 changed files with 1703 additions and 387 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.19.0-beta.7"
current_version = "0.19.0-beta.10"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -18,17 +18,24 @@ concurrency:
group: "pages"
cancel-in-progress: true
env:
# This reduces the disk space needed for the build
RUSTFLAGS: "-C debuginfo=0"
# according to: https://matklad.github.io/2021/09/04/fast-rust-builds.html
# CI builds are faster with incremental disabled.
CARGO_INCREMENTAL: "0"
jobs:
# Single deploy job since we're just deploying
build:
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
runs-on: buildjet-8vcpu-ubuntu-2204
runs-on: ubuntu-24.04
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install dependecies needed for ubuntu
- name: Install dependencies needed for ubuntu
run: |
sudo apt install -y protobuf-compiler libssl-dev
rustup update && rustup default
@@ -38,6 +45,7 @@ jobs:
python-version: "3.10"
cache: "pip"
cache-dependency-path: "docs/requirements.txt"
- uses: Swatinem/rust-cache@v2
- name: Build Python
working-directory: python
run: |
@@ -49,7 +57,6 @@ jobs:
node-version: 20
cache: 'npm'
cache-dependency-path: node/package-lock.json
- uses: Swatinem/rust-cache@v2
- name: Install node dependencies
working-directory: node
run: |

View File

@@ -136,9 +136,9 @@ jobs:
- uses: ./.github/workflows/run_tests
with:
integration: true
- name: Test without pylance
- name: Test without pylance or pandas
run: |
pip uninstall -y pylance
pip uninstall -y pylance pandas
pytest -vv python/tests/test_table.py
# Make sure wheels are not included in the Rust cache
- name: Delete wheels

119
Cargo.lock generated
View File

@@ -128,9 +128,9 @@ dependencies = [
[[package]]
name = "anyhow"
version = "1.0.97"
version = "1.0.98"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dcfed56ad506cb2c684a14971b8861fdc3baaaae314b9e5f9bb532cbe3ba7a4f"
checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487"
[[package]]
name = "arbitrary"
@@ -390,9 +390,9 @@ dependencies = [
[[package]]
name = "async-compression"
version = "0.4.22"
version = "0.4.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59a194f9d963d8099596278594b3107448656ba73831c9d8c783e613ce86da64"
checksum = "b37fc50485c4f3f736a4fb14199f6d5f5ba008d7f28fe710306c92780f004c07"
dependencies = [
"flate2",
"futures-core",
@@ -564,9 +564,9 @@ dependencies = [
[[package]]
name = "aws-lc-sys"
version = "0.28.0"
version = "0.28.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9f7720b74ed28ca77f90769a71fd8c637a0137f6fae4ae947e1050229cff57f"
checksum = "bfa9b6986f250236c27e5a204062434a773a13243d2ffc2955f37bdba4c5c6a1"
dependencies = [
"bindgen",
"cc",
@@ -882,7 +882,7 @@ dependencies = [
"aws-smithy-async",
"aws-smithy-runtime-api",
"aws-smithy-types",
"h2 0.4.8",
"h2 0.4.9",
"http 0.2.12",
"http 1.3.1",
"http-body 0.4.6",
@@ -1185,9 +1185,9 @@ dependencies = [
[[package]]
name = "blake3"
version = "1.8.1"
version = "1.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "389a099b34312839e16420d499a9cad9650541715937ffbdd40d36f49e77eeb3"
checksum = "3888aaa89e4b2a40fca9848e400f6a658a5a3978de7be858e209cafa8be9a4a0"
dependencies = [
"arrayref",
"arrayvec",
@@ -2377,7 +2377,16 @@ version = "5.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225"
dependencies = [
"dirs-sys",
"dirs-sys 0.4.1",
]
[[package]]
name = "dirs"
version = "6.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3e8aa94d75141228480295a7d0e7feb620b1a5ad9f12bc40be62411e38cce4e"
dependencies = [
"dirs-sys 0.5.0",
]
[[package]]
@@ -2388,10 +2397,22 @@ checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c"
dependencies = [
"libc",
"option-ext",
"redox_users",
"redox_users 0.4.6",
"windows-sys 0.48.0",
]
[[package]]
name = "dirs-sys"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e01a3366d27ee9890022452ee61b2b63a67e6f13f58900b651ff5665f0bb1fab"
dependencies = [
"libc",
"option-ext",
"redox_users 0.5.0",
"windows-sys 0.59.0",
]
[[package]]
name = "displaydoc"
version = "0.2.5"
@@ -2558,9 +2579,9 @@ dependencies = [
[[package]]
name = "ethnum"
version = "1.5.0"
version = "1.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b90ca2580b73ab6a1f724b76ca11ab632df820fd6040c336200d2c1df7b3c82c"
checksum = "0939f82868b77ef93ce3c3c3daf2b3c526b456741da5a1a4559e590965b6026b"
[[package]]
name = "event-listener"
@@ -3049,9 +3070,9 @@ dependencies = [
[[package]]
name = "h2"
version = "0.4.8"
version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5017294ff4bb30944501348f6f8e42e6ad28f42c8bbef7a74029aff064a4e3c2"
checksum = "75249d144030531f8dee69fe9cea04d3edf809a017ae445e2abdff6629e86633"
dependencies = [
"atomic-waker",
"bytes",
@@ -3138,7 +3159,7 @@ version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc03dcb0b0a83ae3f3363ec811014ae669f083e4e499c66602f447c4828737a1"
dependencies = [
"dirs",
"dirs 5.0.1",
"futures",
"http 1.3.1",
"indicatif",
@@ -3286,7 +3307,7 @@ dependencies = [
"bytes",
"futures-channel",
"futures-util",
"h2 0.4.8",
"h2 0.4.9",
"http 1.3.1",
"http-body 1.0.1",
"httparse",
@@ -3645,9 +3666,9 @@ checksum = "9028f49264629065d057f340a86acb84867925865f73bbf8d47b4d149a7e88b8"
[[package]]
name = "jiff"
version = "0.2.6"
version = "0.2.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f33145a5cbea837164362c7bd596106eb7c5198f97d1ba6f6ebb3223952e488"
checksum = "5a064218214dc6a10fbae5ec5fa888d80c45d611aba169222fc272072bf7aef6"
dependencies = [
"jiff-static",
"log",
@@ -3658,9 +3679,9 @@ dependencies = [
[[package]]
name = "jiff-static"
version = "0.2.6"
version = "0.2.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43ce13c40ec6956157a3635d97a1ee2df323b263f09ea14165131289cb0f5c19"
checksum = "199b7932d97e325aff3a7030e141eafe7f2c6268e1d1b24859b753a627f45254"
dependencies = [
"proc-macro2",
"quote",
@@ -3965,7 +3986,7 @@ dependencies = [
"datafusion-physical-expr",
"datafusion-sql",
"deepsize",
"dirs",
"dirs 5.0.1",
"fst",
"futures",
"half",
@@ -4115,7 +4136,7 @@ dependencies = [
[[package]]
name = "lancedb"
version = "0.19.0-beta.7"
version = "0.19.0-beta.10"
dependencies = [
"arrow",
"arrow-array",
@@ -4202,7 +4223,7 @@ dependencies = [
[[package]]
name = "lancedb-node"
version = "0.19.0-beta.7"
version = "0.19.0-beta.10"
dependencies = [
"arrow-array",
"arrow-ipc",
@@ -4227,7 +4248,7 @@ dependencies = [
[[package]]
name = "lancedb-nodejs"
version = "0.19.0-beta.7"
version = "0.19.0-beta.10"
dependencies = [
"arrow-array",
"arrow-ipc",
@@ -4245,7 +4266,7 @@ dependencies = [
[[package]]
name = "lancedb-python"
version = "0.22.0-beta.7"
version = "0.22.0-beta.10"
dependencies = [
"arrow",
"env_logger",
@@ -4342,9 +4363,9 @@ dependencies = [
[[package]]
name = "libc"
version = "0.2.171"
version = "0.2.172"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c19937216e9d3aa9956d9bb8dfc0b0c8beb6058fc4f7a4dc4d850edf86a237d6"
checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa"
[[package]]
name = "libloading"
@@ -4368,9 +4389,9 @@ dependencies = [
[[package]]
name = "libm"
version = "0.2.11"
version = "0.2.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa"
checksum = "c9627da5196e5d8ed0b0495e61e518847578da83483c37288316d9b2e03a7f72"
[[package]]
name = "libredox"
@@ -5637,9 +5658,9 @@ dependencies = [
[[package]]
name = "proc-macro2"
version = "1.0.94"
version = "1.0.95"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a31971752e70b8b2686d7e46ec17fb38dad4051d94024c88df49b667caea9c84"
checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778"
dependencies = [
"unicode-ident",
]
@@ -5837,13 +5858,13 @@ dependencies = [
[[package]]
name = "quinn-proto"
version = "0.11.10"
version = "0.11.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b820744eb4dc9b57a3398183639c511b5a26d2ed702cedd3febaa1393caa22cc"
checksum = "bcbafbbdbb0f638fe3f35f3c56739f77a8a1d070cb25603226c83339b391472b"
dependencies = [
"bytes",
"getrandom 0.3.2",
"rand 0.9.0",
"rand 0.9.1",
"ring",
"rustc-hash 2.1.1",
"rustls 0.23.26",
@@ -5903,13 +5924,12 @@ dependencies = [
[[package]]
name = "rand"
version = "0.9.0"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94"
checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97"
dependencies = [
"rand_chacha 0.9.0",
"rand_core 0.9.3",
"zerocopy 0.8.24",
]
[[package]]
@@ -6084,6 +6104,17 @@ dependencies = [
"thiserror 1.0.69",
]
[[package]]
name = "redox_users"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd6f9d3d47bdd2ad6945c5015a226ec6155d0bcdfd8f7cd29f86b71f8de99d2b"
dependencies = [
"getrandom 0.2.15",
"libredox",
"thiserror 2.0.12",
]
[[package]]
name = "regex"
version = "1.11.1"
@@ -6152,7 +6183,7 @@ dependencies = [
"encoding_rs",
"futures-core",
"futures-util",
"h2 0.4.8",
"h2 0.4.9",
"http 1.3.1",
"http-body 1.0.1",
"http-body-util",
@@ -6701,11 +6732,11 @@ dependencies = [
[[package]]
name = "shellexpand"
version = "3.1.0"
version = "3.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da03fa3b94cc19e3ebfc88c4229c49d8f08cdbd1228870a45f0ffdf84988e14b"
checksum = "8b1fdf65dd6331831494dd616b30351c38e96e45921a27745cf98490458b90bb"
dependencies = [
"dirs",
"dirs 6.0.0",
]
[[package]]
@@ -6716,9 +6747,9 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
[[package]]
name = "signal-hook-registry"
version = "1.4.2"
version = "1.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1"
checksum = "9203b8055f63a2a00e2f593bb0510367fe707d7ff1e5c872de2f537b339e5410"
dependencies = [
"libc",
]

View File

@@ -765,7 +765,10 @@ This can be used to update zero to all rows depending on how many rows match the
];
const tbl = await db.createTable("my_table", data)
await tbl.update({vector: [10, 10]}, { where: "x = 2"})
await tbl.update({
values: { vector: [10, 10] },
where: "x = 2"
});
```
=== "vectordb (deprecated)"
@@ -784,7 +787,10 @@ This can be used to update zero to all rows depending on how many rows match the
];
const tbl = await db.createTable("my_table", data)
await tbl.update({ where: "x = 2", values: {vector: [10, 10]} })
await tbl.update({
where: "x = 2",
values: { vector: [10, 10] }
});
```
#### Updating using a sql query

View File

@@ -753,3 +753,26 @@ Retrieve the version of the table
#### Returns
`Promise`&lt;`number`&gt;
***
### waitForIndex()
```ts
abstract waitForIndex(indexNames, timeoutSeconds): Promise<void>
```
Waits for asynchronous indexing to complete on the table.
#### Parameters
* **indexNames**: `string`[]
The name of the indices to wait for
* **timeoutSeconds**: `number`
The number of seconds to wait before timing out
This will raise an error if the indices are not created and fully indexed within the timeout.
#### Returns
`Promise`&lt;`void`&gt;

View File

@@ -39,3 +39,11 @@ and the same name, then an error will be returned. This is true even if
that index is out of date.
The default is true
***
### waitTimeoutSeconds?
```ts
optional waitTimeoutSeconds: number;
```

View File

@@ -8,7 +8,7 @@
<parent>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.19.0-beta.7</version>
<version>0.19.0-beta.10</version>
<relativePath>../pom.xml</relativePath>
</parent>

View File

@@ -6,7 +6,7 @@
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.19.0-beta.7</version>
<version>0.19.0-beta.10</version>
<packaging>pom</packaging>
<name>LanceDB Parent</name>

44
node/package-lock.json generated
View File

@@ -1,12 +1,12 @@
{
"name": "vectordb",
"version": "0.19.0-beta.7",
"version": "0.19.0-beta.10",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "vectordb",
"version": "0.19.0-beta.7",
"version": "0.19.0-beta.10",
"cpu": [
"x64",
"arm64"
@@ -52,11 +52,11 @@
"uuid": "^9.0.0"
},
"optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.19.0-beta.7",
"@lancedb/vectordb-darwin-x64": "0.19.0-beta.7",
"@lancedb/vectordb-linux-arm64-gnu": "0.19.0-beta.7",
"@lancedb/vectordb-linux-x64-gnu": "0.19.0-beta.7",
"@lancedb/vectordb-win32-x64-msvc": "0.19.0-beta.7"
"@lancedb/vectordb-darwin-arm64": "0.19.0-beta.10",
"@lancedb/vectordb-darwin-x64": "0.19.0-beta.10",
"@lancedb/vectordb-linux-arm64-gnu": "0.19.0-beta.10",
"@lancedb/vectordb-linux-x64-gnu": "0.19.0-beta.10",
"@lancedb/vectordb-win32-x64-msvc": "0.19.0-beta.10"
},
"peerDependencies": {
"@apache-arrow/ts": "^14.0.2",
@@ -327,9 +327,9 @@
}
},
"node_modules/@lancedb/vectordb-darwin-arm64": {
"version": "0.19.0-beta.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.19.0-beta.7.tgz",
"integrity": "sha512-HpbVKw4Vs+mPv7uPwaK7ilJlGrGdjOrNlC2mSkMCj0OlEwGRVcEcrSyijI7LXQH7ybEgNnDhSds5TuzBV26SGg==",
"version": "0.19.0-beta.10",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.19.0-beta.10.tgz",
"integrity": "sha512-4PvsrE+hJ+AqFY33yHIEVET2ayv0pzpiWqCnMQ5OdpakQZvfykmp9ykc5KI80VuWAlniJDYuW+fju3z8/wiUHQ==",
"cpu": [
"arm64"
],
@@ -340,9 +340,9 @@
]
},
"node_modules/@lancedb/vectordb-darwin-x64": {
"version": "0.19.0-beta.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.19.0-beta.7.tgz",
"integrity": "sha512-x3X7nqIYVZtxaa0uZUk/M99vKvDinZ5G0+8k2NqZ696YXGWKGyRxR6k8ZzKYCoCTSuYXnBftgKoIlwJGtNt8Bw==",
"version": "0.19.0-beta.10",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.19.0-beta.10.tgz",
"integrity": "sha512-YHXHkD4mmIry+KMoTX7Qts5Ea9fG0DklywIiF8TS7h/9XbXLG74lf+GUy2Eh/s1wKLd4LtRh2SbHpOtZoOH4lA==",
"cpu": [
"x64"
],
@@ -353,9 +353,9 @@
]
},
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
"version": "0.19.0-beta.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.19.0-beta.7.tgz",
"integrity": "sha512-Vwj0HI3+b4NgXKf+5+W/GfLBCGoQMBGM47vA/ts1dpe/PxraOQYPDv67I5kbXkCQKwhal7b0iZx/PbMu0JZPyw==",
"version": "0.19.0-beta.10",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.19.0-beta.10.tgz",
"integrity": "sha512-c019rw30N25WIXnkhAwJ4QkpUcUJqbkGay3RiR3vTmGQ5YOZWw5V5g/v2y7APcv+ZlZfJ4YgDjFH8wqtiECNJQ==",
"cpu": [
"arm64"
],
@@ -366,9 +366,9 @@
]
},
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
"version": "0.19.0-beta.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.19.0-beta.7.tgz",
"integrity": "sha512-Dx2B6UWQei9D7Rt+MgHWqPTYtEK2w3EgsNb5ENEWUTZxH7lD/CV7Sw0JMK5LDG209fFcpXFerveF6J8ZC8uGBQ==",
"version": "0.19.0-beta.10",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.19.0-beta.10.tgz",
"integrity": "sha512-xnbC6rqpuJDv2q6xNBKrrocNOTcM4z6+8Zi7wP+Sb+WXvLzkR7hm7ZS0gyeExRknEEd91imhL/ZuAxEq1892YA==",
"cpu": [
"x64"
],
@@ -379,9 +379,9 @@
]
},
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
"version": "0.19.0-beta.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.19.0-beta.7.tgz",
"integrity": "sha512-F5LZGa+gkUH1TgsWZWLLAMejwXFIWdash7+85ip4k2M0ThyqLF/dtlldOvteUEd5+flxihGjHg6TUtnSY8XBFA==",
"version": "0.19.0-beta.10",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.19.0-beta.10.tgz",
"integrity": "sha512-bUDKvY4tmMEeBpzPfRu2lVo+07nRGzUoUm60WzfvRpa/Y6rwjcCCRuuTOvfTcrnbGYN/kw5yoUu8ZDsZ7mT77Q==",
"cpu": [
"x64"
],

View File

@@ -1,6 +1,6 @@
{
"name": "vectordb",
"version": "0.19.0-beta.7",
"version": "0.19.0-beta.10",
"description": " Serverless, low-latency vector database for AI applications",
"private": false,
"main": "dist/index.js",
@@ -89,10 +89,10 @@
}
},
"optionalDependencies": {
"@lancedb/vectordb-darwin-x64": "0.19.0-beta.7",
"@lancedb/vectordb-darwin-arm64": "0.19.0-beta.7",
"@lancedb/vectordb-linux-x64-gnu": "0.19.0-beta.7",
"@lancedb/vectordb-linux-arm64-gnu": "0.19.0-beta.7",
"@lancedb/vectordb-win32-x64-msvc": "0.19.0-beta.7"
"@lancedb/vectordb-darwin-x64": "0.19.0-beta.10",
"@lancedb/vectordb-darwin-arm64": "0.19.0-beta.10",
"@lancedb/vectordb-linux-x64-gnu": "0.19.0-beta.10",
"@lancedb/vectordb-linux-arm64-gnu": "0.19.0-beta.10",
"@lancedb/vectordb-win32-x64-msvc": "0.19.0-beta.10"
}
}

View File

@@ -1,7 +1,7 @@
[package]
name = "lancedb-nodejs"
edition.workspace = true
version = "0.19.0-beta.7"
version = "0.19.0-beta.10"
license.workspace = true
description.workspace = true
repository.workspace = true

View File

@@ -507,6 +507,15 @@ describe("When creating an index", () => {
expect(indices2.length).toBe(0);
});
it("should wait for index readiness", async () => {
// Create an index and then wait for it to be ready
await tbl.createIndex("vec");
const indices = await tbl.listIndices();
expect(indices.length).toBeGreaterThan(0);
const idxName = indices[0].name;
await expect(tbl.waitForIndex([idxName], 5)).resolves.toBeUndefined();
});
it("should search with distance range", async () => {
await tbl.createIndex("vec");
@@ -824,6 +833,7 @@ describe("When creating an index", () => {
// Only build index over v1
await tbl.createIndex("vec", {
config: Index.ivfPq({ numPartitions: 2, numSubVectors: 2 }),
waitTimeoutSeconds: 30,
});
const rst = await tbl

View File

@@ -681,4 +681,6 @@ export interface IndexOptions {
* The default is true
*/
replace?: boolean;
waitTimeoutSeconds?: number;
}

View File

@@ -246,6 +246,19 @@ export abstract class Table {
*/
abstract prewarmIndex(name: string): Promise<void>;
/**
* Waits for asynchronous indexing to complete on the table.
*
* @param indexNames The name of the indices to wait for
* @param timeoutSeconds The number of seconds to wait before timing out
*
* This will raise an error if the indices are not created and fully indexed within the timeout.
*/
abstract waitForIndex(
indexNames: string[],
timeoutSeconds: number,
): Promise<void>;
/**
* Create a {@link Query} Builder.
*
@@ -569,7 +582,12 @@ export class LocalTable extends Table {
// Bit of a hack to get around the fact that TS has no package-scope.
// biome-ignore lint/suspicious/noExplicitAny: skip
const nativeIndex = (options?.config as any)?.inner;
await this.inner.createIndex(nativeIndex, column, options?.replace);
await this.inner.createIndex(
nativeIndex,
column,
options?.replace,
options?.waitTimeoutSeconds,
);
}
async dropIndex(name: string): Promise<void> {
@@ -580,6 +598,13 @@ export class LocalTable extends Table {
await this.inner.prewarmIndex(name);
}
async waitForIndex(
indexNames: string[],
timeoutSeconds: number,
): Promise<void> {
await this.inner.waitForIndex(indexNames, timeoutSeconds);
}
query(): Query {
return new Query(this.inner);
}

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-arm64",
"version": "0.19.0-beta.7",
"version": "0.19.0-beta.10",
"os": ["darwin"],
"cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-x64",
"version": "0.19.0-beta.7",
"version": "0.19.0-beta.10",
"os": ["darwin"],
"cpu": ["x64"],
"main": "lancedb.darwin-x64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.19.0-beta.7",
"version": "0.19.0-beta.10",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-musl",
"version": "0.19.0-beta.7",
"version": "0.19.0-beta.10",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.19.0-beta.7",
"version": "0.19.0-beta.10",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-musl",
"version": "0.19.0-beta.7",
"version": "0.19.0-beta.10",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-arm64-msvc",
"version": "0.19.0-beta.7",
"version": "0.19.0-beta.10",
"os": [
"win32"
],

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.19.0-beta.7",
"version": "0.19.0-beta.10",
"os": ["win32"],
"cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node",

View File

@@ -1,12 +1,12 @@
{
"name": "@lancedb/lancedb",
"version": "0.19.0-beta.7",
"version": "0.19.0-beta.10",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "@lancedb/lancedb",
"version": "0.19.0-beta.7",
"version": "0.19.0-beta.10",
"cpu": [
"x64",
"arm64"

View File

@@ -11,7 +11,7 @@
"ann"
],
"private": false,
"version": "0.19.0-beta.7",
"version": "0.19.0-beta.10",
"main": "dist/index.js",
"exports": {
".": "./dist/index.js",

View File

@@ -111,6 +111,7 @@ impl Table {
index: Option<&Index>,
column: String,
replace: Option<bool>,
wait_timeout_s: Option<i64>,
) -> napi::Result<()> {
let lancedb_index = if let Some(index) = index {
index.consume()?
@@ -121,6 +122,10 @@ impl Table {
if let Some(replace) = replace {
builder = builder.replace(replace);
}
if let Some(timeout) = wait_timeout_s {
builder =
builder.wait_timeout(std::time::Duration::from_secs(timeout.try_into().unwrap()));
}
builder.execute().await.default_error()
}
@@ -140,6 +145,18 @@ impl Table {
.default_error()
}
#[napi(catch_unwind)]
pub async fn wait_for_index(&self, index_names: Vec<String>, timeout_s: i64) -> Result<()> {
let timeout = std::time::Duration::from_secs(timeout_s.try_into().unwrap());
let index_names: Vec<&str> = index_names.iter().map(|s| s.as_str()).collect();
let slice: &[&str] = &index_names;
self.inner_ref()?
.wait_for_index(slice, timeout)
.await
.default_error()
}
#[napi(catch_unwind)]
pub async fn update(
&self,

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.22.0-beta.8"
current_version = "0.22.0-beta.11"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-python"
version = "0.22.0-beta.8"
version = "0.22.0-beta.11"
edition.workspace = true
description = "Python bindings for LanceDB"
license.workspace = true

View File

@@ -77,6 +77,7 @@ embeddings = [
"pillow",
"open-clip-torch",
"cohere",
"colpali-engine>=0.3.10",
"huggingface_hub",
"InstructorEmbedding",
"google.generativeai",

View File

@@ -9,7 +9,7 @@ import numpy as np
import pyarrow as pa
import pyarrow.dataset
from .dependencies import pandas as pd
from .dependencies import _check_for_pandas, pandas as pd
DATA = Union[List[dict], "pd.DataFrame", pa.Table, Iterable[pa.RecordBatch]]
VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray]
@@ -63,7 +63,7 @@ def data_to_reader(
data: DATA, schema: Optional[pa.Schema] = None
) -> pa.RecordBatchReader:
"""Convert various types of input into a RecordBatchReader"""
if pd is not None and isinstance(data, pd.DataFrame):
if _check_for_pandas(data) and isinstance(data, pd.DataFrame):
return pa.Table.from_pandas(data, schema=schema).to_reader()
elif isinstance(data, pa.Table):
return data.to_reader()

View File

@@ -19,3 +19,4 @@ from .imagebind import ImageBindEmbeddings
from .jinaai import JinaEmbeddings
from .watsonx import WatsonxEmbeddings
from .voyageai import VoyageAIEmbeddingFunction
from .colpali import ColPaliEmbeddings

View File

@@ -0,0 +1,255 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from functools import lru_cache
from typing import List, Union, Optional, Any
import numpy as np
import io
from ..util import attempt_import_or_raise
from .base import EmbeddingFunction
from .registry import register
from .utils import TEXT, IMAGES, is_flash_attn_2_available
@register("colpali")
class ColPaliEmbeddings(EmbeddingFunction):
"""
An embedding function that uses the ColPali engine for
multimodal multi-vector embeddings.
This embedding function supports ColQwen2.5 models, producing multivector outputs
for both text and image inputs. The output embeddings are lists of vectors, each
vector being 128-dimensional by default, represented as List[List[float]].
Parameters
----------
model_name : str
The name of the model to use (e.g., "Metric-AI/ColQwen2.5-3b-multilingual-v1.0")
device : str
The device for inference (default "cuda:0").
dtype : str
Data type for model weights (default "bfloat16").
use_token_pooling : bool
Whether to use token pooling to reduce embedding size (default True).
pool_factor : int
Factor to reduce sequence length if token pooling is enabled (default 2).
quantization_config : Optional[BitsAndBytesConfig]
Quantization configuration for the model. (default None, bitsandbytes needed)
batch_size : int
Batch size for processing inputs (default 2).
"""
model_name: str = "Metric-AI/ColQwen2.5-3b-multilingual-v1.0"
device: str = "auto"
dtype: str = "bfloat16"
use_token_pooling: bool = True
pool_factor: int = 2
quantization_config: Optional[Any] = None
batch_size: int = 2
_model = None
_processor = None
_token_pooler = None
_vector_dim = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
(
self._model,
self._processor,
self._token_pooler,
) = self._load_model(
self.model_name,
self.dtype,
self.device,
self.use_token_pooling,
self.quantization_config,
)
@staticmethod
@lru_cache(maxsize=1)
def _load_model(
model_name: str,
dtype: str,
device: str,
use_token_pooling: bool,
quantization_config: Optional[Any],
):
"""
Initialize and cache the ColPali model, processor, and token pooler.
"""
torch = attempt_import_or_raise("torch", "torch")
transformers = attempt_import_or_raise("transformers", "transformers")
colpali_engine = attempt_import_or_raise("colpali_engine", "colpali_engine")
from colpali_engine.compression.token_pooling import HierarchicalTokenPooler
if quantization_config is not None:
if not isinstance(quantization_config, transformers.BitsAndBytesConfig):
raise ValueError("quantization_config must be a BitsAndBytesConfig")
if dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif dtype == "float16":
torch_dtype = torch.float16
elif dtype == "float64":
torch_dtype = torch.float64
else:
torch_dtype = torch.float32
model = colpali_engine.models.ColQwen2_5.from_pretrained(
model_name,
torch_dtype=torch_dtype,
device_map=device,
quantization_config=quantization_config
if quantization_config is not None
else None,
attn_implementation="flash_attention_2"
if is_flash_attn_2_available()
else None,
).eval()
processor = colpali_engine.models.ColQwen2_5_Processor.from_pretrained(
model_name
)
token_pooler = HierarchicalTokenPooler() if use_token_pooling else None
return model, processor, token_pooler
def ndims(self):
"""
Return the dimension of a vector in the multivector output (e.g., 128).
"""
torch = attempt_import_or_raise("torch", "torch")
if self._vector_dim is None:
dummy_query = "test"
batch_queries = self._processor.process_queries([dummy_query]).to(
self._model.device
)
with torch.no_grad():
query_embeddings = self._model(**batch_queries)
if self.use_token_pooling and self._token_pooler is not None:
query_embeddings = self._token_pooler.pool_embeddings(
query_embeddings,
pool_factor=self.pool_factor,
padding=True,
padding_side=self._processor.tokenizer.padding_side,
)
self._vector_dim = query_embeddings[0].shape[-1]
return self._vector_dim
def _process_embeddings(self, embeddings):
"""
Format model embeddings into List[List[float]].
Use token pooling if enabled.
"""
torch = attempt_import_or_raise("torch", "torch")
if self.use_token_pooling and self._token_pooler is not None:
embeddings = self._token_pooler.pool_embeddings(
embeddings,
pool_factor=self.pool_factor,
padding=True,
padding_side=self._processor.tokenizer.padding_side,
)
if isinstance(embeddings, torch.Tensor):
tensors = embeddings.detach().cpu()
if tensors.dtype == torch.bfloat16:
tensors = tensors.to(torch.float32)
return (
tensors.numpy()
.astype(np.float64 if self.dtype == "float64" else np.float32)
.tolist()
)
return []
def generate_text_embeddings(self, text: TEXT) -> List[List[List[float]]]:
"""
Generate embeddings for text input.
"""
torch = attempt_import_or_raise("torch", "torch")
text = self.sanitize_input(text)
all_embeddings = []
for i in range(0, len(text), self.batch_size):
batch_text = text[i : i + self.batch_size]
batch_queries = self._processor.process_queries(batch_text).to(
self._model.device
)
with torch.no_grad():
query_embeddings = self._model(**batch_queries)
all_embeddings.extend(self._process_embeddings(query_embeddings))
return all_embeddings
def _prepare_images(self, images: IMAGES) -> List:
"""
Convert image inputs to PIL Images.
"""
PIL = attempt_import_or_raise("PIL", "pillow")
requests = attempt_import_or_raise("requests", "requests")
images = self.sanitize_input(images)
pil_images = []
try:
for image in images:
if isinstance(image, str):
if image.startswith(("http://", "https://")):
response = requests.get(image, timeout=10)
response.raise_for_status()
pil_images.append(PIL.Image.open(io.BytesIO(response.content)))
else:
with PIL.Image.open(image) as im:
pil_images.append(im.copy())
elif isinstance(image, bytes):
pil_images.append(PIL.Image.open(io.BytesIO(image)))
else:
# Assume it's a PIL Image; will raise if invalid
pil_images.append(image)
except Exception as e:
raise ValueError(f"Failed to process image: {e}")
return pil_images
def generate_image_embeddings(self, images: IMAGES) -> List[List[List[float]]]:
"""
Generate embeddings for a batch of images.
"""
torch = attempt_import_or_raise("torch", "torch")
pil_images = self._prepare_images(images)
all_embeddings = []
for i in range(0, len(pil_images), self.batch_size):
batch_images = pil_images[i : i + self.batch_size]
batch_images = self._processor.process_images(batch_images).to(
self._model.device
)
with torch.no_grad():
image_embeddings = self._model(**batch_images)
all_embeddings.extend(self._process_embeddings(image_embeddings))
return all_embeddings
def compute_query_embeddings(
self, query: Union[str, IMAGES], *args, **kwargs
) -> List[List[List[float]]]:
"""
Compute embeddings for a single user query (text only).
"""
if not isinstance(query, str):
raise ValueError(
"Query must be a string, image to image search is not supported"
)
return self.generate_text_embeddings([query])
def compute_source_embeddings(
self, images: IMAGES, *args, **kwargs
) -> List[List[List[float]]]:
"""
Compute embeddings for a batch of source images.
Parameters
----------
images : Union[str, bytes, List, pa.Array, pa.ChunkedArray, np.ndarray]
Batch of images (paths, URLs, bytes, or PIL Images).
"""
images = self.sanitize_input(images)
return self.generate_image_embeddings(images)

View File

@@ -18,6 +18,7 @@ import numpy as np
import pyarrow as pa
from ..dependencies import pandas as pd
from ..util import attempt_import_or_raise
# ruff: noqa: PERF203
@@ -275,3 +276,12 @@ def url_retrieve(url: str):
def api_key_not_found_help(provider):
logging.error("Could not find API key for %s", provider)
raise ValueError(f"Please set the {provider.upper()}_API_KEY environment variable.")
def is_flash_attn_2_available():
try:
attempt_import_or_raise("flash_attn", "flash_attn")
return True
except ImportError:
return False

View File

@@ -152,6 +152,104 @@ def Vector(
return FixedSizeList
def MultiVector(
dim: int, value_type: pa.DataType = pa.float32(), nullable: bool = True
) -> Type:
"""Pydantic MultiVector Type for multi-vector embeddings.
This type represents a list of vectors, each with the same dimension.
Useful for models that produce multiple embeddings per input, like ColPali.
Parameters
----------
dim : int
The dimension of each vector in the multi-vector.
value_type : pyarrow.DataType, optional
The value type of the vectors, by default pa.float32()
nullable : bool, optional
Whether the multi-vector is nullable, by default it is True.
Examples
--------
>>> import pydantic
>>> from lancedb.pydantic import MultiVector
...
>>> class MyModel(pydantic.BaseModel):
... id: int
... text: str
... embeddings: MultiVector(128) # List of 128-dimensional vectors
>>> schema = pydantic_to_schema(MyModel)
>>> assert schema == pa.schema([
... pa.field("id", pa.int64(), False),
... pa.field("text", pa.utf8(), False),
... pa.field("embeddings", pa.list_(pa.list_(pa.float32(), 128)))
... ])
"""
class MultiVectorList(list, FixedSizeListMixin):
def __repr__(self):
return f"MultiVector(dim={dim})"
@staticmethod
def nullable() -> bool:
return nullable
@staticmethod
def dim() -> int:
return dim
@staticmethod
def value_arrow_type() -> pa.DataType:
return value_type
@staticmethod
def is_multi_vector() -> bool:
return True
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
) -> CoreSchema:
return core_schema.no_info_after_validator_function(
cls,
core_schema.list_schema(
items_schema=core_schema.list_schema(
min_length=dim,
max_length=dim,
items_schema=core_schema.float_schema(),
),
),
)
@classmethod
def __get_validators__(cls) -> Generator[Callable, None, None]:
yield cls.validate
# For pydantic v1
@classmethod
def validate(cls, v):
if not isinstance(v, (list, range)):
raise TypeError("A list of vectors is needed")
for vec in v:
if not isinstance(vec, (list, range, np.ndarray)) or len(vec) != dim:
raise TypeError(f"Each vector must be a list of {dim} numbers")
return cls(v)
if PYDANTIC_VERSION.major < 2:
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]):
field_schema["items"] = {
"type": "array",
"items": {"type": "number"},
"minItems": dim,
"maxItems": dim,
}
return MultiVectorList
def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
"""Convert a field with native Python type to Arrow data type.
@@ -206,6 +304,9 @@ def _pydantic_type_to_arrow_type(tp: Any, field: FieldInfo) -> pa.DataType:
fields = _pydantic_model_to_fields(tp)
return pa.struct(fields)
if issubclass(tp, FixedSizeListMixin):
if getattr(tp, "is_multi_vector", lambda: False)():
return pa.list_(pa.list_(tp.value_arrow_type(), tp.dim()))
# For regular Vector
return pa.list_(tp.value_arrow_type(), tp.dim())
return _py_type_to_arrow_type(tp, field)

View File

@@ -28,6 +28,8 @@ import pyarrow.compute as pc
import pyarrow.fs as pa_fs
import pydantic
from lancedb.pydantic import PYDANTIC_VERSION
from . import __version__
from .arrow import AsyncRecordBatchReader
from .dependencies import pandas as pd
@@ -498,10 +500,14 @@ class Query(pydantic.BaseModel):
)
return query
class Config:
# This tells pydantic to allow custom types (needed for the `vector` query since
# pa.Array wouln't be allowed otherwise)
arbitrary_types_allowed = True
# This tells pydantic to allow custom types (needed for the `vector` query since
# pa.Array wouln't be allowed otherwise)
if PYDANTIC_VERSION.major < 2: # Pydantic 1.x compat
class Config:
arbitrary_types_allowed = True
else:
model_config = {"arbitrary_types_allowed": True}
class LanceQueryBuilder(ABC):
@@ -1586,6 +1592,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._refine_factor = None
self._distance_type = None
self._phrase_query = None
self._lower_bound = None
self._upper_bound = None
def _validate_query(self, query, vector=None, text=None):
if query is not None and (vector is not None or text is not None):
@@ -1665,6 +1673,10 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._vector_query.ef(self._ef)
if self._bypass_vector_index:
self._vector_query.bypass_vector_index()
if self._lower_bound or self._upper_bound:
self._vector_query.distance_range(
lower_bound=self._lower_bound, upper_bound=self._upper_bound
)
if self._reranker is None:
self._reranker = RRFReranker()

View File

@@ -104,6 +104,7 @@ class RemoteTable(Table):
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST", "scalar"] = "scalar",
*,
replace: bool = False,
wait_timeout: timedelta = None,
):
"""Creates a scalar index
Parameters
@@ -126,13 +127,18 @@ class RemoteTable(Table):
else:
raise ValueError(f"Unknown index type: {index_type}")
LOOP.run(self._table.create_index(column, config=config, replace=replace))
LOOP.run(
self._table.create_index(
column, config=config, replace=replace, wait_timeout=wait_timeout
)
)
def create_fts_index(
self,
column: str,
*,
replace: bool = False,
wait_timeout: timedelta = None,
with_position: bool = True,
# tokenizer configs:
base_tokenizer: str = "simple",
@@ -153,7 +159,11 @@ class RemoteTable(Table):
remove_stop_words=remove_stop_words,
ascii_folding=ascii_folding,
)
LOOP.run(self._table.create_index(column, config=config, replace=replace))
LOOP.run(
self._table.create_index(
column, config=config, replace=replace, wait_timeout=wait_timeout
)
)
def create_index(
self,
@@ -165,6 +175,7 @@ class RemoteTable(Table):
replace: Optional[bool] = None,
accelerator: Optional[str] = None,
index_type="vector",
wait_timeout: Optional[timedelta] = None,
):
"""Create an index on the table.
Currently, the only parameters that matter are
@@ -236,7 +247,11 @@ class RemoteTable(Table):
" 'IVF_FLAT', 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
)
LOOP.run(self._table.create_index(vector_column_name, config=config))
LOOP.run(
self._table.create_index(
vector_column_name, config=config, wait_timeout=wait_timeout
)
)
def add(
self,
@@ -554,6 +569,11 @@ class RemoteTable(Table):
def drop_index(self, index_name: str):
return LOOP.run(self._table.drop_index(index_name))
def wait_for_index(
self, index_names: Iterable[str], timeout: timedelta = timedelta(seconds=300)
):
return LOOP.run(self._table.wait_for_index(index_names, timeout))
def uses_v2_manifest_paths(self) -> bool:
raise NotImplementedError(
"uses_v2_manifest_paths() is not supported on the LanceDB Cloud"

View File

@@ -631,6 +631,7 @@ class Table(ABC):
index_cache_size: Optional[int] = None,
*,
index_type: VectorIndexType = "IVF_PQ",
wait_timeout: Optional[timedelta] = None,
num_bits: int = 8,
max_iterations: int = 50,
sample_rate: int = 256,
@@ -666,6 +667,8 @@ class Table(ABC):
num_bits: int
The number of bits to encode sub-vectors. Only used with the IVF_PQ index.
Only 4 and 8 are supported.
wait_timeout: timedelta, optional
The timeout to wait if indexing is asynchronous.
"""
raise NotImplementedError
@@ -689,6 +692,23 @@ class Table(ABC):
"""
raise NotImplementedError
def wait_for_index(
self, index_names: Iterable[str], timeout: timedelta = timedelta(seconds=300)
) -> None:
"""
Wait for indexing to complete for the given index names.
This will poll the table until all the indices are fully indexed,
or raise a timeout exception if the timeout is reached.
Parameters
----------
index_names: str
The name of the indices to poll
timeout: timedelta
Timeout to wait for asynchronous indexing. The default is 5 minutes.
"""
raise NotImplementedError
@abstractmethod
def create_scalar_index(
self,
@@ -696,6 +716,7 @@ class Table(ABC):
*,
replace: bool = True,
index_type: ScalarIndexType = "BTREE",
wait_timeout: Optional[timedelta] = None,
):
"""Create a scalar index on a column.
@@ -708,7 +729,8 @@ class Table(ABC):
Replace the existing index if it exists.
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"], default "BTREE"
The type of index to create.
wait_timeout: timedelta, optional
The timeout to wait if indexing is asynchronous.
Examples
--------
@@ -767,6 +789,7 @@ class Table(ABC):
stem: bool = False,
remove_stop_words: bool = False,
ascii_folding: bool = False,
wait_timeout: Optional[timedelta] = None,
):
"""Create a full-text search index on the table.
@@ -822,6 +845,8 @@ class Table(ABC):
ascii_folding : bool, default False
Whether to fold ASCII characters. This converts accented characters to
their ASCII equivalent. For example, "café" would be converted to "cafe".
wait_timeout: timedelta, optional
The timeout to wait if indexing is asynchronous.
"""
raise NotImplementedError
@@ -1771,6 +1796,11 @@ class LanceTable(Table):
"""
return LOOP.run(self._table.prewarm_index(name))
def wait_for_index(
self, index_names: Iterable[str], timeout: timedelta = timedelta(seconds=300)
) -> None:
return LOOP.run(self._table.wait_for_index(index_names, timeout))
def create_scalar_index(
self,
column: str,
@@ -2964,6 +2994,7 @@ class AsyncTable:
config: Optional[
Union[IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS]
] = None,
wait_timeout: Optional[timedelta] = None,
):
"""Create an index to speed up queries
@@ -2988,6 +3019,8 @@ class AsyncTable:
For advanced configuration you can specify the type of index you would
like to create. You can also specify index-specific parameters when
creating an index object.
wait_timeout: timedelta, optional
The timeout to wait if indexing is asynchronous.
"""
if config is not None:
if not isinstance(
@@ -2998,7 +3031,9 @@ class AsyncTable:
" Bitmap, LabelList, or FTS"
)
try:
await self._inner.create_index(column, index=config, replace=replace)
await self._inner.create_index(
column, index=config, replace=replace, wait_timeout=wait_timeout
)
except ValueError as e:
if "not support the requested language" in str(e):
supported_langs = ", ".join(lang_mapping.values())
@@ -3043,6 +3078,23 @@ class AsyncTable:
"""
await self._inner.prewarm_index(name)
async def wait_for_index(
self, index_names: Iterable[str], timeout: timedelta = timedelta(seconds=300)
) -> None:
"""
Wait for indexing to complete for the given index names.
This will poll the table until all the indices are fully indexed,
or raise a timeout exception if the timeout is reached.
Parameters
----------
index_names: str
The name of the indices to poll
timeout: timedelta
Timeout to wait for asynchronous indexing. The default is 5 minutes.
"""
await self._inner.wait_for_index(index_names, timeout)
async def add(
self,
data: DATA,

View File

@@ -11,7 +11,7 @@ import pandas as pd
import pyarrow as pa
import pytest
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector
from lancedb.pydantic import LanceModel, Vector, MultiVector
import requests
# These are integration tests for embedding functions.
@@ -575,3 +575,67 @@ def test_voyageai_multimodal_embedding_text_function():
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
@pytest.mark.slow
@pytest.mark.skipif(
importlib.util.find_spec("colpali_engine") is None,
reason="colpali_engine not installed",
)
def test_colpali(tmp_path):
import requests
from lancedb.pydantic import LanceModel
db = lancedb.connect(tmp_path)
registry = get_registry()
func = registry.get("colpali").create()
class MediaItems(LanceModel):
text: str
image_uri: str = func.SourceField()
image_bytes: bytes = func.SourceField()
image_vectors: MultiVector(func.ndims()) = (
func.VectorField()
) # Multivector image embeddings
table = db.create_table("media", schema=MediaItems)
texts = [
"a cute cat playing with yarn",
"a puppy in a flower field",
"a red sports car on the highway",
"a vintage bicycle leaning against a wall",
"a plate of delicious pasta",
"fresh fruit salad in a bowl",
]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
]
# Get images as bytes
image_bytes = [requests.get(uri).content for uri in uris]
table.add(
pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes})
)
# Test text-to-image search
image_results = (
table.search("fluffy companion", vector_column_name="image_vectors")
.limit(1)
.to_pydantic(MediaItems)[0]
)
assert "cat" in image_results.text.lower() or "puppy" in image_results.text.lower()
# Verify multivector dimensions
first_row = table.to_arrow().to_pylist()[0]
assert len(first_row["image_vectors"]) > 1, "Should have multiple image vectors"
assert len(first_row["image_vectors"][0]) == func.ndims(), (
"Vector dimension mismatch"
)

View File

@@ -4,13 +4,32 @@
import lancedb
from lancedb.query import LanceHybridQueryBuilder
from lancedb.rerankers.rrf import RRFReranker
import pyarrow as pa
import pyarrow.compute as pc
import pytest
import pytest_asyncio
from lancedb.index import FTS
from lancedb.table import AsyncTable
from lancedb.table import AsyncTable, Table
@pytest.fixture
def sync_table(tmpdir_factory) -> Table:
tmp_path = str(tmpdir_factory.mktemp("data"))
db = lancedb.connect(tmp_path)
data = pa.table(
{
"text": pa.array(["a", "b", "cat", "dog"]),
"vector": pa.array(
[[0.1, 0.1], [2, 2], [-0.1, -0.1], [0.5, -0.5]],
type=pa.list_(pa.float32(), list_size=2),
),
}
)
table = db.create_table("test", data)
table.create_fts_index("text", with_position=False, use_tantivy=False)
return table
@pytest_asyncio.fixture
@@ -102,6 +121,42 @@ async def test_async_hybrid_query_default_limit(table: AsyncTable):
assert texts.count("a") == 1
def test_hybrid_query_distance_range(sync_table: Table):
reranker = RRFReranker(return_score="all")
result = (
sync_table.search(query_type="hybrid")
.vector([0.0, 0.4])
.text("cat and dog")
.distance_range(lower_bound=0.2, upper_bound=0.5)
.rerank(reranker)
.limit(2)
.to_arrow()
)
assert len(result) == 2
print(result)
for dist in result["_distance"]:
if dist.is_valid:
assert 0.2 <= dist.as_py() <= 0.5
@pytest.mark.asyncio
async def test_hybrid_query_distance_range_async(table: AsyncTable):
reranker = RRFReranker(return_score="all")
result = await (
table.query()
.nearest_to([0.0, 0.4])
.nearest_to_text("cat and dog")
.distance_range(lower_bound=0.2, upper_bound=0.5)
.rerank(reranker)
.limit(2)
.to_arrow()
)
assert len(result) == 2
for dist in result["_distance"]:
if dist.is_valid:
assert 0.2 <= dist.as_py() <= 0.5
@pytest.mark.asyncio
async def test_explain_plan(table: AsyncTable):
plan = await (

View File

@@ -9,7 +9,13 @@ from typing import List, Optional, Tuple
import pyarrow as pa
import pydantic
import pytest
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
from lancedb.pydantic import (
PYDANTIC_VERSION,
LanceModel,
Vector,
pydantic_to_schema,
MultiVector,
)
from pydantic import BaseModel
from pydantic import Field
@@ -354,3 +360,55 @@ def test_optional_nested_model():
),
]
)
def test_multi_vector():
class TestModel(pydantic.BaseModel):
vec: MultiVector(8)
schema = pydantic_to_schema(TestModel)
assert schema == pa.schema(
[pa.field("vec", pa.list_(pa.list_(pa.float32(), 8)), True)]
)
with pytest.raises(pydantic.ValidationError):
TestModel(vec=[[1.0] * 7])
with pytest.raises(pydantic.ValidationError):
TestModel(vec=[[1.0] * 9])
TestModel(vec=[[1.0] * 8])
TestModel(vec=[[1.0] * 8, [2.0] * 8])
TestModel(vec=[])
def test_multi_vector_nullable():
class NullableModel(pydantic.BaseModel):
vec: MultiVector(16, nullable=False)
schema = pydantic_to_schema(NullableModel)
assert schema == pa.schema(
[pa.field("vec", pa.list_(pa.list_(pa.float32(), 16)), False)]
)
class DefaultModel(pydantic.BaseModel):
vec: MultiVector(16)
schema = pydantic_to_schema(DefaultModel)
assert schema == pa.schema(
[pa.field("vec", pa.list_(pa.list_(pa.float32(), 16)), True)]
)
def test_multi_vector_in_lance_model():
class TestModel(LanceModel):
id: int
vectors: MultiVector(16) = Field(default=[[0.0] * 16])
schema = pydantic_to_schema(TestModel)
assert schema == TestModel.to_arrow_schema()
assert TestModel.field_names() == ["id", "vectors"]
t = TestModel(id=1)
assert t.vectors == [[0.0] * 16]

View File

@@ -257,7 +257,9 @@ async def test_distance_range_with_new_rows_async():
}
)
table = await conn.create_table("test", data)
table.create_index("vector", config=IvfPq(num_partitions=1, num_sub_vectors=2))
await table.create_index(
"vector", config=IvfPq(num_partitions=1, num_sub_vectors=2)
)
q = [0, 0]
rs = await table.query().nearest_to(q).to_arrow()

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import re
from concurrent.futures import ThreadPoolExecutor
import contextlib
from datetime import timedelta
@@ -235,6 +235,10 @@ def test_table_add_in_threadpool():
def test_table_create_indices():
def handler(request):
index_stats = dict(
index_type="IVF_PQ", num_indexed_rows=1000, num_unindexed_rows=0
)
if request.path == "/v1/table/test/create_index/":
request.send_response(200)
request.end_headers()
@@ -258,6 +262,47 @@ def test_table_create_indices():
)
)
request.wfile.write(payload.encode())
elif request.path == "/v1/table/test/index/list/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
payload = json.dumps(
dict(
indexes=[
{
"index_name": "id_idx",
"columns": ["id"],
},
{
"index_name": "text_idx",
"columns": ["text"],
},
{
"index_name": "vector_idx",
"columns": ["vector"],
},
]
)
)
request.wfile.write(payload.encode())
elif request.path == "/v1/table/test/index/id_idx/stats/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
payload = json.dumps(index_stats)
request.wfile.write(payload.encode())
elif request.path == "/v1/table/test/index/text_idx/stats/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
payload = json.dumps(index_stats)
request.wfile.write(payload.encode())
elif request.path == "/v1/table/test/index/vector_idx/stats/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
payload = json.dumps(index_stats)
request.wfile.write(payload.encode())
elif "/drop/" in request.path:
request.send_response(200)
request.end_headers()
@@ -269,14 +314,81 @@ def test_table_create_indices():
# Parameters are well-tested through local and async tests.
# This is a smoke-test.
table = db.create_table("test", [{"id": 1}])
table.create_scalar_index("id")
table.create_fts_index("text")
table.create_scalar_index("vector")
table.create_scalar_index("id", wait_timeout=timedelta(seconds=2))
table.create_fts_index("text", wait_timeout=timedelta(seconds=2))
table.create_index(
vector_column_name="vector", wait_timeout=timedelta(seconds=10)
)
table.wait_for_index(["id_idx"], timedelta(seconds=2))
table.wait_for_index(["text_idx", "vector_idx"], timedelta(seconds=2))
table.drop_index("vector_idx")
table.drop_index("id_idx")
table.drop_index("text_idx")
def test_table_wait_for_index_timeout():
def handler(request):
index_stats = dict(
index_type="BTREE", num_indexed_rows=1000, num_unindexed_rows=1
)
if request.path == "/v1/table/test/create/?mode=create":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(b"{}")
elif request.path == "/v1/table/test/describe/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
payload = json.dumps(
dict(
version=1,
schema=dict(
fields=[
dict(name="id", type={"type": "int64"}, nullable=False),
]
),
)
)
request.wfile.write(payload.encode())
elif request.path == "/v1/table/test/index/list/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
payload = json.dumps(
dict(
indexes=[
{
"index_name": "id_idx",
"columns": ["id"],
},
]
)
)
request.wfile.write(payload.encode())
elif request.path == "/v1/table/test/index/id_idx/stats/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
payload = json.dumps(index_stats)
print(f"{index_stats=}")
request.wfile.write(payload.encode())
else:
request.send_response(404)
request.end_headers()
with mock_lancedb_connection(handler) as db:
table = db.create_table("test", [{"id": 1}])
with pytest.raises(
RuntimeError,
match=re.escape(
'Timeout error: timed out waiting for indices: ["id_idx"] after 1s'
),
):
table.wait_for_index(["id_idx"], timedelta(seconds=1))
@contextlib.contextmanager
def query_test_table(query_handler, *, server_version=Version("0.1.0")):
def handler(request):

View File

@@ -9,9 +9,9 @@ from typing import List
from unittest.mock import patch
import lancedb
from lancedb.dependencies import _PANDAS_AVAILABLE
from lancedb.index import HnswPq, HnswSq, IvfPq
import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import pyarrow.dataset
@@ -138,13 +138,16 @@ def test_create_table(mem_db: DBConnection):
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
]
df = pd.DataFrame(rows)
pa_table = pa.Table.from_pandas(df, schema=schema)
pa_table = pa.Table.from_pylist(rows, schema=schema)
data = [
("Rows", rows),
("pd_DataFrame", df),
("pa_Table", pa_table),
]
if _PANDAS_AVAILABLE:
import pandas as pd
df = pd.DataFrame(rows)
data.append(("pd_DataFrame", df))
for name, d in data:
tbl = mem_db.create_table(name, data=d, schema=schema).to_arrow()
@@ -296,7 +299,7 @@ def test_add_subschema(mem_db: DBConnection):
data = {"price": 10.0, "item": "foo"}
table.add([data])
data = pd.DataFrame({"price": [2.0], "vector": [[3.1, 4.1]]})
data = pa.Table.from_pydict({"price": [2.0], "vector": [[3.1, 4.1]]})
table.add(data)
data = {"price": 3.0, "vector": [5.9, 26.5], "item": "bar"}
table.add([data])
@@ -405,6 +408,7 @@ def test_add_nullability(mem_db: DBConnection):
def test_add_pydantic_model(mem_db: DBConnection):
pytest.importorskip("pandas")
# https://github.com/lancedb/lancedb/issues/562
class Metadata(BaseModel):
@@ -473,10 +477,10 @@ def test_polars(mem_db: DBConnection):
table = mem_db.create_table("test", data=pl.DataFrame(data))
assert len(table) == 2
result = table.to_pandas()
assert np.allclose(result["vector"].tolist(), data["vector"])
assert result["item"].tolist() == data["item"]
assert np.allclose(result["price"].tolist(), data["price"])
result = table.to_arrow()
assert np.allclose(result["vector"].to_pylist(), data["vector"])
assert result["item"].to_pylist() == data["item"]
assert np.allclose(result["price"].to_pylist(), data["price"])
schema = pa.schema(
[
@@ -688,7 +692,7 @@ def test_delete(mem_db: DBConnection):
assert len(table.list_versions()) == 2
assert table.version == 2
assert len(table) == 1
assert table.to_pandas()["id"].tolist() == [1]
assert table.to_arrow()["id"].to_pylist() == [1]
def test_update(mem_db: DBConnection):
@@ -852,6 +856,7 @@ def test_merge_insert(mem_db: DBConnection):
ids=["pa.Table", "pd.DataFrame", "rows"],
)
def test_merge_insert_subschema(mem_db: DBConnection, data_format):
pytest.importorskip("pandas")
initial_data = pa.table(
{"id": range(3), "a": [1.0, 2.0, 3.0], "c": ["x", "x", "x"]}
)
@@ -948,7 +953,7 @@ def test_create_with_embedding_function(mem_db: DBConnection):
func = MockTextEmbeddingFunction.create()
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
df = pd.DataFrame({"text": texts, "vector": func.compute_source_embeddings(texts)})
df = pa.table({"text": texts, "vector": func.compute_source_embeddings(texts)})
conf = EmbeddingFunctionConfig(
source_column="text", vector_column="vector", function=func
@@ -973,7 +978,7 @@ def test_create_f16_table(mem_db: DBConnection):
text: str
vector: Vector(32, value_type=pa.float16())
df = pd.DataFrame(
df = pa.table(
{
"text": [f"s-{i}" for i in range(512)],
"vector": [np.random.randn(32).astype(np.float16) for _ in range(512)],
@@ -986,7 +991,7 @@ def test_create_f16_table(mem_db: DBConnection):
table.add(df)
table.create_index(num_partitions=2, num_sub_vectors=2)
query = df.vector.iloc[2]
query = df["vector"][2].as_py()
expected = table.search(query).limit(2).to_arrow()
assert "s-2" in expected["text"].to_pylist()
@@ -1002,7 +1007,7 @@ def test_add_with_embedding_function(mem_db: DBConnection):
table = mem_db.create_table("my_table", schema=MyTable)
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
df = pd.DataFrame({"text": texts})
df = pa.table({"text": texts})
table.add(df)
texts = ["the quick brown fox", "jumped over the lazy dog"]
@@ -1033,14 +1038,14 @@ def test_multiple_vector_columns(mem_db: DBConnection):
{"vector1": v1, "vector2": v2, "text": "foo"},
{"vector1": v2, "vector2": v1, "text": "bar"},
]
df = pd.DataFrame(data)
df = pa.Table.from_pylist(data)
table.add(df)
q = np.random.randn(10)
result1 = table.search(q, vector_column_name="vector1").limit(1).to_pandas()
result2 = table.search(q, vector_column_name="vector2").limit(1).to_pandas()
result1 = table.search(q, vector_column_name="vector1").limit(1).to_arrow()
result2 = table.search(q, vector_column_name="vector2").limit(1).to_arrow()
assert result1["text"].iloc[0] != result2["text"].iloc[0]
assert result1["text"][0] != result2["text"][0]
def test_create_scalar_index(mem_db: DBConnection):
@@ -1078,22 +1083,22 @@ def test_empty_query(mem_db: DBConnection):
"my_table",
data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}],
)
df = table.search().select(["id"]).where("text='bar'").limit(1).to_pandas()
val = df.id.iloc[0]
df = table.search().select(["id"]).where("text='bar'").limit(1).to_arrow()
val = df["id"][0].as_py()
assert val == 1
table = mem_db.create_table("my_table2", data=[{"id": i} for i in range(100)])
df = table.search().select(["id"]).to_pandas()
assert len(df) == 100
df = table.search().select(["id"]).to_arrow()
assert df.num_rows == 100
# None is the same as default
df = table.search().select(["id"]).limit(None).to_pandas()
assert len(df) == 100
df = table.search().select(["id"]).limit(None).to_arrow()
assert df.num_rows == 100
# invalid limist is the same as None, wihch is the same as default
df = table.search().select(["id"]).limit(-1).to_pandas()
assert len(df) == 100
df = table.search().select(["id"]).limit(-1).to_arrow()
assert df.num_rows == 100
# valid limit should work
df = table.search().select(["id"]).limit(42).to_pandas()
assert len(df) == 42
df = table.search().select(["id"]).limit(42).to_arrow()
assert df.num_rows == 42
def test_search_with_schema_inf_single_vector(mem_db: DBConnection):
@@ -1112,14 +1117,14 @@ def test_search_with_schema_inf_single_vector(mem_db: DBConnection):
{"vector_col": v1, "text": "foo"},
{"vector_col": v2, "text": "bar"},
]
df = pd.DataFrame(data)
df = pa.Table.from_pylist(data)
table.add(df)
q = np.random.randn(10)
result1 = table.search(q, vector_column_name="vector_col").limit(1).to_pandas()
result2 = table.search(q).limit(1).to_pandas()
result1 = table.search(q, vector_column_name="vector_col").limit(1).to_arrow()
result2 = table.search(q).limit(1).to_arrow()
assert result1["text"].iloc[0] == result2["text"].iloc[0]
assert result1["text"][0].as_py() == result2["text"][0].as_py()
def test_search_with_schema_inf_multiple_vector(mem_db: DBConnection):
@@ -1139,12 +1144,12 @@ def test_search_with_schema_inf_multiple_vector(mem_db: DBConnection):
{"vector1": v1, "vector2": v2, "text": "foo"},
{"vector1": v2, "vector2": v1, "text": "bar"},
]
df = pd.DataFrame(data)
df = pa.Table.from_pylist(data)
table.add(df)
q = np.random.randn(10)
with pytest.raises(ValueError):
table.search(q).limit(1).to_pandas()
table.search(q).limit(1).to_arrow()
def test_compact_cleanup(tmp_db: DBConnection):

View File

@@ -652,6 +652,11 @@ impl HybridQuery {
self.inner_vec.bypass_vector_index();
}
#[pyo3(signature = (lower_bound=None, upper_bound=None))]
pub fn distance_range(&mut self, lower_bound: Option<f32>, upper_bound: Option<f32>) {
self.inner_vec.distance_range(lower_bound, upper_bound);
}
pub fn to_vector_query(&mut self) -> PyResult<VectorQuery> {
Ok(VectorQuery {
inner: self.inner_vec.inner.clone(),

View File

@@ -177,15 +177,19 @@ impl Table {
})
}
#[pyo3(signature = (column, index=None, replace=None))]
#[pyo3(signature = (column, index=None, replace=None, wait_timeout=None))]
pub fn create_index<'a>(
self_: PyRef<'a, Self>,
column: String,
index: Option<Bound<'_, PyAny>>,
replace: Option<bool>,
wait_timeout: Option<Bound<'_, PyAny>>,
) -> PyResult<Bound<'a, PyAny>> {
let index = extract_index_params(&index)?;
let mut op = self_.inner_ref()?.create_index(&[column], index);
let timeout = wait_timeout.map(|t| t.extract::<std::time::Duration>().unwrap());
let mut op = self_
.inner_ref()?
.create_index_with_timeout(&[column], index, timeout);
if let Some(replace) = replace {
op = op.replace(replace);
}
@@ -204,6 +208,26 @@ impl Table {
})
}
pub fn wait_for_index<'a>(
self_: PyRef<'a, Self>,
index_names: Vec<String>,
timeout: Bound<'_, PyAny>,
) -> PyResult<Bound<'a, PyAny>> {
let inner = self_.inner_ref()?.clone();
let timeout = timeout.extract::<std::time::Duration>()?;
future_into_py(self_.py(), async move {
let index_refs = index_names
.iter()
.map(String::as_str)
.collect::<Vec<&str>>();
inner
.wait_for_index(&index_refs, timeout)
.await
.infer_error()?;
Ok(())
})
}
pub fn prewarm_index(self_: PyRef<'_, Self>, index_name: String) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-node"
version = "0.19.0-beta.7"
version = "0.19.0-beta.10"
description = "Serverless, low-latency vector database for AI applications"
license.workspace = true
edition.workspace = true

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb"
version = "0.19.0-beta.7"
version = "0.19.0-beta.10"
edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true

View File

@@ -35,6 +35,8 @@ pub enum Error {
Schema { message: String },
#[snafu(display("Runtime error: {message}"))]
Runtime { message: String },
#[snafu(display("Timeout error: {message}"))]
Timeout { message: String },
// 3rd party / external errors
#[snafu(display("object_store error: {source}"))]

View File

@@ -1,11 +1,11 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::Arc;
use scalar::FtsIndexBuilder;
use serde::Deserialize;
use serde_with::skip_serializing_none;
use std::sync::Arc;
use std::time::Duration;
use vector::IvfFlatIndexBuilder;
use crate::{table::BaseTable, DistanceType, Error, Result};
@@ -17,6 +17,7 @@ use self::{
pub mod scalar;
pub mod vector;
pub mod waiter;
/// Supported index types.
#[derive(Debug, Clone)]
@@ -69,6 +70,7 @@ pub struct IndexBuilder {
pub(crate) index: Index,
pub(crate) columns: Vec<String>,
pub(crate) replace: bool,
pub(crate) wait_timeout: Option<Duration>,
}
impl IndexBuilder {
@@ -78,6 +80,7 @@ impl IndexBuilder {
index,
columns,
replace: true,
wait_timeout: None,
}
}
@@ -91,6 +94,15 @@ impl IndexBuilder {
self
}
/// Duration of time to wait for asynchronous indexing to complete. If not set,
/// `create_index()` will not wait.
///
/// This is not supported for `NativeTable` since indexing is synchronous.
pub fn wait_timeout(mut self, d: Duration) -> Self {
self.wait_timeout = Some(d);
self
}
pub async fn execute(self) -> Result<()> {
self.parent.clone().create_index(self).await
}

View File

@@ -0,0 +1,89 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use crate::error::Result;
use crate::table::BaseTable;
use crate::Error;
use log::debug;
use std::time::{Duration, Instant};
use tokio::time::sleep;
const DEFAULT_SLEEP_MS: u64 = 1000;
const MAX_WAIT: Duration = Duration::from_secs(2 * 60 * 60);
/// Poll the table using list_indices() and index_stats() until all of the indices have 0 un-indexed rows.
/// Will return Error::Timeout if the columns are not fully indexed within the timeout.
pub async fn wait_for_index(
table: &dyn BaseTable,
index_names: &[&str],
timeout: Duration,
) -> Result<()> {
if timeout > MAX_WAIT {
return Err(Error::InvalidInput {
message: format!("timeout must be less than {:?}", MAX_WAIT),
});
}
let start = Instant::now();
let mut remaining = index_names.to_vec();
// poll via list_indices() and index_stats() until all indices are created and fully indexed
while start.elapsed() < timeout {
let mut completed = vec![];
let indices = table.list_indices().await?;
for &idx in &remaining {
if !indices.iter().any(|i| i.name == *idx) {
debug!("still waiting for new index '{}'", idx);
continue;
}
let stats = table.index_stats(idx.as_ref()).await?;
match stats {
None => {
debug!("still waiting for new index '{}'", idx);
continue;
}
Some(s) => {
if s.num_unindexed_rows == 0 {
// note: this may never stabilize under constant writes.
// we should later replace this with a status/job model
completed.push(idx);
debug!(
"fully indexed '{}'. indexed rows: {}",
idx, s.num_indexed_rows
);
} else {
debug!(
"still waiting for index '{}'. unindexed rows: {}",
idx, s.num_unindexed_rows
);
}
}
}
}
remaining.retain(|idx| !completed.contains(idx));
if remaining.is_empty() {
return Ok(());
}
sleep(Duration::from_millis(DEFAULT_SLEEP_MS)).await;
}
// debug log index diagnostics
for &r in &remaining {
let stats = table.index_stats(r.as_ref()).await?;
match stats {
Some(s) => debug!(
"index '{}' not fully indexed after {:?}. stats: {:?}",
r, timeout, s
),
None => debug!("index '{}' not found after {:?}", r, timeout),
}
}
Err(Error::Timeout {
message: format!(
"timed out waiting for indices: {:?} after {:?}",
remaining, timeout
),
})
}

View File

@@ -8,6 +8,7 @@
pub(crate) mod client;
pub(crate) mod db;
mod retry;
pub(crate) mod table;
pub(crate) mod util;

View File

@@ -1,17 +1,17 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::{collections::HashMap, future::Future, str::FromStr, time::Duration};
use http::HeaderName;
use log::debug;
use reqwest::{
header::{HeaderMap, HeaderValue},
Request, RequestBuilder, Response,
Body, Request, RequestBuilder, Response,
};
use std::{collections::HashMap, future::Future, str::FromStr, time::Duration};
use crate::error::{Error, Result};
use crate::remote::db::RemoteOptions;
use crate::remote::retry::{ResolvedRetryConfig, RetryCounter};
const REQUEST_ID_HEADER: HeaderName = HeaderName::from_static("x-request-id");
@@ -118,41 +118,14 @@ pub struct RetryConfig {
/// You can also set the `LANCE_CLIENT_RETRY_STATUSES` environment variable
/// to set this value. Use a comma-separated list of integer values.
///
/// The default is 429, 500, 502, 503.
/// Note that write operations will never be retried on 5xx errors as this may
/// result in duplicated writes.
///
/// The default is 409, 429, 500, 502, 503, 504.
pub statuses: Option<Vec<u16>>,
// TODO: should we allow customizing methods?
}
#[derive(Debug, Clone)]
struct ResolvedRetryConfig {
retries: u8,
connect_retries: u8,
read_retries: u8,
backoff_factor: f32,
backoff_jitter: f32,
statuses: Vec<reqwest::StatusCode>,
}
impl TryFrom<RetryConfig> for ResolvedRetryConfig {
type Error = Error;
fn try_from(retry_config: RetryConfig) -> Result<Self> {
Ok(Self {
retries: retry_config.retries.unwrap_or(3),
connect_retries: retry_config.connect_retries.unwrap_or(3),
read_retries: retry_config.read_retries.unwrap_or(3),
backoff_factor: retry_config.backoff_factor.unwrap_or(0.25),
backoff_jitter: retry_config.backoff_jitter.unwrap_or(0.25),
statuses: retry_config
.statuses
.unwrap_or_else(|| vec![429, 500, 502, 503])
.into_iter()
.map(|status| reqwest::StatusCode::from_u16(status).unwrap())
.collect(),
})
}
}
// We use the `HttpSend` trait to abstract over the `reqwest::Client` so that
// we can mock responses in tests. Based on the patterns from this blog post:
// https://write.as/balrogboogie/testing-reqwest-based-clients
@@ -160,8 +133,8 @@ impl TryFrom<RetryConfig> for ResolvedRetryConfig {
pub struct RestfulLanceDbClient<S: HttpSend = Sender> {
client: reqwest::Client,
host: String,
retry_config: ResolvedRetryConfig,
sender: S,
pub(crate) retry_config: ResolvedRetryConfig,
pub(crate) sender: S,
}
pub trait HttpSend: Clone + Send + Sync + std::fmt::Debug + 'static {
@@ -375,74 +348,69 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
self.client.post(full_uri)
}
pub async fn send(&self, req: RequestBuilder, with_retry: bool) -> Result<(String, Response)> {
pub async fn send(&self, req: RequestBuilder) -> Result<(String, Response)> {
let (client, request) = req.build_split();
let mut request = request.unwrap();
let request_id = self.extract_request_id(&mut request);
self.log_request(&request, &request_id);
// Set a request id.
// TODO: allow the user to supply this, through middleware?
let request_id = if let Some(request_id) = request.headers().get(REQUEST_ID_HEADER) {
request_id.to_str().unwrap().to_string()
} else {
let request_id = uuid::Uuid::new_v4().to_string();
let header = HeaderValue::from_str(&request_id).unwrap();
request.headers_mut().insert(REQUEST_ID_HEADER, header);
request_id
};
if log::log_enabled!(log::Level::Debug) {
let content_type = request
.headers()
.get("content-type")
.map(|v| v.to_str().unwrap());
if content_type == Some("application/json") {
let body = request.body().as_ref().unwrap().as_bytes().unwrap();
let body = String::from_utf8_lossy(body);
debug!(
"Sending request_id={}: {:?} with body {}",
request_id, request, body
);
} else {
debug!("Sending request_id={}: {:?}", request_id, request);
}
}
if with_retry {
self.send_with_retry_impl(client, request, request_id).await
} else {
let response = self
.sender
.send(&client, request)
.await
.err_to_http(request_id.clone())?;
debug!(
"Received response for request_id={}: {:?}",
request_id, &response
);
Ok((request_id, response))
}
let response = self
.sender
.send(&client, request)
.await
.err_to_http(request_id.clone())?;
debug!(
"Received response for request_id={}: {:?}",
request_id, &response
);
Ok((request_id, response))
}
async fn send_with_retry_impl(
/// Send the request using retries configured in the RetryConfig.
/// If retry_5xx is false, 5xx requests will not be retried regardless of the statuses configured
/// in the RetryConfig.
/// Since this requires arrow serialization, this is implemented here instead of in RestfulLanceDbClient
pub async fn send_with_retry(
&self,
client: reqwest::Client,
req: Request,
request_id: String,
req_builder: RequestBuilder,
mut make_body: Option<Box<dyn FnMut() -> Result<Body> + Send + 'static>>,
retry_5xx: bool,
) -> Result<(String, Response)> {
let mut retry_counter = RetryCounter::new(&self.retry_config, request_id);
let retry_config = &self.retry_config;
let non_5xx_statuses = retry_config
.statuses
.iter()
.filter(|s| !s.is_server_error())
.cloned()
.collect::<Vec<_>>();
// clone and build the request to extract the request id
let tmp_req = req_builder.try_clone().ok_or_else(|| Error::Runtime {
message: "Attempted to retry a request that cannot be cloned".to_string(),
})?;
let (_, r) = tmp_req.build_split();
let mut r = r.unwrap();
let request_id = self.extract_request_id(&mut r);
let mut retry_counter = RetryCounter::new(retry_config, request_id.clone());
loop {
// This only works if the request body is not a stream. If it is
// a stream, we can't use the retry path. We would need to implement
// an outer retry.
let request = req.try_clone().ok_or_else(|| Error::Runtime {
let mut req_builder = req_builder.try_clone().ok_or_else(|| Error::Runtime {
message: "Attempted to retry a request that cannot be cloned".to_string(),
})?;
let response = self
.sender
.send(&client, request)
.await
.map(|r| (r.status(), r));
// set the streaming body on the request builder after clone
if let Some(body_gen) = make_body.as_mut() {
let body = body_gen()?;
req_builder = req_builder.body(body);
}
let (c, request) = req_builder.build_split();
let mut request = request.unwrap();
self.set_request_id(&mut request, &request_id.clone());
self.log_request(&request, &request_id);
let response = self.sender.send(&c, request).await.map(|r| (r.status(), r));
match response {
Ok((status, response)) if status.is_success() => {
debug!(
@@ -451,7 +419,10 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
);
return Ok((retry_counter.request_id, response));
}
Ok((status, response)) if self.retry_config.statuses.contains(&status) => {
Ok((status, response))
if (retry_5xx && retry_config.statuses.contains(&status))
|| non_5xx_statuses.contains(&status) =>
{
let source = self
.check_response(&retry_counter.request_id, response)
.await
@@ -480,6 +451,47 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
}
}
fn log_request(&self, request: &Request, request_id: &String) {
if log::log_enabled!(log::Level::Debug) {
let content_type = request
.headers()
.get("content-type")
.map(|v| v.to_str().unwrap());
if content_type == Some("application/json") {
let body = request.body().as_ref().unwrap().as_bytes().unwrap();
let body = String::from_utf8_lossy(body);
debug!(
"Sending request_id={}: {:?} with body {}",
request_id, request, body
);
} else {
debug!("Sending request_id={}: {:?}", request_id, request);
}
}
}
/// Extract the request ID from the request headers.
/// If the request ID header is not set, this will generate a new one and set
/// it on the request headers
pub fn extract_request_id(&self, request: &mut Request) -> String {
// Set a request id.
// TODO: allow the user to supply this, through middleware?
let request_id = if let Some(request_id) = request.headers().get(REQUEST_ID_HEADER) {
request_id.to_str().unwrap().to_string()
} else {
let request_id = uuid::Uuid::new_v4().to_string();
self.set_request_id(request, &request_id);
request_id
};
request_id
}
/// Set the request ID header
pub fn set_request_id(&self, request: &mut Request, request_id: &str) {
let header = HeaderValue::from_str(request_id).unwrap();
request.headers_mut().insert(REQUEST_ID_HEADER, header);
}
pub async fn check_response(&self, request_id: &str, response: Response) -> Result<Response> {
// Try to get the response text, but if that fails, just return the status code
let status = response.status();
@@ -501,91 +513,6 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
}
}
struct RetryCounter<'a> {
request_failures: u8,
connect_failures: u8,
read_failures: u8,
config: &'a ResolvedRetryConfig,
request_id: String,
}
impl<'a> RetryCounter<'a> {
fn new(config: &'a ResolvedRetryConfig, request_id: String) -> Self {
Self {
request_failures: 0,
connect_failures: 0,
read_failures: 0,
config,
request_id,
}
}
fn check_out_of_retries(
&self,
source: Box<dyn std::error::Error + Send + Sync>,
status_code: Option<reqwest::StatusCode>,
) -> Result<()> {
if self.request_failures >= self.config.retries
|| self.connect_failures >= self.config.connect_retries
|| self.read_failures >= self.config.read_retries
{
Err(Error::Retry {
request_id: self.request_id.clone(),
request_failures: self.request_failures,
max_request_failures: self.config.retries,
connect_failures: self.connect_failures,
max_connect_failures: self.config.connect_retries,
read_failures: self.read_failures,
max_read_failures: self.config.read_retries,
source,
status_code,
})
} else {
Ok(())
}
}
fn increment_request_failures(&mut self, source: crate::Error) -> Result<()> {
self.request_failures += 1;
let status_code = if let crate::Error::Http { status_code, .. } = &source {
*status_code
} else {
None
};
self.check_out_of_retries(Box::new(source), status_code)
}
fn increment_connect_failures(&mut self, source: reqwest::Error) -> Result<()> {
self.connect_failures += 1;
let status_code = source.status();
self.check_out_of_retries(Box::new(source), status_code)
}
fn increment_read_failures(&mut self, source: reqwest::Error) -> Result<()> {
self.read_failures += 1;
let status_code = source.status();
self.check_out_of_retries(Box::new(source), status_code)
}
fn next_sleep_time(&self) -> Duration {
let backoff = self.config.backoff_factor * (2.0f32.powi(self.request_failures as i32));
let jitter = rand::random::<f32>() * self.config.backoff_jitter;
let sleep_time = Duration::from_secs_f32(backoff + jitter);
debug!(
"Retrying request {:?} ({}/{} connect, {}/{} read, {}/{} read) in {:?}",
self.request_id,
self.connect_failures,
self.config.connect_retries,
self.request_failures,
self.config.retries,
self.read_failures,
self.config.read_retries,
sleep_time
);
sleep_time
}
}
pub trait RequestResultExt {
type Output;
fn err_to_http(self, request_id: String) -> Result<Self::Output>;

View File

@@ -255,7 +255,7 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
if let Some(start_after) = request.start_after {
req = req.query(&[("page_token", start_after)]);
}
let (request_id, rsp) = self.client.send(req, true).await?;
let (request_id, rsp) = self.client.send_with_retry(req, None, true).await?;
let rsp = self.client.check_response(&request_id, rsp).await?;
let version = parse_server_version(&request_id, &rsp)?;
let tables = rsp
@@ -302,7 +302,7 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
.body(data_buffer)
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE);
let (request_id, rsp) = self.client.send(req, false).await?;
let (request_id, rsp) = self.client.send(req).await?;
if rsp.status() == StatusCode::BAD_REQUEST {
let body = rsp.text().await.err_to_http(request_id.clone())?;
@@ -362,7 +362,7 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
let req = self
.client
.post(&format!("/v1/table/{}/describe/", request.name));
let (request_id, rsp) = self.client.send(req, true).await?;
let (request_id, rsp) = self.client.send_with_retry(req, None, true).await?;
if rsp.status() == StatusCode::NOT_FOUND {
return Err(crate::Error::TableNotFound { name: request.name });
}
@@ -383,7 +383,7 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
.client
.post(&format!("/v1/table/{}/rename/", current_name));
let req = req.json(&serde_json::json!({ "new_table_name": new_name }));
let (request_id, resp) = self.client.send(req, false).await?;
let (request_id, resp) = self.client.send(req).await?;
self.client.check_response(&request_id, resp).await?;
let table = self.table_cache.remove(current_name).await;
if let Some(table) = table {
@@ -394,7 +394,7 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
async fn drop_table(&self, name: &str) -> Result<()> {
let req = self.client.post(&format!("/v1/table/{}/drop/", name));
let (request_id, resp) = self.client.send(req, true).await?;
let (request_id, resp) = self.client.send(req).await?;
self.client.check_response(&request_id, resp).await?;
self.table_cache.remove(name).await;
Ok(())

View File

@@ -0,0 +1,122 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use crate::remote::RetryConfig;
use crate::Error;
use log::debug;
use std::time::Duration;
pub struct RetryCounter<'a> {
pub request_failures: u8,
pub connect_failures: u8,
pub read_failures: u8,
pub config: &'a ResolvedRetryConfig,
pub request_id: String,
}
impl<'a> RetryCounter<'a> {
pub(crate) fn new(config: &'a ResolvedRetryConfig, request_id: String) -> Self {
Self {
request_failures: 0,
connect_failures: 0,
read_failures: 0,
config,
request_id,
}
}
fn check_out_of_retries(
&self,
source: Box<dyn std::error::Error + Send + Sync>,
status_code: Option<reqwest::StatusCode>,
) -> crate::Result<()> {
if self.request_failures >= self.config.retries
|| self.connect_failures >= self.config.connect_retries
|| self.read_failures >= self.config.read_retries
{
Err(Error::Retry {
request_id: self.request_id.clone(),
request_failures: self.request_failures,
max_request_failures: self.config.retries,
connect_failures: self.connect_failures,
max_connect_failures: self.config.connect_retries,
read_failures: self.read_failures,
max_read_failures: self.config.read_retries,
source,
status_code,
})
} else {
Ok(())
}
}
pub fn increment_request_failures(&mut self, source: crate::Error) -> crate::Result<()> {
self.request_failures += 1;
let status_code = if let crate::Error::Http { status_code, .. } = &source {
*status_code
} else {
None
};
self.check_out_of_retries(Box::new(source), status_code)
}
pub fn increment_connect_failures(&mut self, source: reqwest::Error) -> crate::Result<()> {
self.connect_failures += 1;
let status_code = source.status();
self.check_out_of_retries(Box::new(source), status_code)
}
pub fn increment_read_failures(&mut self, source: reqwest::Error) -> crate::Result<()> {
self.read_failures += 1;
let status_code = source.status();
self.check_out_of_retries(Box::new(source), status_code)
}
pub fn next_sleep_time(&self) -> Duration {
let backoff = self.config.backoff_factor * (2.0f32.powi(self.request_failures as i32));
let jitter = rand::random::<f32>() * self.config.backoff_jitter;
let sleep_time = Duration::from_secs_f32(backoff + jitter);
debug!(
"Retrying request {:?} ({}/{} connect, {}/{} read, {}/{} read) in {:?}",
self.request_id,
self.connect_failures,
self.config.connect_retries,
self.request_failures,
self.config.retries,
self.read_failures,
self.config.read_retries,
sleep_time
);
sleep_time
}
}
#[derive(Debug, Clone)]
pub struct ResolvedRetryConfig {
pub retries: u8,
pub connect_retries: u8,
pub read_retries: u8,
pub backoff_factor: f32,
pub backoff_jitter: f32,
pub statuses: Vec<reqwest::StatusCode>,
}
impl TryFrom<RetryConfig> for ResolvedRetryConfig {
type Error = Error;
fn try_from(retry_config: RetryConfig) -> crate::Result<Self> {
Ok(Self {
retries: retry_config.retries.unwrap_or(3),
connect_retries: retry_config.connect_retries.unwrap_or(3),
read_retries: retry_config.read_retries.unwrap_or(3),
backoff_factor: retry_config.backoff_factor.unwrap_or(0.25),
backoff_jitter: retry_config.backoff_jitter.unwrap_or(0.25),
statuses: retry_config
.statuses
.unwrap_or_else(|| vec![409, 429, 500, 502, 503, 504])
.into_iter()
.map(|status| reqwest::StatusCode::from_u16(status).unwrap())
.collect(),
})
}
}

View File

@@ -1,17 +1,13 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::io::Cursor;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use crate::index::Index;
use crate::index::IndexStatistics;
use crate::query::{QueryFilter, QueryRequest, Select, VectorQueryRequest};
use crate::table::{AddDataMode, AnyQuery, Filter};
use crate::utils::{supported_btree_data_type, supported_vector_data_type};
use crate::{DistanceType, Error, Table};
use arrow_array::RecordBatchReader;
use arrow_array::{RecordBatch, RecordBatchIterator, RecordBatchReader};
use arrow_ipc::reader::FileReader;
use arrow_schema::{DataType, SchemaRef};
use async_trait::async_trait;
@@ -25,9 +21,19 @@ use lance::arrow::json::{JsonDataType, JsonSchema};
use lance::dataset::scanner::DatasetRecordBatchStream;
use lance::dataset::{ColumnAlteration, NewColumnTransform, Version};
use lance_datafusion::exec::{execute_plan, OneShotExec};
use reqwest::{RequestBuilder, Response};
use serde::{Deserialize, Serialize};
use std::io::Cursor;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::sync::RwLock;
use super::client::RequestResultExt;
use super::client::{HttpSend, RestfulLanceDbClient, Sender};
use super::db::ServerVersion;
use super::ARROW_STREAM_CONTENT_TYPE;
use crate::index::waiter::wait_for_index;
use crate::{
connection::NoData,
error::Result,
@@ -39,11 +45,6 @@ use crate::{
},
};
use super::client::RequestResultExt;
use super::client::{HttpSend, RestfulLanceDbClient, Sender};
use super::db::ServerVersion;
use super::ARROW_STREAM_CONTENT_TYPE;
const REQUEST_TIMEOUT_HEADER: HeaderName = HeaderName::from_static("x-request-timeout-ms");
#[derive(Debug)]
@@ -83,7 +84,7 @@ impl<S: HttpSend> RemoteTable<S> {
let body = serde_json::json!({ "version": version });
request = request.json(&body);
let (request_id, response) = self.client.send(request, true).await?;
let (request_id, response) = self.send(request, true).await?;
let response = self.check_table_response(&request_id, response).await?;
@@ -127,6 +128,61 @@ impl<S: HttpSend> RemoteTable<S> {
Ok(reqwest::Body::wrap_stream(body_stream))
}
/// Buffer the reader into memory
async fn buffer_reader<R: RecordBatchReader + ?Sized>(
reader: &mut R,
) -> Result<(SchemaRef, Vec<RecordBatch>)> {
let schema = reader.schema();
let mut batches = Vec::new();
for batch in reader {
batches.push(batch?);
}
Ok((schema, batches))
}
/// Create a new RecordBatchReader from buffered data
fn make_reader(schema: SchemaRef, batches: Vec<RecordBatch>) -> impl RecordBatchReader {
let iter = batches.into_iter().map(Ok);
RecordBatchIterator::new(iter, schema)
}
async fn send(&self, req: RequestBuilder, with_retry: bool) -> Result<(String, Response)> {
let res = if with_retry {
self.client.send_with_retry(req, None, true).await?
} else {
self.client.send(req).await?
};
Ok(res)
}
/// Send the request with streaming body.
/// This will use retries if with_retry is set and the number of configured retries is > 0.
/// If retries are enabled, the stream will be buffered into memory.
async fn send_streaming(
&self,
req: RequestBuilder,
mut data: Box<dyn RecordBatchReader + Send>,
with_retry: bool,
) -> Result<(String, Response)> {
if !with_retry || self.client.retry_config.retries == 0 {
let body = Self::reader_as_body(data)?;
return self.client.send(req.body(body)).await;
}
// to support retries, buffer into memory and clone the batches on each retry
let (schema, batches) = Self::buffer_reader(&mut *data).await?;
let make_body = Box::new(move || {
let reader = Self::make_reader(schema.clone(), batches.clone());
Self::reader_as_body(Box::new(reader))
});
let res = self
.client
.send_with_retry(req, Some(make_body), false)
.await?;
Ok(res)
}
async fn check_table_response(
&self,
request_id: &str,
@@ -168,7 +224,8 @@ impl<S: HttpSend> RemoteTable<S> {
}
// Server requires k.
let limit = params.limit.unwrap_or(usize::MAX);
// use isize::MAX as usize to avoid overflow: https://github.com/lancedb/lancedb/issues/2211
let limit = params.limit.unwrap_or(isize::MAX as usize);
body["k"] = serde_json::Value::Number(serde_json::Number::from(limit));
if let Some(filter) = &params.filter {
@@ -339,8 +396,6 @@ impl<S: HttpSend> RemoteTable<S> {
let mut request = self.client.post(&format!("/v1/table/{}/query/", self.name));
if let Some(timeout) = options.timeout {
// Client side timeout
request = request.timeout(timeout);
// Also send to server, so it can abort the query if it takes too long.
// (If it doesn't fit into u64, it's not worth sending anyways.)
if let Ok(timeout_ms) = u64::try_from(timeout.as_millis()) {
@@ -355,11 +410,29 @@ impl<S: HttpSend> RemoteTable<S> {
.collect();
let futures = requests.into_iter().map(|req| async move {
let (request_id, response) = self.client.send(req, true).await?;
let (request_id, response) = self.send(req, true).await?;
self.read_arrow_stream(&request_id, response).await
});
let streams = futures::future::try_join_all(futures).await?;
Ok(streams)
let streams = futures::future::try_join_all(futures);
if let Some(timeout) = options.timeout {
let timeout_future = tokio::time::sleep(timeout);
tokio::pin!(timeout_future);
tokio::pin!(streams);
tokio::select! {
_ = &mut timeout_future => {
Err(Error::Other {
message: format!("Query timeout after {} ms", timeout.as_millis()),
source: None,
})
}
result = &mut streams => {
Ok(result?)
}
}
} else {
Ok(streams.await?)
}
}
async fn prepare_query_bodies(&self, query: &AnyQuery) -> Result<Vec<serde_json::Value>> {
@@ -455,7 +528,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
let body = serde_json::json!({ "version": version });
request = request.json(&body);
let (request_id, response) = self.client.send(request, true).await?;
let (request_id, response) = self.send(request, true).await?;
self.check_table_response(&request_id, response).await?;
self.checkout_latest().await?;
Ok(())
@@ -465,7 +538,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
let request = self
.client
.post(&format!("/v1/table/{}/version/list/", self.name));
let (request_id, response) = self.client.send(request, true).await?;
let (request_id, response) = self.send(request, true).await?;
let response = self.check_table_response(&request_id, response).await?;
#[derive(Deserialize)]
@@ -511,7 +584,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
request = request.json(&body);
}
let (request_id, response) = self.client.send(request, true).await?;
let (request_id, response) = self.send(request, true).await?;
let response = self.check_table_response(&request_id, response).await?;
@@ -529,12 +602,10 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
data: Box<dyn RecordBatchReader + Send>,
) -> Result<()> {
self.check_mutable().await?;
let body = Self::reader_as_body(data)?;
let mut request = self
.client
.post(&format!("/v1/table/{}/insert/", self.name))
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
.body(body);
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE);
match add.mode {
AddDataMode::Append => {}
@@ -543,8 +614,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
}
}
let (request_id, response) = self.client.send(request, false).await?;
let (request_id, response) = self.send_streaming(request, data, true).await?;
self.check_table_response(&request_id, response).await?;
Ok(())
@@ -612,7 +682,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
.collect::<Vec<_>>();
let futures = requests.into_iter().map(|req| async move {
let (request_id, response) = self.client.send(req, true).await?;
let (request_id, response) = self.send(req, true).await?;
let response = self.check_table_response(&request_id, response).await?;
let body = response.text().await.err_to_http(request_id.clone())?;
@@ -654,7 +724,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
.collect();
let futures = requests.into_iter().map(|req| async move {
let (request_id, response) = self.client.send(req, true).await?;
let (request_id, response) = self.send(req, true).await?;
let response = self.check_table_response(&request_id, response).await?;
let body = response.text().await.err_to_http(request_id.clone())?;
@@ -696,7 +766,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
"predicate": update.filter,
}));
let (request_id, response) = self.client.send(request, false).await?;
let (request_id, response) = self.send(request, true).await?;
self.check_table_response(&request_id, response).await?;
@@ -710,7 +780,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
.client
.post(&format!("/v1/table/{}/delete/", self.name))
.json(&body);
let (request_id, response) = self.client.send(request, false).await?;
let (request_id, response) = self.send(request, true).await?;
self.check_table_response(&request_id, response).await?;
Ok(())
}
@@ -796,34 +866,45 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
let request = request.json(&body);
let (request_id, response) = self.client.send(request, false).await?;
let (request_id, response) = self.send(request, true).await?;
self.check_table_response(&request_id, response).await?;
if let Some(wait_timeout) = index.wait_timeout {
let name = format!("{}_idx", column);
self.wait_for_index(&[&name], wait_timeout).await?;
}
Ok(())
}
/// Poll until the columns are fully indexed. Will return Error::Timeout if the columns
/// are not fully indexed within the timeout.
async fn wait_for_index(&self, index_names: &[&str], timeout: Duration) -> Result<()> {
wait_for_index(self, index_names, timeout).await
}
async fn merge_insert(
&self,
params: MergeInsertBuilder,
new_data: Box<dyn RecordBatchReader + Send>,
) -> Result<()> {
self.check_mutable().await?;
let query = MergeInsertRequest::try_from(params)?;
let body = Self::reader_as_body(new_data)?;
let request = self
.client
.post(&format!("/v1/table/{}/merge_insert/", self.name))
.query(&query)
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
.body(body);
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE);
let (request_id, response) = self.client.send(request, false).await?;
let (request_id, response) = self.send_streaming(request, new_data, true).await?;
self.check_table_response(&request_id, response).await?;
Ok(())
}
async fn optimize(&self, _action: OptimizeAction) -> Result<OptimizeStats> {
self.check_mutable().await?;
Err(Error::NotSupported {
@@ -852,7 +933,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
.client
.post(&format!("/v1/table/{}/add_columns/", self.name))
.json(&body);
let (request_id, response) = self.client.send(request, false).await?;
let (request_id, response) = self.send(request, true).await?; // todo:
self.check_table_response(&request_id, response).await?;
Ok(())
}
@@ -891,7 +972,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
.client
.post(&format!("/v1/table/{}/alter_columns/", self.name))
.json(&body);
let (request_id, response) = self.client.send(request, false).await?;
let (request_id, response) = self.send(request, true).await?;
self.check_table_response(&request_id, response).await?;
Ok(())
}
@@ -903,7 +984,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
.client
.post(&format!("/v1/table/{}/drop_columns/", self.name))
.json(&body);
let (request_id, response) = self.client.send(request, false).await?;
let (request_id, response) = self.send(request, true).await?;
self.check_table_response(&request_id, response).await?;
Ok(())
}
@@ -917,7 +998,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
let body = serde_json::json!({ "version": version });
request = request.json(&body);
let (request_id, response) = self.client.send(request, true).await?;
let (request_id, response) = self.send(request, true).await?;
let response = self.check_table_response(&request_id, response).await?;
#[derive(Deserialize)]
@@ -974,7 +1055,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
let body = serde_json::json!({ "version": version });
request = request.json(&body);
let (request_id, response) = self.client.send(request, true).await?;
let (request_id, response) = self.send(request, true).await?;
if response.status() == StatusCode::NOT_FOUND {
return Ok(None);
@@ -998,7 +1079,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
"/v1/table/{}/index/{}/drop/",
self.name, index_name
));
let (request_id, response) = self.client.send(request, true).await?;
let (request_id, response) = self.send(request, true).await?;
self.check_table_response(&request_id, response).await?;
Ok(())
}
@@ -1459,6 +1540,42 @@ mod tests {
assert_eq!(&body, &expected_body);
}
#[tokio::test]
async fn test_merge_insert_retries_on_409() {
let batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let data = Box::new(RecordBatchIterator::new(
[Ok(batch.clone())],
batch.schema(),
));
// Default parameters
let table = Table::new_with_handler("my_table", |request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/v1/table/my_table/merge_insert/");
let params = request.url().query_pairs().collect::<HashMap<_, _>>();
assert_eq!(params["on"], "some_col");
assert_eq!(params["when_matched_update_all"], "false");
assert_eq!(params["when_not_matched_insert_all"], "false");
assert_eq!(params["when_not_matched_by_source_delete"], "false");
assert!(!params.contains_key("when_matched_update_all_filt"));
assert!(!params.contains_key("when_not_matched_by_source_delete_filt"));
http::Response::builder().status(409).body("").unwrap()
});
let e = table
.merge_insert(&["some_col"])
.execute(data)
.await
.unwrap_err();
assert!(e.to_string().contains("Hit retry limit"));
}
#[tokio::test]
async fn test_delete() {
let table = Table::new_with_handler("my_table", |request| {
@@ -1500,7 +1617,7 @@ mod tests {
let body = request.body().unwrap().as_bytes().unwrap();
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
let expected_body = serde_json::json!({
"k": usize::MAX,
"k": isize::MAX as usize,
"prefilter": true,
"vector": [], // Empty vector means no vector query.
"version": null,
@@ -2416,4 +2533,88 @@ mod tests {
});
table.drop_index("my_index").await.unwrap();
}
#[tokio::test]
async fn test_wait_for_index() {
let table = _make_table_with_indices(0);
table
.wait_for_index(&["vector_idx", "my_idx"], Duration::from_secs(1))
.await
.unwrap();
}
#[tokio::test]
async fn test_wait_for_index_timeout() {
let table = _make_table_with_indices(100);
let e = table
.wait_for_index(&["vector_idx", "my_idx"], Duration::from_secs(1))
.await
.unwrap_err();
assert_eq!(
e.to_string(),
"Timeout error: timed out waiting for indices: [\"vector_idx\", \"my_idx\"] after 1s"
);
}
#[tokio::test]
async fn test_wait_for_index_timeout_never_created() {
let table = _make_table_with_indices(0);
let e = table
.wait_for_index(&["doesnt_exist_idx"], Duration::from_secs(1))
.await
.unwrap_err();
assert_eq!(
e.to_string(),
"Timeout error: timed out waiting for indices: [\"doesnt_exist_idx\"] after 1s"
);
}
fn _make_table_with_indices(unindexed_rows: usize) -> Table {
let table = Table::new_with_handler("my_table", move |request| {
assert_eq!(request.method(), "POST");
let response_body = match request.url().path() {
"/v1/table/my_table/index/list/" => {
serde_json::json!({
"indexes": [
{
"index_name": "vector_idx",
"index_uuid": "3fa85f64-5717-4562-b3fc-2c963f66afa6",
"columns": ["vector"],
"index_status": "done",
},
{
"index_name": "my_idx",
"index_uuid": "34255f64-5717-4562-b3fc-2c963f66afa6",
"columns": ["my_column"],
"index_status": "done",
},
]
})
}
"/v1/table/my_table/index/vector_idx/stats/" => {
serde_json::json!({
"num_indexed_rows": 100000,
"num_unindexed_rows": unindexed_rows,
"index_type": "IVF_PQ",
"distance_type": "l2"
})
}
"/v1/table/my_table/index/my_idx/stats/" => {
serde_json::json!({
"num_indexed_rows": 100000,
"num_unindexed_rows": unindexed_rows,
"index_type": "LABEL_LIST"
})
}
_path => {
serde_json::json!(None::<String>)
}
};
let body = serde_json::to_string(&response_body).unwrap();
let status = if body == "null" { 404 } else { 200 };
http::Response::builder().status(status).body(body).unwrap()
});
table
}
}

View File

@@ -3,10 +3,6 @@
//! LanceDB Table APIs
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use arrow::array::{AsArray, FixedSizeListBuilder, Float32Builder};
use arrow::datatypes::{Float32Type, UInt8Type};
use arrow_array::{RecordBatchIterator, RecordBatchReader};
@@ -45,6 +41,10 @@ use lance_table::format::Manifest;
use lance_table::io::commit::ManifestNamingScheme;
use log::info;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::format;
use std::path::Path;
use std::sync::Arc;
use crate::arrow::IntoArrow;
use crate::connection::NoData;
@@ -78,6 +78,7 @@ pub mod datafusion;
pub(crate) mod dataset;
pub mod merge;
use crate::index::waiter::wait_for_index;
pub use chrono::Duration;
pub use lance::dataset::optimize::CompactionOptions;
pub use lance::dataset::scanner::DatasetRecordBatchStream;
@@ -491,6 +492,13 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
async fn table_definition(&self) -> Result<TableDefinition>;
/// Get the table URI
fn dataset_uri(&self) -> &str;
/// Poll until the columns are fully indexed. Will return Error::Timeout if the columns
/// are not fully indexed within the timeout.
async fn wait_for_index(
&self,
index_names: &[&str],
timeout: std::time::Duration,
) -> Result<()>;
}
/// A Table is a collection of strong typed Rows.
@@ -769,6 +777,28 @@ impl Table {
)
}
/// See [Table::create_index]
/// For remote tables, this allows an optional wait_timeout to poll until asynchronous indexing is complete
pub fn create_index_with_timeout(
&self,
columns: &[impl AsRef<str>],
index: Index,
wait_timeout: Option<std::time::Duration>,
) -> IndexBuilder {
let mut builder = IndexBuilder::new(
self.inner.clone(),
columns
.iter()
.map(|val| val.as_ref().to_string())
.collect::<Vec<_>>(),
index,
);
if let Some(timeout) = wait_timeout {
builder = builder.wait_timeout(timeout);
}
builder
}
/// Create a builder for a merge insert operation
///
/// This operation can add rows, update rows, and remove rows all in a single
@@ -1104,6 +1134,16 @@ impl Table {
self.inner.prewarm_index(name).await
}
/// Poll until the columns are fully indexed. Will return Error::Timeout if the columns
/// are not fully indexed within the timeout.
pub async fn wait_for_index(
&self,
index_names: &[&str],
timeout: std::time::Duration,
) -> Result<()> {
self.inner.wait_for_index(index_names, timeout).await
}
// Take many execution plans and map them into a single plan that adds
// a query_index column and unions them.
pub(crate) fn multi_vector_plan(
@@ -2430,6 +2470,16 @@ impl BaseTable for NativeTable {
loss,
}))
}
/// Poll until the columns are fully indexed. Will return Error::Timeout if the columns
/// are not fully indexed within the timeout.
async fn wait_for_index(
&self,
index_names: &[&str],
timeout: std::time::Duration,
) -> Result<()> {
wait_for_index(self, index_names, timeout).await
}
}
#[cfg(test)]
@@ -3213,7 +3263,10 @@ mod tests {
.execute()
.await
.unwrap();
table
.wait_for_index(&["embeddings_idx"], Duration::from_millis(10))
.await
.unwrap();
let index_configs = table.list_indices().await.unwrap();
assert_eq!(index_configs.len(), 1);
let index = index_configs.into_iter().next().unwrap();
@@ -3281,7 +3334,10 @@ mod tests {
.execute()
.await
.unwrap();
table
.wait_for_index(&["i_idx"], Duration::from_millis(10))
.await
.unwrap();
let index_configs = table.list_indices().await.unwrap();
assert_eq!(index_configs.len(), 1);
let index = index_configs.into_iter().next().unwrap();