Compare commits

...

33 Commits

Author SHA1 Message Date
Wyatt Alt
bc76671b62 add nodejs test 2026-02-28 22:18:37 -08:00
Wyatt Alt
bc814e1f66 lint js docs 2026-02-28 21:50:40 -08:00
Wyatt Alt
82a6e5a196 doctest update 2026-02-28 20:25:29 -08:00
Wyatt Alt
b3100f0e7c add test 2026-02-28 19:26:00 -08:00
Wyatt Alt
8ec60e3c9d feat: add num_deleted to delete result 2026-02-28 19:26:00 -08:00
Lance Release
0498ac1f2f Bump version: 0.27.0-beta.2 → 0.27.0-beta.3 2026-02-28 01:31:51 +00:00
Lance Release
aeb1c3ee6a Bump version: 0.30.0-beta.2 → 0.30.0-beta.3 2026-02-28 01:29:53 +00:00
Weston Pace
f9ae46c0e7 feat: upgrade lance to 3.0.0-rc.2 and add bindings for fast_search (#3083) 2026-02-27 17:27:01 -08:00
Will Jones
84bf022fb1 fix(python): pin pylance to make datafusion table provider match version (#3080) 2026-02-27 13:34:05 -08:00
Will Jones
310967eceb ci(rust): fix linux job (#3076) 2026-02-26 19:25:46 -08:00
Jack Ye
154dbeee2a chore: fix clippy for PreprocessingOutput without remote feature (#3070)
Fix clippy:

```
error: fields `overwrite` and `rescannable` are never read
Error:    --> /home/runner/work/xxxx/xxxx/src/lancedb/rust/lancedb/src/table/add_data.rs:158:9
    |
156 | pub struct PreprocessingOutput {
    |            ------------------- fields in this struct
157 |     pub plan: Arc<dyn datafusion_physical_plan::ExecutionPlan>,
158 |     pub overwrite: bool,
    |         ^^^^^^^^^
159 |     pub rescannable: bool,
    |         ^^^^^^^^^^^
    |
    = note: `-D dead-code` implied by `-D warnings`
    = help: to override `-D warnings` add `#[allow(dead_code)]`
```
2026-02-25 14:59:32 -08:00
Lance Release
c9c08ac8b9 Bump version: 0.27.0-beta.1 → 0.27.0-beta.2 2026-02-25 07:47:54 +00:00
Lance Release
e253f5d9b6 Bump version: 0.30.0-beta.1 → 0.30.0-beta.2 2026-02-25 07:46:06 +00:00
LanceDB Robot
05b4fb0990 chore: update lance dependency to v3.1.0-beta.2 (#3068)
## Summary
- Bump Lance Rust workspace dependencies to `v3.1.0-beta.2` via
`ci/set_lance_version.py`.
- Update Java `lance-core.version` to `3.1.0-beta.2`.

## Verification
- `cargo clippy --workspace --tests --all-features -- -D warnings`
- `cargo fmt --all`

## Release Reference
- refs/tags/v3.1.0-beta.2
2026-02-24 23:02:22 -08:00
Mesut-Doner
613b9c1099 feat(rust): add expression builder API for type-safe query filters (#3032)
## Summary

Adds a Rust expression builder API as a type-safe alternative to SQL
strings for query filters.

## Motivation

Filtering with raw SQL strings can be awkward when using variables and
special types:


Closes  #3038

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
2026-02-24 18:44:03 -08:00
Will Jones
d5948576b9 feat: parallel inserts for local tables (#3062)
When input data is sufficiently large, we automatically split up into
parallel writes using a round-robin exchange operator. We sample the
first batch to determine data width, and target size of 1 million rows
or 2GB, whichever is smaller.
2026-02-24 12:26:51 -08:00
Will Jones
0d3fc7860a ci: fix python DataFusion test (#3060) 2026-02-24 07:59:12 -08:00
Weston Pace
531cec075c fix: don't expect all offsets to fit in one batch in permutation reader (#3065)
This would cause takes against large permutations to fail
2026-02-24 06:32:54 -08:00
Will Jones
0e486511fa feat: hook up new writer for insert (#3029)
This hooks up a new writer implementation for the `add()` method. The
main immediate benefit is it allows streaming requests to remote tables,
and at the same time allowing retries for most inputs.

In NodeJS, we always convert the data to `Vec<RecordBatch>`, so it's
always retry-able.

For Python, all are retry-able, except `Iterator` and
`pa.RecordBatchReader`, which can only be consumed once. Some, like
`pa.datasets.Dataset` are retry-able *and* streaming.

A lot of the changes here are to make the new DataFusion write pipeline
maintain the same behavior as the existing Python-based preprocessing,
such as:

* casting input data to target schema
* rejecting NaN values if `on_bad_vectors="error"`
* applying embedding functions.

In future PRs, we'll enhance these by moving the embedding calls into
DataFusion and making sure we parallelize them. See:
https://github.com/lancedb/lancedb/issues/3048

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 14:43:31 -08:00
Will Jones
367262662d feat(nodejs): upgrade napi-rs from v2 to v3 (#3057)
## Summary

- Upgrades `@napi-rs/cli` from v2 to v3, `napi`/`napi-derive` Rust
crates to 3.x
- Fixes a bug
([napi-rs#1170](https://github.com/napi-rs/napi-rs/issues/1170)) where
the CLI failed to locate the built `.node` binary when a custom Cargo
target directory is set (via `config.toml`)

## Changes

**package.json / CLI**:
- `napi.name` → `napi.binaryName`, `napi.triples` → `napi.targets`
- Removed `--no-const-enum` flag and fixed output dir arg
- `napi universal` → `napi universalize`

**Rust API migration**:
- `#[napi::module_init]` → `#[napi_derive::module_init]`
- `napi::JsObject` → `Object`, `.get::<_, T>()` → `.get::<T>()`
- `ErrorStrategy` removed; `ThreadsafeFunction` now takes an explicit
`Return` type with `CalleeHandled = false` const generic
- `JsFunction` + `create_threadsafe_function` replaced by typed
`Function<Args, Return>` + `build_threadsafe_function().build()`
- `RerankerCallbacks` struct removed (`Function<'env,...>` can't be
stored in structs); `VectorQuery::rerank` now accepts the function
directly
- `ClassInstance::clone()` now returns `ClassInstance`, fixed with
explicit deref
- `Vec<u8>` in `#[napi(object)]` now maps to `Array<number>` in v3;
changed to `Buffer` to preserve the TypeScript `Buffer` type

**TypeScript**:
- `inner.rerank({ rerankHybrid: async (_, args) => ... })` →
`inner.rerank(async (args) => ...)`
- Header provider callback wrapped in `async` to match stricter typed
constructor signature

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-02-23 14:42:55 -08:00
Lance Release
11efaf46ae Bump version: 0.27.0-beta.0 → 0.27.0-beta.1 2026-02-23 18:34:48 +00:00
Lance Release
1ea22ee5ef Bump version: 0.30.0-beta.0 → 0.30.0-beta.1 2026-02-23 18:33:28 +00:00
LanceDB Robot
8cef8806e9 chore: update lance dependency to v3.0.0-beta.5 (#3058)
## Summary
- Bump Lance Rust dependencies and Java `lance-core` to v3.0.0-beta.5
(refs/tags/v3.0.0-beta.5).
- Update workspace toolchain and dependency defaults needed for the new
Lance release.
- Resolve new clippy lint defaults introduced by the toolchain update.

## Validation
- `cargo clippy --workspace --tests --all-features -- -D warnings`
- `cargo fmt --all`

---------

Co-authored-by: Jack Ye <yezhaoqin@gmail.com>
2026-02-23 00:39:30 -08:00
Will Jones
a3cd7fce69 fix: update DatasetConsistencyWrapper to accept same-version updates (#3055)
## Summary

`DatasetConsistencyWrapper::update()` only stored datasets with a
strictly newer
version. This caused `migrate_manifest_paths_v2` to silently drop its
update since
the migration renames files without bumping the dataset version. The
subsequent
`uses_v2_manifest_paths()` call would then return the stale cached
dataset.

Changed the version check from `>` to `>=` so same-version updates are
accepted.

## Test plan

- [x] Existing `test_create_table_v2_manifest_paths_async` Python test
should pass
- [x] Existing `should be able to migrate tables to the V2 manifest
paths` NodeJS test should pass
- [x] All dataset wrapper unit tests pass locally


🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-20 16:01:15 -08:00
Will Jones
48ddc833dd feat: check for dataset updates in the background (#3021)
This updates `DatasetConsistencyWrapper` to block less:

1. `DatasetConsistencyWrapper::get()` just returns `Arc<Dataset>` now,
instead of a guard that blocks writes.
`DatasetConsistencyWrapper::get_mut()` is gone; now write methods just
use `get()` and then later call `update()` with the new version. This
means a given table handle can do concurrent reads **and** writes.
2. In weak consistency mode, will check for dataset updates in the
background, instead of blocking calls to `get()`.

---------

Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-02-20 11:18:33 -08:00
Varun Chawla
2802764092 fix(embeddings): stop retrying OpenAI 401 authentication errors (#2995)
## Summary
Fixes #1679

This PR prevents the OpenAI embedding function from retrying when
receiving a 401 Unauthorized error. Authentication errors are permanent
failures that won't be fixed by retrying, yet the current implementation
retries all exceptions up to 7 times by default.

## Changes
- Modified `retry_with_exponential_backoff` in `utils.py` to check for
non-retryable errors before retrying
- Added `_is_non_retryable_error` helper function that detects:
  - Exceptions with name `AuthenticationError` (OpenAI's 401 error)
  - Exceptions with `status_code` attribute of 401 or 403
- Enhanced OpenAI embeddings to explicitly catch and re-raise
`AuthenticationError` with better logging
- Added unit test `test_openai_no_retry_on_401` to verify authentication
errors don't trigger retries

## Test Plan
- Added test that verifies:
  1. A function raising `AuthenticationError` is only called once
  2. No retry delays occur (sleep is never called)
- Existing tests continue to pass
- Formatting applied via `make format`

## Example Behavior

**Before**: With an invalid API key, users would see 7 retry attempts
over ~2 minutes:
```
WARNING:root:Error occurred: Error code: 401 - {'error': {'message': 'Incorrect API key provided...'}}
 Retrying in 3.97 seconds (retry 1 of 7)
WARNING:root:Error occurred: Error code: 401...
 Retrying in 7.94 seconds (retry 2 of 7)
...
```

**After**: With an invalid API key, the error is raised immediately:
```
ERROR:root:Authentication failed: Invalid API key provided
AuthenticationError: Error code: 401 - {'error': {'message': 'Incorrect API key provided...'}}
```

This provides better UX and prevents unnecessary API calls that would
fail anyway.

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
2026-02-19 09:20:54 -08:00
Weston Pace
37bbb0dba1 fix: allow permutation reader to work with remote tables as well (#3047)
Fixed one more spot that was relying on `_inner`.
2026-02-19 00:41:41 +05:30
Prashanth Rao
155ec16161 fix: deprecate outdated files for embedding registry (#3037)
There are old and outdated files in our embedding registry that can
confuse coding agents. This PR deprecates the following files that have
newer, more modern methods to generate such embeddings.

- Deprecate `embeddings/siglip.py` 
- Deprecate `embeddings/gte.py` 

## Why this change?

Per a discussion with @AyushExel, the [embedding registry directory
](1840aa7edc/python/python/lancedb/embeddings)
in the LanceDB repo has a number of outdated files that need to be
deprecated.

See https://github.com/lancedb/docs/issues/85 for the docs gaps that
identified this.
- Add note in `openclip` docs that it can be used for SigLip embeddings,
which it now supports
- Add note in the `sentence-transformers` page that ALL text embedding
models on Hugging Face can be used
2026-02-18 12:04:39 -05:00
Weston Pace
636b8b5bbd fix: allow permutation reader to be used with remote tables (#3019)
There were two issues:

1. The python code needs to get access to the underlying rust table to
setup the permutation reader and the attributes involved in this differ
between the python local table and remote table objects.
~~2. The remote table was sending projection dictionaries as arrays of
tuples and (on LanceDB cloud at least) it does not appear this is how
rest servers are setup to receive them.~~ (this is now fixed as #3023)

~~Leaving as draft as this is built on
https://github.com/lancedb/lancedb/pull/3016~~
2026-02-18 05:44:08 -08:00
Omair Afzal
715b81c86b fix(python): graceful handling of empty result sets in hybrid search (#3030)
## Problem

When applying hard filters that result in zero matches, hybrid search
crashes with `IndexError: list index out of range` during reranking.
This happens because empty result tables are passed through the full
reranker pipeline, which expects at least one result.

Traceback from the issue:
```
lancedb/query.py: in _combine_hybrid_results
    results = reranker.rerank_hybrid(fts_query, vector_results, fts_results)
lancedb/rerankers/answerdotai.py: in rerank_hybrid
    combined_results = self._rerank(combined_results, query)
...
IndexError: list index out of range
```

## Fix

Added an early return in `_combine_hybrid_results` when both vector and
FTS results are empty. Instead of passing empty tables through
normalization, reranking, and score restoration (which can fail in
various ways), we now build a properly-typed empty result table with the
`_relevance_score` column and return it directly.

## Test

Added `test_empty_hybrid_result_reranker` that exercises
`_combine_hybrid_results` directly with empty vector and FTS tables,
verifying:
- Returns empty table with correct schema  
- Includes `_relevance_score` column
- Respects `with_row_ids` flag

Closes #2425
2026-02-17 11:37:10 -08:00
Omair Afzal
7e1616376e refactor: extract merge_insert into table/merge.rs submodule (#3031)
Completes the **merge_insert.rs** checklist item from #2949.

## Changes

- Moved `MergeResult` struct from `table.rs` to `table/merge.rs`
- Moved the `NativeTable::merge_insert` implementation into
`merge::execute_merge_insert()`, with the trait impl now delegating to
it (same pattern as `delete.rs`)
- Moved `test_merge_insert` and `test_merge_insert_use_index` tests into
`table/merge.rs`
- Improved moved tests to use `memory://` URIs instead of temporary
directories
- Cleaned up unused imports from `table.rs` (`FutureExt`,
`TryFutureExt`, `Either`, `WhenMatched`, `WhenNotMatchedBySource`,
`LanceMergeInsertBuilder`)
- `MergeResult` is re-exported from `table.rs` so the public API is
unchanged

## Testing

`cargo build -p lancedb` compiles cleanly with no warnings.
2026-02-17 11:36:53 -08:00
ChinmayGowda71
d5ac5b949a refactor(rust): extract query logic to src/table/query.rs (#3035)
References #2949 Moved query logic and helpers from table.rs to
query.rs. Refactored tests using guidelines and added coverage for multi
vector plan structure.
2026-02-17 09:04:21 -08:00
Lance Release
7be6f45e0b Bump version: 0.26.2 → 0.27.0-beta.0 2026-02-17 00:28:24 +00:00
86 changed files with 8210 additions and 2599 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.26.2"
current_version = "0.27.0-beta.3"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -8,6 +8,7 @@ on:
paths:
- Cargo.toml
- nodejs/**
- rust/**
- docs/src/js/**
- .github/workflows/nodejs.yml
- docker-compose.yml
@@ -77,8 +78,11 @@ jobs:
fetch-depth: 0
lfs: true
- uses: actions/setup-node@v3
name: Setup Node.js 20 for build
with:
node-version: ${{ matrix.node-version }}
# @napi-rs/cli v3 requires Node >= 20.12 (via @inquirer/prompts@8).
# Build always on Node 20; tests run on the matrix version below.
node-version: 20
cache: 'npm'
cache-dependency-path: nodejs/package-lock.json
- uses: Swatinem/rust-cache@v2
@@ -86,12 +90,16 @@ jobs:
run: |
sudo apt update
sudo apt install -y protobuf-compiler libssl-dev
npm install -g @napi-rs/cli
- name: Build
run: |
npm ci --include=optional
npm run build:debug -- --profile ci
npm run tsc
- uses: actions/setup-node@v3
name: Setup Node.js ${{ matrix.node-version }} for test
with:
node-version: ${{ matrix.node-version }}
- name: Compile TypeScript
run: npm run tsc
- name: Setup localstack
working-directory: .
run: docker compose up --detach --wait
@@ -144,7 +152,6 @@ jobs:
- name: Install dependencies
run: |
brew install protobuf
npm install -g @napi-rs/cli
- name: Build
run: |
npm ci --include=optional

View File

@@ -128,16 +128,13 @@ jobs:
- target: x86_64-unknown-linux-musl
# This one seems to need some extra memory
host: ubuntu-2404-8x-x64
# https://github.com/napi-rs/napi-rs/blob/main/alpine.Dockerfile
docker: ghcr.io/napi-rs/napi-rs/nodejs-rust:lts-alpine
features: fp16kernels
pre_build: |-
set -e &&
apk add protobuf-dev curl &&
ln -s /usr/lib/gcc/x86_64-alpine-linux-musl/14.2.0/crtbeginS.o /usr/lib/crtbeginS.o &&
ln -s /usr/lib/libgcc_s.so /usr/lib/libgcc.so &&
CC=gcc &&
CXX=g++
sudo apt-get update &&
sudo apt-get install -y protobuf-compiler pkg-config &&
rustup target add x86_64-unknown-linux-musl &&
export EXTRA_ARGS="-x"
- target: aarch64-unknown-linux-gnu
host: ubuntu-2404-8x-x64
# https://github.com/napi-rs/napi-rs/blob/main/debian-aarch64.Dockerfile
@@ -153,15 +150,13 @@ jobs:
rustup target add aarch64-unknown-linux-gnu
- target: aarch64-unknown-linux-musl
host: ubuntu-2404-8x-x64
# https://github.com/napi-rs/napi-rs/blob/main/alpine.Dockerfile
docker: ghcr.io/napi-rs/napi-rs/nodejs-rust:lts-alpine
features: ","
pre_build: |-
set -e &&
apk add protobuf-dev &&
sudo apt-get update &&
sudo apt-get install -y protobuf-compiler &&
rustup target add aarch64-unknown-linux-musl &&
export CC_aarch64_unknown_linux_musl=aarch64-linux-musl-gcc &&
export CXX_aarch64_unknown_linux_musl=aarch64-linux-musl-g++
export EXTRA_ARGS="-x"
name: build - ${{ matrix.settings.target }}
runs-on: ${{ matrix.settings.host }}
defaults:
@@ -192,12 +187,18 @@ jobs:
.cargo-cache
target/
key: nodejs-${{ matrix.settings.target }}-cargo-${{ matrix.settings.host }}
- name: Setup toolchain
run: ${{ matrix.settings.setup }}
if: ${{ matrix.settings.setup }}
shell: bash
- name: Install dependencies
run: npm ci
- name: Install Zig
uses: mlugg/setup-zig@v2
if: ${{ contains(matrix.settings.target, 'musl') }}
with:
version: 0.14.1
- name: Install cargo-zigbuild
uses: taiki-e/install-action@v2
if: ${{ contains(matrix.settings.target, 'musl') }}
with:
tool: cargo-zigbuild
- name: Build in docker
uses: addnab/docker-run-action@v3
if: ${{ matrix.settings.docker }}
@@ -210,24 +211,24 @@ jobs:
run: |
set -e
${{ matrix.settings.pre_build }}
npx napi build --platform --release --no-const-enum \
npx napi build --platform --release \
--features ${{ matrix.settings.features }} \
--target ${{ matrix.settings.target }} \
--dts ../lancedb/native.d.ts \
--js ../lancedb/native.js \
--strip \
dist/
--output-dir dist/
- name: Build
run: |
${{ matrix.settings.pre_build }}
npx napi build --platform --release --no-const-enum \
npx napi build --platform --release \
--features ${{ matrix.settings.features }} \
--target ${{ matrix.settings.target }} \
--dts ../lancedb/native.d.ts \
--js ../lancedb/native.js \
--strip \
$EXTRA_ARGS \
dist/
--output-dir dist/
if: ${{ !matrix.settings.docker }}
shell: bash
- name: Upload artifact

View File

@@ -8,6 +8,7 @@ on:
paths:
- Cargo.toml
- python/**
- rust/**
- .github/workflows/python.yml
concurrency:

View File

@@ -100,7 +100,9 @@ jobs:
lfs: true
- uses: Swatinem/rust-cache@v2
- name: Install dependencies
run: sudo apt install -y protobuf-compiler libssl-dev
run: |
sudo apt update
sudo apt install -y protobuf-compiler libssl-dev
- uses: rui314/setup-mold@v1
- name: Make Swap
run: |
@@ -183,7 +185,7 @@ jobs:
runs-on: ubuntu-24.04
strategy:
matrix:
msrv: ["1.88.0"] # This should match up with rust-version in Cargo.toml
msrv: ["1.91.0"] # This should match up with rust-version in Cargo.toml
env:
# Need up-to-date compilers for kernels
CC: clang-18

877
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -12,23 +12,23 @@ repository = "https://github.com/lancedb/lancedb"
description = "Serverless, low-latency vector database for AI applications"
keywords = ["lancedb", "lance", "database", "vector", "search"]
categories = ["database-implementations"]
rust-version = "1.88.0"
rust-version = "1.91.0"
[workspace.dependencies]
lance = { "version" = "=2.0.1", default-features = false }
lance-core = "=2.0.1"
lance-datagen = "=2.0.1"
lance-file = "=2.0.1"
lance-io = { "version" = "=2.0.1", default-features = false }
lance-index = "=2.0.1"
lance-linalg = "=2.0.1"
lance-namespace = "=2.0.1"
lance-namespace-impls = { "version" = "=2.0.1", default-features = false }
lance-table = "=2.0.1"
lance-testing = "=2.0.1"
lance-datafusion = "=2.0.1"
lance-encoding = "=2.0.1"
lance-arrow = "=2.0.1"
lance = { "version" = "=3.0.0-rc.2", default-features = false, "tag" = "v3.0.0-rc.2", "git" = "https://github.com/lance-format/lance.git" }
lance-core = { "version" = "=3.0.0-rc.2", "tag" = "v3.0.0-rc.2", "git" = "https://github.com/lance-format/lance.git" }
lance-datagen = { "version" = "=3.0.0-rc.2", "tag" = "v3.0.0-rc.2", "git" = "https://github.com/lance-format/lance.git" }
lance-file = { "version" = "=3.0.0-rc.2", "tag" = "v3.0.0-rc.2", "git" = "https://github.com/lance-format/lance.git" }
lance-io = { "version" = "=3.0.0-rc.2", default-features = false, "tag" = "v3.0.0-rc.2", "git" = "https://github.com/lance-format/lance.git" }
lance-index = { "version" = "=3.0.0-rc.2", "tag" = "v3.0.0-rc.2", "git" = "https://github.com/lance-format/lance.git" }
lance-linalg = { "version" = "=3.0.0-rc.2", "tag" = "v3.0.0-rc.2", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace = { "version" = "=3.0.0-rc.2", "tag" = "v3.0.0-rc.2", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace-impls = { "version" = "=3.0.0-rc.2", default-features = false, "tag" = "v3.0.0-rc.2", "git" = "https://github.com/lance-format/lance.git" }
lance-table = { "version" = "=3.0.0-rc.2", "tag" = "v3.0.0-rc.2", "git" = "https://github.com/lance-format/lance.git" }
lance-testing = { "version" = "=3.0.0-rc.2", "tag" = "v3.0.0-rc.2", "git" = "https://github.com/lance-format/lance.git" }
lance-datafusion = { "version" = "=3.0.0-rc.2", "tag" = "v3.0.0-rc.2", "git" = "https://github.com/lance-format/lance.git" }
lance-encoding = { "version" = "=3.0.0-rc.2", "tag" = "v3.0.0-rc.2", "git" = "https://github.com/lance-format/lance.git" }
lance-arrow = { "version" = "=3.0.0-rc.2", "tag" = "v3.0.0-rc.2", "git" = "https://github.com/lance-format/lance.git" }
ahash = "0.8"
# Note that this one does not include pyarrow
arrow = { version = "57.2", optional = false }
@@ -40,13 +40,15 @@ arrow-schema = "57.2"
arrow-select = "57.2"
arrow-cast = "57.2"
async-trait = "0"
datafusion = { version = "51.0", default-features = false }
datafusion-catalog = "51.0"
datafusion-common = { version = "51.0", default-features = false }
datafusion-execution = "51.0"
datafusion-expr = "51.0"
datafusion-physical-plan = "51.0"
datafusion-physical-expr = "51.0"
datafusion = { version = "52.1", default-features = false }
datafusion-catalog = "52.1"
datafusion-common = { version = "52.1", default-features = false }
datafusion-execution = "52.1"
datafusion-expr = "52.1"
datafusion-functions = "52.1"
datafusion-physical-plan = "52.1"
datafusion-physical-expr = "52.1"
datafusion-sql = "52.1"
env_logger = "0.11"
half = { "version" = "2.7.1", default-features = false, features = [
"num-traits",

View File

@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
<dependency>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-core</artifactId>
<version>0.26.2</version>
<version>0.27.0-beta.3</version>
</dependency>
```

View File

@@ -8,6 +8,14 @@
## Properties
### numDeletedRows
```ts
numDeletedRows: number;
```
***
### version
```ts

View File

@@ -8,7 +8,7 @@
<parent>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.26.2-final.0</version>
<version>0.27.0-beta.3</version>
<relativePath>../pom.xml</relativePath>
</parent>

View File

@@ -6,7 +6,7 @@
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.26.2-final.0</version>
<version>0.27.0-beta.3</version>
<packaging>pom</packaging>
<name>${project.artifactId}</name>
<description>LanceDB Java SDK Parent POM</description>
@@ -28,7 +28,7 @@
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<arrow.version>15.0.0</arrow.version>
<lance-core.version>2.0.1</lance-core.version>
<lance-core.version>3.1.0-beta.2</lance-core.version>
<spotless.skip>false</spotless.skip>
<spotless.version>2.30.0</spotless.version>
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>

View File

@@ -1,7 +1,7 @@
[package]
name = "lancedb-nodejs"
edition.workspace = true
version = "0.26.2"
version = "0.27.0-beta.3"
license.workspace = true
description.workspace = true
repository.workspace = true
@@ -19,11 +19,11 @@ arrow-schema.workspace = true
env_logger.workspace = true
futures.workspace = true
lancedb = { path = "../rust/lancedb", default-features = false }
napi = { version = "2.16.8", default-features = false, features = [
napi = { version = "3.8.3", default-features = false, features = [
"napi9",
"async"
] }
napi-derive = "2.16.4"
napi-derive = "3.5.2"
# Prevent dynamic linking of lzma, which comes from datafusion
lzma-sys = { version = "*", features = ["static"] }
log.workspace = true
@@ -33,7 +33,7 @@ aws-lc-sys = "=0.28.0"
aws-lc-rs = "=1.13.0"
[build-dependencies]
napi-build = "2.1"
napi-build = "2.3.1"
[features]
default = ["remote", "lancedb/aws", "lancedb/gcs", "lancedb/azure", "lancedb/dynamodb", "lancedb/oss", "lancedb/huggingface"]

View File

@@ -450,6 +450,31 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
},
);
describe("delete", () => {
let tmpDir: tmp.DirResult;
let table: Table;
beforeEach(async () => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
const conn = await connect(tmpDir.name);
table = await conn.createTable("delete_test", [
{ id: 1, value: "a" },
{ id: 2, value: "b" },
{ id: 3, value: "c" },
{ id: 4, value: "d" },
{ id: 5, value: "e" },
]);
});
afterEach(() => tmpDir.removeCallback());
test("returns num_deleted_rows", async () => {
const result = await table.delete("id > 3");
expect(result.numDeletedRows).toBe(2);
expect(result.version).toBe(2);
expect(await table.countRows()).toBe(3);
});
});
describe("merge insert", () => {
let tmpDir: tmp.DirResult;
let table: Table;

View File

@@ -273,7 +273,9 @@ export async function connect(
let nativeProvider: NativeJsHeaderProvider | undefined;
if (finalHeaderProvider) {
if (typeof finalHeaderProvider === "function") {
nativeProvider = new NativeJsHeaderProvider(finalHeaderProvider);
nativeProvider = new NativeJsHeaderProvider(async () =>
finalHeaderProvider(),
);
} else if (
finalHeaderProvider &&
typeof finalHeaderProvider.getHeaders === "function"

View File

@@ -684,19 +684,17 @@ export class VectorQuery extends StandardQueryBase<NativeVectorQuery> {
rerank(reranker: Reranker): VectorQuery {
super.doCall((inner) =>
inner.rerank({
rerankHybrid: async (_, args) => {
const vecResults = await fromBufferToRecordBatch(args.vecResults);
const ftsResults = await fromBufferToRecordBatch(args.ftsResults);
const result = await reranker.rerankHybrid(
args.query,
vecResults as RecordBatch,
ftsResults as RecordBatch,
);
inner.rerank(async (args) => {
const vecResults = await fromBufferToRecordBatch(args.vecResults);
const ftsResults = await fromBufferToRecordBatch(args.ftsResults);
const result = await reranker.rerankHybrid(
args.query,
vecResults as RecordBatch,
ftsResults as RecordBatch,
);
const buffer = fromRecordBatchToBuffer(result);
return buffer;
},
const buffer = fromRecordBatchToBuffer(result);
return buffer;
}),
);

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-arm64",
"version": "0.26.2",
"version": "0.27.0-beta.3",
"os": ["darwin"],
"cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.26.2",
"version": "0.27.0-beta.3",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-musl",
"version": "0.26.2",
"version": "0.27.0-beta.3",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.26.2",
"version": "0.27.0-beta.3",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-musl",
"version": "0.26.2",
"version": "0.27.0-beta.3",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-arm64-msvc",
"version": "0.26.2",
"version": "0.27.0-beta.3",
"os": [
"win32"
],

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.26.2",
"version": "0.27.0-beta.3",
"os": ["win32"],
"cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node",

1781
nodejs/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@@ -11,7 +11,7 @@
"ann"
],
"private": false,
"version": "0.26.2",
"version": "0.27.0-beta.3",
"main": "dist/index.js",
"exports": {
".": "./dist/index.js",
@@ -21,19 +21,16 @@
},
"types": "dist/index.d.ts",
"napi": {
"name": "lancedb",
"triples": {
"defaults": false,
"additional": [
"aarch64-apple-darwin",
"x86_64-unknown-linux-gnu",
"aarch64-unknown-linux-gnu",
"x86_64-unknown-linux-musl",
"aarch64-unknown-linux-musl",
"x86_64-pc-windows-msvc",
"aarch64-pc-windows-msvc"
]
}
"binaryName": "lancedb",
"targets": [
"aarch64-apple-darwin",
"x86_64-unknown-linux-gnu",
"aarch64-unknown-linux-gnu",
"x86_64-unknown-linux-musl",
"aarch64-unknown-linux-musl",
"x86_64-pc-windows-msvc",
"aarch64-pc-windows-msvc"
]
},
"license": "Apache-2.0",
"repository": {
@@ -46,7 +43,7 @@
"@aws-sdk/client-s3": "^3.33.0",
"@biomejs/biome": "^1.7.3",
"@jest/globals": "^29.7.0",
"@napi-rs/cli": "^2.18.3",
"@napi-rs/cli": "^3.5.1",
"@types/axios": "^0.14.0",
"@types/jest": "^29.1.2",
"@types/node": "^22.7.4",
@@ -75,9 +72,9 @@
"os": ["darwin", "linux", "win32"],
"scripts": {
"artifacts": "napi artifacts",
"build:debug": "napi build --platform --no-const-enum --dts ../lancedb/native.d.ts --js ../lancedb/native.js lancedb",
"build:debug": "napi build --platform --dts ../lancedb/native.d.ts --js ../lancedb/native.js --output-dir lancedb",
"postbuild:debug": "shx mkdir -p dist && shx cp lancedb/*.node dist/",
"build:release": "napi build --platform --no-const-enum --release --dts ../lancedb/native.d.ts --js ../lancedb/native.js dist/",
"build:release": "napi build --platform --release --dts ../lancedb/native.d.ts --js ../lancedb/native.js --output-dir dist",
"postbuild:release": "shx mkdir -p dist && shx cp lancedb/*.node dist/",
"build": "npm run build:debug && npm run tsc",
"build-release": "npm run build:release && npm run tsc",
@@ -91,7 +88,7 @@
"prepublishOnly": "napi prepublish -t npm",
"test": "jest --verbose",
"integration": "S3_TEST=1 npm run test",
"universal": "napi universal",
"universal": "napi universalize",
"version": "napi version"
},
"dependencies": {

View File

@@ -1,20 +1,19 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use napi::{
bindgen_prelude::*,
threadsafe_function::{ErrorStrategy, ThreadsafeFunction},
};
use napi::{bindgen_prelude::*, threadsafe_function::ThreadsafeFunction};
use napi_derive::napi;
use std::collections::HashMap;
use std::sync::Arc;
type GetHeadersFn = ThreadsafeFunction<(), Promise<HashMap<String, String>>, (), Status, false>;
/// JavaScript HeaderProvider implementation that wraps a JavaScript callback.
/// This is the only native header provider - all header provider implementations
/// should provide a JavaScript function that returns headers.
#[napi]
pub struct JsHeaderProvider {
get_headers_fn: Arc<ThreadsafeFunction<(), ErrorStrategy::CalleeHandled>>,
get_headers_fn: Arc<GetHeadersFn>,
}
impl Clone for JsHeaderProvider {
@@ -29,9 +28,12 @@ impl Clone for JsHeaderProvider {
impl JsHeaderProvider {
/// Create a new JsHeaderProvider from a JavaScript callback
#[napi(constructor)]
pub fn new(get_headers_callback: JsFunction) -> Result<Self> {
pub fn new(
get_headers_callback: Function<(), Promise<HashMap<String, String>>>,
) -> Result<Self> {
let get_headers_fn = get_headers_callback
.create_threadsafe_function(0, |ctx| Ok(vec![ctx.value]))
.build_threadsafe_function()
.build()
.map_err(|e| {
Error::new(
Status::GenericFailure,
@@ -51,7 +53,7 @@ impl lancedb::remote::HeaderProvider for JsHeaderProvider {
async fn get_headers(&self) -> lancedb::error::Result<HashMap<String, String>> {
// Call the JavaScript function asynchronously
let promise: Promise<HashMap<String, String>> =
self.get_headers_fn.call_async(Ok(())).await.map_err(|e| {
self.get_headers_fn.call_async(()).await.map_err(|e| {
lancedb::error::Error::Runtime {
message: format!("Failed to call JavaScript get_headers: {}", e),
}

View File

@@ -60,7 +60,7 @@ pub struct OpenTableOptions {
pub storage_options: Option<HashMap<String, String>>,
}
#[napi::module_init]
#[napi_derive::module_init]
fn init() {
let env = Env::new()
.filter_or("LANCEDB_LOG", "warn")

View File

@@ -20,8 +20,8 @@ use napi_derive::napi;
use crate::error::convert_error;
use crate::error::NapiErrorExt;
use crate::iterator::RecordBatchIterator;
use crate::rerankers::RerankHybridCallbackArgs;
use crate::rerankers::Reranker;
use crate::rerankers::RerankerCallbacks;
use crate::util::{parse_distance_type, schema_to_buffer};
#[napi]
@@ -42,7 +42,7 @@ impl Query {
}
#[napi]
pub fn full_text_search(&mut self, query: napi::JsObject) -> napi::Result<()> {
pub fn full_text_search(&mut self, query: Object) -> napi::Result<()> {
let query = parse_fts_query(query)?;
self.inner = self.inner.clone().full_text_search(query);
Ok(())
@@ -235,7 +235,7 @@ impl VectorQuery {
}
#[napi]
pub fn full_text_search(&mut self, query: napi::JsObject) -> napi::Result<()> {
pub fn full_text_search(&mut self, query: Object) -> napi::Result<()> {
let query = parse_fts_query(query)?;
self.inner = self.inner.clone().full_text_search(query);
Ok(())
@@ -272,11 +272,13 @@ impl VectorQuery {
}
#[napi]
pub fn rerank(&mut self, callbacks: RerankerCallbacks) {
self.inner = self
.inner
.clone()
.rerank(Arc::new(Reranker::new(callbacks)));
pub fn rerank(
&mut self,
rerank_hybrid: Function<RerankHybridCallbackArgs, Promise<Buffer>>,
) -> napi::Result<()> {
let reranker = Reranker::new(rerank_hybrid)?;
self.inner = self.inner.clone().rerank(Arc::new(reranker));
Ok(())
}
#[napi(catch_unwind)]
@@ -523,12 +525,12 @@ impl JsFullTextQuery {
}
}
fn parse_fts_query(query: napi::JsObject) -> napi::Result<FullTextSearchQuery> {
if let Ok(Some(query)) = query.get::<_, &JsFullTextQuery>("query") {
fn parse_fts_query(query: Object) -> napi::Result<FullTextSearchQuery> {
if let Ok(Some(query)) = query.get::<&JsFullTextQuery>("query") {
Ok(FullTextSearchQuery::new_query(query.inner.clone()))
} else if let Ok(Some(query_text)) = query.get::<_, String>("query") {
} else if let Ok(Some(query_text)) = query.get::<String>("query") {
let mut query_text = query_text;
let columns = query.get::<_, Option<Vec<String>>>("columns")?.flatten();
let columns = query.get::<Option<Vec<String>>>("columns")?.flatten();
let is_phrase =
query_text.len() >= 2 && query_text.starts_with('"') && query_text.ends_with('"');

View File

@@ -3,10 +3,7 @@
use arrow_array::RecordBatch;
use async_trait::async_trait;
use napi::{
bindgen_prelude::*,
threadsafe_function::{ErrorStrategy, ThreadsafeFunction},
};
use napi::{bindgen_prelude::*, threadsafe_function::ThreadsafeFunction};
use napi_derive::napi;
use lancedb::ipc::batches_to_ipc_file;
@@ -15,27 +12,28 @@ use lancedb::{error::Error, ipc::ipc_file_to_batches};
use crate::error::NapiErrorExt;
type RerankHybridFn = ThreadsafeFunction<
RerankHybridCallbackArgs,
Promise<Buffer>,
RerankHybridCallbackArgs,
Status,
false,
>;
/// Reranker implementation that "wraps" a NodeJS Reranker implementation.
/// This contains references to the callbacks that can be used to invoke the
/// reranking methods on the NodeJS implementation and handles serializing the
/// record batches to Arrow IPC buffers.
#[napi]
pub struct Reranker {
/// callback to the Javascript which will call the rerankHybrid method of
/// some Reranker implementation
rerank_hybrid: ThreadsafeFunction<RerankHybridCallbackArgs, ErrorStrategy::CalleeHandled>,
rerank_hybrid: RerankHybridFn,
}
#[napi]
impl Reranker {
#[napi]
pub fn new(callbacks: RerankerCallbacks) -> Self {
let rerank_hybrid = callbacks
.rerank_hybrid
.create_threadsafe_function(0, move |ctx| Ok(vec![ctx.value]))
.unwrap();
Self { rerank_hybrid }
pub fn new(
rerank_hybrid: Function<RerankHybridCallbackArgs, Promise<Buffer>>,
) -> napi::Result<Self> {
let rerank_hybrid = rerank_hybrid.build_threadsafe_function().build()?;
Ok(Self { rerank_hybrid })
}
}
@@ -49,16 +47,16 @@ impl lancedb::rerankers::Reranker for Reranker {
) -> lancedb::error::Result<RecordBatch> {
let callback_args = RerankHybridCallbackArgs {
query: query.to_string(),
vec_results: batches_to_ipc_file(&[vector_results])?,
fts_results: batches_to_ipc_file(&[fts_results])?,
vec_results: Buffer::from(batches_to_ipc_file(&[vector_results])?.as_ref()),
fts_results: Buffer::from(batches_to_ipc_file(&[fts_results])?.as_ref()),
};
let promised_buffer: Promise<Buffer> = self
.rerank_hybrid
.call_async(Ok(callback_args))
.call_async(callback_args)
.await
.map_err(|e| Error::Runtime {
message: format!("napi error status={}, reason={}", e.status, e.reason),
})?;
message: format!("napi error status={}, reason={}", e.status, e.reason),
})?;
let buffer = promised_buffer.await.map_err(|e| Error::Runtime {
message: format!("napi error status={}, reason={}", e.status, e.reason),
})?;
@@ -77,16 +75,11 @@ impl std::fmt::Debug for Reranker {
}
}
#[napi(object)]
pub struct RerankerCallbacks {
pub rerank_hybrid: JsFunction,
}
#[napi(object)]
pub struct RerankHybridCallbackArgs {
pub query: String,
pub vec_results: Vec<u8>,
pub fts_results: Vec<u8>,
pub vec_results: Buffer,
pub fts_results: Buffer,
}
fn buffer_to_record_batch(buffer: Buffer) -> Result<RecordBatch> {

View File

@@ -96,7 +96,6 @@ impl napi::bindgen_prelude::FromNapiValue for Session {
) -> napi::Result<Self> {
let object: napi::bindgen_prelude::ClassInstance<Self> =
napi::bindgen_prelude::ClassInstance::from_napi_value(env, napi_val)?;
let copy = object.clone();
Ok(copy)
Ok((*object).clone())
}
}

View File

@@ -71,6 +71,17 @@ impl Table {
pub async fn add(&self, buf: Buffer, mode: String) -> napi::Result<AddResult> {
let batches = ipc_file_to_batches(buf.to_vec())
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
let batches = batches
.into_iter()
.map(|batch| {
batch.map_err(|e| {
napi::Error::from_reason(format!(
"Failed to read record batch from IPC file: {}",
e
))
})
})
.collect::<Result<Vec<_>>>()?;
let mut op = self.inner_ref()?.add(batches);
op = if mode == "append" {
@@ -742,12 +753,14 @@ impl From<lancedb::table::AddResult> for AddResult {
#[napi(object)]
pub struct DeleteResult {
pub num_deleted_rows: i64,
pub version: i64,
}
impl From<lancedb::table::DeleteResult> for DeleteResult {
fn from(value: lancedb::table::DeleteResult) -> Self {
Self {
num_deleted_rows: value.num_deleted_rows as i64,
version: value.version as i64,
}
}

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.30.0-beta.0"
current_version = "0.30.0-beta.3"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,13 +1,13 @@
[package]
name = "lancedb-python"
version = "0.30.0-beta.0"
version = "0.30.0-beta.3"
edition.workspace = true
description = "Python bindings for LanceDB"
license.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
rust-version = "1.88.0"
rust-version = "1.91.0"
[lib]
name = "_lancedb"

View File

@@ -59,9 +59,9 @@ tests = [
"polars>=0.19, <=1.3.0",
"tantivy",
"pyarrow-stubs",
"pylance>=1.0.0b14",
"pylance>=1.0.0b14,<3.0.0",
"requests",
"datafusion",
"datafusion<52",
]
dev = [
"ruff",

View File

@@ -1,8 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from functools import singledispatch
from typing import List, Optional, Tuple, Union
from lancedb.pydantic import LanceModel, model_to_dict
import pyarrow as pa
from ._lancedb import RecordBatchStream
@@ -80,3 +82,32 @@ def peek_reader(
yield from reader
return batch, pa.RecordBatchReader.from_batches(batch.schema, all_batches())
@singledispatch
def to_arrow(data) -> pa.Table:
"""Convert a single data object to a pa.Table."""
raise NotImplementedError(f"to_arrow not implemented for type {type(data)}")
@to_arrow.register(pa.RecordBatch)
def _arrow_from_batch(data: pa.RecordBatch) -> pa.Table:
return pa.Table.from_batches([data])
@to_arrow.register(pa.Table)
def _arrow_from_table(data: pa.Table) -> pa.Table:
return data
@to_arrow.register(list)
def _arrow_from_list(data: list) -> pa.Table:
if not data:
raise ValueError("Cannot create table from empty list without a schema")
if isinstance(data[0], LanceModel):
schema = data[0].__class__.to_arrow_schema()
dicts = [model_to_dict(d) for d in data]
return pa.Table.from_pylist(dicts, schema=schema)
return pa.Table.from_pylist(data)

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import warnings
from typing import List, Union
import numpy as np
@@ -15,6 +16,8 @@ from .utils import weak_lru
@register("gte-text")
class GteEmbeddings(TextEmbeddingFunction):
"""
Deprecated: GTE embeddings should be used through sentence-transformers.
An embedding function that uses GTE-LARGE MLX format(for Apple silicon devices only)
as well as the standard cpu/gpu version from: https://huggingface.co/thenlper/gte-large.
@@ -61,6 +64,13 @@ class GteEmbeddings(TextEmbeddingFunction):
def __init__(self, **kwargs):
super().__init__(**kwargs)
warnings.warn(
"GTE embeddings as a standalone embedding function are deprecated. "
"Use the 'sentence-transformers' embedding function with a GTE model "
"instead.",
DeprecationWarning,
stacklevel=3,
)
self._ndims = None
if kwargs:
self.mlx = kwargs.get("mlx", False)

View File

@@ -110,6 +110,9 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
valid_embeddings = {
idx: v.embedding for v, idx in zip(rs.data, valid_indices)
}
except openai.AuthenticationError:
logging.error("Authentication failed: Invalid API key provided")
raise
except openai.BadRequestError:
logging.exception("Bad request: %s", texts)
return [None] * len(texts)

View File

@@ -6,6 +6,7 @@ import io
import os
from typing import TYPE_CHECKING, List, Union
import urllib.parse as urlparse
import warnings
import numpy as np
import pyarrow as pa
@@ -24,6 +25,7 @@ if TYPE_CHECKING:
@register("siglip")
class SigLipEmbeddings(EmbeddingFunction):
# Deprecated: prefer CLIP embeddings via `open-clip`.
model_name: str = "google/siglip-base-patch16-224"
device: str = "cpu"
batch_size: int = 64
@@ -36,6 +38,12 @@ class SigLipEmbeddings(EmbeddingFunction):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
warnings.warn(
"SigLip embeddings are deprecated. Use CLIP embeddings via the "
"'open-clip' embedding function instead.",
DeprecationWarning,
stacklevel=3,
)
transformers = attempt_import_or_raise("transformers")
self._torch = attempt_import_or_raise("torch")

View File

@@ -269,6 +269,11 @@ def retry_with_exponential_backoff(
# and say that it is assumed that if this portion errors out, it's due
# to rate limit but the user should check the error message to be sure.
except Exception as e: # noqa: PERF203
# Don't retry on authentication errors (e.g., OpenAI 401)
# These are permanent failures that won't be fixed by retrying
if _is_non_retryable_error(e):
raise
num_retries += 1
if num_retries > max_retries:
@@ -289,6 +294,29 @@ def retry_with_exponential_backoff(
return wrapper
def _is_non_retryable_error(error: Exception) -> bool:
"""Check if an error should not be retried.
Args:
error: The exception to check
Returns:
True if the error should not be retried, False otherwise
"""
# Check for OpenAI authentication errors
error_type = type(error).__name__
if error_type == "AuthenticationError":
return True
# Check for other common non-retryable HTTP status codes
# 401 Unauthorized, 403 Forbidden
if hasattr(error, "status_code"):
if error.status_code in (401, 403):
return True
return False
def url_retrieve(url: str):
"""
Parameters

View File

@@ -44,7 +44,7 @@ from lance_namespace import (
ListNamespacesRequest,
CreateNamespaceRequest,
DropNamespaceRequest,
CreateEmptyTableRequest,
DeclareTableRequest,
)
from lancedb.table import AsyncTable, LanceTable, Table
from lancedb.util import validate_table_name
@@ -318,20 +318,20 @@ class LanceNamespaceDBConnection(DBConnection):
if location is None:
# Table doesn't exist or mode is "create", reserve a new location
create_empty_request = CreateEmptyTableRequest(
declare_request = DeclareTableRequest(
id=table_id,
location=None,
properties=self.storage_options if self.storage_options else None,
)
create_empty_response = self._ns.create_empty_table(create_empty_request)
declare_response = self._ns.declare_table(declare_request)
if not create_empty_response.location:
if not declare_response.location:
raise ValueError(
"Table location is missing from create_empty_table response"
"Table location is missing from declare_table response"
)
location = create_empty_response.location
namespace_storage_options = create_empty_response.storage_options
location = declare_response.location
namespace_storage_options = declare_response.storage_options
# Merge storage options: self.storage_options < user options < namespace options
merged_storage_options = dict(self.storage_options)
@@ -759,20 +759,20 @@ class AsyncLanceNamespaceDBConnection:
if location is None:
# Table doesn't exist or mode is "create", reserve a new location
create_empty_request = CreateEmptyTableRequest(
declare_request = DeclareTableRequest(
id=table_id,
location=None,
properties=self.storage_options if self.storage_options else None,
)
create_empty_response = self._ns.create_empty_table(create_empty_request)
declare_response = self._ns.declare_table(declare_request)
if not create_empty_response.location:
if not declare_response.location:
raise ValueError(
"Table location is missing from create_empty_table response"
"Table location is missing from declare_table response"
)
location = create_empty_response.location
namespace_storage_options = create_empty_response.storage_options
location = declare_response.location
namespace_storage_options = declare_response.storage_options
# Merge storage options: self.storage_options < user options < namespace options
merged_storage_options = dict(self.storage_options)

View File

@@ -1462,6 +1462,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
self._phrase_query = False
self.ordering_field_name = ordering_field_name
self._reranker = None
self._fast_search = None
if isinstance(fts_columns, str):
fts_columns = [fts_columns]
self._fts_columns = fts_columns
@@ -1483,6 +1484,19 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
self._phrase_query = phrase_query
return self
def fast_search(self) -> LanceFtsQueryBuilder:
"""
Skip a flat search of unindexed data. This will improve
search performance but search results will not include unindexed data.
Returns
-------
LanceFtsQueryBuilder
The LanceFtsQueryBuilder object.
"""
self._fast_search = True
return self
def to_query_object(self) -> Query:
return Query(
columns=self._columns,
@@ -1494,6 +1508,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
query=self._query, columns=self._fts_columns
),
offset=self._offset,
fast_search=self._fast_search,
)
def output_schema(self) -> pa.Schema:
@@ -1782,6 +1797,26 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
vector_results = LanceHybridQueryBuilder._rank(vector_results, "_distance")
fts_results = LanceHybridQueryBuilder._rank(fts_results, "_score")
# If both result sets are empty (e.g. after hard filtering),
# return early to avoid errors in reranking or score restoration.
if vector_results.num_rows == 0 and fts_results.num_rows == 0:
# Build a minimal empty table with the _relevance_score column
combined_schema = pa.unify_schemas(
[vector_results.schema, fts_results.schema],
)
empty = pa.table(
{
col: pa.array([], type=combined_schema.field(col).type)
for col in combined_schema.names
}
)
empty = empty.append_column(
"_relevance_score", pa.array([], type=pa.float32())
)
if not with_row_ids and "_rowid" in empty.column_names:
empty = empty.drop(["_rowid"])
return empty
original_distances = None
original_scores = None
original_distance_row_ids = None

View File

@@ -0,0 +1,214 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from dataclasses import dataclass
from functools import singledispatch
import sys
from typing import Callable, Iterator, Optional
from lancedb.arrow import to_arrow
import pyarrow as pa
import pyarrow.dataset as ds
from .pydantic import LanceModel
@dataclass
class Scannable:
schema: pa.Schema
num_rows: Optional[int]
# Factory function to create a new reader each time (supports re-scanning)
reader: Callable[[], pa.RecordBatchReader]
# Whether reader can be called more than once. For example, an iterator can
# only be consumed once, while a DataFrame can be converted to a new reader
# each time.
rescannable: bool = True
@singledispatch
def to_scannable(data) -> Scannable:
# Fallback: try iterable protocol
if hasattr(data, "__iter__"):
return _from_iterable(iter(data))
raise NotImplementedError(f"to_scannable not implemented for type {type(data)}")
@to_scannable.register(pa.RecordBatchReader)
def _from_reader(data: pa.RecordBatchReader) -> Scannable:
# RecordBatchReader can only be consumed once - not rescannable
return Scannable(
schema=data.schema, num_rows=None, reader=lambda: data, rescannable=False
)
@to_scannable.register(pa.RecordBatch)
def _from_batch(data: pa.RecordBatch) -> Scannable:
return Scannable(
schema=data.schema,
num_rows=data.num_rows,
reader=lambda: pa.RecordBatchReader.from_batches(data.schema, [data]),
)
@to_scannable.register(pa.Table)
def _from_table(data: pa.Table) -> Scannable:
return Scannable(schema=data.schema, num_rows=data.num_rows, reader=data.to_reader)
@to_scannable.register(ds.Dataset)
def _from_dataset(data: ds.Dataset) -> Scannable:
return Scannable(
schema=data.schema,
num_rows=data.count_rows(),
reader=lambda: data.scanner().to_reader(),
)
@to_scannable.register(ds.Scanner)
def _from_scanner(data: ds.Scanner) -> Scannable:
# Scanner can only be consumed once - not rescannable
return Scannable(
schema=data.projected_schema,
num_rows=None,
reader=data.to_reader,
rescannable=False,
)
@to_scannable.register(list)
def _from_list(data: list) -> Scannable:
if not data:
raise ValueError("Cannot create table from empty list without a schema")
table = to_arrow(data)
return Scannable(
schema=table.schema, num_rows=table.num_rows, reader=table.to_reader
)
@to_scannable.register(dict)
def _from_dict(data: dict) -> Scannable:
raise ValueError("Cannot add a single dictionary to a table. Use a list.")
@to_scannable.register(LanceModel)
def _from_lance_model(data: LanceModel) -> Scannable:
raise ValueError("Cannot add a single LanceModel to a table. Use a list.")
def _from_iterable(data: Iterator) -> Scannable:
first_item = next(data, None)
if first_item is None:
raise ValueError("Cannot create table from empty iterator")
first = to_arrow(first_item)
schema = first.schema
def iter():
yield from first.to_batches()
for item in data:
batch = to_arrow(item)
if batch.schema != schema:
try:
batch = batch.cast(schema)
except pa.lib.ArrowInvalid:
raise ValueError(
f"Input iterator yielded a batch with schema that "
f"does not match the schema of other batches.\n"
f"Expected:\n{schema}\nGot:\n{batch.schema}"
)
yield from batch.to_batches()
reader = pa.RecordBatchReader.from_batches(schema, iter())
return to_scannable(reader)
_registered_modules: set[str] = set()
def _register_optional_converters():
"""Register converters for optional dependencies that are already imported."""
if "pandas" in sys.modules and "pandas" not in _registered_modules:
_registered_modules.add("pandas")
import pandas as pd
@to_arrow.register(pd.DataFrame)
def _arrow_from_pandas(data: pd.DataFrame) -> pa.Table:
table = pa.Table.from_pandas(data, preserve_index=False)
return table.replace_schema_metadata(None)
@to_scannable.register(pd.DataFrame)
def _from_pandas(data: pd.DataFrame) -> Scannable:
return to_scannable(_arrow_from_pandas(data))
if "polars" in sys.modules and "polars" not in _registered_modules:
_registered_modules.add("polars")
import polars as pl
@to_arrow.register(pl.DataFrame)
def _arrow_from_polars(data: pl.DataFrame) -> pa.Table:
return data.to_arrow()
@to_scannable.register(pl.DataFrame)
def _from_polars(data: pl.DataFrame) -> Scannable:
arrow = data.to_arrow()
return Scannable(
schema=arrow.schema, num_rows=len(data), reader=arrow.to_reader
)
@to_scannable.register(pl.LazyFrame)
def _from_polars_lazy(data: pl.LazyFrame) -> Scannable:
arrow = data.collect().to_arrow()
return Scannable(
schema=arrow.schema, num_rows=arrow.num_rows, reader=arrow.to_reader
)
if "datasets" in sys.modules and "datasets" not in _registered_modules:
_registered_modules.add("datasets")
from datasets import Dataset as HFDataset
from datasets import DatasetDict as HFDatasetDict
@to_scannable.register(HFDataset)
def _from_hf_dataset(data: HFDataset) -> Scannable:
table = data.data.table # Access underlying Arrow table
return Scannable(
schema=table.schema, num_rows=len(data), reader=table.to_reader
)
@to_scannable.register(HFDatasetDict)
def _from_hf_dataset_dict(data: HFDatasetDict) -> Scannable:
# HuggingFace DatasetDict: combine all splits with a 'split' column
schema = data[list(data.keys())[0]].features.arrow_schema
if "split" not in schema.names:
schema = schema.append(pa.field("split", pa.string()))
def gen():
for split_name, dataset in data.items():
for batch in dataset.data.to_batches():
split_arr = pa.array(
[split_name] * len(batch), type=pa.string()
)
yield pa.RecordBatch.from_arrays(
list(batch.columns) + [split_arr], schema=schema
)
total_rows = sum(len(dataset) for dataset in data.values())
return Scannable(
schema=schema,
num_rows=total_rows,
reader=lambda: pa.RecordBatchReader.from_batches(schema, gen()),
)
if "lance" in sys.modules and "lance" not in _registered_modules:
_registered_modules.add("lance")
import lance
@to_scannable.register(lance.LanceDataset)
def _from_lance(data: lance.LanceDataset) -> Scannable:
return Scannable(
schema=data.schema,
num_rows=data.count_rows(),
reader=lambda: data.scanner().to_reader(),
)
# Register on module load
_register_optional_converters()

View File

@@ -25,6 +25,8 @@ from typing import (
)
from urllib.parse import urlparse
from lancedb.scannable import _register_optional_converters, to_scannable
from . import __version__
from lancedb.arrow import peek_reader
from lancedb.background_loop import LOOP
@@ -1329,7 +1331,7 @@ class Table(ABC):
1 2 [3.0, 4.0]
2 3 [5.0, 6.0]
>>> table.delete("x = 2")
DeleteResult(version=2)
DeleteResult(num_deleted_rows=1, version=2)
>>> table.to_pandas()
x vector
0 1 [1.0, 2.0]
@@ -1343,7 +1345,7 @@ class Table(ABC):
>>> to_remove
'1, 5'
>>> table.delete(f"x IN ({to_remove})")
DeleteResult(version=3)
DeleteResult(num_deleted_rows=1, version=3)
>>> table.to_pandas()
x vector
0 3 [5.0, 6.0]
@@ -3727,18 +3729,31 @@ class AsyncTable:
on_bad_vectors = "error"
if fill_value is None:
fill_value = 0.0
data = _sanitize_data(
data,
schema,
metadata=schema.metadata,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
allow_subschema=True,
)
if isinstance(data, pa.Table):
data = data.to_reader()
return await self._inner.add(data, mode or "append")
# _santitize_data is an old code path, but we will use it until the
# new code path is ready.
if on_bad_vectors != "error" or (
schema.metadata is not None and b"embedding_functions" in schema.metadata
):
data = _sanitize_data(
data,
schema,
metadata=schema.metadata,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
allow_subschema=True,
)
_register_optional_converters()
data = to_scannable(data)
try:
return await self._inner.add(data, mode or "append")
except RuntimeError as e:
if "Cast error" in str(e):
raise ValueError(e)
elif "Vector column contains NaN" in str(e):
raise ValueError(e)
else:
raise
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
"""
@@ -4200,7 +4215,7 @@ class AsyncTable:
1 2 [3.0, 4.0]
2 3 [5.0, 6.0]
>>> table.delete("x = 2")
DeleteResult(version=2)
DeleteResult(num_deleted_rows=1, version=2)
>>> table.to_pandas()
x vector
0 1 [1.0, 2.0]
@@ -4214,7 +4229,7 @@ class AsyncTable:
>>> to_remove
'1, 5'
>>> table.delete(f"x IN ({to_remove})")
DeleteResult(version=3)
DeleteResult(num_deleted_rows=1, version=3)
>>> table.to_pandas()
x vector
0 3 [5.0, 6.0]

View File

@@ -515,3 +515,34 @@ def test_openai_propagates_api_key(monkeypatch):
query = "greetings"
actual = table.search(query).limit(1).to_pydantic(Words)[0]
assert len(actual.text) > 0
@patch("time.sleep")
def test_openai_no_retry_on_401(mock_sleep):
"""
Test that OpenAI embedding function does not retry on 401 authentication
errors.
"""
from lancedb.embeddings.utils import retry_with_exponential_backoff
# Create a mock that raises an AuthenticationError
class MockAuthenticationError(Exception):
"""Mock OpenAI AuthenticationError"""
pass
MockAuthenticationError.__name__ = "AuthenticationError"
mock_func = MagicMock(side_effect=MockAuthenticationError("Invalid API key"))
# Wrap the function with retry logic
wrapped_func = retry_with_exponential_backoff(mock_func, max_retries=3)
# Should raise without retrying
with pytest.raises(MockAuthenticationError):
wrapped_func()
# Verify that the function was only called once (no retries)
assert mock_func.call_count == 1
# Verify that sleep was never called (no retries)
assert mock_sleep.call_count == 0

View File

@@ -882,3 +882,105 @@ def test_fts_query_to_json():
'"must_not":[]}}'
)
assert json_str == expected
def test_fts_fast_search(table):
table.create_fts_index("text", use_tantivy=False)
# Insert some unindexed data
table.add(
[
{
"text": "xyz",
"vector": [0 for _ in range(128)],
"id": 101,
"text2": "xyz",
"nested": {"text": "xyz"},
"count": 10,
}
]
)
# Without fast_search, the query object should not have fast_search set
builder = table.search("xyz", query_type="fts").limit(10)
query = builder.to_query_object()
assert query.fast_search is None
# With fast_search, the query object should have fast_search=True
builder = table.search("xyz", query_type="fts").fast_search().limit(10)
query = builder.to_query_object()
assert query.fast_search is True
# fast_search should be chainable with other methods
builder = (
table.search("xyz", query_type="fts").fast_search().select(["text"]).limit(5)
)
query = builder.to_query_object()
assert query.fast_search is True
assert query.limit == 5
assert query.columns == ["text"]
# Verify it executes without error and skips unindexed data
results = table.search("xyz", query_type="fts").fast_search().limit(5).to_list()
assert len(results) == 0
# Update index and verify it returns results
table.optimize()
results = table.search("xyz", query_type="fts").fast_search().limit(5).to_list()
assert len(results) > 0
@pytest.mark.asyncio
async def test_fts_fast_search_async(async_table):
await async_table.create_index("text", config=FTS())
# Insert some unindexed data
await async_table.add(
[
{
"text": "xyz",
"vector": [0 for _ in range(128)],
"id": 101,
"text2": "xyz",
"nested": {"text": "xyz"},
"count": 10,
}
]
)
# Without fast_search, should return results
results = await async_table.query().nearest_to_text("xyz").limit(5).to_list()
assert len(results) > 0
# With fast_search, should return no results data unindexed
fast_results = (
await async_table.query()
.nearest_to_text("xyz")
.fast_search()
.limit(5)
.to_list()
)
assert len(fast_results) == 0
# Update index and verify it returns results
await async_table.optimize()
fast_results = (
await async_table.query()
.nearest_to_text("xyz")
.fast_search()
.limit(5)
.to_list()
)
assert len(fast_results) > 0
# fast_search should be chainable with other methods
results = (
await async_table.query()
.nearest_to_text("xyz")
.fast_search()
.select(["text"])
.limit(5)
.to_list()
)
assert len(results) > 0

View File

@@ -531,6 +531,78 @@ def test_empty_result_reranker():
)
def test_empty_hybrid_result_reranker():
"""Test that hybrid search with empty results after filtering doesn't crash.
Regression test for https://github.com/lancedb/lancedb/issues/2425
"""
from lancedb.query import LanceHybridQueryBuilder
# Simulate empty vector and FTS results with the expected schema
vector_schema = pa.schema(
[
("text", pa.string()),
("vector", pa.list_(pa.float32(), 4)),
("_rowid", pa.uint64()),
("_distance", pa.float32()),
]
)
fts_schema = pa.schema(
[
("text", pa.string()),
("vector", pa.list_(pa.float32(), 4)),
("_rowid", pa.uint64()),
("_score", pa.float32()),
]
)
empty_vector = pa.table(
{
"text": pa.array([], type=pa.string()),
"vector": pa.array([], type=pa.list_(pa.float32(), 4)),
"_rowid": pa.array([], type=pa.uint64()),
"_distance": pa.array([], type=pa.float32()),
},
schema=vector_schema,
)
empty_fts = pa.table(
{
"text": pa.array([], type=pa.string()),
"vector": pa.array([], type=pa.list_(pa.float32(), 4)),
"_rowid": pa.array([], type=pa.uint64()),
"_score": pa.array([], type=pa.float32()),
},
schema=fts_schema,
)
for reranker in [LinearCombinationReranker(), RRFReranker()]:
result = LanceHybridQueryBuilder._combine_hybrid_results(
fts_results=empty_fts,
vector_results=empty_vector,
norm="score",
fts_query="nonexistent query",
reranker=reranker,
limit=10,
with_row_ids=False,
)
assert len(result) == 0
assert "_relevance_score" in result.column_names
assert "_rowid" not in result.column_names
# Also test with with_row_ids=True
result = LanceHybridQueryBuilder._combine_hybrid_results(
fts_results=empty_fts,
vector_results=empty_vector,
norm="score",
fts_query="nonexistent query",
reranker=LinearCombinationReranker(),
limit=10,
with_row_ids=True,
)
assert len(result) == 0
assert "_relevance_score" in result.column_names
assert "_rowid" in result.column_names
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_cross_encoder_reranker_return_all(tmp_path, use_tantivy):
pytest.importorskip("sentence_transformers")

View File

@@ -810,7 +810,7 @@ def test_create_index_name_and_train_parameters(
)
def test_add_with_nans(mem_db: DBConnection):
def test_create_with_nans(mem_db: DBConnection):
# by default we raise an error on bad input vectors
bad_data = [
{"vector": [np.nan], "item": "bar", "price": 20.0},
@@ -854,6 +854,57 @@ def test_add_with_nans(mem_db: DBConnection):
assert np.allclose(v, np.array([0.0, 0.0]))
def test_add_with_nans(mem_db: DBConnection):
schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 2), nullable=True),
pa.field("item", pa.string(), nullable=True),
pa.field("price", pa.float64(), nullable=False),
],
)
table = mem_db.create_table("test", schema=schema)
# by default we raise an error on bad input vectors
bad_data = [
{"vector": [np.nan], "item": "bar", "price": 20.0},
{"vector": [5], "item": "bar", "price": 20.0},
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
{"vector": [np.nan, 5.0], "item": "bar", "price": 20.0},
]
for row in bad_data:
with pytest.raises(ValueError):
table.add(
data=[row],
)
table.add(
[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [2.1, 4.1], "item": "foo", "price": 9.0},
{"vector": [np.nan], "item": "bar", "price": 20.0},
{"vector": [5], "item": "bar", "price": 20.0},
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
],
on_bad_vectors="drop",
)
assert len(table) == 2
table.delete("true")
# We can fill bad input with some value
table.add(
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [np.nan], "item": "bar", "price": 20.0},
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
],
on_bad_vectors="fill",
fill_value=0.0,
)
assert len(table) == 3
arrow_tbl = table.search().where("item == 'bar'").to_arrow()
v = arrow_tbl["vector"].to_pylist()[0]
assert np.allclose(v, np.array([0.0, 0.0]))
def test_restore(mem_db: DBConnection):
table = mem_db.create_table(
"my_table",

View File

@@ -292,18 +292,14 @@ class TestModel(lancedb.pydantic.LanceModel):
lambda: pa.table({"a": [1], "b": [2]}),
lambda: pa.table({"a": [1], "b": [2]}).to_reader(),
lambda: iter(pa.table({"a": [1], "b": [2]}).to_batches()),
lambda: (
lance.write_dataset(
pa.table({"a": [1], "b": [2]}),
"memory://test",
)
),
lambda: (
lance.write_dataset(
pa.table({"a": [1], "b": [2]}),
"memory://test",
).scanner()
lambda: lance.write_dataset(
pa.table({"a": [1], "b": [2]}),
"memory://test",
),
lambda: lance.write_dataset(
pa.table({"a": [1], "b": [2]}),
"memory://test",
).scanner(),
lambda: pd.DataFrame({"a": [1], "b": [2]}),
lambda: pl.DataFrame({"a": [1], "b": [2]}),
lambda: pl.LazyFrame({"a": [1], "b": [2]}),

View File

@@ -23,10 +23,25 @@ use pyo3::{
};
use pyo3_async_runtimes::tokio::future_into_py;
fn table_from_py<'a>(table: Bound<'a, PyAny>) -> PyResult<Bound<'a, Table>> {
if table.hasattr("_inner")? {
Ok(table.getattr("_inner")?.downcast_into::<Table>()?)
} else if table.hasattr("_table")? {
Ok(table
.getattr("_table")?
.getattr("_inner")?
.downcast_into::<Table>()?)
} else {
Err(PyRuntimeError::new_err(
"Provided table does not appear to be a Table or RemoteTable instance",
))
}
}
/// Create a permutation builder for the given table
#[pyo3::pyfunction]
pub fn async_permutation_builder(table: Bound<'_, PyAny>) -> PyResult<PyAsyncPermutationBuilder> {
let table = table.getattr("_inner")?.downcast_into::<Table>()?;
let table = table_from_py(table)?;
let inner_table = table.borrow().inner_ref()?.clone();
let inner_builder = LancePermutationBuilder::new(inner_table);
@@ -250,10 +265,8 @@ impl PyPermutationReader {
permutation_table: Option<Bound<'py, PyAny>>,
split: u64,
) -> PyResult<Bound<'py, PyAny>> {
let base_table = base_table.getattr("_inner")?.downcast_into::<Table>()?;
let permutation_table = permutation_table
.map(|p| PyResult::Ok(p.getattr("_inner")?.downcast_into::<Table>()?))
.transpose()?;
let base_table = table_from_py(base_table)?;
let permutation_table = permutation_table.map(table_from_py).transpose()?;
let base_table = base_table.borrow().inner_ref()?.base_table().clone();
let permutation_table = permutation_table

View File

@@ -7,6 +7,7 @@ use crate::{
error::PythonErrorExt,
index::{extract_index_params, IndexConfig},
query::{Query, TakeQuery},
table::scannable::PyScannable,
};
use arrow::{
datatypes::{DataType, Schema},
@@ -25,6 +26,8 @@ use pyo3::{
};
use pyo3_async_runtimes::tokio::future_into_py;
mod scannable;
/// Statistics about a compaction operation.
#[pyclass(get_all)]
#[derive(Clone, Debug)]
@@ -109,19 +112,24 @@ impl From<lancedb::table::AddResult> for AddResult {
#[pyclass(get_all)]
#[derive(Clone, Debug)]
pub struct DeleteResult {
pub num_deleted_rows: u64,
pub version: u64,
}
#[pymethods]
impl DeleteResult {
pub fn __repr__(&self) -> String {
format!("DeleteResult(version={})", self.version)
format!(
"DeleteResult(num_deleted_rows={}, version={})",
self.num_deleted_rows, self.version
)
}
}
impl From<lancedb::table::DeleteResult> for DeleteResult {
fn from(result: lancedb::table::DeleteResult) -> Self {
Self {
num_deleted_rows: result.num_deleted_rows,
version: result.version,
}
}
@@ -293,12 +301,10 @@ impl Table {
pub fn add<'a>(
self_: PyRef<'a, Self>,
data: Bound<'_, PyAny>,
data: PyScannable,
mode: String,
) -> PyResult<Bound<'a, PyAny>> {
let batches: Box<dyn arrow::array::RecordBatchReader + Send> =
Box::new(ArrowArrayStreamReader::from_pyarrow_bound(&data)?);
let mut op = self_.inner_ref()?.add(batches);
let mut op = self_.inner_ref()?.add(data);
if mode == "append" {
op = op.mode(AddDataMode::Append);
} else if mode == "overwrite" {

View File

@@ -0,0 +1,145 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::Arc;
use arrow::{
datatypes::{Schema, SchemaRef},
ffi_stream::ArrowArrayStreamReader,
pyarrow::{FromPyArrow, PyArrowType},
};
use futures::StreamExt;
use lancedb::{
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
data::scannable::Scannable,
Error,
};
use pyo3::{types::PyAnyMethods, FromPyObject, Py, PyAny, Python};
/// Adapter that implements Scannable for a Python reader factory callable.
///
/// This holds a Python callable that returns a RecordBatchReader when called.
/// For rescannable sources, the callable can be invoked multiple times to
/// get fresh readers.
pub struct PyScannable {
/// Python callable that returns a RecordBatchReader
reader_factory: Py<PyAny>,
schema: SchemaRef,
num_rows: Option<usize>,
rescannable: bool,
}
impl std::fmt::Debug for PyScannable {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PyScannable")
.field("schema", &self.schema)
.field("num_rows", &self.num_rows)
.field("rescannable", &self.rescannable)
.finish()
}
}
impl Scannable for PyScannable {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
fn scan_as_stream(&mut self) -> SendableRecordBatchStream {
let reader: Result<ArrowArrayStreamReader, Error> = {
Python::attach(|py| {
let result =
self.reader_factory
.call0(py)
.map_err(|e| lancedb::Error::Runtime {
message: format!("Python reader factory failed: {}", e),
})?;
ArrowArrayStreamReader::from_pyarrow_bound(result.bind(py)).map_err(|e| {
lancedb::Error::Runtime {
message: format!("Failed to create Arrow reader from Python: {}", e),
}
})
})
};
// Reader is blocking but stream is non-blocking, so we need to spawn a task to pull.
let (tx, rx) = tokio::sync::mpsc::channel(1);
let join_handle = tokio::task::spawn_blocking(move || {
let reader = match reader {
Ok(reader) => reader,
Err(e) => {
let _ = tx.blocking_send(Err(e));
return;
}
};
for batch in reader {
match batch {
Ok(batch) => {
if tx.blocking_send(Ok(batch)).is_err() {
// Receiver dropped, stop processing
break;
}
}
Err(source) => {
let _ = tx.blocking_send(Err(Error::Arrow { source }));
break;
}
}
}
});
let schema = self.schema.clone();
let stream = futures::stream::unfold(
(rx, Some(join_handle)),
|(mut rx, join_handle)| async move {
match rx.recv().await {
Some(Ok(batch)) => Some((Ok(batch), (rx, join_handle))),
Some(Err(e)) => Some((Err(e), (rx, join_handle))),
None => {
// Channel closed. Check if the task panicked — a panic
// drops the sender without sending an error, so without
// this check we'd silently return a truncated stream.
if let Some(handle) = join_handle {
if let Err(join_err) = handle.await {
return Some((
Err(Error::Runtime {
message: format!("Reader task panicked: {}", join_err),
}),
(rx, None),
));
}
}
None
}
}
},
);
Box::pin(SimpleRecordBatchStream::new(stream.fuse(), schema))
}
fn num_rows(&self) -> Option<usize> {
self.num_rows
}
fn rescannable(&self) -> bool {
self.rescannable
}
}
impl<'py> FromPyObject<'py> for PyScannable {
fn extract_bound(ob: &pyo3::Bound<'py, PyAny>) -> pyo3::PyResult<Self> {
// Convert from Scannable dataclass.
let schema: PyArrowType<Schema> = ob.getattr("schema")?.extract()?;
let schema = Arc::new(schema.0);
let num_rows: Option<usize> = ob.getattr("num_rows")?.extract()?;
let rescannable: bool = ob.getattr("rescannable")?.extract()?;
let reader_factory: Py<PyAny> = ob.getattr("reader")?.unbind();
Ok(Self {
schema,
reader_factory,
num_rows,
rescannable,
})
}
}

View File

@@ -1,2 +1,2 @@
[toolchain]
channel = "1.90.0"
channel = "1.91.0"

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb"
version = "0.26.2"
version = "0.27.0-beta.3"
edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true
@@ -25,7 +25,9 @@ datafusion-catalog.workspace = true
datafusion-common.workspace = true
datafusion-execution.workspace = true
datafusion-expr.workspace = true
datafusion-functions.workspace = true
datafusion-physical-expr.workspace = true
datafusion-sql.workspace = true
datafusion-physical-plan.workspace = true
datafusion.workspace = true
object_store = { workspace = true }

View File

@@ -155,9 +155,7 @@ impl IntoArrowStream for SendableRecordBatchStream {
impl IntoArrowStream for datafusion_physical_plan::SendableRecordBatchStream {
fn into_arrow(self) -> Result<SendableRecordBatchStream> {
let schema = self.schema();
let stream = self.map_err(|df_err| Error::Runtime {
message: df_err.to_string(),
});
let stream = self.map_err(|df_err| df_err.into());
Ok(Box::pin(SimpleRecordBatchStream::new(stream, schema)))
}
}

View File

@@ -9,13 +9,6 @@
use std::sync::Arc;
use arrow_array::{RecordBatch, RecordBatchIterator, RecordBatchReader};
use arrow_schema::{ArrowError, SchemaRef};
use async_trait::async_trait;
use futures::stream::once;
use futures::StreamExt;
use lance_datafusion::utils::StreamingWriteSource;
use crate::arrow::{
SendableRecordBatchStream, SendableRecordBatchStreamExt, SimpleRecordBatchStream,
};
@@ -25,6 +18,12 @@ use crate::embeddings::{
};
use crate::table::{ColumnDefinition, ColumnKind, TableDefinition};
use crate::{Error, Result};
use arrow_array::{ArrayRef, RecordBatch, RecordBatchIterator, RecordBatchReader};
use arrow_schema::{ArrowError, SchemaRef};
use async_trait::async_trait;
use futures::stream::once;
use futures::StreamExt;
use lance_datafusion::utils::StreamingWriteSource;
pub trait Scannable: Send {
/// Returns the schema of the data.
@@ -228,6 +227,19 @@ impl WithEmbeddingsScannable {
let table_definition = TableDefinition::new(output_schema, column_definitions);
let output_schema = table_definition.into_rich_schema();
Self::with_schema(inner, embeddings, output_schema)
}
/// Create a WithEmbeddingsScannable with a specific output schema.
///
/// Use this when the table schema is already known (e.g. during add) to
/// avoid nullability mismatches between the embedding function's declared
/// type and the table's stored type.
pub fn with_schema(
inner: Box<dyn Scannable>,
embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
output_schema: SchemaRef,
) -> Result<Self> {
Ok(Self {
inner,
embeddings,
@@ -245,9 +257,11 @@ impl Scannable for WithEmbeddingsScannable {
let inner_stream = self.inner.scan_as_stream();
let embeddings = self.embeddings.clone();
let output_schema = self.output_schema.clone();
let stream_schema = output_schema.clone();
let mapped_stream = inner_stream.then(move |batch_result| {
let embeddings = embeddings.clone();
let output_schema = output_schema.clone();
async move {
let batch = batch_result?;
let result = tokio::task::spawn_blocking(move || {
@@ -257,12 +271,29 @@ impl Scannable for WithEmbeddingsScannable {
.map_err(|e| Error::Runtime {
message: format!("Task panicked during embedding computation: {}", e),
})??;
// Cast columns to match the declared output schema. The data is
// identical but field metadata (e.g. nested nullability) may
// differ between the embedding function output and the table.
let columns: Vec<ArrayRef> = result
.columns()
.iter()
.enumerate()
.map(|(i, col)| {
let target_type = output_schema.field(i).data_type();
if col.data_type() == target_type {
Ok(col.clone())
} else {
arrow_cast::cast(col, target_type).map_err(Error::from)
}
})
.collect::<Result<_>>()?;
let result = RecordBatch::try_new(output_schema, columns)?;
Ok(result)
}
});
Box::pin(SimpleRecordBatchStream {
schema: output_schema,
schema: stream_schema,
stream: mapped_stream,
})
}
@@ -303,8 +334,13 @@ pub fn scannable_with_embeddings(
}
if !embeddings.is_empty() {
return Ok(Box::new(WithEmbeddingsScannable::try_new(
inner, embeddings,
// Use the table's schema so embedding column types (including nested
// nullability) match what's stored, avoiding mismatches with the
// embedding function's declared dest_type.
return Ok(Box::new(WithEmbeddingsScannable::with_schema(
inner,
embeddings,
table_definition.schema.clone(),
)?));
}
}
@@ -312,6 +348,133 @@ pub fn scannable_with_embeddings(
Ok(inner)
}
/// A wrapper that buffers the first RecordBatch from a Scannable so we can
/// inspect it (e.g. to estimate data size) without losing it.
pub(crate) struct PeekedScannable {
inner: Box<dyn Scannable>,
peeked: Option<RecordBatch>,
/// The first item from the stream, if it was an error. Stored so we can
/// re-emit it from `scan_as_stream` instead of silently dropping it.
first_error: Option<crate::Error>,
stream: Option<SendableRecordBatchStream>,
}
impl PeekedScannable {
pub fn new(inner: Box<dyn Scannable>) -> Self {
Self {
inner,
peeked: None,
first_error: None,
stream: None,
}
}
/// Reads and buffers the first batch from the inner scannable.
/// Returns a clone of it. Subsequent calls return the same batch.
///
/// Returns `None` if the stream is empty or the first item is an error.
/// Errors are preserved and re-emitted by `scan_as_stream`.
pub async fn peek(&mut self) -> Option<RecordBatch> {
if self.peeked.is_some() {
return self.peeked.clone();
}
// Already peeked and got an error or empty stream.
if self.stream.is_some() || self.first_error.is_some() {
return None;
}
let mut stream = self.inner.scan_as_stream();
match stream.next().await {
Some(Ok(batch)) => {
self.peeked = Some(batch.clone());
self.stream = Some(stream);
Some(batch)
}
Some(Err(e)) => {
self.first_error = Some(e);
self.stream = Some(stream);
None
}
None => {
self.stream = Some(stream);
None
}
}
}
}
impl Scannable for PeekedScannable {
fn schema(&self) -> SchemaRef {
self.inner.schema()
}
fn num_rows(&self) -> Option<usize> {
self.inner.num_rows()
}
fn rescannable(&self) -> bool {
self.inner.rescannable()
}
fn scan_as_stream(&mut self) -> SendableRecordBatchStream {
let schema = self.inner.schema();
// If peek() hit an error, prepend it so downstream sees the error.
let error_item = self.first_error.take().map(Err);
match (self.peeked.take(), self.stream.take()) {
(Some(batch), Some(rest)) => {
let prepend = futures::stream::once(std::future::ready(Ok(batch)));
Box::pin(SimpleRecordBatchStream {
schema,
stream: prepend.chain(rest),
})
}
(Some(batch), None) => Box::pin(SimpleRecordBatchStream {
schema,
stream: futures::stream::once(std::future::ready(Ok(batch))),
}),
(None, Some(rest)) => {
if let Some(err) = error_item {
let stream = futures::stream::once(std::future::ready(err));
Box::pin(SimpleRecordBatchStream { schema, stream })
} else {
rest
}
}
(None, None) => {
// peek() was never called — just delegate
self.inner.scan_as_stream()
}
}
}
}
/// Compute the number of write partitions based on data size estimates.
///
/// `sample_bytes` and `sample_rows` come from a representative batch and are
/// used to estimate per-row size. `total_rows_hint` is the total row count
/// when known; otherwise `sample_rows` row count is used as a lower bound
/// estimate.
///
/// Targets roughly 1 million rows or 2 GB per partition, capped at
/// `max_partitions` (typically the number of available CPU cores).
pub(crate) fn estimate_write_partitions(
sample_bytes: usize,
sample_rows: usize,
total_rows_hint: Option<usize>,
max_partitions: usize,
) -> usize {
if sample_rows == 0 {
return 1;
}
let bytes_per_row = sample_bytes / sample_rows;
let total_rows = total_rows_hint.unwrap_or(sample_rows);
let total_bytes = total_rows * bytes_per_row;
let by_rows = total_rows.div_ceil(1_000_000);
let by_bytes = total_bytes.div_ceil(2 * 1024 * 1024 * 1024);
by_rows.max(by_bytes).max(1).min(max_partitions)
}
#[cfg(test)]
mod tests {
use super::*;
@@ -408,6 +571,231 @@ mod tests {
assert!(result2.unwrap().is_err());
}
mod peeked_scannable_tests {
use crate::test_utils::TestCustomError;
use super::*;
#[tokio::test]
async fn test_peek_returns_first_batch() {
let batch = record_batch!(("id", Int64, [1, 2, 3])).unwrap();
let mut peeked = PeekedScannable::new(Box::new(batch.clone()));
let first = peeked.peek().await.unwrap();
assert_eq!(first, batch);
}
#[tokio::test]
async fn test_peek_is_idempotent() {
let batch = record_batch!(("id", Int64, [1, 2, 3])).unwrap();
let mut peeked = PeekedScannable::new(Box::new(batch.clone()));
let first = peeked.peek().await.unwrap();
let second = peeked.peek().await.unwrap();
assert_eq!(first, second);
}
#[tokio::test]
async fn test_scan_after_peek_returns_all_data() {
let batches = vec![
record_batch!(("id", Int64, [1, 2])).unwrap(),
record_batch!(("id", Int64, [3, 4, 5])).unwrap(),
];
let mut peeked = PeekedScannable::new(Box::new(batches.clone()));
let first = peeked.peek().await.unwrap();
assert_eq!(first, batches[0]);
let result: Vec<RecordBatch> = peeked.scan_as_stream().try_collect().await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0], batches[0]);
assert_eq!(result[1], batches[1]);
}
#[tokio::test]
async fn test_scan_without_peek_passes_through() {
let batch = record_batch!(("id", Int64, [1, 2, 3])).unwrap();
let mut peeked = PeekedScannable::new(Box::new(batch.clone()));
let result: Vec<RecordBatch> = peeked.scan_as_stream().try_collect().await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0], batch);
}
#[tokio::test]
async fn test_delegates_num_rows() {
let batches = vec![
record_batch!(("id", Int64, [1, 2])).unwrap(),
record_batch!(("id", Int64, [3])).unwrap(),
];
let peeked = PeekedScannable::new(Box::new(batches));
assert_eq!(peeked.num_rows(), Some(3));
}
#[tokio::test]
async fn test_non_rescannable_stream_data_preserved() {
let batches = vec![
record_batch!(("id", Int64, [1, 2])).unwrap(),
record_batch!(("id", Int64, [3])).unwrap(),
];
let schema = batches[0].schema();
let inner = futures::stream::iter(batches.clone().into_iter().map(Ok));
let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream {
schema,
stream: inner,
});
let mut peeked = PeekedScannable::new(Box::new(stream));
assert!(!peeked.rescannable());
assert_eq!(peeked.num_rows(), None);
let first = peeked.peek().await.unwrap();
assert_eq!(first, batches[0]);
// All data is still available via scan_as_stream
let result: Vec<RecordBatch> = peeked.scan_as_stream().try_collect().await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0], batches[0]);
assert_eq!(result[1], batches[1]);
}
#[tokio::test]
async fn test_error_in_first_batch_propagates() {
let schema = Arc::new(arrow_schema::Schema::new(vec![arrow_schema::Field::new(
"id",
arrow_schema::DataType::Int64,
false,
)]));
let inner = futures::stream::iter(vec![Err(Error::External {
source: Box::new(TestCustomError),
})]);
let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream {
schema,
stream: inner,
});
let mut peeked = PeekedScannable::new(Box::new(stream));
// peek returns None for errors
assert!(peeked.peek().await.is_none());
// But the error should come through when scanning
let mut stream = peeked.scan_as_stream();
let first = stream.next().await.unwrap();
assert!(first.is_err());
let err = first.unwrap_err();
assert!(
matches!(&err, Error::External { source } if source.downcast_ref::<TestCustomError>().is_some()),
"Expected TestCustomError to be preserved, got: {err}"
);
}
#[tokio::test]
async fn test_error_in_later_batch_propagates() {
let good_batch = record_batch!(("id", Int64, [1, 2])).unwrap();
let schema = good_batch.schema();
let inner = futures::stream::iter(vec![
Ok(good_batch.clone()),
Err(Error::External {
source: Box::new(TestCustomError),
}),
]);
let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream {
schema,
stream: inner,
});
let mut peeked = PeekedScannable::new(Box::new(stream));
// peek succeeds with the first batch
let first = peeked.peek().await.unwrap();
assert_eq!(first, good_batch);
// scan_as_stream should yield the first batch, then the error
let mut stream = peeked.scan_as_stream();
let batch1 = stream.next().await.unwrap().unwrap();
assert_eq!(batch1, good_batch);
let batch2 = stream.next().await.unwrap();
assert!(batch2.is_err());
let err = batch2.unwrap_err();
assert!(
matches!(&err, Error::External { source } if source.downcast_ref::<TestCustomError>().is_some()),
"Expected TestCustomError to be preserved, got: {err}"
);
}
#[tokio::test]
async fn test_empty_stream_returns_none() {
let schema = Arc::new(arrow_schema::Schema::new(vec![arrow_schema::Field::new(
"id",
arrow_schema::DataType::Int64,
false,
)]));
let inner = futures::stream::empty();
let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream {
schema,
stream: inner,
});
let mut peeked = PeekedScannable::new(Box::new(stream));
assert!(peeked.peek().await.is_none());
// Scanning an empty (post-peek) stream should yield nothing
let result: Vec<RecordBatch> = peeked.scan_as_stream().try_collect().await.unwrap();
assert!(result.is_empty());
}
}
mod estimate_write_partitions_tests {
use super::*;
#[test]
fn test_small_data_single_partition() {
// 100 rows * 24 bytes/row = 2400 bytes — well under both thresholds
assert_eq!(estimate_write_partitions(2400, 100, Some(100), 8), 1);
}
#[test]
fn test_scales_by_row_count() {
// 2.5M rows at 24 bytes/row — row threshold dominates
// ceil(2_500_000 / 1_000_000) = 3
assert_eq!(estimate_write_partitions(72, 3, Some(2_500_000), 8), 3);
}
#[test]
fn test_scales_by_byte_size() {
// 100k rows at 40KB/row = ~4GB total → ceil(4GB / 2GB) = 2
let sample_bytes = 40_000 * 10;
assert_eq!(
estimate_write_partitions(sample_bytes, 10, Some(100_000), 8),
2
);
}
#[test]
fn test_capped_at_max_partitions() {
// 10M rows would want 10 partitions, but capped at 4
assert_eq!(estimate_write_partitions(72, 3, Some(10_000_000), 4), 4);
}
#[test]
fn test_zero_sample_rows_returns_one() {
assert_eq!(estimate_write_partitions(0, 0, Some(1_000_000), 8), 1);
}
#[test]
fn test_no_row_hint_uses_sample_size() {
// Without a hint, uses sample_rows (3), which is small
assert_eq!(estimate_write_partitions(72, 3, None, 8), 1);
}
#[test]
fn test_always_at_least_one() {
assert_eq!(estimate_write_partitions(24, 1, Some(1), 8), 1);
}
}
mod embedding_tests {
use super::*;
use crate::embeddings::MemoryRegistry;

View File

@@ -85,8 +85,10 @@ pub type TableBuilderCallback = Box<dyn FnOnce(OpenTableRequest) -> OpenTableReq
/// Describes what happens when creating a table and a table with
/// the same name already exists
#[derive(Default)]
pub enum CreateTableMode {
/// If the table already exists, an error is returned
#[default]
Create,
/// If the table already exists, it is opened. Any provided data is
/// ignored. The function will be passed an OpenTableBuilder to customize
@@ -104,12 +106,6 @@ impl CreateTableMode {
}
}
impl Default for CreateTableMode {
fn default() -> Self {
Self::Create
}
}
/// A request to create a table
pub struct CreateTableRequest {
/// The name of the new table

View File

@@ -7,6 +7,7 @@ use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use lance_io::object_store::{ObjectStoreParams, StorageOptionsAccessor};
use lance_namespace::{
models::{
CreateEmptyTableRequest, CreateNamespaceRequest, CreateNamespaceResponse,
@@ -212,45 +213,75 @@ impl Database for LanceNamespaceDatabase {
..Default::default()
};
let location = match self.namespace.declare_table(declare_request).await {
Ok(response) => response.location.ok_or_else(|| Error::Runtime {
message: "Table location is missing from declare_table response".to_string(),
})?,
Err(e) => {
// Check if the error is "not supported" and try create_empty_table as fallback
let err_str = e.to_string().to_lowercase();
if err_str.contains("not supported") || err_str.contains("not implemented") {
warn!(
"declare_table is not supported by the namespace client, \
let (location, initial_storage_options) =
match self.namespace.declare_table(declare_request).await {
Ok(response) => {
let loc = response.location.ok_or_else(|| Error::Runtime {
message: "Table location is missing from declare_table response"
.to_string(),
})?;
// Use storage options from response, fall back to self.storage_options
let opts = response
.storage_options
.or_else(|| Some(self.storage_options.clone()))
.filter(|o| !o.is_empty());
(loc, opts)
}
Err(e) => {
// Check if the error is "not supported" and try create_empty_table as fallback
let err_str = e.to_string().to_lowercase();
if err_str.contains("not supported") || err_str.contains("not implemented") {
warn!(
"declare_table is not supported by the namespace client, \
falling back to deprecated create_empty_table. \
create_empty_table is deprecated and will be removed in Lance 3.0.0. \
Please upgrade your namespace client to support declare_table."
);
#[allow(deprecated)]
let create_empty_request = CreateEmptyTableRequest {
id: Some(table_id.clone()),
..Default::default()
};
);
#[allow(deprecated)]
let create_empty_request = CreateEmptyTableRequest {
id: Some(table_id.clone()),
..Default::default()
};
#[allow(deprecated)]
let create_response = self
.namespace
.create_empty_table(create_empty_request)
.await
.map_err(|e| Error::Runtime {
message: format!("Failed to create empty table: {}", e),
#[allow(deprecated)]
let create_response = self
.namespace
.create_empty_table(create_empty_request)
.await
.map_err(|e| Error::Runtime {
message: format!("Failed to create empty table: {}", e),
})?;
let loc = create_response.location.ok_or_else(|| Error::Runtime {
message: "Table location is missing from create_empty_table response"
.to_string(),
})?;
create_response.location.ok_or_else(|| Error::Runtime {
message: "Table location is missing from create_empty_table response"
.to_string(),
})?
} else {
return Err(Error::Runtime {
message: format!("Failed to declare table: {}", e),
});
// For deprecated path, use self.storage_options
let opts = if self.storage_options.is_empty() {
None
} else {
Some(self.storage_options.clone())
};
(loc, opts)
} else {
return Err(Error::Runtime {
message: format!("Failed to declare table: {}", e),
});
}
}
}
};
let write_params = if let Some(storage_opts) = initial_storage_options {
let mut params = request.write_options.lance_write_params.unwrap_or_default();
let store_params = params
.store_params
.get_or_insert_with(ObjectStoreParams::default);
store_params.storage_options_accessor = Some(Arc::new(
StorageOptionsAccessor::with_static_options(storage_opts),
));
Some(params)
} else {
request.write_options.lance_write_params
};
let native_table = NativeTable::create_from_namespace(
@@ -260,7 +291,7 @@ impl Database for LanceNamespaceDatabase {
request.namespace.clone(),
request.data,
None, // write_store_wrapper not used for namespace connections
request.write_options.lance_write_params,
write_params,
self.read_consistency_interval,
self.server_side_query_enabled,
self.session.clone(),

View File

@@ -57,7 +57,7 @@ pub struct PermutationConfig {
}
/// Strategy for shuffling the data.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Default)]
pub enum ShuffleStrategy {
/// The data is randomly shuffled
///
@@ -78,15 +78,10 @@ pub enum ShuffleStrategy {
/// The data is not shuffled
///
/// This is useful for debugging and testing.
#[default]
None,
}
impl Default for ShuffleStrategy {
fn default() -> Self {
Self::None
}
}
/// Builder for creating a permutation table.
///
/// A permutation table is a table that stores split assignments and a shuffled order of rows. This

View File

@@ -426,6 +426,7 @@ impl PermutationReader {
row_ids_query = row_ids_query.limit(limit as usize);
}
let mut row_ids = row_ids_query.execute().await?;
let mut idx_offset = 0;
while let Some(batch) = row_ids.try_next().await? {
let row_ids = batch
.column(0)
@@ -433,8 +434,9 @@ impl PermutationReader {
.values()
.to_vec();
for (i, row_id) in row_ids.iter().enumerate() {
offset_map.insert(i as u64, *row_id);
offset_map.insert(i as u64 + idx_offset, *row_id);
}
idx_offset += batch.num_rows() as u64;
}
let offset_map = Arc::new(offset_map);
*offset_map_ref = Some(offset_map.clone());
@@ -845,4 +847,106 @@ mod tests {
.to_vec();
assert_eq!(idx_values, vec![row_ids[2] as i32]);
}
#[tokio::test]
async fn test_filtered_permutation_full_iteration() {
use crate::dataloader::permutation::builder::PermutationBuilder;
// Create a base table with 10000 rows where idx goes 0..10000.
// Filter to even values only, giving 5000 rows in the permutation.
let base_table = lance_datagen::gen_batch()
.col("idx", lance_datagen::array::step::<Int32Type>())
.into_mem_table("tbl", RowCount::from(10000), BatchCount::from(1))
.await;
let permutation_table = PermutationBuilder::new(base_table.clone())
.with_filter("idx % 2 = 0".to_string())
.build()
.await
.unwrap();
assert_eq!(permutation_table.count_rows(None).await.unwrap(), 5000);
let reader = PermutationReader::try_from_tables(
base_table.base_table().clone(),
permutation_table.base_table().clone(),
0,
)
.await
.unwrap();
assert_eq!(reader.count_rows(), 5000);
// Iterate through all batches using a batch size that doesn't evenly divide
// the row count (5000 / 128 = 39 full batches + 1 batch of 8 rows).
let batch_size = 128;
let mut stream = reader
.read(
Select::All,
QueryExecutionOptions {
max_batch_length: batch_size,
..Default::default()
},
)
.await
.unwrap();
let mut total_rows = 0u64;
let mut all_idx_values = Vec::new();
while let Some(batch) = stream.try_next().await.unwrap() {
assert!(batch.num_rows() <= batch_size as usize);
total_rows += batch.num_rows() as u64;
let idx_col = batch.column(0).as_primitive::<Int32Type>().values();
all_idx_values.extend(idx_col.iter().copied());
}
assert_eq!(total_rows, 5000);
assert_eq!(all_idx_values.len(), 5000);
// Every value should be even (from the filter)
assert!(all_idx_values.iter().all(|v| v % 2 == 0));
// Should have 5000 unique values
let unique: std::collections::HashSet<i32> = all_idx_values.iter().copied().collect();
assert_eq!(unique.len(), 5000);
// Use take_offsets to fetch rows from the beginning, middle, and end
// of the permutation. The values should match what we saw during iteration.
// Beginning
let batch = reader.take_offsets(&[0, 1, 2], Select::All).await.unwrap();
assert_eq!(batch.num_rows(), 3);
let idx_values = batch
.column(0)
.as_primitive::<Int32Type>()
.values()
.to_vec();
assert_eq!(idx_values, &all_idx_values[0..3]);
// Middle
let batch = reader
.take_offsets(&[2499, 2500, 2501], Select::All)
.await
.unwrap();
assert_eq!(batch.num_rows(), 3);
let idx_values = batch
.column(0)
.as_primitive::<Int32Type>()
.values()
.to_vec();
assert_eq!(idx_values, &all_idx_values[2499..2502]);
// End (last 3 rows)
let batch = reader
.take_offsets(&[4997, 4998, 4999], Select::All)
.await
.unwrap();
assert_eq!(batch.num_rows(), 3);
let idx_values = batch
.column(0)
.as_primitive::<Int32Type>()
.values()
.to_vec();
assert_eq!(idx_values, &all_idx_values[4997..5000]);
}
}

View File

@@ -27,9 +27,10 @@ use crate::{
pub const SPLIT_ID_COLUMN: &str = "split_id";
/// Strategy for assigning rows to splits
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Default)]
pub enum SplitStrategy {
/// All rows will have split id 0
#[default]
NoSplit,
/// Rows will be randomly assigned to splits
///
@@ -73,15 +74,6 @@ pub enum SplitStrategy {
Calculated { calculation: String },
}
// The default is not to split the data
//
// All data will be assigned to a single split.
impl Default for SplitStrategy {
fn default() -> Self {
Self::NoSplit
}
}
impl SplitStrategy {
pub fn validate(&self, num_rows: u64) -> Result<()> {
match self {

View File

@@ -4,6 +4,7 @@
use std::sync::PoisonError;
use arrow_schema::ArrowError;
use datafusion_common::DataFusionError;
use snafu::Snafu;
pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync>;
@@ -96,28 +97,74 @@ pub type Result<T> = std::result::Result<T, Error>;
impl From<ArrowError> for Error {
fn from(source: ArrowError) -> Self {
match source {
ArrowError::ExternalError(source) => match source.downcast::<Self>() {
Ok(e) => *e,
Err(source) => Self::External { source },
},
ArrowError::ExternalError(source) => Self::from_box_error(source),
_ => Self::Arrow { source },
}
}
}
impl From<DataFusionError> for Error {
fn from(source: DataFusionError) -> Self {
match source {
DataFusionError::ArrowError(source, _) => (*source).into(),
DataFusionError::External(source) => Self::from_box_error(source),
other => Self::External {
source: Box::new(other),
},
}
}
}
impl From<lance::Error> for Error {
fn from(source: lance::Error) -> Self {
// Try to unwrap external errors that were wrapped by lance
match source {
lance::Error::Wrapped { error, .. } => match error.downcast::<Self>() {
Ok(e) => *e,
Err(source) => Self::External { source },
},
lance::Error::Wrapped { error, .. } => Self::from_box_error(error),
lance::Error::External { source } => Self::from_box_error(source),
_ => Self::Lance { source },
}
}
}
impl Error {
fn from_box_error(mut source: Box<dyn std::error::Error + Send + Sync>) -> Self {
source = match source.downcast::<Self>() {
Ok(e) => match *e {
Self::External { source } => return Self::from_box_error(source),
other => return other,
},
Err(source) => source,
};
source = match source.downcast::<lance::Error>() {
Ok(e) => match *e {
lance::Error::Wrapped { error, .. } => return Self::from_box_error(error),
other => return other.into(),
},
Err(source) => source,
};
source = match source.downcast::<ArrowError>() {
Ok(e) => match *e {
ArrowError::ExternalError(source) => return Self::from_box_error(source),
other => return other.into(),
},
Err(source) => source,
};
source = match source.downcast::<DataFusionError>() {
Ok(e) => match *e {
DataFusionError::ArrowError(source, _) => return (*source).into(),
DataFusionError::External(source) => return Self::from_box_error(source),
other => return other.into(),
},
Err(source) => source,
};
Self::External { source }
}
}
impl From<object_store::Error> for Error {
fn from(source: object_store::Error) -> Self {
Self::ObjectStore { source }

131
rust/lancedb/src/expr.rs Normal file
View File

@@ -0,0 +1,131 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! Expression builder API for type-safe query construction
//!
//! This module provides a fluent API for building expressions that can be used
//! in filters and projections. It wraps DataFusion's expression system.
//!
//! # Examples
//!
//! ```rust
//! use std::ops::Mul;
//! use lancedb::expr::{col, lit};
//!
//! let expr = col("age").gt(lit(18));
//! let expr = col("age").gt(lit(18)).and(col("status").eq(lit("active")));
//! let expr = col("price") * lit(1.1);
//! ```
mod sql;
pub use sql::expr_to_sql_string;
use std::sync::Arc;
use arrow_schema::DataType;
use datafusion_expr::{expr_fn::cast, Expr, ScalarUDF};
use datafusion_functions::string::expr_fn as string_expr_fn;
pub use datafusion_expr::{col, lit};
pub use datafusion_expr::Expr as DfExpr;
pub fn lower(expr: Expr) -> Expr {
string_expr_fn::lower(expr)
}
pub fn upper(expr: Expr) -> Expr {
string_expr_fn::upper(expr)
}
pub fn contains(expr: Expr, search: Expr) -> Expr {
string_expr_fn::contains(expr, search)
}
pub fn expr_cast(expr: Expr, data_type: DataType) -> Expr {
cast(expr, data_type)
}
lazy_static::lazy_static! {
static ref FUNC_REGISTRY: std::sync::RwLock<std::collections::HashMap<String, Arc<ScalarUDF>>> = {
let mut m = std::collections::HashMap::new();
m.insert("lower".to_string(), datafusion_functions::string::lower());
m.insert("upper".to_string(), datafusion_functions::string::upper());
m.insert("contains".to_string(), datafusion_functions::string::contains());
m.insert("btrim".to_string(), datafusion_functions::string::btrim());
m.insert("ltrim".to_string(), datafusion_functions::string::ltrim());
m.insert("rtrim".to_string(), datafusion_functions::string::rtrim());
m.insert("concat".to_string(), datafusion_functions::string::concat());
m.insert("octet_length".to_string(), datafusion_functions::string::octet_length());
std::sync::RwLock::new(m)
};
}
pub fn func(name: impl AsRef<str>, args: Vec<Expr>) -> crate::Result<Expr> {
let name = name.as_ref();
let registry = FUNC_REGISTRY
.read()
.map_err(|e| crate::Error::InvalidInput {
message: format!("lock poisoned: {}", e),
})?;
let udf = registry
.get(name)
.ok_or_else(|| crate::Error::InvalidInput {
message: format!("unknown function: {}", name),
})?;
Ok(Expr::ScalarFunction(
datafusion_expr::expr::ScalarFunction::new_udf(udf.clone(), args),
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_col_lit_comparisons() {
let expr = col("age").gt(lit(18));
let sql = expr_to_sql_string(&expr).unwrap();
assert!(sql.contains("age") && sql.contains("18"));
let expr = col("name").eq(lit("Alice"));
let sql = expr_to_sql_string(&expr).unwrap();
assert!(sql.contains("name") && sql.contains("Alice"));
}
#[test]
fn test_compound_expression() {
let expr = col("age").gt(lit(18)).and(col("status").eq(lit("active")));
let sql = expr_to_sql_string(&expr).unwrap();
assert!(sql.contains("age") && sql.contains("status"));
}
#[test]
fn test_string_functions() {
let expr = lower(col("name"));
let sql = expr_to_sql_string(&expr).unwrap();
assert!(sql.to_lowercase().contains("lower"));
let expr = contains(col("text"), lit("search"));
let sql = expr_to_sql_string(&expr).unwrap();
assert!(sql.to_lowercase().contains("contains"));
}
#[test]
fn test_func() {
let expr = func("lower", vec![col("x")]).unwrap();
let sql = expr_to_sql_string(&expr).unwrap();
assert!(sql.to_lowercase().contains("lower"));
let result = func("unknown_func", vec![col("x")]);
assert!(result.is_err());
}
#[test]
fn test_arithmetic() {
let expr = col("price") * lit(1.1);
let sql = expr_to_sql_string(&expr).unwrap();
assert!(sql.contains("price"));
}
}

View File

@@ -0,0 +1,12 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use datafusion_expr::Expr;
use datafusion_sql::unparser;
pub fn expr_to_sql_string(expr: &Expr) -> crate::Result<String> {
let ast = unparser::expr_to_sql(expr).map_err(|e| crate::Error::InvalidInput {
message: format!("failed to serialize expression to SQL: {}", e),
})?;
Ok(ast.to_string())
}

View File

@@ -195,6 +195,11 @@ mod test {
table::WriteOptions,
};
// This test is ignored because lance 3.0 introduced LocalWriter optimization
// that bypasses the object store wrapper for local writes. The mirroring feature
// still works for remote/cloud storage, but can't be tested with local storage.
// See lance commit c878af433 "perf: create local writer for efficient local writes"
#[ignore]
#[tokio::test]
async fn test_e2e() {
let dir1 = tempfile::tempdir().unwrap().keep().canonicalize().unwrap();
@@ -250,32 +255,38 @@ mod test {
let primary_location = dir1.join("test.lance").canonicalize().unwrap();
let secondary_location = dir2.join(primary_location.strip_prefix("/").unwrap());
let mut primary_iter = WalkDir::new(&primary_location).into_iter();
let mut secondary_iter = WalkDir::new(&secondary_location).into_iter();
// Skip lance internal directories (_versions, _transactions) and manifest files
let should_skip = |path: &std::path::Path| -> bool {
let path_str = path.to_str().unwrap();
path_str.contains("_latest.manifest")
|| path_str.contains("_versions")
|| path_str.contains("_transactions")
};
let mut primary_elem = primary_iter.next();
let mut secondary_elem = secondary_iter.next();
let primary_files: Vec<_> = WalkDir::new(&primary_location)
.into_iter()
.filter_entry(|e| !should_skip(e.path()))
.filter_map(|e| e.ok())
.map(|e| {
e.path()
.strip_prefix(&primary_location)
.unwrap()
.to_path_buf()
})
.collect();
loop {
if primary_elem.is_none() && secondary_elem.is_none() {
break;
}
// primary has more data then secondary, should not run out before secondary
let primary_f = primary_elem.unwrap().unwrap();
// hit manifest, skip, _versions contains all the manifest and should not exist on secondary
let primary_raw_path = primary_f.file_name().to_str().unwrap();
if primary_raw_path.contains("_latest.manifest") {
primary_elem = primary_iter.next();
continue;
}
let secondary_f = secondary_elem.unwrap().unwrap();
assert_eq!(
primary_f.path().strip_prefix(&primary_location),
secondary_f.path().strip_prefix(&secondary_location)
);
let secondary_files: Vec<_> = WalkDir::new(&secondary_location)
.into_iter()
.filter_entry(|e| !should_skip(e.path()))
.filter_map(|e| e.ok())
.map(|e| {
e.path()
.strip_prefix(&secondary_location)
.unwrap()
.to_path_buf()
})
.collect();
primary_elem = primary_iter.next();
secondary_elem = secondary_iter.next();
}
assert_eq!(primary_files, secondary_files, "File lists should match");
}
}

View File

@@ -169,6 +169,7 @@ pub mod database;
pub mod dataloader;
pub mod embeddings;
pub mod error;
pub mod expr;
pub mod index;
pub mod io;
pub mod ipc;
@@ -192,13 +193,14 @@ pub use error::{Error, Result};
use lance_linalg::distance::DistanceType as LanceDistanceType;
pub use table::Table;
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize, Default)]
#[non_exhaustive]
#[serde(rename_all = "lowercase")]
pub enum DistanceType {
/// Euclidean distance. This is a very common distance metric that
/// accounts for both magnitude and direction when determining the distance
/// between vectors. l2 distance has a range of [0, ∞).
#[default]
L2,
/// Cosine distance. Cosine distance is a distance metric
/// calculated from the cosine similarity between two vectors. Cosine
@@ -220,12 +222,6 @@ pub enum DistanceType {
Hamming,
}
impl Default for DistanceType {
fn default() -> Self {
Self::L2
}
}
impl From<DistanceType> for LanceDistanceType {
fn from(value: DistanceType) -> Self {
match value {

View File

@@ -359,6 +359,28 @@ pub trait QueryBase {
/// on the filter column(s).
fn only_if(self, filter: impl AsRef<str>) -> Self;
/// Only return rows which match the filter, using an expression builder.
///
/// Use [`crate::expr`] for building type-safe expressions:
///
/// ```
/// use lancedb::expr::{col, lit};
/// use lancedb::query::{QueryBase, ExecutableQuery};
///
/// # use lancedb::Table;
/// # async fn query(table: &Table) -> Result<(), Box<dyn std::error::Error>> {
/// let results = table.query()
/// .only_if_expr(col("age").gt(lit(18)).and(col("status").eq(lit("active"))))
/// .execute()
/// .await?;
/// # Ok(())
/// # }
/// ```
///
/// Note: Expression filters are not supported for remote/server-side queries.
/// Use [`QueryBase::only_if`] with SQL strings for remote tables.
fn only_if_expr(self, filter: datafusion_expr::Expr) -> Self;
/// Perform a full text search on the table.
///
/// The results will be returned in order of BM25 scores.
@@ -468,6 +490,11 @@ impl<T: HasQuery> QueryBase for T {
self
}
fn only_if_expr(mut self, filter: datafusion_expr::Expr) -> Self {
self.mut_query().filter = Some(QueryFilter::Datafusion(filter));
self
}
fn full_text_search(mut self, query: FullTextSearchQuery) -> Self {
if self.mut_query().limit.is_none() {
self.mut_query().limit = Some(DEFAULT_TOP_K);

View File

@@ -724,12 +724,58 @@ pub mod test_utils {
}
}
/// Consume a reqwest body into bytes, returning an error if the body
/// stream fails. This is used by MockSender to materialize streaming
/// bodies so that data pipeline errors (e.g. NaN rejection) are triggered
/// during mock sends just as they would be during a real HTTP upload.
pub async fn try_collect_body(body: reqwest::Body) -> std::result::Result<Vec<u8>, String> {
use http_body::Body;
use std::pin::Pin;
let mut body = body;
let mut data = Vec::new();
let mut body_pin = Pin::new(&mut body);
while let Some(frame) = futures::StreamExt::next(&mut futures::stream::poll_fn(|cx| {
body_pin.as_mut().poll_frame(cx)
}))
.await
{
match frame {
Ok(frame) => {
if let Some(bytes) = frame.data_ref() {
data.extend_from_slice(bytes);
}
}
Err(e) => return Err(e.to_string()),
}
}
Ok(data)
}
impl HttpSend for MockSender {
async fn send(
&self,
_client: &reqwest::Client,
request: reqwest::Request,
mut request: reqwest::Request,
) -> reqwest::Result<reqwest::Response> {
// Consume any streaming body to materialize it into bytes.
// This triggers data pipeline errors (e.g. NaN rejection) that
// would otherwise only fire when a real HTTP client reads the body.
if let Some(body) = request.body_mut().take() {
match try_collect_body(body).await {
Ok(bytes) => {
*request.body_mut() = Some(reqwest::Body::from(bytes));
}
Err(msg) => {
// Simulate a failed request by returning a 500 response.
return Ok(http::Response::builder()
.status(500)
.body(msg)
.unwrap()
.into());
}
}
}
let response = (self.f)(request);
Ok(response)
}

View File

@@ -60,6 +60,34 @@ impl<'a> RetryCounter<'a> {
self.check_out_of_retries(Box::new(source), status_code)
}
/// Increment the appropriate failure counter based on the error type.
///
/// For `Error::Http` whose source is a connect error, increments
/// `connect_failures`. For read errors (`is_body` or `is_decode`),
/// increments `read_failures`. For all other errors, increments
/// `request_failures`. Calls `check_out_of_retries` to enforce global limits.
pub fn increment_from_error(&mut self, source: crate::Error) -> crate::Result<()> {
let reqwest_err = match &source {
crate::Error::Http { source, .. } => source.downcast_ref::<reqwest::Error>(),
_ => None,
};
if reqwest_err.is_some_and(|e| e.is_connect()) {
self.connect_failures += 1;
} else if reqwest_err.is_some_and(|e| e.is_body() || e.is_decode()) {
self.read_failures += 1;
} else {
self.request_failures += 1;
}
let status_code = if let crate::Error::Http { status_code, .. } = &source {
*status_code
} else {
None
};
self.check_out_of_retries(Box::new(source), status_code)
}
pub fn increment_connect_failures(&mut self, source: reqwest::Error) -> crate::Result<()> {
self.connect_failures += 1;
let status_code = source.status();
@@ -77,7 +105,7 @@ impl<'a> RetryCounter<'a> {
let jitter = rand::random::<f32>() * self.config.backoff_jitter;
let sleep_time = Duration::from_secs_f32(backoff + jitter);
debug!(
"Retrying request {:?} ({}/{} connect, {}/{} read, {}/{} read) in {:?}",
"Retrying request {:?} ({}/{} connect, {}/{} request, {}/{} read) in {:?}",
self.request_id,
self.connect_failures,
self.config.connect_retries,
@@ -91,6 +119,115 @@ impl<'a> RetryCounter<'a> {
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config() -> ResolvedRetryConfig {
ResolvedRetryConfig {
retries: 3,
connect_retries: 2,
read_retries: 3,
backoff_factor: 0.0,
backoff_jitter: 0.0,
statuses: vec![reqwest::StatusCode::BAD_GATEWAY],
}
}
/// Get a real reqwest connect error by trying to connect to a refused port.
async fn make_connect_error() -> reqwest::Error {
// Port 1 is almost always refused/unavailable.
reqwest::Client::new()
.get("http://127.0.0.1:1")
.send()
.await
.unwrap_err()
}
#[tokio::test]
async fn test_increment_from_error_connect() {
let config = test_config();
let mut counter = RetryCounter::new(&config, "test".to_string());
let connect_err = make_connect_error().await;
assert!(connect_err.is_connect());
let http_err = crate::Error::Http {
source: Box::new(connect_err),
request_id: "test".to_string(),
status_code: None,
};
// First connect failure: should be ok (1 < 2)
counter.increment_from_error(http_err).unwrap();
assert_eq!(counter.connect_failures, 1);
assert_eq!(counter.request_failures, 0);
// Second connect failure: should hit the limit (2 >= 2)
let connect_err2 = make_connect_error().await;
let http_err2 = crate::Error::Http {
source: Box::new(connect_err2),
request_id: "test".to_string(),
status_code: None,
};
let result = counter.increment_from_error(http_err2);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
crate::Error::Retry {
connect_failures: 2,
max_connect_failures: 2,
..
}
));
}
#[test]
fn test_increment_from_error_request() {
let config = test_config();
let mut counter = RetryCounter::new(&config, "test".to_string());
let http_err = crate::Error::Http {
source: "bad gateway".into(),
request_id: "test".to_string(),
status_code: Some(reqwest::StatusCode::BAD_GATEWAY),
};
counter.increment_from_error(http_err).unwrap();
assert_eq!(counter.request_failures, 1);
assert_eq!(counter.connect_failures, 0);
}
#[tokio::test]
async fn test_increment_from_error_respects_global_limits() {
// If request_failures is already at max, a connect error should still
// trigger the global limit check.
let config = test_config();
let mut counter = RetryCounter::new(&config, "test".to_string());
counter.request_failures = 3; // at max
let connect_err = make_connect_error().await;
let http_err = crate::Error::Http {
source: Box::new(connect_err),
request_id: "test".to_string(),
status_code: None,
};
// Even though connect_failures would be 1 (under limit of 2),
// request_failures is already at 3 (>= limit of 3), so this should fail.
let result = counter.increment_from_error(http_err);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
crate::Error::Retry {
request_failures: 3,
connect_failures: 1,
..
}
));
}
}
#[derive(Debug, Clone)]
pub struct ResolvedRetryConfig {
pub retries: u8,

File diff suppressed because it is too large Load Diff

View File

@@ -8,7 +8,6 @@ use std::sync::{Arc, Mutex};
use arrow_array::{ArrayRef, RecordBatch, UInt64Array};
use arrow_ipc::CompressionType;
use arrow_schema::ArrowError;
use datafusion_common::{DataFusionError, Result as DataFusionResult};
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_physical_expr::EquivalenceProperties;
@@ -76,7 +75,15 @@ impl<S: HttpSend + 'static> RemoteInsertExec<S> {
self.add_result.lock().unwrap().clone()
}
fn stream_as_body(data: SendableRecordBatchStream) -> DataFusionResult<reqwest::Body> {
/// Stream the input into an HTTP body as an Arrow IPC stream, capturing any
/// stream errors into the provided channel. Errors from the input plan
/// (e.g. NaN rejection) would otherwise be swallowed inside the HTTP body
/// upload; by stashing them in the channel we can surface them with their
/// original message after the request completes.
fn stream_as_http_body(
data: SendableRecordBatchStream,
error_tx: tokio::sync::oneshot::Sender<DataFusionError>,
) -> DataFusionResult<reqwest::Body> {
let options = arrow_ipc::writer::IpcWriteOptions::default()
.try_with_compression(Some(CompressionType::LZ4_FRAME))?;
let writer = arrow_ipc::writer::StreamWriter::try_new_with_options(
@@ -85,26 +92,44 @@ impl<S: HttpSend + 'static> RemoteInsertExec<S> {
options,
)?;
let stream = futures::stream::try_unfold((data, writer), move |(mut data, mut writer)| {
async move {
let stream = futures::stream::try_unfold(
(data, writer, Some(error_tx), false),
move |(mut data, mut writer, error_tx, finished)| async move {
if finished {
return Ok(None);
}
match data.next().await {
Some(Ok(batch)) => {
writer.write(&batch)?;
writer
.write(&batch)
.map_err(|e| std::io::Error::other(e.to_string()))?;
let buffer = std::mem::take(writer.get_mut());
Ok(Some((buffer, (data, writer))))
Ok(Some((buffer, (data, writer, error_tx, false))))
}
Some(Err(e)) => {
// Send the original error through the channel before
// returning a generic error to reqwest.
if let Some(tx) = error_tx {
let _ = tx.send(e);
}
Err(std::io::Error::other(
"input stream error (see error channel)",
))
}
Some(Err(e)) => Err(e),
None => {
if let Err(ArrowError::IpcError(_msg)) = writer.finish() {
// Will error if already closed.
return Ok(None);
};
writer
.finish()
.map_err(|e| std::io::Error::other(e.to_string()))?;
let buffer = std::mem::take(writer.get_mut());
Ok(Some((buffer, (data, writer))))
if buffer.is_empty() {
Ok(None)
} else {
Ok(Some((buffer, (data, writer, None, true))))
}
}
}
}
});
},
);
Ok(reqwest::Body::wrap_stream(stream))
}
@@ -202,24 +227,41 @@ impl<S: HttpSend + 'static> ExecutionPlan for RemoteInsertExec<S> {
request = request.query(&[("mode", "overwrite")]);
}
let body = Self::stream_as_body(input_stream)?;
let (error_tx, mut error_rx) = tokio::sync::oneshot::channel();
let body = Self::stream_as_http_body(input_stream, error_tx)?;
let request = request.body(body);
let (request_id, response) = client
.send(request)
.await
.map_err(|e| DataFusionError::External(Box::new(e)))?;
let response =
RemoteTable::<Sender>::handle_table_not_found(&table_name, response, &request_id)
let result: DataFusionResult<(String, _)> = async {
let (request_id, response) = client
.send(request)
.await
.map_err(|e| DataFusionError::External(Box::new(e)))?;
let response = client
.check_response(&request_id, response)
let response = RemoteTable::<Sender>::handle_table_not_found(
&table_name,
response,
&request_id,
)
.await
.map_err(|e| DataFusionError::External(Box::new(e)))?;
let response = client
.check_response(&request_id, response)
.await
.map_err(|e| DataFusionError::External(Box::new(e)))?;
Ok((request_id, response))
}
.await;
// If the request failed due to an input stream error, surface the
// original error (e.g. NaN rejection) instead of the HTTP error.
if let Ok(stream_err) = error_rx.try_recv() {
return Err(stream_err);
}
let (request_id, response) = result?;
let body_text = response.text().await.map_err(|e| {
DataFusionError::External(Box::new(Error::Http {
source: Box::new(e),

File diff suppressed because it is too large Load Diff

View File

@@ -3,13 +3,19 @@
use std::sync::Arc;
use arrow_schema::{DataType, Fields, Schema};
use lance::dataset::WriteMode;
use serde::{Deserialize, Serialize};
use crate::data::scannable::scannable_with_embeddings;
use crate::data::scannable::Scannable;
use crate::embeddings::EmbeddingRegistry;
use crate::Result;
use crate::table::datafusion::cast::cast_to_table_schema;
use crate::table::datafusion::reject_nan::reject_nan_vectors;
use crate::table::datafusion::scannable_exec::ScannableExec;
use crate::{Error, Result};
use super::{BaseTable, WriteOptions};
use super::{BaseTable, TableDefinition, WriteOptions};
#[derive(Debug, Clone, Default)]
pub enum AddDataMode {
@@ -29,12 +35,22 @@ pub struct AddResult {
pub version: u64,
}
#[derive(Debug, Default, Clone, Copy)]
pub enum NaNVectorBehavior {
/// Reject any vectors containing NaN values (the default)
#[default]
Error,
/// Allow NaN values to be added, but they will not be indexed for search
Keep,
}
/// A builder for configuring a [`crate::table::Table::add`] operation
pub struct AddDataBuilder {
pub(crate) parent: Arc<dyn BaseTable>,
pub(crate) data: Box<dyn Scannable>,
pub(crate) mode: AddDataMode,
pub(crate) write_options: WriteOptions,
pub(crate) on_nan_vectors: NaNVectorBehavior,
pub(crate) embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
}
@@ -59,6 +75,7 @@ impl AddDataBuilder {
data,
mode: AddDataMode::Append,
write_options: WriteOptions::default(),
on_nan_vectors: NaNVectorBehavior::default(),
embedding_registry,
}
}
@@ -73,16 +90,123 @@ impl AddDataBuilder {
self
}
/// Configure how to handle NaN values in vector columns.
///
/// By default, any vectors containing NaN values will be rejected with an
/// error, since NaNs cannot be indexed for search. Setting this to `Keep`
/// will allow NaN values to be added to the table, but they will not be
/// indexed and will not be searchable.
pub fn on_nan_vectors(mut self, behavior: NaNVectorBehavior) -> Self {
self.on_nan_vectors = behavior;
self
}
pub async fn execute(self) -> Result<AddResult> {
self.parent.clone().add(self).await
}
/// Build a DataFusion execution plan that applies embeddings, casts data to
/// the table schema, and optionally rejects NaN vectors.
///
/// Returns the plan along with whether the input is rescannable (for retry
/// decisions) and whether this is an overwrite operation.
pub(crate) fn into_plan(
mut self,
table_schema: &Schema,
table_def: &TableDefinition,
) -> Result<PreprocessingOutput> {
let overwrite = self
.write_options
.lance_write_params
.as_ref()
.is_some_and(|p| matches!(p.mode, WriteMode::Overwrite))
|| matches!(self.mode, AddDataMode::Overwrite);
if !overwrite {
validate_schema(&self.data.schema(), table_schema)?;
}
self.data =
scannable_with_embeddings(self.data, table_def, self.embedding_registry.as_ref())?;
let rescannable = self.data.rescannable();
let plan: Arc<dyn datafusion_physical_plan::ExecutionPlan> =
Arc::new(ScannableExec::new(self.data));
// Skip casting when overwriting — the input schema replaces the table schema.
let plan = if overwrite {
plan
} else {
cast_to_table_schema(plan, table_schema)?
};
let plan = match self.on_nan_vectors {
NaNVectorBehavior::Error => reject_nan_vectors(plan)?,
NaNVectorBehavior::Keep => plan,
};
Ok(PreprocessingOutput {
plan,
overwrite,
rescannable,
write_options: self.write_options,
mode: self.mode,
})
}
}
pub struct PreprocessingOutput {
pub plan: Arc<dyn datafusion_physical_plan::ExecutionPlan>,
#[cfg_attr(not(feature = "remote"), allow(dead_code))]
pub overwrite: bool,
#[cfg_attr(not(feature = "remote"), allow(dead_code))]
pub rescannable: bool,
pub write_options: WriteOptions,
pub mode: AddDataMode,
}
/// Check that the input schema is valid for insert.
///
/// Fields can be in different orders, so match by name.
///
/// If a column exists in input but not in table, error (no extra columns allowed).
///
/// If a column exists in table but not in input, that is okay - it may be filled with nulls.
///
/// If the types are not exactly the same, we will attempt to cast later - so that is also okay at this stage.
///
/// If the nullability is different, that is also okay - we can relax nullability when casting.
fn validate_schema(input: &Schema, table: &Schema) -> Result<()> {
validate_fields(input.fields(), table.fields())
}
fn validate_fields(input: &Fields, table: &Fields) -> Result<()> {
for field in input {
match table.iter().find(|f| f.name() == field.name()) {
None => {
return Err(Error::InvalidInput {
message: format!("field '{}' does not exist in table schema", field.name()),
});
}
Some(table_field) => {
if let (DataType::Struct(in_children), DataType::Struct(tbl_children)) =
(field.data_type(), table_field.data_type())
{
validate_fields(in_children, tbl_children)?;
}
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow_array::{record_batch, RecordBatch, RecordBatchIterator};
use arrow::datatypes::Float64Type;
use arrow_array::{
record_batch, FixedSizeListArray, Float32Array, Int32Array, LargeStringArray, ListArray,
RecordBatch, RecordBatchIterator,
};
use arrow_schema::{ArrowError, DataType, Field, Schema};
use futures::TryStreamExt;
use lance::dataset::{WriteMode, WriteParams};
@@ -94,8 +218,10 @@ mod tests {
EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry, MemoryRegistry,
};
use crate::query::{ExecutableQuery, QueryBase, Select};
use crate::table::add_data::NaNVectorBehavior;
use crate::table::{ColumnDefinition, ColumnKind, Table, TableDefinition, WriteOptions};
use crate::test_utils::embeddings::MockEmbed;
use crate::test_utils::TestCustomError;
use crate::Error;
use super::AddDataMode;
@@ -160,17 +286,20 @@ mod tests {
test_add_with_data(stream).await;
}
#[derive(Debug)]
struct MyError;
impl std::fmt::Display for MyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "MyError occurred")
}
fn assert_preserves_external_error(err: &Error) {
assert!(
matches!(err, Error::External { source } if source.downcast_ref::<TestCustomError>().is_some()),
"Expected Error::External, got: {err:?}"
);
// The original TestCustomError message should be preserved through the
// error chain, even if the error gets wrapped multiple times by
// lance's insert pipeline.
assert!(
err.to_string().contains("TestCustomError occurred"),
"Expected original error message to be preserved, got: {err}"
);
}
impl std::error::Error for MyError {}
#[tokio::test]
async fn test_add_preserves_reader_error() {
let table = create_test_table().await;
@@ -178,7 +307,7 @@ mod tests {
let schema = first_batch.schema();
let iterator = vec![
Ok(first_batch),
Err(ArrowError::ExternalError(Box::new(MyError))),
Err(ArrowError::ExternalError(Box::new(TestCustomError))),
];
let reader: Box<dyn arrow_array::RecordBatchReader + Send> = Box::new(
RecordBatchIterator::new(iterator.into_iter(), schema.clone()),
@@ -186,7 +315,7 @@ mod tests {
let result = table.add(reader).execute().await;
assert!(result.is_err());
assert_preserves_external_error(&result.unwrap_err());
}
#[tokio::test]
@@ -197,7 +326,7 @@ mod tests {
let iterator = vec![
Ok(first_batch),
Err(Error::External {
source: Box::new(MyError),
source: Box::new(TestCustomError),
}),
];
let stream = futures::stream::iter(iterator);
@@ -208,7 +337,7 @@ mod tests {
let result = table.add(stream).execute().await;
assert!(result.is_err());
assert_preserves_external_error(&result.unwrap_err());
}
#[tokio::test]
@@ -340,4 +469,248 @@ mod tests {
assert_eq!(embedding_col.null_count(), 0);
}
}
#[tokio::test]
async fn test_add_casts_to_table_schema() {
let table_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("text", DataType::Utf8, false),
Field::new(
"embedding",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4),
false,
),
]));
let input_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false), // Upcast integer
Field::new("text", DataType::LargeUtf8, false), // Re-encode string
// Cast list of float64 to fixed-size list of float32
// (This will only work if list size is correct. See next test.
Field::new(
"embedding",
DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
false,
),
]));
let db = connect("memory://").execute().await.unwrap();
let table = db
.create_empty_table("cast_test", table_schema.clone())
.execute()
.await
.unwrap();
let batch = RecordBatch::try_new(
input_schema,
vec![
Arc::new(Int32Array::from(vec![1, 2])),
Arc::new(LargeStringArray::from(vec!["hello", "world"])),
Arc::new(ListArray::from_iter_primitive::<Float64Type, _, _>(vec![
Some(vec![0.1, 0.2, 0.3, 0.4].into_iter().map(Some)),
Some(vec![0.5, 0.6, 0.7, 0.8].into_iter().map(Some)),
])),
],
)
.unwrap();
table.add(batch).execute().await.unwrap();
let row_count = table.count_rows(None).await.unwrap();
assert_eq!(row_count, 2);
}
#[tokio::test]
async fn test_add_rejects_bad_vector_dimensions() {
let table_schema = Arc::new(Schema::new(vec![Field::new(
"embedding",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4),
false,
)]));
let input_schema = Arc::new(Schema::new(vec![Field::new(
"embedding",
DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
false,
)]));
let db = connect("memory://").execute().await.unwrap();
let table = db
.create_empty_table("cast_test", table_schema.clone())
.execute()
.await
.unwrap();
let batch = RecordBatch::try_new(
input_schema,
vec![Arc::new(
ListArray::from_iter_primitive::<Float64Type, _, _>(vec![
Some(vec![0.1, 0.2, 0.3, 0.4].into_iter().map(Some)),
Some(vec![0.5, 0.6, 0.8].into_iter().map(Some)),
]),
)],
)
.unwrap();
let res = table.add(batch).execute().await;
// TODO: to recover the error, we will need fix upstream in Lance.
// assert!(
// matches!(res, Err(Error::Arrow { source: ArrowError::CastError(_) })),
// "Expected schema mismatch error due to wrong vector dimensions, but got: {res:?}"
// );
assert!(
res.is_err(),
"Expected error due to wrong vector dimensions, but got success"
);
}
#[tokio::test]
async fn test_add_rejects_nan_vectors() {
let schema = Arc::new(Schema::new(vec![Field::new(
"embedding",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4),
false,
)]));
let db = connect("memory://").execute().await.unwrap();
let table = db
.create_empty_table("nan_test", schema.clone())
.execute()
.await
.unwrap();
let batch = RecordBatch::try_new(
schema,
vec![Arc::new(
FixedSizeListArray::try_new(
Arc::new(Field::new("item", DataType::Float32, true)),
4,
Arc::new(Float32Array::from(vec![0.1, 0.2, f32::NAN, 0.4])),
None,
)
.unwrap(),
)],
)
.unwrap();
let res = table.add(batch.clone()).execute().await;
let err = res.unwrap_err();
assert!(
err.to_string().contains("NaN"),
"Expected error mentioning NaN values, but got: {err:?}"
);
table
.add(batch)
.on_nan_vectors(NaNVectorBehavior::Keep)
.execute()
.await
.unwrap();
let row_count = table.count_rows(None).await.unwrap();
assert_eq!(row_count, 1);
}
#[tokio::test]
async fn test_add_subschema() {
let data = record_batch!(("id", Int64, [4, 5]), ("text", Utf8, ["foo", "bar"])).unwrap();
let db = connect("memory://").execute().await.unwrap();
let table = db
.create_table("test", data.clone())
.execute()
.await
.unwrap();
let new_data = record_batch!(("id", Int64, [6, 7])).unwrap();
table.add(new_data).execute().await.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 4);
assert_eq!(
table
.count_rows(Some("id IS NOT NULL".to_string()))
.await
.unwrap(),
4
);
assert_eq!(
table
.count_rows(Some("text IS NOT NULL".to_string()))
.await
.unwrap(),
2
);
// We can still cast
let new_data = record_batch!(("text", LargeUtf8, ["baz", "qux"])).unwrap();
table.add(new_data).execute().await.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 6);
assert_eq!(
table
.count_rows(Some("id IS NOT NULL".to_string()))
.await
.unwrap(),
4
);
assert_eq!(
table
.count_rows(Some("text IS NOT NULL".to_string()))
.await
.unwrap(),
4
);
// Extra columns mean an error
let new_data =
record_batch!(("id", Int64, [8, 9]), ("extra", Utf8, ["extra1", "extra2"])).unwrap();
let res = table.add(new_data).execute().await;
assert!(
res.is_err(),
"Expected error due to extra column, but got: {res:?}"
);
// Insert with a subset of struct sub-fields
let struct_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new(
"metadata",
DataType::Struct(
vec![
Field::new("a", DataType::Int64, true),
Field::new("b", DataType::Utf8, true),
]
.into(),
),
true,
),
]));
let db2 = connect("memory://").execute().await.unwrap();
let table2 = db2
.create_empty_table("struct_test", struct_schema)
.execute()
.await
.unwrap();
// Insert with only the "a" sub-field of the struct
let sub_struct_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new(
"metadata",
DataType::Struct(vec![Field::new("a", DataType::Int64, true)].into()),
true,
),
]));
let struct_batch = RecordBatch::try_new(
sub_struct_schema,
vec![
Arc::new(arrow_array::Int64Array::from(vec![1, 2])),
Arc::new(arrow_array::StructArray::from(vec![(
Arc::new(Field::new("a", DataType::Int64, true)),
Arc::new(arrow_array::Int64Array::from(vec![10, 20]))
as Arc<dyn arrow_array::Array>,
)])),
],
)
.unwrap();
table2.add(struct_batch).execute().await.unwrap();
assert_eq!(table2.count_rows(None).await.unwrap(), 2);
}
}

View File

@@ -3,7 +3,10 @@
//! This module contains adapters to allow LanceDB tables to be used as DataFusion table providers.
pub mod cast;
pub mod insert;
pub mod reject_nan;
pub mod scannable_exec;
pub mod udtf;
use std::{collections::HashMap, sync::Arc};

View File

@@ -0,0 +1,498 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::Arc;
use arrow_schema::{DataType, Field, FieldRef, Fields, Schema};
use datafusion::functions::core::{get_field, named_struct};
use datafusion_common::config::ConfigOptions;
use datafusion_common::ScalarValue;
use datafusion_physical_expr::expressions::{cast, Literal};
use datafusion_physical_expr::ScalarFunctionExpr;
use datafusion_physical_plan::expressions::Column;
use datafusion_physical_plan::projection::ProjectionExec;
use datafusion_physical_plan::{ExecutionPlan, PhysicalExpr};
use crate::{Error, Result};
pub fn cast_to_table_schema(
input: Arc<dyn ExecutionPlan>,
table_schema: &Schema,
) -> Result<Arc<dyn ExecutionPlan>> {
let input_schema = input.schema();
if input_schema.fields() == table_schema.fields() {
return Ok(input);
}
let exprs = build_field_exprs(
input_schema.fields(),
table_schema.fields(),
&|idx| Arc::new(Column::new(input_schema.field(idx).name(), idx)) as Arc<dyn PhysicalExpr>,
&input_schema,
)?;
let exprs: Vec<(Arc<dyn PhysicalExpr>, String)> = exprs
.into_iter()
.map(|(expr, field)| (expr, field.name().clone()))
.collect();
let projection = ProjectionExec::try_new(exprs, input).map_err(crate::Error::from)?;
Ok(Arc::new(projection))
}
/// Build expressions to project input fields to match the table schema.
///
/// For each table field that exists in the input, produce an expression that
/// reads from the input and casts if needed. Fields in the table but not in the
/// input are omitted (the storage layer handles missing columns).
fn build_field_exprs(
input_fields: &Fields,
table_fields: &Fields,
get_input_expr: &dyn Fn(usize) -> Arc<dyn PhysicalExpr>,
input_schema: &Schema,
) -> Result<Vec<(Arc<dyn PhysicalExpr>, FieldRef)>> {
let config = Arc::new(ConfigOptions::default());
let mut result = Vec::new();
for table_field in table_fields {
let Some(input_idx) = input_fields
.iter()
.position(|f| f.name() == table_field.name())
else {
continue;
};
let input_field = &input_fields[input_idx];
let input_expr = get_input_expr(input_idx);
let expr = match (input_field.data_type(), table_field.data_type()) {
// Both are structs: recurse into sub-fields to handle subschemas and casts.
(DataType::Struct(in_children), DataType::Struct(tbl_children))
if in_children != tbl_children =>
{
let sub_exprs = build_field_exprs(
in_children,
tbl_children,
&|child_idx| {
let child_name = in_children[child_idx].name();
Arc::new(ScalarFunctionExpr::new(
&format!("get_field({child_name})"),
get_field(),
vec![
input_expr.clone(),
Arc::new(Literal::new(ScalarValue::from(child_name.as_str()))),
],
Arc::new(in_children[child_idx].as_ref().clone()),
config.clone(),
)) as Arc<dyn PhysicalExpr>
},
input_schema,
)?;
let output_struct_fields: Fields = sub_exprs
.iter()
.map(|(_, f)| f.clone())
.collect::<Vec<_>>()
.into();
let output_field: FieldRef = Arc::new(Field::new(
table_field.name(),
DataType::Struct(output_struct_fields),
table_field.is_nullable(),
));
// Build named_struct(lit("a"), expr_a, lit("b"), expr_b, ...)
let mut ns_args: Vec<Arc<dyn PhysicalExpr>> = Vec::new();
for (sub_expr, sub_field) in &sub_exprs {
ns_args.push(Arc::new(Literal::new(ScalarValue::from(
sub_field.name().as_str(),
))));
ns_args.push(sub_expr.clone());
}
let ns_expr: Arc<dyn PhysicalExpr> = Arc::new(ScalarFunctionExpr::new(
&format!("named_struct({})", table_field.name()),
named_struct(),
ns_args,
output_field.clone(),
config.clone(),
));
result.push((ns_expr, output_field));
continue;
}
// Types match: pass through.
(inp, tbl) if inp == tbl => input_expr,
// Types differ: cast.
_ => cast(input_expr, input_schema, table_field.data_type().clone()).map_err(|e| {
Error::InvalidInput {
message: format!(
"cannot cast field '{}' from {} to {}: {}",
table_field.name(),
input_field.data_type(),
table_field.data_type(),
e
),
}
})?,
};
let output_field = Arc::new(Field::new(
table_field.name(),
table_field.data_type().clone(),
table_field.is_nullable(),
));
result.push((expr, output_field));
}
Ok(result)
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow_array::{
Float32Array, Float64Array, Int32Array, Int64Array, RecordBatch, StringArray, StructArray,
};
use arrow_schema::{DataType, Field, Schema};
use datafusion::prelude::SessionContext;
use datafusion_catalog::MemTable;
use futures::TryStreamExt;
use super::cast_to_table_schema;
async fn plan_from_batch(
batch: RecordBatch,
) -> Arc<dyn datafusion_physical_plan::ExecutionPlan> {
let schema = batch.schema();
let table = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
let ctx = SessionContext::new();
ctx.register_table("t", Arc::new(table)).unwrap();
let df = ctx.table("t").await.unwrap();
df.create_physical_plan().await.unwrap()
}
async fn collect(plan: Arc<dyn datafusion_physical_plan::ExecutionPlan>) -> RecordBatch {
let ctx = SessionContext::new();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
arrow_select::concat::concat_batches(&plan.schema(), &batches).unwrap()
}
#[tokio::test]
async fn test_noop_when_schemas_match() {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2])),
Arc::new(StringArray::from(vec!["x", "y"])),
],
)
.unwrap();
let input = plan_from_batch(batch).await;
let input_ptr = Arc::as_ptr(&input);
let result = cast_to_table_schema(input, &schema).unwrap();
assert_eq!(Arc::as_ptr(&result), input_ptr);
}
#[tokio::test]
async fn test_simple_type_cast() {
let input_batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("val", DataType::Float32, false),
])),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Float32Array::from(vec![1.5, 2.5, 3.5])),
],
)
.unwrap();
let table_schema = Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("val", DataType::Float64, false),
]);
let plan = plan_from_batch(input_batch).await;
let casted = cast_to_table_schema(plan, &table_schema).unwrap();
let result = collect(casted).await;
assert_eq!(result.schema().field(0).data_type(), &DataType::Int64);
assert_eq!(result.schema().field(1).data_type(), &DataType::Float64);
let ids: &Int64Array = result.column(0).as_any().downcast_ref().unwrap();
assert_eq!(ids.values(), &[1, 2, 3]);
let vals: &Float64Array = result.column(1).as_any().downcast_ref().unwrap();
assert!((vals.value(0) - 1.5).abs() < 1e-6);
assert!((vals.value(1) - 2.5).abs() < 1e-6);
assert!((vals.value(2) - 3.5).abs() < 1e-6);
}
#[tokio::test]
async fn test_missing_table_field_skipped() {
// Input has "a", table expects "a" and "b". "b" is omitted from the
// projection since the storage layer fills in missing columns.
let input_batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
vec![Arc::new(Int32Array::from(vec![10, 20]))],
)
.unwrap();
let table_schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, true),
]);
let plan = plan_from_batch(input_batch).await;
let casted = cast_to_table_schema(plan, &table_schema).unwrap();
let result = collect(casted).await;
assert_eq!(result.num_columns(), 1);
assert_eq!(result.schema().field(0).name(), "a");
}
#[tokio::test]
async fn test_extra_input_fields_dropped() {
// Input has "a" and "extra"; table only expects "a".
let input_batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("extra", DataType::Utf8, false),
])),
vec![
Arc::new(Int32Array::from(vec![1, 2])),
Arc::new(StringArray::from(vec!["x", "y"])),
],
)
.unwrap();
let table_schema = Schema::new(vec![Field::new("a", DataType::Int64, false)]);
let plan = plan_from_batch(input_batch).await;
let casted = cast_to_table_schema(plan, &table_schema).unwrap();
let result = collect(casted).await;
assert_eq!(result.num_columns(), 1);
assert_eq!(result.schema().field(0).name(), "a");
assert_eq!(result.schema().field(0).data_type(), &DataType::Int64);
}
#[tokio::test]
async fn test_reorders_to_table_schema() {
let input_batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("b", DataType::Utf8, false),
Field::new("a", DataType::Int32, false),
])),
vec![
Arc::new(StringArray::from(vec!["x", "y"])),
Arc::new(Int32Array::from(vec![1, 2])),
],
)
.unwrap();
let table_schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, false),
]);
let plan = plan_from_batch(input_batch).await;
let casted = cast_to_table_schema(plan, &table_schema).unwrap();
let result = collect(casted).await;
assert_eq!(result.schema().field(0).name(), "a");
assert_eq!(result.schema().field(1).name(), "b");
let a: &Int32Array = result.column(0).as_any().downcast_ref().unwrap();
assert_eq!(a.values(), &[1, 2]);
let b: &StringArray = result.column(1).as_any().downcast_ref().unwrap();
assert_eq!(b.value(0), "x");
}
#[tokio::test]
async fn test_struct_subfield_cast() {
// Input struct has {x: Int32, y: Int32}, table expects {x: Int64, y: Int64}.
let inner_fields = vec![
Field::new("x", DataType::Int32, false),
Field::new("y", DataType::Int32, false),
];
let struct_array = StructArray::from(vec![
(
Arc::new(inner_fields[0].clone()),
Arc::new(Int32Array::from(vec![1, 2])) as _,
),
(
Arc::new(inner_fields[1].clone()),
Arc::new(Int32Array::from(vec![3, 4])) as _,
),
]);
let input_batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new(
"s",
DataType::Struct(inner_fields.into()),
false,
)])),
vec![Arc::new(struct_array)],
)
.unwrap();
let table_inner = vec![
Field::new("x", DataType::Int64, false),
Field::new("y", DataType::Int64, false),
];
let table_schema = Schema::new(vec![Field::new(
"s",
DataType::Struct(table_inner.into()),
false,
)]);
let plan = plan_from_batch(input_batch).await;
let casted = cast_to_table_schema(plan, &table_schema).unwrap();
let result = collect(casted).await;
let struct_col = result
.column(0)
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
assert_eq!(struct_col.column(0).data_type(), &DataType::Int64);
assert_eq!(struct_col.column(1).data_type(), &DataType::Int64);
let x: &Int64Array = struct_col.column(0).as_any().downcast_ref().unwrap();
assert_eq!(x.values(), &[1, 2]);
let y: &Int64Array = struct_col.column(1).as_any().downcast_ref().unwrap();
assert_eq!(y.values(), &[3, 4]);
}
#[tokio::test]
async fn test_struct_subschema() {
// Input struct has {x, y, z}, table only expects {x, z}.
let inner_fields = vec![
Field::new("x", DataType::Int32, false),
Field::new("y", DataType::Int32, false),
Field::new("z", DataType::Int32, false),
];
let struct_array = StructArray::from(vec![
(
Arc::new(inner_fields[0].clone()),
Arc::new(Int32Array::from(vec![1, 2])) as _,
),
(
Arc::new(inner_fields[1].clone()),
Arc::new(Int32Array::from(vec![10, 20])) as _,
),
(
Arc::new(inner_fields[2].clone()),
Arc::new(Int32Array::from(vec![100, 200])) as _,
),
]);
let input_batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new(
"s",
DataType::Struct(inner_fields.into()),
false,
)])),
vec![Arc::new(struct_array)],
)
.unwrap();
let table_inner = vec![
Field::new("x", DataType::Int32, false),
Field::new("z", DataType::Int32, false),
];
let table_schema = Schema::new(vec![Field::new(
"s",
DataType::Struct(table_inner.into()),
false,
)]);
let plan = plan_from_batch(input_batch).await;
let casted = cast_to_table_schema(plan, &table_schema).unwrap();
let result = collect(casted).await;
let struct_col = result
.column(0)
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
assert_eq!(struct_col.num_columns(), 2);
let x: &Int32Array = struct_col
.column_by_name("x")
.unwrap()
.as_any()
.downcast_ref()
.unwrap();
assert_eq!(x.values(), &[1, 2]);
let z: &Int32Array = struct_col
.column_by_name("z")
.unwrap()
.as_any()
.downcast_ref()
.unwrap();
assert_eq!(z.values(), &[100, 200]);
}
#[tokio::test]
async fn test_incompatible_cast_errors() {
let input_batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("a", DataType::Binary, false)])),
vec![Arc::new(arrow_array::BinaryArray::from_vec(vec![b"hi"]))],
)
.unwrap();
let table_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let plan = plan_from_batch(input_batch).await;
let result = cast_to_table_schema(plan, &table_schema);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("cannot cast field 'a'"),
"unexpected error: {err_msg}"
);
}
#[tokio::test]
async fn test_mixed_cast_and_passthrough() {
// "a" needs cast (Int32→Int64), "b" passes through unchanged.
let input_batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, false),
])),
vec![
Arc::new(Int32Array::from(vec![7, 8])),
Arc::new(StringArray::from(vec!["hello", "world"])),
],
)
.unwrap();
let table_schema = Schema::new(vec![
Field::new("a", DataType::Int64, false),
Field::new("b", DataType::Utf8, false),
]);
let plan = plan_from_batch(input_batch).await;
let casted = cast_to_table_schema(plan, &table_schema).unwrap();
let result = collect(casted).await;
assert_eq!(result.schema().field(0).data_type(), &DataType::Int64);
assert_eq!(result.schema().field(1).data_type(), &DataType::Utf8);
let a: &Int64Array = result.column(0).as_any().downcast_ref().unwrap();
assert_eq!(a.values(), &[7, 8]);
let b: &StringArray = result.column(1).as_any().downcast_ref().unwrap();
assert_eq!(b.value(0), "hello");
assert_eq!(b.value(1), "world");
}
}

View File

@@ -200,7 +200,7 @@ impl ExecutionPlan for InsertExec {
let new_dataset = CommitBuilder::new(dataset.clone())
.execute(merged_txn)
.await?;
ds_wrapper.set_latest(new_dataset).await;
ds_wrapper.update(new_dataset);
}
}

View File

@@ -0,0 +1,269 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! A DataFusion projection that rejects vectors containing NaN values.
use std::any::Any;
use std::sync::{Arc, LazyLock};
use arrow_array::{Array, FixedSizeListArray};
use arrow_schema::{DataType, Field, FieldRef};
use datafusion_common::config::ConfigOptions;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
use datafusion_physical_expr::ScalarFunctionExpr;
use datafusion_physical_plan::expressions::Column;
use datafusion_physical_plan::projection::ProjectionExec;
use datafusion_physical_plan::{ExecutionPlan, PhysicalExpr};
use crate::{Error, Result};
static REJECT_NAN_UDF: LazyLock<Arc<datafusion_expr::ScalarUDF>> =
LazyLock::new(|| Arc::new(datafusion_expr::ScalarUDF::from(RejectNanUdf::new())));
/// Returns true if the field is a vector column: FixedSizeList<Float16/32/64>.
fn is_vector_field(field: &Field) -> bool {
if let DataType::FixedSizeList(child, _) = field.data_type() {
matches!(
child.data_type(),
DataType::Float16 | DataType::Float32 | DataType::Float64
)
} else {
false
}
}
/// Wraps the input plan with a projection that checks vector columns for NaN values.
///
/// Non-vector columns pass through unchanged. Vector columns are wrapped with a
/// UDF that returns the column as-is if no NaNs are present, or errors otherwise.
pub fn reject_nan_vectors(input: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
let schema = input.schema();
let config = Arc::new(ConfigOptions::default());
let udf = REJECT_NAN_UDF.clone();
let mut has_vector_cols = false;
let mut exprs: Vec<(Arc<dyn PhysicalExpr>, String)> = Vec::new();
for (idx, field) in schema.fields().iter().enumerate() {
let col_expr: Arc<dyn PhysicalExpr> = Arc::new(Column::new(field.name(), idx));
if is_vector_field(field) {
has_vector_cols = true;
let wrapped: Arc<dyn PhysicalExpr> = Arc::new(ScalarFunctionExpr::new(
&format!("reject_nan({})", field.name()),
udf.clone(),
vec![col_expr],
Arc::clone(field) as FieldRef,
config.clone(),
));
exprs.push((wrapped, field.name().clone()));
} else {
exprs.push((col_expr, field.name().clone()));
}
}
if !has_vector_cols {
return Ok(input);
}
let projection = ProjectionExec::try_new(exprs, input).map_err(Error::from)?;
Ok(Arc::new(projection))
}
/// A scalar UDF that passes through FixedSizeList arrays unchanged, but errors
/// if any float values in the list are NaN.
#[derive(Debug, Hash, PartialEq, Eq)]
struct RejectNanUdf {
signature: Signature,
}
impl RejectNanUdf {
fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}
impl ScalarUDFImpl for RejectNanUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"reject_nan"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
Ok(arg_types[0].clone())
}
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let arg = &args.args[0];
match arg {
ColumnarValue::Array(array) => {
check_no_nans(array)?;
Ok(ColumnarValue::Array(array.clone()))
}
ColumnarValue::Scalar(_) => Ok(arg.clone()),
}
}
}
fn check_no_nans(array: &dyn Array) -> datafusion_common::Result<()> {
let fsl = array
.as_any()
.downcast_ref::<FixedSizeListArray>()
.ok_or_else(|| {
datafusion_common::DataFusionError::Internal(
"reject_nan expected FixedSizeList".to_string(),
)
})?;
// Only inspect elements that are both in a valid parent row and non-null
// themselves. Values backing null parent rows or null child elements may
// contain garbage (including NaN) per the Arrow spec.
let has_nan = (0..fsl.len()).filter(|i| fsl.is_valid(*i)).any(|i| {
let row = fsl.value(i);
match row.data_type() {
DataType::Float16 => row
.as_any()
.downcast_ref::<arrow_array::Float16Array>()
.unwrap()
.iter()
.any(|v| v.is_some_and(|v| v.is_nan())),
DataType::Float32 => row
.as_any()
.downcast_ref::<arrow_array::Float32Array>()
.unwrap()
.iter()
.any(|v| v.is_some_and(|v| v.is_nan())),
DataType::Float64 => row
.as_any()
.downcast_ref::<arrow_array::Float64Array>()
.unwrap()
.iter()
.any(|v| v.is_some_and(|v| v.is_nan())),
_ => false,
}
});
if has_nan {
return Err(datafusion_common::DataFusionError::ArrowError(
Box::new(arrow_schema::ArrowError::ComputeError(
"Vector column contains NaN values".to_string(),
)),
None,
));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::Float32Array;
#[test]
fn test_passes_clean_vectors() {
let fsl = FixedSizeListArray::try_new(
Arc::new(Field::new("item", DataType::Float32, true)),
2,
Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0, 4.0])),
None,
)
.unwrap();
assert!(check_no_nans(&fsl).is_ok());
}
#[test]
fn test_rejects_nan_vectors() {
let fsl = FixedSizeListArray::try_new(
Arc::new(Field::new("item", DataType::Float32, true)),
2,
Arc::new(Float32Array::from(vec![1.0, f32::NAN, 3.0, 4.0])),
None,
)
.unwrap();
assert!(check_no_nans(&fsl).is_err());
}
#[test]
fn test_skips_null_rows() {
// Values backing null rows may contain NaN per the Arrow spec.
// We should not reject a batch just because of garbage in null slots.
let values = Float32Array::from(vec![1.0, 2.0, f32::NAN, f32::NAN]);
let fsl = FixedSizeListArray::try_new(
Arc::new(Field::new("item", DataType::Float32, true)),
2,
Arc::new(values),
// Row 0 is valid [1.0, 2.0], row 1 is null [NAN, NAN]
Some(vec![true, false].into()),
)
.unwrap();
assert!(fsl.is_null(1));
assert!(check_no_nans(&fsl).is_ok());
}
#[test]
fn test_skips_null_elements_within_valid_row() {
// A valid row with null child elements: the underlying buffer may hold
// NaN but the null bitmap says they're absent — should not reject.
let values = Float32Array::from(vec![
Some(1.0),
None, // null element — buffer may contain NaN
Some(3.0),
None, // null element
]);
let fsl = FixedSizeListArray::try_new(
Arc::new(Field::new("item", DataType::Float32, true)),
2,
Arc::new(values),
None, // both rows are valid
)
.unwrap();
assert!(check_no_nans(&fsl).is_ok());
}
#[test]
fn test_rejects_nan_in_valid_row_with_nulls_present() {
// Row 0 is null, row 1 is valid but contains NaN — should reject.
let values = Float32Array::from(vec![0.0, 0.0, 1.0, f32::NAN]);
let fsl = FixedSizeListArray::try_new(
Arc::new(Field::new("item", DataType::Float32, true)),
2,
Arc::new(values),
Some(vec![false, true].into()),
)
.unwrap();
assert!(check_no_nans(&fsl).is_err());
}
#[test]
fn test_is_vector_field() {
assert!(is_vector_field(&Field::new(
"v",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4),
false,
)));
assert!(is_vector_field(&Field::new(
"v",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, true)), 4),
false,
)));
assert!(!is_vector_field(&Field::new("id", DataType::Int32, false)));
assert!(!is_vector_field(&Field::new(
"v",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 4),
false,
)));
}
}

View File

@@ -0,0 +1,118 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use core::fmt;
use std::sync::{Arc, Mutex};
use datafusion_common::{stats::Precision, DataFusionError, Result as DFResult, Statistics};
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
use datafusion_physical_plan::{
execution_plan::EmissionType, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties,
};
use crate::{arrow::SendableRecordBatchStreamExt, data::scannable::Scannable};
pub struct ScannableExec {
// We don't require Scannable to by Sync, so we wrap it in a Mutex to allow safe concurrent access.
source: Mutex<Box<dyn Scannable>>,
num_rows: Option<usize>,
properties: PlanProperties,
}
impl std::fmt::Debug for ScannableExec {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ScannableExec")
.field("schema", &self.schema())
.field("num_rows", &self.num_rows)
.finish()
}
}
impl ScannableExec {
pub fn new(source: Box<dyn Scannable>) -> Self {
let schema = source.schema();
let eq_properties = EquivalenceProperties::new(schema);
let properties = PlanProperties::new(
eq_properties,
Partitioning::UnknownPartitioning(1),
EmissionType::Incremental,
datafusion_physical_plan::execution_plan::Boundedness::Bounded,
);
let num_rows = source.num_rows();
let source = Mutex::new(source);
Self {
source,
num_rows,
properties,
}
}
}
impl DisplayAs for ScannableExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ScannableExec: num_rows={:?}", self.num_rows)
}
}
impl ExecutionPlan for ScannableExec {
fn name(&self) -> &str {
"ScannableExec"
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn properties(&self) -> &PlanProperties {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> DFResult<Arc<dyn ExecutionPlan>> {
if !children.is_empty() {
return Err(DataFusionError::Internal(
"ScannableExec does not have children".to_string(),
));
}
Ok(self)
}
fn execute(
&self,
partition: usize,
_context: Arc<TaskContext>,
) -> DFResult<SendableRecordBatchStream> {
if partition != 0 {
return Err(DataFusionError::Internal(format!(
"ScannableExec only supports partition 0, got {}",
partition
)));
}
let stream = match self.source.lock() {
Ok(mut guard) => guard.scan_as_stream(),
Err(poison) => poison.into_inner().scan_as_stream(),
};
Ok(stream.into_df_stream())
}
fn partition_statistics(&self, _partition: Option<usize>) -> DFResult<Statistics> {
Ok(Statistics {
num_rows: self
.num_rows
.map(Precision::Exact)
.unwrap_or(Precision::Absent),
total_byte_size: Precision::Absent,
column_statistics: vec![],
})
}
}

View File

@@ -2,301 +2,501 @@
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::{
ops::{Deref, DerefMut},
sync::Arc,
time::{self, Duration, Instant},
sync::{Arc, Mutex},
time::Duration,
};
use lance::{dataset::refs, Dataset};
use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use crate::error::Result;
/// A wrapper around a [Dataset] that provides lazy-loading and consistency checks.
///
/// This can be cloned cheaply. It supports concurrent reads or exclusive writes.
#[derive(Debug, Clone)]
pub struct DatasetConsistencyWrapper(Arc<RwLock<DatasetRef>>);
use crate::{error::Result, utils::background_cache::BackgroundCache, Error};
/// A wrapper around a [Dataset] that provides consistency checks.
///
/// The dataset is lazily loaded, and starts off as None. On the first access,
/// the dataset is loaded.
/// This can be cloned cheaply. Callers get an [`Arc<Dataset>`] from [`get()`](Self::get)
/// and call [`update()`](Self::update) after writes to store the new version.
#[derive(Debug, Clone)]
enum DatasetRef {
/// In this mode, the dataset is always the latest version.
Latest {
dataset: Dataset,
read_consistency_interval: Option<Duration>,
last_consistency_check: Option<time::Instant>,
},
/// In this mode, the dataset is a specific version. It cannot be mutated.
TimeTravel { dataset: Dataset, version: u64 },
pub struct DatasetConsistencyWrapper {
state: Arc<Mutex<DatasetState>>,
consistency: ConsistencyMode,
}
impl DatasetRef {
/// Reload the dataset to the appropriate version.
async fn reload(&mut self) -> Result<()> {
match self {
Self::Latest {
dataset,
last_consistency_check,
..
} => {
dataset.checkout_latest().await?;
last_consistency_check.replace(Instant::now());
}
Self::TimeTravel { dataset, version } => {
dataset.checkout_version(*version).await?;
}
}
Ok(())
}
/// The current dataset and whether it is pinned to a specific version.
#[derive(Debug, Clone)]
struct DatasetState {
dataset: Arc<Dataset>,
/// `Some(version)` = pinned to a specific version (time travel),
/// `None` = tracking latest.
pinned_version: Option<u64>,
}
fn is_latest(&self) -> bool {
matches!(self, Self::Latest { .. })
}
async fn as_latest(&mut self, read_consistency_interval: Option<Duration>) -> Result<()> {
match self {
Self::Latest { .. } => Ok(()),
Self::TimeTravel { dataset, .. } => {
dataset
.checkout_version(dataset.latest_version_id().await?)
.await?;
*self = Self::Latest {
dataset: dataset.clone(),
read_consistency_interval,
last_consistency_check: Some(Instant::now()),
};
Ok(())
}
}
}
async fn as_time_travel(&mut self, target_version: impl Into<refs::Ref>) -> Result<()> {
let target_ref = target_version.into();
match self {
Self::Latest { dataset, .. } => {
let new_dataset = dataset.checkout_version(target_ref.clone()).await?;
let version_value = new_dataset.version().version;
*self = Self::TimeTravel {
dataset: new_dataset,
version: version_value,
};
}
Self::TimeTravel { dataset, version } => {
let should_checkout = match &target_ref {
refs::Ref::Version(_, Some(target_ver)) => version != target_ver,
refs::Ref::Version(_, None) => true, // No specific version, always checkout
refs::Ref::VersionNumber(target_ver) => version != target_ver,
refs::Ref::Tag(_) => true, // Always checkout for tags
};
if should_checkout {
let new_dataset = dataset.checkout_version(target_ref).await?;
let version_value = new_dataset.version().version;
*self = Self::TimeTravel {
dataset: new_dataset,
version: version_value,
};
}
}
}
Ok(())
}
fn is_up_to_date(&self) -> bool {
match self {
Self::Latest {
read_consistency_interval,
last_consistency_check,
..
} => match (read_consistency_interval, last_consistency_check) {
(None, _) => true,
(Some(_), None) => false,
(Some(interval), Some(last_check)) => last_check.elapsed() < *interval,
},
Self::TimeTravel { dataset, version } => dataset.version().version == *version,
}
}
fn time_travel_version(&self) -> Option<u64> {
match self {
Self::Latest { .. } => None,
Self::TimeTravel { version, .. } => Some(*version),
}
}
fn set_latest(&mut self, dataset: Dataset) {
match self {
Self::Latest {
dataset: ref mut ds,
..
} => {
if dataset.manifest().version > ds.manifest().version {
*ds = dataset;
}
}
_ => unreachable!("Dataset should be in latest mode at this point"),
}
}
#[derive(Debug, Clone)]
enum ConsistencyMode {
/// Only update table state when explicitly asked.
Lazy,
/// Always check for a new version on every read.
Strong,
/// Periodically check for new version in the background. If the table is being
/// regularly accessed, refresh will happen in the background. If the table is idle for a while,
/// the next access will trigger a refresh before returning the dataset.
///
/// read_consistency_interval = TTL
/// refresh_window = min(3s, TTL/4)
///
/// | t < TTL - refresh_window | t < TTL | t >= TTL |
/// | Return value | Background refresh & return value | syncronous refresh |
Eventual(BackgroundCache<Arc<Dataset>, Error>),
}
impl DatasetConsistencyWrapper {
/// Create a new wrapper in the latest version mode.
pub fn new_latest(dataset: Dataset, read_consistency_interval: Option<Duration>) -> Self {
Self(Arc::new(RwLock::new(DatasetRef::Latest {
dataset,
read_consistency_interval,
last_consistency_check: Some(Instant::now()),
})))
let dataset = Arc::new(dataset);
let consistency = match read_consistency_interval {
Some(d) if d == Duration::ZERO => ConsistencyMode::Strong,
Some(d) => {
let refresh_window = std::cmp::min(std::time::Duration::from_secs(3), d / 4);
let cache = BackgroundCache::new(d, refresh_window);
cache.seed(dataset.clone());
ConsistencyMode::Eventual(cache)
}
None => ConsistencyMode::Lazy,
};
Self {
state: Arc::new(Mutex::new(DatasetState {
dataset,
pinned_version: None,
})),
consistency,
}
}
/// Get an immutable reference to the dataset.
pub async fn get(&self) -> Result<DatasetReadGuard<'_>> {
self.ensure_up_to_date().await?;
Ok(DatasetReadGuard {
guard: self.0.read().await,
})
}
/// Get a mutable reference to the dataset.
/// Get the current dataset.
///
/// If the dataset is in time travel mode this will fail
pub async fn get_mut(&self) -> Result<DatasetWriteGuard<'_>> {
self.ensure_mutable().await?;
self.ensure_up_to_date().await?;
Ok(DatasetWriteGuard {
guard: self.0.write().await,
})
}
/// Get a mutable reference to the dataset without requiring the
/// dataset to be in a Latest mode.
pub async fn get_mut_unchecked(&self) -> Result<DatasetWriteGuard<'_>> {
self.ensure_up_to_date().await?;
Ok(DatasetWriteGuard {
guard: self.0.write().await,
})
}
/// Convert into a wrapper in latest version mode
pub async fn as_latest(&self, read_consistency_interval: Option<Duration>) -> Result<()> {
if self.0.read().await.is_latest() {
return Ok(());
/// Behavior depends on the consistency mode:
/// - **Lazy** (`None`): returns the cached dataset immediately.
/// - **Strong** (`Some(ZERO)`): checks for a new version before returning.
/// - **Eventual** (`Some(d)` where `d > 0`): returns a cached value immediately
/// while refreshing in the background when the TTL expires.
///
/// If pinned to a specific version (time travel), always returns the
/// pinned dataset regardless of consistency mode.
pub async fn get(&self) -> Result<Arc<Dataset>> {
{
let state = self.state.lock().unwrap();
if state.pinned_version.is_some() {
return Ok(state.dataset.clone());
}
}
let mut write_guard = self.0.write().await;
if write_guard.is_latest() {
return Ok(());
match &self.consistency {
ConsistencyMode::Eventual(bg_cache) => {
if let Some(dataset) = bg_cache.try_get() {
return Ok(dataset);
}
let state = self.state.clone();
bg_cache
.get(move || refresh_latest(state))
.await
.map_err(unwrap_shared_error)
}
ConsistencyMode::Strong => refresh_latest(self.state.clone()).await,
ConsistencyMode::Lazy => {
let state = self.state.lock().unwrap();
Ok(state.dataset.clone())
}
}
}
write_guard.as_latest(read_consistency_interval).await
/// Store a new dataset version after a write operation.
///
/// Only stores the dataset if its version is at least as new as the current one.
/// Same-version updates are accepted for operations like manifest path migration
/// that modify the dataset without creating a new version.
/// If the wrapper has since transitioned to time-travel mode (e.g. via a
/// concurrent [`as_time_travel`](Self::as_time_travel) call), the update
/// is silently ignored — the write already committed to storage.
pub fn update(&self, dataset: Dataset) {
let mut state = self.state.lock().unwrap();
if state.pinned_version.is_some() {
// A concurrent as_time_travel() beat us here. The write succeeded
// in storage, but since we're now pinned we don't advance the
// cached pointer.
return;
}
if dataset.manifest().version >= state.dataset.manifest().version {
state.dataset = Arc::new(dataset);
}
drop(state);
if let ConsistencyMode::Eventual(bg_cache) = &self.consistency {
bg_cache.invalidate();
}
}
/// Checkout a branch and track its HEAD for new versions.
pub async fn as_branch(&self, _branch: impl Into<String>) -> Result<()> {
todo!("Branch support not yet implemented")
}
/// Check that the dataset is in a mutable mode (Latest).
pub fn ensure_mutable(&self) -> Result<()> {
let state = self.state.lock().unwrap();
if state.pinned_version.is_some() {
Err(crate::Error::InvalidInput {
message: "table cannot be modified when a specific version is checked out"
.to_string(),
})
} else {
Ok(())
}
}
/// Returns the version, if in time travel mode, or None otherwise.
pub fn time_travel_version(&self) -> Option<u64> {
self.state.lock().unwrap().pinned_version
}
/// Convert into a wrapper in latest version mode.
pub async fn as_latest(&self) -> Result<()> {
let dataset = {
let state = self.state.lock().unwrap();
if state.pinned_version.is_none() {
return Ok(());
}
state.dataset.clone()
};
let latest_version = dataset.latest_version_id().await?;
let new_dataset = dataset.checkout_version(latest_version).await?;
let mut state = self.state.lock().unwrap();
if state.pinned_version.is_some() {
state.dataset = Arc::new(new_dataset);
state.pinned_version = None;
}
drop(state);
if let ConsistencyMode::Eventual(bg_cache) = &self.consistency {
bg_cache.invalidate();
}
Ok(())
}
pub async fn as_time_travel(&self, target_version: impl Into<refs::Ref>) -> Result<()> {
self.0.write().await.as_time_travel(target_version).await
}
let target_ref = target_version.into();
/// Provide a known latest version of the dataset.
///
/// This is usually done after some write operation, which inherently will
/// have the latest version.
pub async fn set_latest(&self, dataset: Dataset) {
self.0.write().await.set_latest(dataset);
let (should_checkout, dataset) = {
let state = self.state.lock().unwrap();
let should = match state.pinned_version {
None => true,
Some(version) => match &target_ref {
refs::Ref::Version(_, Some(target_ver)) => version != *target_ver,
refs::Ref::Version(_, None) => true,
refs::Ref::VersionNumber(target_ver) => version != *target_ver,
refs::Ref::Tag(_) => true,
},
};
(should, state.dataset.clone())
};
if !should_checkout {
return Ok(());
}
let new_dataset = dataset.checkout_version(target_ref).await?;
let version_value = new_dataset.version().version;
let mut state = self.state.lock().unwrap();
state.dataset = Arc::new(new_dataset);
state.pinned_version = Some(version_value);
Ok(())
}
pub async fn reload(&self) -> Result<()> {
self.0.write().await.reload().await
}
let (dataset, pinned_version) = {
let state = self.state.lock().unwrap();
(state.dataset.clone(), state.pinned_version)
};
/// Returns the version, if in time travel mode, or None otherwise
pub async fn time_travel_version(&self) -> Option<u64> {
self.0.read().await.time_travel_version()
}
match pinned_version {
None => {
refresh_latest(self.state.clone()).await?;
if let ConsistencyMode::Eventual(bg_cache) = &self.consistency {
bg_cache.invalidate();
}
}
Some(version) => {
if dataset.version().version == version {
return Ok(());
}
pub async fn ensure_mutable(&self) -> Result<()> {
let dataset_ref = self.0.read().await;
match &*dataset_ref {
DatasetRef::Latest { .. } => Ok(()),
DatasetRef::TimeTravel { .. } => Err(crate::Error::InvalidInput {
message: "table cannot be modified when a specific version is checked out"
.to_string(),
}),
}
}
let new_dataset = dataset.checkout_version(version).await?;
async fn is_up_to_date(&self) -> bool {
self.0.read().await.is_up_to_date()
}
/// Ensures that the dataset is loaded and up-to-date with consistency and
/// version parameters.
async fn ensure_up_to_date(&self) -> Result<()> {
if !self.is_up_to_date().await {
// Re-check under write lock — another task may have reloaded
// while we waited for the lock.
let mut write_guard = self.0.write().await;
if !write_guard.is_up_to_date() {
write_guard.reload().await?;
let mut state = self.state.lock().unwrap();
if state.pinned_version == Some(version) {
state.dataset = Arc::new(new_dataset);
}
}
}
Ok(())
}
}
pub struct DatasetReadGuard<'a> {
guard: RwLockReadGuard<'a, DatasetRef>,
}
async fn refresh_latest(state: Arc<Mutex<DatasetState>>) -> Result<Arc<Dataset>> {
let dataset = { state.lock().unwrap().dataset.clone() };
impl Deref for DatasetReadGuard<'_> {
type Target = Dataset;
let mut ds = (*dataset).clone();
ds.checkout_latest().await?;
let new_arc = Arc::new(ds);
fn deref(&self) -> &Self::Target {
match &*self.guard {
DatasetRef::Latest { dataset, .. } => dataset,
DatasetRef::TimeTravel { dataset, .. } => dataset,
{
let mut state = state.lock().unwrap();
if state.pinned_version.is_none()
&& new_arc.manifest().version >= state.dataset.manifest().version
{
state.dataset = new_arc.clone();
}
}
Ok(new_arc)
}
pub struct DatasetWriteGuard<'a> {
guard: RwLockWriteGuard<'a, DatasetRef>,
}
impl Deref for DatasetWriteGuard<'_> {
type Target = Dataset;
fn deref(&self) -> &Self::Target {
match &*self.guard {
DatasetRef::Latest { dataset, .. } => dataset,
DatasetRef::TimeTravel { dataset, .. } => dataset,
}
}
}
impl DerefMut for DatasetWriteGuard<'_> {
fn deref_mut(&mut self) -> &mut Self::Target {
match &mut *self.guard {
DatasetRef::Latest { dataset, .. } => dataset,
DatasetRef::TimeTravel { dataset, .. } => dataset,
}
fn unwrap_shared_error(arc: Arc<Error>) -> Error {
match Arc::try_unwrap(arc) {
Ok(err) => err,
Err(arc) => Error::Runtime {
message: arc.to_string(),
},
}
}
#[cfg(test)]
mod tests {
use std::time::Instant;
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator};
use arrow_schema::{DataType, Field, Schema};
use lance::{dataset::WriteParams, io::ObjectStoreParams};
use lance::{
dataset::{WriteMode, WriteParams},
io::ObjectStoreParams,
};
use super::*;
use crate::{connect, io::object_store::io_tracking::IoStatsHolder, table::WriteOptions};
async fn create_test_dataset(uri: &str) -> Dataset {
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
Dataset::write(
RecordBatchIterator::new(vec![Ok(batch)], schema),
uri,
Some(WriteParams::default()),
)
.await
.unwrap()
}
async fn append_to_dataset(uri: &str) -> Dataset {
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![4, 5, 6]))],
)
.unwrap();
Dataset::write(
RecordBatchIterator::new(vec![Ok(batch)], schema),
uri,
Some(WriteParams {
mode: WriteMode::Append,
..Default::default()
}),
)
.await
.unwrap()
}
#[tokio::test]
async fn test_get_returns_dataset() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let ds = create_test_dataset(uri).await;
let version = ds.version().version;
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
let ds1 = wrapper.get().await.unwrap();
let ds2 = wrapper.get().await.unwrap();
assert_eq!(ds1.version().version, version);
assert_eq!(ds2.version().version, version);
// Arc<Dataset> is independent — not borrowing from wrapper
drop(wrapper);
assert_eq!(ds1.version().version, version);
}
#[tokio::test]
async fn test_update_stores_newer_version() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let ds_v1 = create_test_dataset(uri).await;
assert_eq!(ds_v1.version().version, 1);
let wrapper = DatasetConsistencyWrapper::new_latest(ds_v1, None);
let ds_v2 = append_to_dataset(uri).await;
assert_eq!(ds_v2.version().version, 2);
wrapper.update(ds_v2);
let ds = wrapper.get().await.unwrap();
assert_eq!(ds.version().version, 2);
}
#[tokio::test]
async fn test_update_ignores_older_version() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let ds_v1 = create_test_dataset(uri).await;
let ds_v2 = append_to_dataset(uri).await;
let wrapper = DatasetConsistencyWrapper::new_latest(ds_v2, None);
wrapper.update(ds_v1);
let ds = wrapper.get().await.unwrap();
assert_eq!(ds.version().version, 2);
}
#[tokio::test]
async fn test_ensure_mutable_allows_latest() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let ds = create_test_dataset(uri).await;
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
assert!(wrapper.ensure_mutable().is_ok());
}
#[tokio::test]
async fn test_ensure_mutable_rejects_time_travel() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let ds = create_test_dataset(uri).await;
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
wrapper.as_time_travel(1u64).await.unwrap();
assert!(wrapper.ensure_mutable().is_err());
}
#[tokio::test]
async fn test_time_travel_version() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let ds = create_test_dataset(uri).await;
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
assert_eq!(wrapper.time_travel_version(), None);
wrapper.as_time_travel(1u64).await.unwrap();
assert_eq!(wrapper.time_travel_version(), Some(1));
}
#[tokio::test]
async fn test_as_latest_from_time_travel() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let ds = create_test_dataset(uri).await;
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
wrapper.as_time_travel(1u64).await.unwrap();
assert!(wrapper.ensure_mutable().is_err());
wrapper.as_latest().await.unwrap();
assert!(wrapper.ensure_mutable().is_ok());
assert_eq!(wrapper.time_travel_version(), None);
}
#[tokio::test]
async fn test_lazy_consistency_never_refreshes() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let ds = create_test_dataset(uri).await;
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
let v1 = wrapper.get().await.unwrap().version().version;
// External write
append_to_dataset(uri).await;
// Lazy consistency should not pick up external write
let v_after = wrapper.get().await.unwrap().version().version;
assert_eq!(v1, v_after);
}
#[tokio::test]
async fn test_strong_consistency_always_refreshes() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let ds = create_test_dataset(uri).await;
let wrapper = DatasetConsistencyWrapper::new_latest(ds, Some(Duration::ZERO));
let v1 = wrapper.get().await.unwrap().version().version;
// External write
append_to_dataset(uri).await;
// Strong consistency should pick up external write
let v_after = wrapper.get().await.unwrap().version().version;
assert_eq!(v_after, v1 + 1);
}
#[tokio::test]
async fn test_eventual_consistency_background_refresh() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let ds = create_test_dataset(uri).await;
let wrapper = DatasetConsistencyWrapper::new_latest(ds, Some(Duration::from_millis(200)));
// Populate the cache
let v1 = wrapper.get().await.unwrap().version().version;
assert_eq!(v1, 1);
// External write
append_to_dataset(uri).await;
// Should return cached value immediately (within TTL)
let v_cached = wrapper.get().await.unwrap().version().version;
assert_eq!(v_cached, 1);
// Wait for TTL to expire, then get() should trigger a refresh
tokio::time::sleep(Duration::from_millis(300)).await;
let v_after = wrapper.get().await.unwrap().version().version;
assert_eq!(v_after, 2);
}
#[tokio::test]
async fn test_eventual_consistency_update_invalidates_cache() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let ds_v1 = create_test_dataset(uri).await;
let wrapper = DatasetConsistencyWrapper::new_latest(ds_v1, Some(Duration::from_secs(60)));
// Simulate a write that produces v2
let ds_v2 = append_to_dataset(uri).await;
wrapper.update(ds_v2);
// get() should return v2 immediately (update invalidated the bg_cache,
// and the mutex state was updated)
let v = wrapper.get().await.unwrap().version().version;
assert_eq!(v, 2);
}
#[tokio::test]
async fn test_iops_open_strong_consistency() {
let db = connect("memory://")
@@ -312,7 +512,7 @@ mod tests {
.create_empty_table("test", schema)
.write_options(WriteOptions {
lance_write_params: Some(WriteParams {
store_params: Some(ObjectStoreParams {
store_params: Some(lance::io::ObjectStoreParams {
object_store_wrapper: Some(Arc::new(io_stats.clone())),
..Default::default()
}),
@@ -332,6 +532,31 @@ mod tests {
assert_eq!(stats.read_iops, 1);
}
/// Regression test: a write that races with as_time_travel() must not panic.
///
/// Sequence: ensure_mutable() passes → as_time_travel() completes → write
/// calls update(). Previously the assert!() in update() would fire.
#[tokio::test]
async fn test_update_after_concurrent_time_travel_does_not_panic() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let ds_v1 = create_test_dataset(uri).await;
let wrapper = DatasetConsistencyWrapper::new_latest(ds_v1, None);
// Simulate: as_time_travel() completes just before the write's update().
wrapper.as_time_travel(1u64).await.unwrap();
assert_eq!(wrapper.time_travel_version(), Some(1));
// The write already committed to storage; now it calls update().
// This must not panic, and the wrapper must stay pinned.
let ds_v2 = append_to_dataset(uri).await;
wrapper.update(ds_v2);
let ds = wrapper.get().await.unwrap();
assert_eq!(ds.version().version, 1);
}
/// Regression test: before the fix, the reload fast-path (no version change)
/// did not reset `last_consistency_check`, causing a list call on every
/// subsequent query once the interval expired.

View File

@@ -7,6 +7,9 @@ use crate::Result;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct DeleteResult {
/// The number of rows that were deleted.
#[serde(default)]
pub num_deleted_rows: u64,
// The commit version associated with the operation.
// A version of `0` indicates compatibility with legacy servers that do not return
/// a commit version.
@@ -18,16 +21,15 @@ pub struct DeleteResult {
///
/// This logic was moved from NativeTable::delete to keep table.rs clean.
pub(crate) async fn execute_delete(table: &NativeTable, predicate: &str) -> Result<DeleteResult> {
// We access the dataset from the table. Since this is in the same module hierarchy (super),
// and 'dataset' is pub(crate), we can access it.
let mut dataset = table.dataset.get_mut().await?;
// Perform the actual delete on the Lance dataset
dataset.delete(predicate).await?;
// Return the result with the new version
table.dataset.ensure_mutable()?;
let mut dataset = (*table.dataset.get().await?).clone();
let delete_result = dataset.delete(predicate).await?;
let num_deleted_rows = delete_result.num_deleted_rows;
let version = dataset.version().version;
table.dataset.update(dataset);
Ok(DeleteResult {
version: dataset.version().version,
num_deleted_rows,
version,
})
}
@@ -113,6 +115,32 @@ mod tests {
assert_eq!(current_schema, original_schema);
}
#[tokio::test]
async fn test_delete_returns_num_deleted_rows() {
let conn = connect("memory://").execute().await.unwrap();
let batch = record_batch!(("id", Int32, [1, 2, 3, 4, 5])).unwrap();
let table = conn
.create_table("test_num_deleted", batch)
.execute()
.await
.unwrap();
// Delete 2 rows (id > 3 means id=4 and id=5)
let result = table.delete("id > 3").await.unwrap();
assert_eq!(result.num_deleted_rows, 2);
assert_eq!(table.count_rows(None).await.unwrap(), 3);
// Delete 0 rows (no rows match)
let result = table.delete("id > 100").await.unwrap();
assert_eq!(result.num_deleted_rows, 0);
assert_eq!(table.count_rows(None).await.unwrap(), 3);
// Delete remaining rows
let result = table.delete("true").await.unwrap();
assert_eq!(result.num_deleted_rows, 3);
assert_eq!(table.count_rows(None).await.unwrap(), 0);
}
#[tokio::test]
async fn test_delete_false_increments_version() {
let conn = connect("memory://").execute().await.unwrap();

View File

@@ -1,13 +1,45 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::{sync::Arc, time::Duration};
use std::sync::Arc;
use std::time::Duration;
use arrow_array::RecordBatchReader;
use futures::future::Either;
use futures::{FutureExt, TryFutureExt};
use lance::dataset::{
MergeInsertBuilder as LanceMergeInsertBuilder, WhenMatched, WhenNotMatchedBySource,
};
use serde::{Deserialize, Serialize};
use crate::Result;
use crate::error::{Error, Result};
use super::{BaseTable, MergeResult};
use super::{BaseTable, NativeTable};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct MergeResult {
// The commit version associated with the operation.
// A version of `0` indicates compatibility with legacy servers that do not return
/// a commit version.
#[serde(default)]
pub version: u64,
/// Number of inserted rows (for user statistics)
#[serde(default)]
pub num_inserted_rows: u64,
/// Number of updated rows (for user statistics)
#[serde(default)]
pub num_updated_rows: u64,
/// Number of deleted rows (for user statistics)
/// Note: This is different from internal references to 'deleted_rows', since we technically "delete" updated rows during processing.
/// However those rows are not shared with the user.
#[serde(default)]
pub num_deleted_rows: u64,
/// Number of attempts performed during the merge operation.
/// This includes the initial attempt plus any retries due to transaction conflicts.
/// A value of 1 means the operation succeeded on the first try.
#[serde(default)]
pub num_attempts: u32,
}
/// A builder used to create and run a merge insert operation
///
@@ -124,3 +156,172 @@ impl MergeInsertBuilder {
self.table.clone().merge_insert(self, new_data).await
}
}
/// Internal implementation of the merge insert logic
///
/// This logic was moved from NativeTable::merge_insert to keep table.rs clean.
pub(crate) async fn execute_merge_insert(
table: &NativeTable,
params: MergeInsertBuilder,
new_data: Box<dyn RecordBatchReader + Send>,
) -> Result<MergeResult> {
let dataset = table.dataset.get().await?;
let mut builder = LanceMergeInsertBuilder::try_new(dataset.clone(), params.on)?;
match (
params.when_matched_update_all,
params.when_matched_update_all_filt,
) {
(false, _) => builder.when_matched(WhenMatched::DoNothing),
(true, None) => builder.when_matched(WhenMatched::UpdateAll),
(true, Some(filt)) => builder.when_matched(WhenMatched::update_if(&dataset, &filt)?),
};
if params.when_not_matched_insert_all {
builder.when_not_matched(lance::dataset::WhenNotMatched::InsertAll);
} else {
builder.when_not_matched(lance::dataset::WhenNotMatched::DoNothing);
}
if params.when_not_matched_by_source_delete {
let behavior = if let Some(filter) = params.when_not_matched_by_source_delete_filt {
WhenNotMatchedBySource::delete_if(dataset.as_ref(), &filter)?
} else {
WhenNotMatchedBySource::Delete
};
builder.when_not_matched_by_source(behavior);
} else {
builder.when_not_matched_by_source(WhenNotMatchedBySource::Keep);
}
builder.use_index(params.use_index);
let future = if let Some(timeout) = params.timeout {
let future = builder
.retry_timeout(timeout)
.try_build()?
.execute_reader(new_data);
Either::Left(tokio::time::timeout(timeout, future).map(|res| match res {
Ok(Ok((new_dataset, stats))) => Ok((new_dataset, stats)),
Ok(Err(e)) => Err(e.into()),
Err(_) => Err(Error::Runtime {
message: "merge insert timed out".to_string(),
}),
}))
} else {
let job = builder.try_build()?;
Either::Right(job.execute_reader(new_data).map_err(|e| e.into()))
};
let (new_dataset, stats) = future.await?;
let version = new_dataset.manifest().version;
table.dataset.update(new_dataset.as_ref().clone());
Ok(MergeResult {
version,
num_updated_rows: stats.num_updated_rows,
num_inserted_rows: stats.num_inserted_rows,
num_deleted_rows: stats.num_deleted_rows,
num_attempts: stats.num_attempts,
})
}
#[cfg(test)]
mod tests {
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, RecordBatchReader};
use arrow_schema::{DataType, Field, Schema};
use std::sync::Arc;
use crate::connect;
fn merge_insert_test_batches(offset: i32, age: i32) -> Box<dyn RecordBatchReader + Send> {
let schema = Arc::new(Schema::new(vec![
Field::new("i", DataType::Int32, false),
Field::new("age", DataType::Int32, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(offset..(offset + 10))),
Arc::new(Int32Array::from_iter_values(std::iter::repeat_n(age, 10))),
],
)
.unwrap();
Box::new(RecordBatchIterator::new(vec![Ok(batch)], schema))
}
#[tokio::test]
async fn test_merge_insert() {
let conn = connect("memory://").execute().await.unwrap();
// Create a dataset with i=0..10
let batches = merge_insert_test_batches(0, 0);
let table = conn
.create_table("my_table", batches)
.execute()
.await
.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 10);
// Create new data with i=5..15
let new_batches = merge_insert_test_batches(5, 1);
// Perform a "insert if not exists"
let mut merge_insert_builder = table.merge_insert(&["i"]);
merge_insert_builder.when_not_matched_insert_all();
let result = merge_insert_builder.execute(new_batches).await.unwrap();
// Only 5 rows should actually be inserted
assert_eq!(table.count_rows(None).await.unwrap(), 15);
assert_eq!(result.num_inserted_rows, 5);
assert_eq!(result.num_updated_rows, 0);
assert_eq!(result.num_deleted_rows, 0);
assert_eq!(result.num_attempts, 1);
// Create new data with i=15..25 (no id matches)
let new_batches = merge_insert_test_batches(15, 2);
// Perform a "bulk update" (should not affect anything)
let mut merge_insert_builder = table.merge_insert(&["i"]);
merge_insert_builder.when_matched_update_all(None);
merge_insert_builder.execute(new_batches).await.unwrap();
// No new rows should have been inserted
assert_eq!(table.count_rows(None).await.unwrap(), 15);
assert_eq!(
table.count_rows(Some("age = 2".to_string())).await.unwrap(),
0
);
// Conditional update that only replaces the age=0 data
let new_batches = merge_insert_test_batches(5, 3);
let mut merge_insert_builder = table.merge_insert(&["i"]);
merge_insert_builder.when_matched_update_all(Some("target.age = 0".to_string()));
merge_insert_builder.execute(new_batches).await.unwrap();
assert_eq!(
table.count_rows(Some("age = 3".to_string())).await.unwrap(),
5
);
}
#[tokio::test]
async fn test_merge_insert_use_index() {
let conn = connect("memory://").execute().await.unwrap();
// Create a dataset with i=0..10
let batches = merge_insert_test_batches(0, 0);
let table = conn
.create_table("my_table", batches)
.execute()
.await
.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 10);
// Test use_index=true (default behavior)
let new_batches = merge_insert_test_batches(5, 1);
let mut merge_insert_builder = table.merge_insert(&["i"]);
merge_insert_builder.when_not_matched_insert_all();
merge_insert_builder.use_index(true);
merge_insert_builder.execute(new_batches).await.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 15);
// Test use_index=false (force table scan)
let new_batches = merge_insert_test_batches(15, 2);
let mut merge_insert_builder = table.merge_insert(&["i"]);
merge_insert_builder.when_not_matched_insert_all();
merge_insert_builder.use_index(false);
merge_insert_builder.execute(new_batches).await.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 25);
}
}

View File

@@ -26,8 +26,10 @@ use crate::error::Result;
/// optimize different parts of the table on disk.
///
/// By default, it optimizes everything, as [`OptimizeAction::All`].
#[derive(Default)]
pub enum OptimizeAction {
/// Run all optimizations with default values
#[default]
All,
/// Compacts files in the dataset
///
@@ -84,12 +86,6 @@ pub enum OptimizeAction {
Index(OptimizeOptions),
}
impl Default for OptimizeAction {
fn default() -> Self {
Self::All
}
}
/// Statistics about the optimization.
#[derive(Debug, Default)]
pub struct OptimizeStats {
@@ -105,12 +101,10 @@ pub struct OptimizeStats {
/// This logic was moved from NativeTable to keep table.rs clean.
pub(crate) async fn optimize_indices(table: &NativeTable, options: &OptimizeOptions) -> Result<()> {
info!("LanceDB: optimizing indices: {:?}", options);
table
.dataset
.get_mut()
.await?
.optimize_indices(options)
.await?;
table.dataset.ensure_mutable()?;
let mut dataset = (*table.dataset.get().await?).clone();
dataset.optimize_indices(options).await?;
table.dataset.update(dataset);
Ok(())
}
@@ -131,10 +125,9 @@ pub(crate) async fn cleanup_old_versions(
delete_unverified: Option<bool>,
error_if_tagged_old_versions: Option<bool>,
) -> Result<RemovalStats> {
Ok(table
.dataset
.get_mut()
.await?
table.dataset.ensure_mutable()?;
let dataset = table.dataset.get().await?;
Ok(dataset
.cleanup_old_versions(older_than, delete_unverified, error_if_tagged_old_versions)
.await?)
}
@@ -150,8 +143,10 @@ pub(crate) async fn compact_files_impl(
options: CompactionOptions,
remap_options: Option<Arc<dyn IndexRemapperOptions>>,
) -> Result<CompactionMetrics> {
let mut dataset_mut = table.dataset.get_mut().await?;
let metrics = compact_files(&mut dataset_mut, options, remap_options).await?;
table.dataset.ensure_mutable()?;
let mut dataset = (*table.dataset.get().await?).clone();
let metrics = compact_files(&mut dataset, options, remap_options).await?;
table.dataset.update(dataset);
Ok(metrics)
}

View File

@@ -0,0 +1,738 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::Arc;
use super::NativeTable;
use crate::error::{Error, Result};
use crate::expr::expr_to_sql_string;
use crate::query::{
QueryExecutionOptions, QueryFilter, QueryRequest, Select, VectorQueryRequest, DEFAULT_TOP_K,
};
use crate::utils::{default_vector_column, TimeoutStream};
use arrow::array::{AsArray, FixedSizeListBuilder, Float32Builder};
use arrow::datatypes::{Float32Type, UInt8Type};
use arrow_array::Array;
use arrow_schema::{DataType, Schema};
use datafusion_physical_plan::projection::ProjectionExec;
use datafusion_physical_plan::repartition::RepartitionExec;
use datafusion_physical_plan::union::UnionExec;
use datafusion_physical_plan::ExecutionPlan;
use futures::future::try_join_all;
use lance::dataset::scanner::DatasetRecordBatchStream;
use lance::dataset::scanner::Scanner;
use lance_datafusion::exec::{analyze_plan as lance_analyze_plan, execute_plan};
use lance_namespace::models::{
QueryTableRequest as NsQueryTableRequest, QueryTableRequestColumns,
QueryTableRequestFullTextQuery, QueryTableRequestVector, StringFtsQuery,
};
use lance_namespace::LanceNamespace;
#[derive(Debug, Clone)]
pub enum AnyQuery {
Query(QueryRequest),
VectorQuery(VectorQueryRequest),
}
//Decide between namespace or local
pub async fn execute_query(
table: &NativeTable,
query: &AnyQuery,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
// If namespace client is configured, use server-side query execution
if let Some(ref namespace_client) = table.namespace_client {
return execute_namespace_query(table, namespace_client.clone(), query, options).await;
}
execute_generic_query(table, query, options).await
}
pub async fn analyze_query_plan(
table: &NativeTable,
query: &AnyQuery,
options: QueryExecutionOptions,
) -> Result<String> {
let plan = create_plan(table, query, options).await?;
Ok(lance_analyze_plan(plan, Default::default()).await?)
}
/// Local Execution Path (DataFusion)
async fn execute_generic_query(
table: &NativeTable,
query: &AnyQuery,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
let plan = create_plan(table, query, options.clone()).await?;
let inner = execute_plan(plan, Default::default())?;
let inner = if let Some(timeout) = options.timeout {
TimeoutStream::new_boxed(inner, timeout)
} else {
inner
};
Ok(DatasetRecordBatchStream::new(inner))
}
pub async fn create_plan(
table: &NativeTable,
query: &AnyQuery,
options: QueryExecutionOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
let query = match query {
AnyQuery::VectorQuery(query) => query.clone(),
AnyQuery::Query(query) => VectorQueryRequest::from_plain_query(query.clone()),
};
let ds_ref = table.dataset.get().await?;
let schema = ds_ref.schema();
let mut column = query.column.clone();
let mut query_vector = query.query_vector.first().cloned();
if query.query_vector.len() > 1 {
if column.is_none() {
// Infer a vector column with the same dimension of the query vector.
let arrow_schema = Schema::from(ds_ref.schema());
column = Some(default_vector_column(
&arrow_schema,
Some(query.query_vector[0].len() as i32),
)?);
}
let vector_field = schema.field(column.as_ref().unwrap()).unwrap();
if let DataType::List(_) = vector_field.data_type() {
// Multivector handling: concatenate into FixedSizeList<FixedSizeList<_>>
let vectors = query
.query_vector
.iter()
.map(|arr| arr.as_ref())
.collect::<Vec<_>>();
let dim = vectors[0].len();
let mut fsl_builder = FixedSizeListBuilder::with_capacity(
Float32Builder::with_capacity(dim),
dim as i32,
vectors.len(),
);
for vec in vectors {
fsl_builder
.values()
.append_slice(vec.as_primitive::<Float32Type>().values());
fsl_builder.append(true);
}
query_vector = Some(Arc::new(fsl_builder.finish()));
} else {
// Multiple query vectors: create a plan for each and union them
let query_vecs = query.query_vector.clone();
let plan_futures = query_vecs
.into_iter()
.map(|query_vector| {
let mut sub_query = query.clone();
sub_query.query_vector = vec![query_vector];
let options_ref = options.clone();
async move {
create_plan(table, &AnyQuery::VectorQuery(sub_query), options_ref).await
}
})
.collect::<Vec<_>>();
let plans = try_join_all(plan_futures).await?;
return create_multi_vector_plan(plans);
}
}
let mut scanner: Scanner = ds_ref.scan();
if let Some(query_vector) = query_vector {
let column = if let Some(col) = column {
col
} else {
let arrow_schema = Schema::from(ds_ref.schema());
default_vector_column(&arrow_schema, Some(query_vector.len() as i32))?
};
let (_, element_type) = lance::index::vector::utils::get_vector_type(schema, &column)?;
let is_binary = matches!(element_type, DataType::UInt8);
let top_k = query.base.limit.unwrap_or(DEFAULT_TOP_K) + query.base.offset.unwrap_or(0);
if is_binary {
let query_vector = arrow::compute::cast(&query_vector, &DataType::UInt8)?;
let query_vector = query_vector.as_primitive::<UInt8Type>();
scanner.nearest(&column, query_vector, top_k)?;
} else {
scanner.nearest(&column, query_vector.as_ref(), top_k)?;
}
scanner.minimum_nprobes(query.minimum_nprobes);
if let Some(maximum_nprobes) = query.maximum_nprobes {
scanner.maximum_nprobes(maximum_nprobes);
}
}
scanner.limit(
query.base.limit.map(|limit| limit as i64),
query.base.offset.map(|offset| offset as i64),
)?;
if let Some(ef) = query.ef {
scanner.ef(ef);
}
scanner.distance_range(query.lower_bound, query.upper_bound);
scanner.use_index(query.use_index);
scanner.prefilter(query.base.prefilter);
match query.base.select {
Select::Columns(ref columns) => {
scanner.project(columns.as_slice())?;
}
Select::Dynamic(ref select_with_transform) => {
scanner.project_with_transform(select_with_transform.as_slice())?;
}
Select::All => {}
}
if query.base.with_row_id {
scanner.with_row_id();
}
scanner.batch_size(options.max_batch_length as usize);
if query.base.fast_search {
scanner.fast_search();
}
if let Some(filter) = &query.base.filter {
match filter {
QueryFilter::Sql(sql) => {
scanner.filter(sql)?;
}
QueryFilter::Substrait(substrait) => {
scanner.filter_substrait(substrait)?;
}
QueryFilter::Datafusion(expr) => {
scanner.filter_expr(expr.clone());
}
}
}
if let Some(fts) = &query.base.full_text_search {
scanner.full_text_search(fts.clone())?;
}
if let Some(refine_factor) = query.refine_factor {
scanner.refine(refine_factor);
}
if let Some(distance_type) = query.distance_type {
scanner.distance_metric(distance_type.into());
}
if query.base.disable_scoring_autoprojection {
scanner.disable_scoring_autoprojection();
}
Ok(scanner.create_plan().await?)
}
//Helper functions below
// Take many execution plans and map them into a single plan that adds
// a query_index column and unions them.
pub(crate) fn create_multi_vector_plan(
plans: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
if plans.is_empty() {
return Err(Error::InvalidInput {
message: "No plans provided".to_string(),
});
}
// Projection to keeping all existing columns
let first_plan = plans[0].clone();
let project_all_columns = first_plan
.schema()
.fields()
.iter()
.enumerate()
.map(|(i, field)| {
let expr = datafusion_physical_plan::expressions::Column::new(field.name().as_str(), i);
let expr = Arc::new(expr) as Arc<dyn datafusion_physical_plan::PhysicalExpr>;
(expr, field.name().clone())
})
.collect::<Vec<_>>();
let projected_plans = plans
.into_iter()
.enumerate()
.map(|(plan_i, plan)| {
let query_index = datafusion_common::ScalarValue::Int32(Some(plan_i as i32));
let query_index_expr = datafusion_physical_plan::expressions::Literal::new(query_index);
let query_index_expr =
Arc::new(query_index_expr) as Arc<dyn datafusion_physical_plan::PhysicalExpr>;
let mut projections = vec![(query_index_expr, "query_index".to_string())];
projections.extend_from_slice(&project_all_columns);
let projection = ProjectionExec::try_new(projections, plan).unwrap();
Arc::new(projection) as Arc<dyn datafusion_physical_plan::ExecutionPlan>
})
.collect::<Vec<_>>();
let unioned = UnionExec::try_new(projected_plans).map_err(|err| Error::Runtime {
message: err.to_string(),
})?;
// We require 1 partition in the final output
let repartitioned = RepartitionExec::try_new(
unioned,
datafusion_physical_plan::Partitioning::RoundRobinBatch(1),
)
.unwrap();
Ok(Arc::new(repartitioned))
}
/// Execute a query on the namespace server instead of locally.
async fn execute_namespace_query(
table: &NativeTable,
namespace_client: Arc<dyn LanceNamespace>,
query: &AnyQuery,
_options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
// Build table_id from namespace + table name
let mut table_id = table.namespace.clone();
table_id.push(table.name.clone());
// Convert AnyQuery to namespace QueryTableRequest
let mut ns_request = convert_to_namespace_query(query)?;
// Set the table ID on the request
ns_request.id = Some(table_id);
// Call the namespace query_table API
let response_bytes = namespace_client
.query_table(ns_request)
.await
.map_err(|e| Error::Runtime {
message: format!("Failed to execute server-side query: {}", e),
})?;
// Parse the Arrow IPC response into a RecordBatchStream
parse_arrow_ipc_response(response_bytes).await
}
/// Convert an AnyQuery to the namespace QueryTableRequest format.
fn convert_to_namespace_query(query: &AnyQuery) -> Result<NsQueryTableRequest> {
match query {
AnyQuery::VectorQuery(vq) => {
// Extract the query vector(s)
let vector = extract_query_vector(&vq.query_vector)?;
// Convert filter to SQL string
let filter = match &vq.base.filter {
Some(f) => Some(filter_to_sql(f)?),
None => None,
};
// Convert select to columns list
let columns = match &vq.base.select {
Select::All => None,
Select::Columns(cols) => Some(Box::new(QueryTableRequestColumns {
column_names: Some(cols.clone()),
column_aliases: None,
})),
Select::Dynamic(_) => {
return Err(Error::NotSupported {
message:
"Dynamic column selection is not supported for server-side queries"
.to_string(),
});
}
};
// Check for unsupported features
if vq.base.reranker.is_some() {
return Err(Error::NotSupported {
message: "Reranker is not supported for server-side queries".to_string(),
});
}
// Convert FTS query if present
let full_text_query = vq.base.full_text_search.as_ref().map(|fts| {
let columns = fts.columns();
let columns_vec = if columns.is_empty() {
None
} else {
Some(columns.into_iter().collect())
};
Box::new(QueryTableRequestFullTextQuery {
string_query: Some(Box::new(StringFtsQuery {
query: fts.query.to_string(),
columns: columns_vec,
})),
structured_query: None,
})
});
Ok(NsQueryTableRequest {
id: None, // Will be set in namespace_query
k: vq.base.limit.unwrap_or(10) as i32,
vector: Box::new(vector),
vector_column: vq.column.clone(),
filter,
columns,
offset: vq.base.offset.map(|o| o as i32),
distance_type: vq.distance_type.map(|dt| dt.to_string()),
nprobes: Some(vq.minimum_nprobes as i32),
ef: vq.ef.map(|e| e as i32),
refine_factor: vq.refine_factor.map(|r| r as i32),
lower_bound: vq.lower_bound,
upper_bound: vq.upper_bound,
prefilter: Some(vq.base.prefilter),
fast_search: Some(vq.base.fast_search),
with_row_id: Some(vq.base.with_row_id),
bypass_vector_index: Some(!vq.use_index),
full_text_query,
..Default::default()
})
}
AnyQuery::Query(q) => {
// For non-vector queries, pass an empty vector (similar to remote table implementation)
if q.reranker.is_some() {
return Err(Error::NotSupported {
message: "Reranker is not supported for server-side query execution"
.to_string(),
});
}
let filter = q.filter.as_ref().map(filter_to_sql).transpose()?;
let columns = match &q.select {
Select::All => None,
Select::Columns(cols) => Some(Box::new(QueryTableRequestColumns {
column_names: Some(cols.clone()),
column_aliases: None,
})),
Select::Dynamic(_) => {
return Err(Error::NotSupported {
message: "Dynamic columns are not supported for server-side query"
.to_string(),
});
}
};
// Handle full text search if present
let full_text_query = q.full_text_search.as_ref().map(|fts| {
let columns_vec = if fts.columns().is_empty() {
None
} else {
Some(fts.columns().iter().cloned().collect())
};
Box::new(QueryTableRequestFullTextQuery {
string_query: Some(Box::new(StringFtsQuery {
query: fts.query.to_string(),
columns: columns_vec,
})),
structured_query: None,
})
});
// Empty vector for non-vector queries
let vector = Box::new(QueryTableRequestVector {
single_vector: Some(vec![]),
multi_vector: None,
});
Ok(NsQueryTableRequest {
id: None, // Will be set by caller
vector,
k: q.limit.unwrap_or(10) as i32,
filter,
columns,
prefilter: Some(q.prefilter),
offset: q.offset.map(|o| o as i32),
vector_column: None, // No vector column for plain queries
with_row_id: Some(q.with_row_id),
bypass_vector_index: Some(true), // No vector index for plain queries
full_text_query,
..Default::default()
})
}
}
}
fn filter_to_sql(filter: &QueryFilter) -> Result<String> {
match filter {
QueryFilter::Sql(sql) => Ok(sql.clone()),
QueryFilter::Substrait(_) => Err(Error::NotSupported {
message: "Substrait filters are not supported for server-side queries".to_string(),
}),
QueryFilter::Datafusion(expr) => expr_to_sql_string(expr),
}
}
/// Extract query vector(s) from Arrow arrays into the namespace format.
fn extract_query_vector(
query_vectors: &[Arc<dyn arrow_array::Array>],
) -> Result<QueryTableRequestVector> {
if query_vectors.is_empty() {
return Err(Error::InvalidInput {
message: "Query vector is required for vector search".to_string(),
});
}
// Handle single vector case
if query_vectors.len() == 1 {
let arr = &query_vectors[0];
let single_vector = array_to_f32_vec(arr)?;
Ok(QueryTableRequestVector {
single_vector: Some(single_vector),
multi_vector: None,
})
} else {
// Handle multi-vector case
let multi_vector: Result<Vec<Vec<f32>>> =
query_vectors.iter().map(array_to_f32_vec).collect();
Ok(QueryTableRequestVector {
single_vector: None,
multi_vector: Some(multi_vector?),
})
}
}
/// Convert an Arrow array to a Vec<f32>.
fn array_to_f32_vec(arr: &Arc<dyn arrow_array::Array>) -> Result<Vec<f32>> {
// Handle FixedSizeList (common for vectors)
if let Some(fsl) = arr
.as_any()
.downcast_ref::<arrow_array::FixedSizeListArray>()
{
let values = fsl.values();
if let Some(f32_arr) = values.as_any().downcast_ref::<arrow_array::Float32Array>() {
return Ok(f32_arr.values().to_vec());
}
}
// Handle direct Float32Array
if let Some(f32_arr) = arr.as_any().downcast_ref::<arrow_array::Float32Array>() {
return Ok(f32_arr.values().to_vec());
}
Err(Error::InvalidInput {
message: "Query vector must be Float32 type".to_string(),
})
}
/// Parse Arrow IPC response from the namespace server.
async fn parse_arrow_ipc_response(bytes: bytes::Bytes) -> Result<DatasetRecordBatchStream> {
use arrow_ipc::reader::StreamReader;
use std::io::Cursor;
let cursor = Cursor::new(bytes);
let reader = StreamReader::try_new(cursor, None).map_err(|e| Error::Runtime {
message: format!("Failed to parse Arrow IPC response: {}", e),
})?;
// Collect all record batches
let schema = reader.schema();
let batches: Vec<_> = reader
.into_iter()
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| Error::Runtime {
message: format!("Failed to read Arrow IPC batches: {}", e),
})?;
// Create a stream from the batches
let stream = futures::stream::iter(batches.into_iter().map(Ok));
let record_batch_stream =
Box::pin(datafusion_physical_plan::stream::RecordBatchStreamAdapter::new(schema, stream));
Ok(DatasetRecordBatchStream::new(record_batch_stream))
}
#[cfg(test)]
#[allow(deprecated)]
mod tests {
use arrow_array::Float32Array;
use futures::TryStreamExt;
use std::sync::Arc;
use super::*;
use crate::query::QueryExecutionOptions;
#[test]
fn test_convert_to_namespace_query_vector() {
let query_vector = Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0, 4.0]));
let vq = VectorQueryRequest {
base: QueryRequest {
limit: Some(10),
offset: Some(5),
filter: Some(QueryFilter::Sql("id > 0".to_string())),
select: Select::Columns(vec!["id".to_string()]),
..Default::default()
},
column: Some("vector".to_string()),
// We cast here to satisfy the struct definition
query_vector: vec![query_vector as Arc<dyn Array>],
minimum_nprobes: 20,
distance_type: Some(crate::DistanceType::L2),
..Default::default()
};
let any_query = AnyQuery::VectorQuery(vq);
let ns_request = convert_to_namespace_query(&any_query).unwrap();
assert_eq!(ns_request.k, 10);
assert_eq!(ns_request.offset, Some(5));
assert_eq!(ns_request.filter, Some("id > 0".to_string()));
assert_eq!(
ns_request
.columns
.as_ref()
.and_then(|c| c.column_names.as_ref()),
Some(&vec!["id".to_string()])
);
assert_eq!(ns_request.vector_column, Some("vector".to_string()));
assert_eq!(ns_request.distance_type, Some("l2".to_string()));
// Verify the vector data was extracted correctly
assert!(ns_request.vector.single_vector.is_some());
assert_eq!(
ns_request.vector.single_vector.as_ref().unwrap(),
&vec![1.0, 2.0, 3.0, 4.0]
);
}
#[test]
fn test_convert_to_namespace_query_plain_query() {
let q = QueryRequest {
limit: Some(20),
offset: Some(5),
filter: Some(QueryFilter::Sql("id > 5".to_string())),
select: Select::Columns(vec!["id".to_string()]),
with_row_id: true,
..Default::default()
};
let any_query = AnyQuery::Query(q);
let ns_request = convert_to_namespace_query(&any_query).unwrap();
assert_eq!(ns_request.k, 20);
assert_eq!(ns_request.offset, Some(5));
assert_eq!(ns_request.filter, Some("id > 5".to_string()));
assert_eq!(
ns_request
.columns
.as_ref()
.and_then(|c| c.column_names.as_ref()),
Some(&vec!["id".to_string()])
);
assert_eq!(ns_request.with_row_id, Some(true));
assert_eq!(ns_request.bypass_vector_index, Some(true));
assert!(ns_request.vector_column.is_none());
assert!(ns_request.vector.single_vector.as_ref().unwrap().is_empty());
}
#[tokio::test]
async fn test_execute_query_local_routing() {
use crate::connect;
use crate::table::query::execute_query;
use arrow_array::{Int32Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
let conn = connect("memory://").execute().await.unwrap();
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))],
)
.unwrap();
let table = conn
.create_table("test_routing", vec![batch])
.execute()
.await
.unwrap();
let native_table = table.as_native().unwrap();
// Setup a request
let req = QueryRequest {
filter: Some(QueryFilter::Sql("id > 3".to_string())),
..Default::default()
};
let query = AnyQuery::Query(req);
// Action: Call execute_query directly
// This validates that execute_query correctly routes to the local DataFusion engine
// when table.namespace_client is None.
let stream = execute_query(native_table, &query, QueryExecutionOptions::default())
.await
.unwrap();
// Verify results
let batches = stream.try_collect::<Vec<_>>().await.unwrap();
let count: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(count, 2); // 4 and 5
}
#[tokio::test]
async fn test_create_plan_multivector_structure() {
use arrow_array::{Float32Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use datafusion_physical_plan::display::DisplayableExecutionPlan;
use crate::table::query::create_plan;
use crate::connect;
let conn = connect("memory://").execute().await.unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new(
"vector",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2),
false,
),
]));
let batch = RecordBatch::new_empty(schema.clone());
let table = conn
.create_table("test_plan", vec![batch])
.execute()
.await
.unwrap();
let native_table = table.as_native().unwrap();
// This triggers the "create_multi_vector_plan" logic branch
let q1 = Arc::new(Float32Array::from(vec![1.0, 2.0]));
let q2 = Arc::new(Float32Array::from(vec![3.0, 4.0]));
let req = VectorQueryRequest {
column: Some("vector".to_string()),
query_vector: vec![q1, q2],
..Default::default()
};
let query = AnyQuery::VectorQuery(req);
// Create the Plan
let plan = create_plan(native_table, &query, QueryExecutionOptions::default())
.await
.unwrap();
// formatting it allows us to see the hierarchy
let display = DisplayableExecutionPlan::new(plan.as_ref())
.indent(true)
.to_string();
// We expect a RepartitionExec wrapping a UnionExec
assert!(
display.contains("RepartitionExec"),
"Plan should include Repartitioning"
);
assert!(
display.contains("UnionExec"),
"Plan should include a Union of multiple searches"
);
// We expect the projection to add the 'query_index' column (logic inside multi_vector_plan)
assert!(
display.contains("query_index"),
"Plan should add query_index column"
);
}
}

View File

@@ -52,11 +52,12 @@ pub(crate) async fn execute_add_columns(
transforms: NewColumnTransform,
read_columns: Option<Vec<String>>,
) -> Result<AddColumnsResult> {
let mut dataset = table.dataset.get_mut().await?;
table.dataset.ensure_mutable()?;
let mut dataset = (*table.dataset.get().await?).clone();
dataset.add_columns(transforms, read_columns, None).await?;
Ok(AddColumnsResult {
version: dataset.version().version,
})
let version = dataset.version().version;
table.dataset.update(dataset);
Ok(AddColumnsResult { version })
}
/// Internal implementation of the alter columns logic.
@@ -66,11 +67,12 @@ pub(crate) async fn execute_alter_columns(
table: &NativeTable,
alterations: &[ColumnAlteration],
) -> Result<AlterColumnsResult> {
let mut dataset = table.dataset.get_mut().await?;
table.dataset.ensure_mutable()?;
let mut dataset = (*table.dataset.get().await?).clone();
dataset.alter_columns(alterations).await?;
Ok(AlterColumnsResult {
version: dataset.version().version,
})
let version = dataset.version().version;
table.dataset.update(dataset);
Ok(AlterColumnsResult { version })
}
/// Internal implementation of the drop columns logic.
@@ -80,11 +82,12 @@ pub(crate) async fn execute_drop_columns(
table: &NativeTable,
columns: &[&str],
) -> Result<DropColumnsResult> {
let mut dataset = table.dataset.get_mut().await?;
table.dataset.ensure_mutable()?;
let mut dataset = (*table.dataset.get().await?).clone();
dataset.drop_columns(columns).await?;
Ok(DropColumnsResult {
version: dataset.version().version,
})
let version = dataset.version().version;
table.dataset.update(dataset);
Ok(DropColumnsResult { version })
}
#[cfg(test)]

View File

@@ -78,11 +78,13 @@ pub(crate) async fn execute_update(
table: &NativeTable,
update: UpdateBuilder,
) -> Result<UpdateResult> {
table.dataset.ensure_mutable()?;
// 1. Snapshot the current dataset
let dataset = table.dataset.get().await?.clone();
let dataset = table.dataset.get().await?;
// 2. Initialize the Lance Core builder
let mut builder = LanceUpdateBuilder::new(Arc::new(dataset));
let mut builder = LanceUpdateBuilder::new(dataset);
// 3. Apply the filter (WHERE clause)
if let Some(predicate) = update.filter {
@@ -99,10 +101,7 @@ pub(crate) async fn execute_update(
let res = operation.execute().await?;
// 6. Update the table's view of the latest version
table
.dataset
.set_latest(res.new_dataset.as_ref().clone())
.await;
table.dataset.update(res.new_dataset.as_ref().clone());
Ok(UpdateResult {
rows_updated: res.rows_updated,

View File

@@ -4,3 +4,14 @@
pub mod connection;
pub mod datagen;
pub mod embeddings;
#[derive(Debug)]
pub struct TestCustomError;
impl std::fmt::Display for TestCustomError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "TestCustomError occurred")
}
}
impl std::error::Error for TestCustomError {}

View File

@@ -0,0 +1,593 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! A cache that refreshes values in the background before they expire.
//!
//! See [`BackgroundCache`] for details.
use std::future::Future;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use futures::future::{BoxFuture, Shared};
use futures::FutureExt;
type SharedFut<V, E> = Shared<BoxFuture<'static, Result<V, Arc<E>>>>;
enum State<V, E> {
Empty,
Current(V, clock::Instant),
Refreshing {
previous: Option<(V, clock::Instant)>,
future: SharedFut<V, E>,
},
}
impl<V: Clone, E> State<V, E> {
fn fresh_value(&self, ttl: Duration, refresh_window: Duration) -> Option<V> {
let fresh_threshold = ttl - refresh_window;
match self {
Self::Current(value, cached_at) => {
if clock::now().duration_since(*cached_at) < fresh_threshold {
Some(value.clone())
} else {
None
}
}
Self::Refreshing {
previous: Some((value, cached_at)),
..
} => {
if clock::now().duration_since(*cached_at) < fresh_threshold {
Some(value.clone())
} else {
None
}
}
_ => None,
}
}
}
struct CacheInner<V, E> {
state: State<V, E>,
/// Incremented on invalidation. Background fetches check this to avoid
/// overwriting with stale data after a concurrent invalidation.
generation: u64,
}
enum Action<V, E> {
Return(V),
Wait(SharedFut<V, E>),
}
/// A cache that refreshes values in the background before they expire.
///
/// The cache has three states:
/// - **Empty**: No cached value. The next [`get()`](Self::get) blocks until a fetch completes.
/// - **Current**: A valid cached value with a timestamp. Returns immediately if fresh.
/// - **Refreshing**: A fetch is in progress. Returns the previous value if still valid,
/// otherwise blocks until the fetch completes.
///
/// When the cached value enters the refresh window (close to TTL expiry),
/// [`get()`](Self::get) starts a background fetch and returns the current value
/// immediately. Multiple concurrent callers share a single in-flight fetch.
pub struct BackgroundCache<V, E> {
inner: Arc<Mutex<CacheInner<V, E>>>,
ttl: Duration,
refresh_window: Duration,
}
impl<V, E> std::fmt::Debug for BackgroundCache<V, E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BackgroundCache")
.field("ttl", &self.ttl)
.field("refresh_window", &self.refresh_window)
.finish_non_exhaustive()
}
}
impl<V, E> Clone for BackgroundCache<V, E> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
ttl: self.ttl,
refresh_window: self.refresh_window,
}
}
}
impl<V, E> BackgroundCache<V, E>
where
V: Clone + Send + Sync + 'static,
E: Send + Sync + 'static,
{
pub fn new(ttl: Duration, refresh_window: Duration) -> Self {
assert!(
refresh_window < ttl,
"refresh_window ({refresh_window:?}) must be less than ttl ({ttl:?})"
);
Self {
inner: Arc::new(Mutex::new(CacheInner {
state: State::Empty,
generation: 0,
})),
ttl,
refresh_window,
}
}
/// Returns the cached value if it's fresh (not in the refresh window).
///
/// This is a cheap synchronous check useful as a fast path before
/// constructing a fetch closure for [`get()`](Self::get).
pub fn try_get(&self) -> Option<V> {
let cache = self.inner.lock().unwrap();
cache.state.fresh_value(self.ttl, self.refresh_window)
}
/// Get the cached value, fetching if needed.
///
/// The closure is called to create the fetch future only when a new fetch
/// is needed. If the cache already has an in-flight fetch, the closure is
/// not called and the caller joins the existing fetch.
pub async fn get<F, Fut>(&self, fetch: F) -> Result<V, Arc<E>>
where
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = Result<V, E>> + Send + 'static,
{
// Fast path: check if cache is fresh
{
let cache = self.inner.lock().unwrap();
if let Some(value) = cache.state.fresh_value(self.ttl, self.refresh_window) {
return Ok(value);
}
}
// Slow path
let mut fetch = Some(fetch);
let action = {
let mut cache = self.inner.lock().unwrap();
self.determine_action(&mut cache, &mut fetch)
};
match action {
Action::Return(value) => Ok(value),
Action::Wait(fut) => fut.await,
}
}
/// Pre-populate the cache with an initial value.
///
/// This avoids a blocking fetch on the first [`get()`](Self::get) call.
pub fn seed(&self, value: V) {
let mut cache = self.inner.lock().unwrap();
cache.state = State::Current(value, clock::now());
}
/// Invalidate the cache. The next [`get()`](Self::get) will start a fresh fetch.
///
/// Any in-flight background fetch from before this call will not update the
/// cache (the generation counter prevents stale writes).
pub fn invalidate(&self) {
let mut cache = self.inner.lock().unwrap();
cache.state = State::Empty;
cache.generation += 1;
}
fn determine_action<F, Fut>(
&self,
cache: &mut CacheInner<V, E>,
fetch: &mut Option<F>,
) -> Action<V, E>
where
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = Result<V, E>> + Send + 'static,
{
match &cache.state {
State::Empty => {
let f = fetch
.take()
.expect("fetch closure required for empty cache");
let shared = self.start_fetch(cache, f, None);
Action::Wait(shared)
}
State::Current(value, cached_at) => {
let elapsed = clock::now().duration_since(*cached_at);
if elapsed < self.ttl - self.refresh_window {
Action::Return(value.clone())
} else if elapsed < self.ttl {
// In refresh window: start background fetch, return current value
let value = value.clone();
let previous = Some((value.clone(), *cached_at));
if let Some(f) = fetch.take() {
// The spawned task inside start_fetch drives the future;
// we don't need to await the returned handle here.
drop(self.start_fetch(cache, f, previous));
}
Action::Return(value)
} else {
// Expired: must wait for fetch
let previous = Some((value.clone(), *cached_at));
let f = fetch
.take()
.expect("fetch closure required for expired cache");
let shared = self.start_fetch(cache, f, previous);
Action::Wait(shared)
}
}
State::Refreshing { previous, future } => {
// If the background fetch already completed (spawned task hasn't
// run yet to update state), transition the state and re-evaluate.
if let Some(result) = future.peek() {
match result {
Ok(value) => {
cache.state = State::Current(value.clone(), clock::now());
}
Err(_) => {
cache.state = match previous.clone() {
Some((v, t)) => State::Current(v, t),
None => State::Empty,
};
}
}
return self.determine_action(cache, fetch);
}
if let Some((value, cached_at)) = previous {
if clock::now().duration_since(*cached_at) < self.ttl {
Action::Return(value.clone())
} else {
Action::Wait(future.clone())
}
} else {
Action::Wait(future.clone())
}
}
}
}
fn start_fetch<F, Fut>(
&self,
cache: &mut CacheInner<V, E>,
fetch: F,
previous: Option<(V, clock::Instant)>,
) -> SharedFut<V, E>
where
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = Result<V, E>> + Send + 'static,
{
let generation = cache.generation;
let shared = async move { (fetch)().await.map_err(Arc::new) }
.boxed()
.shared();
// Spawn task to eagerly drive the future and update state on completion
let inner = self.inner.clone();
let fut_for_spawn = shared.clone();
tokio::spawn(async move {
let result = fut_for_spawn.await;
let mut cache = inner.lock().unwrap();
// Only update if no invalidation has happened since we started
if cache.generation != generation {
return;
}
match result {
Ok(value) => {
cache.state = State::Current(value, clock::now());
}
Err(_) => {
let prev = match &cache.state {
State::Refreshing { previous, .. } => previous.clone(),
_ => None,
};
cache.state = match prev {
Some((v, t)) => State::Current(v, t),
None => State::Empty,
};
}
}
});
cache.state = State::Refreshing {
previous,
future: shared.clone(),
};
shared
}
}
#[cfg(test)]
pub mod clock {
use std::cell::Cell;
use std::time::Duration;
// Re-export Instant so callers use the same type
pub use std::time::Instant;
thread_local! {
static MOCK_NOW: Cell<Option<Instant>> = const { Cell::new(None) };
}
pub fn now() -> Instant {
MOCK_NOW.with(|mock| mock.get().unwrap_or_else(Instant::now))
}
pub fn advance_by(duration: Duration) {
MOCK_NOW.with(|mock| {
let current = mock.get().unwrap_or_else(Instant::now);
mock.set(Some(current + duration));
});
}
#[allow(dead_code)]
pub fn clear_mock() {
MOCK_NOW.with(|mock| mock.set(None));
}
}
#[cfg(not(test))]
mod clock {
// Re-export Instant so callers use the same type
pub use std::time::Instant;
pub fn now() -> Instant {
Instant::now()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug)]
struct TestError(String);
impl std::fmt::Display for TestError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
const TEST_TTL: Duration = Duration::from_secs(30);
const TEST_REFRESH_WINDOW: Duration = Duration::from_secs(5);
fn new_cache() -> BackgroundCache<String, TestError> {
BackgroundCache::new(TEST_TTL, TEST_REFRESH_WINDOW)
}
fn ok_fetcher(
counter: Arc<AtomicUsize>,
value: &str,
) -> impl FnOnce() -> BoxFuture<'static, Result<String, TestError>> + Send + 'static {
let value = value.to_string();
move || {
counter.fetch_add(1, Ordering::SeqCst);
async move { Ok(value) }.boxed()
}
}
fn err_fetcher(
counter: Arc<AtomicUsize>,
msg: &str,
) -> impl FnOnce() -> BoxFuture<'static, Result<String, TestError>> + Send + 'static {
let msg = msg.to_string();
move || {
counter.fetch_add(1, Ordering::SeqCst);
async move { Err(TestError(msg)) }.boxed()
}
}
#[tokio::test]
async fn test_basic_caching() {
let cache = new_cache();
let count = Arc::new(AtomicUsize::new(0));
let v1 = cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap();
assert_eq!(v1, "hello");
assert_eq!(count.load(Ordering::SeqCst), 1);
// Second call triggers peek transition to Current, returns cached
let v2 = cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap();
assert_eq!(v2, "hello");
assert_eq!(count.load(Ordering::SeqCst), 1);
// Third call still cached
let v3 = cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap();
assert_eq!(v3, "hello");
assert_eq!(count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_try_get_returns_none_when_empty() {
let cache: BackgroundCache<String, TestError> = new_cache();
assert!(cache.try_get().is_none());
}
#[tokio::test]
async fn test_try_get_returns_value_when_fresh() {
let cache = new_cache();
let count = Arc::new(AtomicUsize::new(0));
cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap();
// Peek transition
cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap();
assert_eq!(cache.try_get().unwrap(), "hello");
}
#[tokio::test]
async fn test_try_get_returns_none_in_refresh_window() {
let cache = new_cache();
let count = Arc::new(AtomicUsize::new(0));
cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap();
cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap(); // peek
clock::advance_by(Duration::from_secs(26));
assert!(cache.try_get().is_none());
}
#[tokio::test]
async fn test_ttl_expiration() {
let cache = new_cache();
let count = Arc::new(AtomicUsize::new(0));
cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap();
cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap(); // peek
assert_eq!(count.load(Ordering::SeqCst), 1);
clock::advance_by(Duration::from_secs(31));
let v = cache.get(ok_fetcher(count.clone(), "v2")).await.unwrap();
assert_eq!(v, "v2");
assert_eq!(count.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn test_invalidate_forces_refetch() {
let cache = new_cache();
let count = Arc::new(AtomicUsize::new(0));
cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap();
cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap(); // peek
assert_eq!(count.load(Ordering::SeqCst), 1);
cache.invalidate();
let v = cache.get(ok_fetcher(count.clone(), "v2")).await.unwrap();
assert_eq!(v, "v2");
assert_eq!(count.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn test_concurrent_get_single_fetch() {
let cache = Arc::new(new_cache());
let count = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::new();
for _ in 0..10 {
let cache = cache.clone();
let count = count.clone();
handles.push(tokio::spawn(async move {
cache.get(ok_fetcher(count, "hello")).await.unwrap()
}));
}
let results: Vec<String> = futures::future::try_join_all(handles).await.unwrap();
for r in &results {
assert_eq!(r, "hello");
}
assert_eq!(count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_background_refresh_in_window() {
let cache = new_cache();
let count = Arc::new(AtomicUsize::new(0));
// Populate and transition to Current
cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap();
cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap(); // peek
assert_eq!(count.load(Ordering::SeqCst), 1);
// Move into refresh window
clock::advance_by(Duration::from_secs(26));
// Returns cached value and starts background fetch
let v = cache.get(ok_fetcher(count.clone(), "v2")).await.unwrap();
assert_eq!(v, "v1"); // Still old value
assert_eq!(count.load(Ordering::SeqCst), 1); // bg task hasn't run yet
// Advance past TTL to force waiting on the shared future
clock::advance_by(Duration::from_secs(30));
let v = cache.get(ok_fetcher(count.clone(), "v3")).await.unwrap();
assert_eq!(count.load(Ordering::SeqCst), 2);
assert_eq!(v, "v2"); // Got the bg refresh result
}
#[tokio::test]
async fn test_no_duplicate_background_refreshes() {
let cache = new_cache();
let count = Arc::new(AtomicUsize::new(0));
// Populate and transition to Current
cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap();
cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap(); // peek
assert_eq!(count.load(Ordering::SeqCst), 1);
// Move into refresh window
clock::advance_by(Duration::from_secs(26));
// Multiple calls should all return cached, only one bg fetch
for _ in 0..5 {
let v = cache.get(ok_fetcher(count.clone(), "v2")).await.unwrap();
assert_eq!(v, "v1");
}
// Drive the shared future to completion
clock::advance_by(Duration::from_secs(30));
cache.get(ok_fetcher(count.clone(), "v3")).await.unwrap();
// Only 1 additional fetch (the background refresh)
assert_eq!(count.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn test_background_refresh_error_preserves_cache() {
let cache = new_cache();
let count = Arc::new(AtomicUsize::new(0));
// Populate and transition to Current
cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap();
cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap(); // peek
assert_eq!(count.load(Ordering::SeqCst), 1);
// Move into refresh window
clock::advance_by(Duration::from_secs(26));
// Start bg refresh that will fail, returns cached value
let v = cache.get(err_fetcher(count.clone(), "fail")).await.unwrap();
assert_eq!(v, "v1");
// Still in refresh window, previous is valid
let v = cache.get(err_fetcher(count.clone(), "fail")).await.unwrap();
assert_eq!(v, "v1");
// Advance past TTL to drive the failed future
clock::advance_by(Duration::from_secs(30));
// The peek error path restores previous, but it's expired,
// so a new fetch is needed. This one also fails.
let result = cache.get(err_fetcher(count.clone(), "fail again")).await;
assert!(result.is_err());
assert_eq!(count.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn test_invalidation_during_fetch_prevents_stale_update() {
let cache = new_cache();
let count = Arc::new(AtomicUsize::new(0));
// Populate and transition to Current
cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap();
cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap(); // peek
// Move into refresh window to start background fetch
clock::advance_by(Duration::from_secs(26));
cache.get(ok_fetcher(count.clone(), "stale")).await.unwrap();
// Invalidate before bg task completes
cache.invalidate();
// Advance past TTL
clock::advance_by(Duration::from_secs(30));
// Should get fresh data, not the stale background result
let v = cache.get(ok_fetcher(count.clone(), "fresh")).await.unwrap();
assert_eq!(v, "fresh");
}
}

View File

@@ -1,6 +1,8 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
pub(crate) mod background_cache;
use std::sync::Arc;
use arrow_array::RecordBatch;