Compare commits

..

14 Commits

Author SHA1 Message Date
Dan Tasse
97754f5123 fix: change _client reference to _conn (#3188)
This code previously referenced `self._client`, which does not exist.
This change makes it correctly call `self._conn.close()`
2026-03-31 13:29:17 -07:00
Pratik Dey
7b1c063848 feat(python): add type-safe expression builder API (#3150)
Introduces col(), lit(), func(), and Expr class as alternatives to raw
SQL strings in .where() and .select(). Expressions are backed by
DataFusion's Expr AST and serialized to SQL for remote table compat.

Resolves: 
- https://github.com/lancedb/lancedb/issues/3044 (python api's)
- https://github.com/lancedb/lancedb/issues/3043 (support for filter)
- https://github.com/lancedb/lancedb/issues/3045 (support for
projection)

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-31 11:32:49 -07:00
Will Jones
c7f189f27b chore: upgrade lance to stable 4.0.0 (#3207)
Bumps all lance-* workspace dependencies from `4.0.0-rc.3` (git source)
to the stable `4.0.0` release on crates.io, removing the `git`/`tag`
overrides.

No code changes were required — compiles and passes clippy cleanly.

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-30 17:05:45 -07:00
yaommen
a0a2942ad5 fix: respect max_batch_length for Rust vector and hybrid queries (#3172)
Fixes #1540

I could not reproduce this on current `main` from Python, but I could
still reproduce it from the Rust SDK.

Python no longer reproduces because the current Python vector/hybrid
query paths re-chunk results into a `pyarrow.Table` before returning
batches. Rust still reproduced because `max_batch_length` was passed
into planning/scanning, but vector search could still emit larger
`RecordBatch`es later in execution (for example after KNN / TopK), so it
was not enforced on the final Rust output stream.

This PR enforces `max_batch_length` on the final Rust query output
stream and adds Rust regression coverage.

Before the fix, the Rust repro produced:
`num_batches=2, max_batch=8192, min_batch=1808, all_le_100=false`

After the fix, the same repro produces batches `<= 100`.

## Runnable Rust repro

Before this fix, current `main` could still return batches like `[8192,
1808]` here even with `max_batch_length = 100`:

```rust
use std::sync::Arc;

use arrow_array::{
    types::Float32Type, FixedSizeListArray, RecordBatch, RecordBatchReader, StringArray,
};
use arrow_schema::{DataType, Field, Schema};
use futures::TryStreamExt;
use lancedb::query::{ExecutableQuery, QueryBase, QueryExecutionOptions};

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    let tmp = tempfile::tempdir()?;
    let uri = tmp.path().to_str().unwrap();

    let rows = 10_000;
    let schema = Arc::new(Schema::new(vec![
        Field::new("id", DataType::Utf8, false),
        Field::new(
            "vector",
            DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4),
            false,
        ),
    ]));

    let ids = StringArray::from_iter_values((0..rows).map(|i| format!("row-{i}")));
    let vectors = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
        (0..rows).map(|i| Some(vec![Some(i as f32), Some(1.0), Some(2.0), Some(3.0)])),
        4,
    );
    let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(ids), Arc::new(vectors)])?;
    let reader: Box<dyn RecordBatchReader + Send> = Box::new(
        arrow_array::RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema),
    );

    let db = lancedb::connect(uri).execute().await?;
    let table = db.create_table("test", reader).execute().await?;

    let mut opts = QueryExecutionOptions::default();
    opts.max_batch_length = 100;

    let mut stream = table
        .query()
        .nearest_to(vec![0.0, 1.0, 2.0, 3.0])?
        .limit(rows)
        .execute_with_options(opts)
        .await?;

    let mut sizes = Vec::new();
    while let Some(batch) = stream.try_next().await? {
        sizes.push(batch.num_rows());
    }

    println!("{sizes:?}");
    Ok(())
}
```

Signed-off-by: yaommen <myanstu@163.com>
2026-03-30 15:43:58 -07:00
Will Jones
e3d53dd185 fix(python): skip test_url_retrieve_downloads_image when PIL not installed (#3208)
The test added in #3190 unconditionally imports `PIL`, which is an
optional dependency. This causes CI failures in environments where
Pillow isn't installed (`ModuleNotFoundError: No module named 'PIL'`).

Use `pytest.importorskip` to skip gracefully when Pillow is unavailable.

Fixes CI failure on main.

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-30 14:48:49 -07:00
Will Jones
66804e99fc fix(python): use correct exception types in namespace tests (#3206)
## Summary
- Namespace tests expected `RuntimeError` for table-not-found and
namespace-not-empty cases, but `lance_namespace` raises
`TableNotFoundError` and `NamespaceNotEmptyError` which inherit from
`Exception`, not `RuntimeError`.
- Updated `pytest.raises` to use the correct exception types.

## Test plan
- [x] CI passes on `test_namespace.py`

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-30 12:55:54 -07:00
lennylxx
9f85d4c639 fix(embeddings): add missing urllib.request import in url_retrieve (#3190)
url_retrieve() calls urllib.request.urlopen() but only urllib.error was
imported, causing AttributeError for any HTTP URL input. This affects
open-clip, siglip, and jinaai embedding functions when processing image
URLs.

The bug has existed since the embeddings API refactor (#580) but was
masked because most users pass local file paths or bytes rather than
HTTP URLs.
2026-03-30 12:03:44 -07:00
Vedant Madane
1ba19d728e feat(node): support Float16, Float64, and Uint8 vector queries (#3193)
Fixes #2716

## Summary

Add support for querying with Float16Array, Float64Array, and Uint8Array
vectors in the Node.js SDK, eliminating precision loss from the previous
\Float32Array.from()\ conversion.

## Implementation

Follows @wjones127's [5-step
plan](https://github.com/lancedb/lancedb/issues/2716#issuecomment-3447750543):

### Rust (\
odejs/src/query.rs\)

1. \ytes_to_arrow_array(data: Uint8Array, dtype: String)\ helper that:
   - Creates an Arrow \Buffer\ from the raw bytes
   - Wraps it in a typed \ScalarBuffer<T>\ based on the dtype enum
   - Constructs a \PrimitiveArray\ and returns \Arc<dyn Array>\
2. \
earest_to_raw(data, dtype)\ and \dd_query_vector_raw(data, dtype)\ NAPI
methods that pass the type-erased array to the core \
earest_to\/\dd_query_vector\ which already accept \impl
IntoQueryVector\ for \Arc<dyn Array>\

### TypeScript (\
odejs/lancedb/query.ts\, \rrow.ts\)

3. Extended \IntoVector\ type to include \Uint8Array\ (and
\Float16Array\ via runtime check for Node 22+)
4. \xtractVectorBuffer()\ helper detects non-Float32 typed arrays and
extracts their underlying byte buffer + dtype string
5. \
earestTo()\ and \ddQueryVector()\ route through the raw NAPI path when
the input is Float16/Float64/Uint8

### Backward compatibility

Existing \Float32Array\ and \
umber[]\ inputs are unchanged -- they still use the original \
earest_to(Float32Array)\ NAPI method. The new raw path is only used when
a non-Float32 typed array is detected.

## Usage

\\\	ypescript
// Float16Array (Node 22+) -- no precision loss
const f16vec = new Float16Array([0.1, 0.2, 0.3]);
const results = await
table.query().nearestTo(f16vec).limit(10).toArray();

// Float64Array -- no precision loss
const f64vec = new Float64Array([0.1, 0.2, 0.3]);
const results = await
table.query().nearestTo(f64vec).limit(10).toArray();

// Uint8Array (binary embeddings)
const u8vec = new Uint8Array([1, 0, 1, 1, 0]);
const results = await
table.query().nearestTo(u8vec).limit(10).toArray();

// Existing usage unchanged
const results = await table.query().nearestTo([0.1, 0.2,
0.3]).limit(10).toArray();
\\\

## Note on dependencies

The Rust side uses \rrow_array\, \rrow_buffer\, and \half\ crates.
These should already be in the dependency tree via \lancedb\ core, but
\Cargo.toml\ may need explicit entries for \half\ and the arrow
sub-crates in the nodejs workspace.

---------

Signed-off-by: Vedant Madane <6527493+VedantMadane@users.noreply.github.com>
Co-authored-by: Will Jones <willjones127@gmail.com>
2026-03-30 11:15:35 -07:00
lif
4c44587af0 fix: table.add(mode='overwrite') infers vector column types (#3184)
Fixes #3183

## Summary

When `table.add(mode='overwrite')` is called, PyArrow infers input data
types (e.g. `list<double>`) which differ from the original table schema
(e.g. `fixed_size_list<float32>`). Previously, overwrite mode bypassed
`cast_to_table_schema()` entirely, so the inferred types replaced the
original schema, breaking vector search.

This fix builds a merged target schema for overwrite: columns present in
the existing table schema keep their original types, while columns
unique to the input pass through as-is. This way
`cast_to_table_schema()` is applied unconditionally, preserving vector
column types without blocking schema evolution.

## Changes

- `rust/lancedb/src/table/add_data.rs`: For overwrite mode, construct a
target schema by matching input columns against the existing table
schema, then cast. Non-overwrite (append) path is unchanged.
- Added `test_add_overwrite_preserves_vector_type` test that creates a
table with `fixed_size_list<float32>`, overwrites with `list<double>`
input, and asserts the original type is preserved.

## Test Plan

- `cargo test --features remote -p lancedb -- test_add_overwrite` — all
4 overwrite tests pass
- Full suite: 454 passed, 2 failed (pre-existing `remote::retry` flakes
unrelated to this change)

---------

Signed-off-by: majiayu000 <1835304752@qq.com>
2026-03-30 10:57:33 -07:00
lennylxx
1d1cafb59c fix(python): don't assign dict.update() return value in _sanitize_data (#3198)
dict.update() mutates in place and returns None. Assigning its result
caused with_metadata(None) to strip all schema metadata when embedding
metadata was merged during create_table with embedding_functions.
2026-03-30 10:15:45 -07:00
aikido-autofix[bot]
4714598155 ci: mitigate template injection attack in build_linux_wheel (#3195)
This patch mitigates template injection vulnerabilities in GitHub
Workflows by replacing direct references with an environment variable.

Aikido used AI to generate this PR.

High confidence: Aikido has a robust set of benchmarks for similar
fixes, and they are proven to be effective.

Co-authored-by: aikido-autofix[bot] <119856028+aikido-autofix[bot]@users.noreply.github.com>
2026-03-30 09:29:24 -07:00
lennylxx
74f457a0f2 fix(rust): handle Mutex lock poisoning gracefully across codebase (#3196)
Replace ~30 production `lock().unwrap()` calls that would cascade-panic
on a poisoned Mutex. Functions returning `Result` now propagate the
poison as an error via `?` (leveraging the existing `From<PoisonError>`
impl). Functions without a `Result` return recover via
`unwrap_or_else(|e| e.into_inner())`, which is safe because the guarded
data (counters, caches, RNG state) remains logically valid after a
panic.
2026-03-30 09:25:18 -07:00
Dan Tasse
cca6a7c989 fix: raise instead of return ValueError (#3189)
These couple of cases used to return ValueError; should raise it
instead.
2026-03-25 18:49:29 -07:00
Lance Release
ad96489114 Bump version: 0.27.2-beta.0 → 0.27.2-beta.1 2026-03-25 16:22:09 +00:00
55 changed files with 2103 additions and 198 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion] [tool.bumpversion]
current_version = "0.27.2-beta.0" current_version = "0.27.2-beta.1"
parse = """(?x) parse = """(?x)
(?P<major>0|[1-9]\\d*)\\. (?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\. (?P<minor>0|[1-9]\\d*)\\.

View File

@@ -23,8 +23,10 @@ runs:
steps: steps:
- name: CONFIRM ARM BUILD - name: CONFIRM ARM BUILD
shell: bash shell: bash
env:
ARM_BUILD: ${{ inputs.arm-build }}
run: | run: |
echo "ARM BUILD: ${{ inputs.arm-build }}" echo "ARM BUILD: $ARM_BUILD"
- name: Build x86_64 Manylinux wheel - name: Build x86_64 Manylinux wheel
if: ${{ inputs.arm-build == 'false' }} if: ${{ inputs.arm-build == 'false' }}
uses: PyO3/maturin-action@v1 uses: PyO3/maturin-action@v1

125
Cargo.lock generated
View File

@@ -108,7 +108,7 @@ version = "1.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc"
dependencies = [ dependencies = [
"windows-sys 0.60.2", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
@@ -119,7 +119,7 @@ checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d"
dependencies = [ dependencies = [
"anstyle", "anstyle",
"once_cell_polyfill", "once_cell_polyfill",
"windows-sys 0.60.2", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
@@ -2682,7 +2682,7 @@ dependencies = [
"libc", "libc",
"option-ext", "option-ext",
"redox_users", "redox_users",
"windows-sys 0.60.2", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
@@ -2876,7 +2876,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [ dependencies = [
"libc", "libc",
"windows-sys 0.52.0", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
@@ -3072,8 +3072,9 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]] [[package]]
name = "fsst" name = "fsst"
version = "4.0.0-rc.3" version = "4.0.0"
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2195cc7f87e84bd695586137de99605e7e9579b26ec5e01b82960ddb4d0922f2"
dependencies = [ dependencies = [
"arrow-array", "arrow-array",
"rand 0.9.2", "rand 0.9.2",
@@ -3736,7 +3737,7 @@ dependencies = [
"libc", "libc",
"percent-encoding", "percent-encoding",
"pin-project-lite", "pin-project-lite",
"socket2 0.5.10", "socket2 0.6.3",
"system-configuration", "system-configuration",
"tokio", "tokio",
"tower-service", "tower-service",
@@ -4037,7 +4038,7 @@ dependencies = [
"portable-atomic", "portable-atomic",
"portable-atomic-util", "portable-atomic-util",
"serde_core", "serde_core",
"windows-sys 0.52.0", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
@@ -4123,8 +4124,9 @@ dependencies = [
[[package]] [[package]]
name = "lance" name = "lance"
version = "4.0.0-rc.3" version = "4.0.0"
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "efe6c3ddd79cdfd2b7e1c23cafae52806906bc40fbd97de9e8cf2f8c7a75fc04"
dependencies = [ dependencies = [
"arrow", "arrow",
"arrow-arith", "arrow-arith",
@@ -4190,8 +4192,9 @@ dependencies = [
[[package]] [[package]]
name = "lance-arrow" name = "lance-arrow"
version = "4.0.0-rc.3" version = "4.0.0"
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d9f5d95bdda2a2b790f1fb8028b5b6dcf661abeb3133a8bca0f3d24b054af87"
dependencies = [ dependencies = [
"arrow-array", "arrow-array",
"arrow-buffer", "arrow-buffer",
@@ -4211,8 +4214,9 @@ dependencies = [
[[package]] [[package]]
name = "lance-bitpacking" name = "lance-bitpacking"
version = "4.0.0-rc.3" version = "4.0.0"
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f827d6ab9f8f337a9509d5ad66a12f3314db8713868260521c344ef6135eb4e4"
dependencies = [ dependencies = [
"arrayref", "arrayref",
"paste", "paste",
@@ -4221,8 +4225,9 @@ dependencies = [
[[package]] [[package]]
name = "lance-core" name = "lance-core"
version = "4.0.0-rc.3" version = "4.0.0"
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f1e25df6a79bf72ee6bcde0851f19b1cd36c5848c1b7db83340882d3c9fdecb"
dependencies = [ dependencies = [
"arrow-array", "arrow-array",
"arrow-buffer", "arrow-buffer",
@@ -4259,8 +4264,9 @@ dependencies = [
[[package]] [[package]]
name = "lance-datafusion" name = "lance-datafusion"
version = "4.0.0-rc.3" version = "4.0.0"
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93146de8ae720cb90edef81c2f2d0a1b065fc2f23ecff2419546f389b0fa70a4"
dependencies = [ dependencies = [
"arrow", "arrow",
"arrow-array", "arrow-array",
@@ -4290,8 +4296,9 @@ dependencies = [
[[package]] [[package]]
name = "lance-datagen" name = "lance-datagen"
version = "4.0.0-rc.3" version = "4.0.0"
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ccec8ce4d8e0a87a99c431dab2364398029f2ffb649c1a693c60c79e05ed30dd"
dependencies = [ dependencies = [
"arrow", "arrow",
"arrow-array", "arrow-array",
@@ -4309,8 +4316,9 @@ dependencies = [
[[package]] [[package]]
name = "lance-encoding" name = "lance-encoding"
version = "4.0.0-rc.3" version = "4.0.0"
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c1aec0bbbac6bce829bc10f1ba066258126100596c375fb71908ecf11c2c2a5"
dependencies = [ dependencies = [
"arrow-arith", "arrow-arith",
"arrow-array", "arrow-array",
@@ -4347,8 +4355,9 @@ dependencies = [
[[package]] [[package]]
name = "lance-file" name = "lance-file"
version = "4.0.0-rc.3" version = "4.0.0"
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14a8c548804f5b17486dc2d3282356ed1957095a852780283bc401fdd69e9075"
dependencies = [ dependencies = [
"arrow-arith", "arrow-arith",
"arrow-array", "arrow-array",
@@ -4380,8 +4389,9 @@ dependencies = [
[[package]] [[package]]
name = "lance-index" name = "lance-index"
version = "4.0.0-rc.3" version = "4.0.0"
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2da212f0090ea59f79ac3686660f596520c167fe1cb5f408900cf71d215f0e03"
dependencies = [ dependencies = [
"arrow", "arrow",
"arrow-arith", "arrow-arith",
@@ -4445,8 +4455,9 @@ dependencies = [
[[package]] [[package]]
name = "lance-io" name = "lance-io"
version = "4.0.0-rc.3" version = "4.0.0"
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41d958eb4b56f03bbe0f5f85eb2b4e9657882812297b6f711f201ffc995f259f"
dependencies = [ dependencies = [
"arrow", "arrow",
"arrow-arith", "arrow-arith",
@@ -4487,8 +4498,9 @@ dependencies = [
[[package]] [[package]]
name = "lance-linalg" name = "lance-linalg"
version = "4.0.0-rc.3" version = "4.0.0"
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0285b70da35def7ed95e150fae1d5308089554e1290470403ed3c50cb235bc5e"
dependencies = [ dependencies = [
"arrow-array", "arrow-array",
"arrow-buffer", "arrow-buffer",
@@ -4504,8 +4516,9 @@ dependencies = [
[[package]] [[package]]
name = "lance-namespace" name = "lance-namespace"
version = "4.0.0-rc.3" version = "4.0.0"
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f78e2a828b654e062a495462c6e3eb4fcf0e7e907d761b8f217fc09ccd3ceac"
dependencies = [ dependencies = [
"arrow", "arrow",
"async-trait", "async-trait",
@@ -4518,8 +4531,9 @@ dependencies = [
[[package]] [[package]]
name = "lance-namespace-impls" name = "lance-namespace-impls"
version = "4.0.0-rc.3" version = "4.0.0"
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2392314f3da38f00d166295e44244208a65ccfc256e274fa8631849fc3f4d94"
dependencies = [ dependencies = [
"arrow", "arrow",
"arrow-ipc", "arrow-ipc",
@@ -4563,8 +4577,9 @@ dependencies = [
[[package]] [[package]]
name = "lance-table" name = "lance-table"
version = "4.0.0-rc.3" version = "4.0.0"
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3df9c4adca3eb2074b3850432a9fb34248a3d90c3d6427d158b13ff9355664ee"
dependencies = [ dependencies = [
"arrow", "arrow",
"arrow-array", "arrow-array",
@@ -4603,8 +4618,9 @@ dependencies = [
[[package]] [[package]]
name = "lance-testing" name = "lance-testing"
version = "4.0.0-rc.3" version = "4.0.0"
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ed7119bdd6983718387b4ac44af873a165262ca94f181b104cd6f97912eb3bf"
dependencies = [ dependencies = [
"arrow-array", "arrow-array",
"arrow-schema", "arrow-schema",
@@ -4615,7 +4631,7 @@ dependencies = [
[[package]] [[package]]
name = "lancedb" name = "lancedb"
version = "0.27.2-beta.0" version = "0.27.2-beta.1"
dependencies = [ dependencies = [
"ahash", "ahash",
"anyhow", "anyhow",
@@ -4697,9 +4713,10 @@ dependencies = [
[[package]] [[package]]
name = "lancedb-nodejs" name = "lancedb-nodejs"
version = "0.27.2-beta.0" version = "0.27.2-beta.1"
dependencies = [ dependencies = [
"arrow-array", "arrow-array",
"arrow-buffer",
"arrow-ipc", "arrow-ipc",
"arrow-schema", "arrow-schema",
"async-trait", "async-trait",
@@ -4707,6 +4724,7 @@ dependencies = [
"aws-lc-sys", "aws-lc-sys",
"env_logger", "env_logger",
"futures", "futures",
"half",
"lancedb", "lancedb",
"log", "log",
"lzma-sys", "lzma-sys",
@@ -4717,7 +4735,7 @@ dependencies = [
[[package]] [[package]]
name = "lancedb-python" name = "lancedb-python"
version = "0.30.2-beta.0" version = "0.30.2-beta.1"
dependencies = [ dependencies = [
"arrow", "arrow",
"async-trait", "async-trait",
@@ -5305,7 +5323,7 @@ version = "0.50.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
dependencies = [ dependencies = [
"windows-sys 0.60.2", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
@@ -6284,7 +6302,7 @@ version = "0.14.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7" checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7"
dependencies = [ dependencies = [
"heck 0.4.1", "heck 0.5.0",
"itertools 0.14.0", "itertools 0.14.0",
"log", "log",
"multimap", "multimap",
@@ -6471,7 +6489,7 @@ dependencies = [
"quinn-udp", "quinn-udp",
"rustc-hash", "rustc-hash",
"rustls 0.23.37", "rustls 0.23.37",
"socket2 0.5.10", "socket2 0.6.3",
"thiserror 2.0.18", "thiserror 2.0.18",
"tokio", "tokio",
"tracing", "tracing",
@@ -6508,9 +6526,9 @@ dependencies = [
"cfg_aliases", "cfg_aliases",
"libc", "libc",
"once_cell", "once_cell",
"socket2 0.5.10", "socket2 0.6.3",
"tracing", "tracing",
"windows-sys 0.52.0", "windows-sys 0.60.2",
] ]
[[package]] [[package]]
@@ -7039,7 +7057,7 @@ dependencies = [
"errno", "errno",
"libc", "libc",
"linux-raw-sys 0.4.15", "linux-raw-sys 0.4.15",
"windows-sys 0.52.0", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
@@ -7052,7 +7070,7 @@ dependencies = [
"errno", "errno",
"libc", "libc",
"linux-raw-sys 0.12.1", "linux-raw-sys 0.12.1",
"windows-sys 0.52.0", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
@@ -7572,7 +7590,7 @@ version = "0.8.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1c97747dbf44bb1ca44a561ece23508e99cb592e862f22222dcf42f51d1e451" checksum = "c1c97747dbf44bb1ca44a561ece23508e99cb592e862f22222dcf42f51d1e451"
dependencies = [ dependencies = [
"heck 0.4.1", "heck 0.5.0",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.117", "syn 2.0.117",
@@ -7584,7 +7602,7 @@ version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54254b8531cafa275c5e096f62d48c81435d1015405a91198ddb11e967301d40" checksum = "54254b8531cafa275c5e096f62d48c81435d1015405a91198ddb11e967301d40"
dependencies = [ dependencies = [
"heck 0.4.1", "heck 0.5.0",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.117", "syn 2.0.117",
@@ -7607,7 +7625,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e"
dependencies = [ dependencies = [
"libc", "libc",
"windows-sys 0.60.2", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
@@ -7711,7 +7729,6 @@ dependencies = [
"cfg-if", "cfg-if",
"libc", "libc",
"psm", "psm",
"windows-sys 0.52.0",
"windows-sys 0.59.0", "windows-sys 0.59.0",
] ]
@@ -8072,7 +8089,7 @@ dependencies = [
"getrandom 0.4.2", "getrandom 0.4.2",
"once_cell", "once_cell",
"rustix 1.1.4", "rustix 1.1.4",
"windows-sys 0.52.0", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
@@ -8877,7 +8894,7 @@ version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
dependencies = [ dependencies = [
"windows-sys 0.52.0", "windows-sys 0.61.2",
] ]
[[package]] [[package]]

View File

@@ -15,20 +15,20 @@ categories = ["database-implementations"]
rust-version = "1.91.0" rust-version = "1.91.0"
[workspace.dependencies] [workspace.dependencies]
lance = { "version" = "=4.0.0-rc.3", default-features = false, "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" } lance = { version = "=4.0.0", default-features = false }
lance-core = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" } lance-core = { version = "=4.0.0" }
lance-datagen = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" } lance-datagen = { version = "=4.0.0" }
lance-file = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" } lance-file = { version = "=4.0.0" }
lance-io = { "version" = "=4.0.0-rc.3", default-features = false, "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" } lance-io = { version = "=4.0.0", default-features = false }
lance-index = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" } lance-index = { version = "=4.0.0" }
lance-linalg = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" } lance-linalg = { version = "=4.0.0" }
lance-namespace = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" } lance-namespace = { version = "=4.0.0" }
lance-namespace-impls = { "version" = "=4.0.0-rc.3", default-features = false, "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" } lance-namespace-impls = { version = "=4.0.0", default-features = false }
lance-table = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" } lance-table = { version = "=4.0.0" }
lance-testing = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" } lance-testing = { version = "=4.0.0" }
lance-datafusion = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" } lance-datafusion = { version = "=4.0.0" }
lance-encoding = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" } lance-encoding = { version = "=4.0.0" }
lance-arrow = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" } lance-arrow = { version = "=4.0.0" }
ahash = "0.8" ahash = "0.8"
# Note that this one does not include pyarrow # Note that this one does not include pyarrow
arrow = { version = "57.2", optional = false } arrow = { version = "57.2", optional = false }

View File

@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
<dependency> <dependency>
<groupId>com.lancedb</groupId> <groupId>com.lancedb</groupId>
<artifactId>lancedb-core</artifactId> <artifactId>lancedb-core</artifactId>
<version>0.27.2-beta.0</version> <version>0.27.2-beta.1</version>
</dependency> </dependency>
``` ```

View File

@@ -52,7 +52,7 @@ new EmbeddingFunction<T, M>(): EmbeddingFunction<T, M>
### computeQueryEmbeddings() ### computeQueryEmbeddings()
```ts ```ts
computeQueryEmbeddings(data): Promise<number[] | Float32Array | Float64Array> computeQueryEmbeddings(data): Promise<number[] | Uint8Array | Float32Array | Float64Array>
``` ```
Compute the embeddings for a single query Compute the embeddings for a single query
@@ -63,7 +63,7 @@ Compute the embeddings for a single query
#### Returns #### Returns
`Promise`&lt;`number`[] \| `Float32Array` \| `Float64Array`&gt; `Promise`&lt;`number`[] \| `Uint8Array` \| `Float32Array` \| `Float64Array`&gt;
*** ***

View File

@@ -37,7 +37,7 @@ new TextEmbeddingFunction<M>(): TextEmbeddingFunction<M>
### computeQueryEmbeddings() ### computeQueryEmbeddings()
```ts ```ts
computeQueryEmbeddings(data): Promise<number[] | Float32Array | Float64Array> computeQueryEmbeddings(data): Promise<number[] | Uint8Array | Float32Array | Float64Array>
``` ```
Compute the embeddings for a single query Compute the embeddings for a single query
@@ -48,7 +48,7 @@ Compute the embeddings for a single query
#### Returns #### Returns
`Promise`&lt;`number`[] \| `Float32Array` \| `Float64Array`&gt; `Promise`&lt;`number`[] \| `Uint8Array` \| `Float32Array` \| `Float64Array`&gt;
#### Overrides #### Overrides

View File

@@ -7,5 +7,10 @@
# Type Alias: IntoVector # Type Alias: IntoVector
```ts ```ts
type IntoVector: Float32Array | Float64Array | number[] | Promise<Float32Array | Float64Array | number[]>; type IntoVector:
| Float32Array
| Float64Array
| Uint8Array
| number[]
| Promise<Float32Array | Float64Array | Uint8Array | number[]>;
``` ```

View File

@@ -36,6 +36,20 @@ is also an [asynchronous API client](#connections-asynchronous).
::: lancedb.table.Tags ::: lancedb.table.Tags
## Expressions
Type-safe expression builder for filters and projections. Use these instead
of raw SQL strings with [where][lancedb.query.LanceQueryBuilder.where] and
[select][lancedb.query.LanceQueryBuilder.select].
::: lancedb.expr.Expr
::: lancedb.expr.col
::: lancedb.expr.lit
::: lancedb.expr.func
## Querying (Synchronous) ## Querying (Synchronous)
::: lancedb.query.Query ::: lancedb.query.Query

View File

@@ -8,7 +8,7 @@
<parent> <parent>
<groupId>com.lancedb</groupId> <groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId> <artifactId>lancedb-parent</artifactId>
<version>0.27.2-beta.0</version> <version>0.27.2-beta.1</version>
<relativePath>../pom.xml</relativePath> <relativePath>../pom.xml</relativePath>
</parent> </parent>

View File

@@ -6,7 +6,7 @@
<groupId>com.lancedb</groupId> <groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId> <artifactId>lancedb-parent</artifactId>
<version>0.27.2-beta.0</version> <version>0.27.2-beta.1</version>
<packaging>pom</packaging> <packaging>pom</packaging>
<name>${project.artifactId}</name> <name>${project.artifactId}</name>
<description>LanceDB Java SDK Parent POM</description> <description>LanceDB Java SDK Parent POM</description>

View File

@@ -1,7 +1,7 @@
[package] [package]
name = "lancedb-nodejs" name = "lancedb-nodejs"
edition.workspace = true edition.workspace = true
version = "0.27.2-beta.0" version = "0.27.2-beta.1"
license.workspace = true license.workspace = true
description.workspace = true description.workspace = true
repository.workspace = true repository.workspace = true
@@ -15,6 +15,8 @@ crate-type = ["cdylib"]
async-trait.workspace = true async-trait.workspace = true
arrow-ipc.workspace = true arrow-ipc.workspace = true
arrow-array.workspace = true arrow-array.workspace = true
arrow-buffer = "57.2"
half.workspace = true
arrow-schema.workspace = true arrow-schema.workspace = true
env_logger.workspace = true env_logger.workspace = true
futures.workspace = true futures.workspace = true

View File

@@ -0,0 +1,110 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
import * as tmp from "tmp";
import { type Table, connect } from "../lancedb";
import {
Field,
FixedSizeList,
Float32,
Int64,
Schema,
makeArrowTable,
} from "../lancedb/arrow";
describe("Vector query with different typed arrays", () => {
let tmpDir: tmp.DirResult;
afterEach(() => {
tmpDir?.removeCallback();
});
async function createFloat32Table(): Promise<Table> {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
const db = await connect(tmpDir.name);
const schema = new Schema([
new Field("id", new Int64(), true),
new Field(
"vec",
new FixedSizeList(2, new Field("item", new Float32())),
true,
),
]);
const data = makeArrowTable(
[
{ id: 1n, vec: [1.0, 0.0] },
{ id: 2n, vec: [0.0, 1.0] },
{ id: 3n, vec: [1.0, 1.0] },
],
{ schema },
);
return db.createTable("test_f32", data);
}
it("should search with Float32Array (baseline)", async () => {
const table = await createFloat32Table();
const results = await table
.query()
.nearestTo(new Float32Array([1.0, 0.0]))
.limit(1)
.toArray();
expect(results.length).toBe(1);
expect(Number(results[0].id)).toBe(1);
});
it("should search with number[] (backward compat)", async () => {
const table = await createFloat32Table();
const results = await table
.query()
.nearestTo([1.0, 0.0])
.limit(1)
.toArray();
expect(results.length).toBe(1);
expect(Number(results[0].id)).toBe(1);
});
it("should search with Float64Array via raw path", async () => {
const table = await createFloat32Table();
const results = await table
.query()
.nearestTo(new Float64Array([1.0, 0.0]))
.limit(1)
.toArray();
expect(results.length).toBe(1);
expect(Number(results[0].id)).toBe(1);
});
it("should add multiple query vectors with Float64Array", async () => {
const table = await createFloat32Table();
const results = await table
.query()
.nearestTo(new Float64Array([1.0, 0.0]))
.addQueryVector(new Float64Array([0.0, 1.0]))
.limit(2)
.toArray();
expect(results.length).toBeGreaterThanOrEqual(2);
});
// Float16Array is only available in Node 22+; not in TypeScript's standard lib yet
const float16ArrayCtor = (globalThis as unknown as Record<string, unknown>)
.Float16Array as (new (values: number[]) => unknown) | undefined;
const hasFloat16 = float16ArrayCtor !== undefined;
const f16it = hasFloat16 ? it : it.skip;
f16it("should search with Float16Array via raw path", async () => {
const table = await createFloat32Table();
const results = await table
.query()
.nearestTo(new float16ArrayCtor!([1.0, 0.0]) as Float32Array)
.limit(1)
.toArray();
expect(results.length).toBe(1);
expect(Number(results[0].id)).toBe(1);
});
});

View File

@@ -117,8 +117,9 @@ export type TableLike =
export type IntoVector = export type IntoVector =
| Float32Array | Float32Array
| Float64Array | Float64Array
| Uint8Array
| number[] | number[]
| Promise<Float32Array | Float64Array | number[]>; | Promise<Float32Array | Float64Array | Uint8Array | number[]>;
export type MultiVector = IntoVector[]; export type MultiVector = IntoVector[];
@@ -126,14 +127,48 @@ export function isMultiVector(value: unknown): value is MultiVector {
return Array.isArray(value) && isIntoVector(value[0]); return Array.isArray(value) && isIntoVector(value[0]);
} }
// Float16Array is not in TypeScript's standard lib yet; access dynamically
type Float16ArrayCtor = new (
...args: unknown[]
) => { buffer: ArrayBuffer; byteOffset: number; byteLength: number };
const float16ArrayCtor = (globalThis as unknown as Record<string, unknown>)
.Float16Array as Float16ArrayCtor | undefined;
export function isIntoVector(value: unknown): value is IntoVector { export function isIntoVector(value: unknown): value is IntoVector {
return ( return (
value instanceof Float32Array || value instanceof Float32Array ||
value instanceof Float64Array || value instanceof Float64Array ||
value instanceof Uint8Array ||
(float16ArrayCtor !== undefined && value instanceof float16ArrayCtor) ||
(Array.isArray(value) && !Array.isArray(value[0])) (Array.isArray(value) && !Array.isArray(value[0]))
); );
} }
/**
* Extract the underlying byte buffer and data type from a typed array
* for passing to the Rust NAPI layer without precision loss.
*/
export function extractVectorBuffer(
vector: Float32Array | Float64Array | Uint8Array,
): { data: Uint8Array; dtype: string } | null {
if (float16ArrayCtor !== undefined && vector instanceof float16ArrayCtor) {
return {
data: new Uint8Array(vector.buffer, vector.byteOffset, vector.byteLength),
dtype: "float16",
};
}
if (vector instanceof Float64Array) {
return {
data: new Uint8Array(vector.buffer, vector.byteOffset, vector.byteLength),
dtype: "float64",
};
}
if (vector instanceof Uint8Array && !(vector instanceof Float32Array)) {
return { data: vector, dtype: "uint8" };
}
return null;
}
export function isArrowTable(value: object): value is TableLike { export function isArrowTable(value: object): value is TableLike {
if (value instanceof ArrowTable) return true; if (value instanceof ArrowTable) return true;
return "schema" in value && "batches" in value; return "schema" in value && "batches" in value;

View File

@@ -5,6 +5,7 @@ import {
Table as ArrowTable, Table as ArrowTable,
type IntoVector, type IntoVector,
RecordBatch, RecordBatch,
extractVectorBuffer,
fromBufferToRecordBatch, fromBufferToRecordBatch,
fromRecordBatchToBuffer, fromRecordBatchToBuffer,
tableFromIPC, tableFromIPC,
@@ -661,10 +662,8 @@ export class VectorQuery extends StandardQueryBase<NativeVectorQuery> {
const res = (async () => { const res = (async () => {
try { try {
const v = await vector; const v = await vector;
const arr = Float32Array.from(v);
//
// biome-ignore lint/suspicious/noExplicitAny: we need to get the `inner`, but js has no package scoping // biome-ignore lint/suspicious/noExplicitAny: we need to get the `inner`, but js has no package scoping
const value: any = this.addQueryVector(arr); const value: any = this.addQueryVector(v);
const inner = value.inner as const inner = value.inner as
| NativeVectorQuery | NativeVectorQuery
| Promise<NativeVectorQuery>; | Promise<NativeVectorQuery>;
@@ -676,7 +675,12 @@ export class VectorQuery extends StandardQueryBase<NativeVectorQuery> {
return new VectorQuery(res); return new VectorQuery(res);
} else { } else {
super.doCall((inner) => { super.doCall((inner) => {
inner.addQueryVector(Float32Array.from(vector)); const raw = Array.isArray(vector) ? null : extractVectorBuffer(vector);
if (raw) {
inner.addQueryVectorRaw(raw.data, raw.dtype);
} else {
inner.addQueryVector(Float32Array.from(vector as number[]));
}
}); });
return this; return this;
} }
@@ -765,14 +769,23 @@ export class Query extends StandardQueryBase<NativeQuery> {
* a default `limit` of 10 will be used. @see {@link Query#limit} * a default `limit` of 10 will be used. @see {@link Query#limit}
*/ */
nearestTo(vector: IntoVector): VectorQuery { nearestTo(vector: IntoVector): VectorQuery {
const callNearestTo = (
inner: NativeQuery,
resolved: Float32Array | Float64Array | Uint8Array | number[],
): NativeVectorQuery => {
const raw = Array.isArray(resolved)
? null
: extractVectorBuffer(resolved);
if (raw) {
return inner.nearestToRaw(raw.data, raw.dtype);
}
return inner.nearestTo(Float32Array.from(resolved as number[]));
};
if (this.inner instanceof Promise) { if (this.inner instanceof Promise) {
const nativeQuery = this.inner.then(async (inner) => { const nativeQuery = this.inner.then(async (inner) => {
if (vector instanceof Promise) { const resolved = vector instanceof Promise ? await vector : vector;
const arr = await vector.then((v) => Float32Array.from(v)); return callNearestTo(inner, resolved);
return inner.nearestTo(arr);
} else {
return inner.nearestTo(Float32Array.from(vector));
}
}); });
return new VectorQuery(nativeQuery); return new VectorQuery(nativeQuery);
} }
@@ -780,10 +793,8 @@ export class Query extends StandardQueryBase<NativeQuery> {
const res = (async () => { const res = (async () => {
try { try {
const v = await vector; const v = await vector;
const arr = Float32Array.from(v);
//
// biome-ignore lint/suspicious/noExplicitAny: we need to get the `inner`, but js has no package scoping // biome-ignore lint/suspicious/noExplicitAny: we need to get the `inner`, but js has no package scoping
const value: any = this.nearestTo(arr); const value: any = this.nearestTo(v);
const inner = value.inner as const inner = value.inner as
| NativeVectorQuery | NativeVectorQuery
| Promise<NativeVectorQuery>; | Promise<NativeVectorQuery>;
@@ -794,7 +805,7 @@ export class Query extends StandardQueryBase<NativeQuery> {
})(); })();
return new VectorQuery(res); return new VectorQuery(res);
} else { } else {
const vectorQuery = this.inner.nearestTo(Float32Array.from(vector)); const vectorQuery = callNearestTo(this.inner, vector);
return new VectorQuery(vectorQuery); return new VectorQuery(vectorQuery);
} }
} }

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-darwin-arm64", "name": "@lancedb/lancedb-darwin-arm64",
"version": "0.27.2-beta.0", "version": "0.27.2-beta.1",
"os": ["darwin"], "os": ["darwin"],
"cpu": ["arm64"], "cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node", "main": "lancedb.darwin-arm64.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-arm64-gnu", "name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.27.2-beta.0", "version": "0.27.2-beta.1",
"os": ["linux"], "os": ["linux"],
"cpu": ["arm64"], "cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node", "main": "lancedb.linux-arm64-gnu.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-arm64-musl", "name": "@lancedb/lancedb-linux-arm64-musl",
"version": "0.27.2-beta.0", "version": "0.27.2-beta.1",
"os": ["linux"], "os": ["linux"],
"cpu": ["arm64"], "cpu": ["arm64"],
"main": "lancedb.linux-arm64-musl.node", "main": "lancedb.linux-arm64-musl.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-x64-gnu", "name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.27.2-beta.0", "version": "0.27.2-beta.1",
"os": ["linux"], "os": ["linux"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node", "main": "lancedb.linux-x64-gnu.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-x64-musl", "name": "@lancedb/lancedb-linux-x64-musl",
"version": "0.27.2-beta.0", "version": "0.27.2-beta.1",
"os": ["linux"], "os": ["linux"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.linux-x64-musl.node", "main": "lancedb.linux-x64-musl.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-win32-arm64-msvc", "name": "@lancedb/lancedb-win32-arm64-msvc",
"version": "0.27.2-beta.0", "version": "0.27.2-beta.1",
"os": [ "os": [
"win32" "win32"
], ],

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-win32-x64-msvc", "name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.27.2-beta.0", "version": "0.27.2-beta.1",
"os": ["win32"], "os": ["win32"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node", "main": "lancedb.win32-x64-msvc.node",

View File

@@ -1,12 +1,12 @@
{ {
"name": "@lancedb/lancedb", "name": "@lancedb/lancedb",
"version": "0.27.2-beta.0", "version": "0.27.2-beta.1",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "@lancedb/lancedb", "name": "@lancedb/lancedb",
"version": "0.27.2-beta.0", "version": "0.27.2-beta.1",
"cpu": [ "cpu": [
"x64", "x64",
"arm64" "arm64"

View File

@@ -11,7 +11,7 @@
"ann" "ann"
], ],
"private": false, "private": false,
"version": "0.27.2-beta.0", "version": "0.27.2-beta.1",
"main": "dist/index.js", "main": "dist/index.js",
"exports": { "exports": {
".": "./dist/index.js", ".": "./dist/index.js",

View File

@@ -3,6 +3,12 @@
use std::sync::Arc; use std::sync::Arc;
use arrow_array::{
Array, Float16Array as ArrowFloat16Array, Float32Array as ArrowFloat32Array,
Float64Array as ArrowFloat64Array, UInt8Array as ArrowUInt8Array,
};
use arrow_buffer::ScalarBuffer;
use half::f16;
use lancedb::index::scalar::{ use lancedb::index::scalar::{
BooleanQuery, BoostQuery, FtsQuery, FullTextSearchQuery, MatchQuery, MultiMatchQuery, Occur, BooleanQuery, BoostQuery, FtsQuery, FullTextSearchQuery, MatchQuery, MultiMatchQuery, Occur,
Operator, PhraseQuery, Operator, PhraseQuery,
@@ -24,6 +30,33 @@ use crate::rerankers::RerankHybridCallbackArgs;
use crate::rerankers::Reranker; use crate::rerankers::Reranker;
use crate::util::{parse_distance_type, schema_to_buffer}; use crate::util::{parse_distance_type, schema_to_buffer};
fn bytes_to_arrow_array(data: Uint8Array, dtype: String) -> napi::Result<Arc<dyn Array>> {
let buf = arrow_buffer::Buffer::from(data.to_vec());
let num_bytes = buf.len();
match dtype.as_str() {
"float16" => {
let scalar_buf = ScalarBuffer::<f16>::new(buf, 0, num_bytes / 2);
Ok(Arc::new(ArrowFloat16Array::new(scalar_buf, None)))
}
"float32" => {
let scalar_buf = ScalarBuffer::<f32>::new(buf, 0, num_bytes / 4);
Ok(Arc::new(ArrowFloat32Array::new(scalar_buf, None)))
}
"float64" => {
let scalar_buf = ScalarBuffer::<f64>::new(buf, 0, num_bytes / 8);
Ok(Arc::new(ArrowFloat64Array::new(scalar_buf, None)))
}
"uint8" => {
let scalar_buf = ScalarBuffer::<u8>::new(buf, 0, num_bytes);
Ok(Arc::new(ArrowUInt8Array::new(scalar_buf, None)))
}
_ => Err(napi::Error::from_reason(format!(
"Unsupported vector dtype: {}. Expected one of: float16, float32, float64, uint8",
dtype
))),
}
}
#[napi] #[napi]
pub struct Query { pub struct Query {
inner: LanceDbQuery, inner: LanceDbQuery,
@@ -78,6 +111,13 @@ impl Query {
Ok(VectorQuery { inner }) Ok(VectorQuery { inner })
} }
#[napi]
pub fn nearest_to_raw(&mut self, data: Uint8Array, dtype: String) -> Result<VectorQuery> {
let array = bytes_to_arrow_array(data, dtype)?;
let inner = self.inner.clone().nearest_to(array).default_error()?;
Ok(VectorQuery { inner })
}
#[napi] #[napi]
pub fn fast_search(&mut self) { pub fn fast_search(&mut self) {
self.inner = self.inner.clone().fast_search(); self.inner = self.inner.clone().fast_search();
@@ -163,6 +203,13 @@ impl VectorQuery {
Ok(()) Ok(())
} }
#[napi]
pub fn add_query_vector_raw(&mut self, data: Uint8Array, dtype: String) -> Result<()> {
let array = bytes_to_arrow_array(data, dtype)?;
self.inner = self.inner.clone().add_query_vector(array).default_error()?;
Ok(())
}
#[napi] #[napi]
pub fn distance_type(&mut self, distance_type: String) -> napi::Result<()> { pub fn distance_type(&mut self, distance_type: String) -> napi::Result<()> {
let distance_type = parse_distance_type(distance_type)?; let distance_type = parse_distance_type(distance_type)?;

2
python/.gitignore vendored
View File

@@ -1,3 +1,5 @@
# Test data created by some example tests # Test data created by some example tests
data/ data/
_lancedb.pyd _lancedb.pyd
# macOS debug symbols bundle generated during build
*.dSYM/

View File

@@ -18,6 +18,7 @@ from .db import AsyncConnection, DBConnection, LanceDBConnection
from .io import StorageOptionsProvider from .io import StorageOptionsProvider
from .remote import ClientConfig from .remote import ClientConfig
from .remote.db import RemoteDBConnection from .remote.db import RemoteDBConnection
from .expr import Expr, col, lit, func
from .schema import vector from .schema import vector
from .table import AsyncTable, Table from .table import AsyncTable, Table
from ._lancedb import Session from ._lancedb import Session
@@ -271,6 +272,10 @@ __all__ = [
"AsyncConnection", "AsyncConnection",
"AsyncLanceNamespaceDBConnection", "AsyncLanceNamespaceDBConnection",
"AsyncTable", "AsyncTable",
"col",
"Expr",
"func",
"lit",
"URI", "URI",
"sanitize_uri", "sanitize_uri",
"vector", "vector",

View File

@@ -27,6 +27,32 @@ from .remote import ClientConfig
IvfHnswPq: type[HnswPq] = HnswPq IvfHnswPq: type[HnswPq] = HnswPq
IvfHnswSq: type[HnswSq] = HnswSq IvfHnswSq: type[HnswSq] = HnswSq
class PyExpr:
"""A type-safe DataFusion expression node (Rust-side handle)."""
def eq(self, other: "PyExpr") -> "PyExpr": ...
def ne(self, other: "PyExpr") -> "PyExpr": ...
def lt(self, other: "PyExpr") -> "PyExpr": ...
def lte(self, other: "PyExpr") -> "PyExpr": ...
def gt(self, other: "PyExpr") -> "PyExpr": ...
def gte(self, other: "PyExpr") -> "PyExpr": ...
def and_(self, other: "PyExpr") -> "PyExpr": ...
def or_(self, other: "PyExpr") -> "PyExpr": ...
def not_(self) -> "PyExpr": ...
def add(self, other: "PyExpr") -> "PyExpr": ...
def sub(self, other: "PyExpr") -> "PyExpr": ...
def mul(self, other: "PyExpr") -> "PyExpr": ...
def div(self, other: "PyExpr") -> "PyExpr": ...
def lower(self) -> "PyExpr": ...
def upper(self) -> "PyExpr": ...
def contains(self, substr: "PyExpr") -> "PyExpr": ...
def cast(self, data_type: pa.DataType) -> "PyExpr": ...
def to_sql(self) -> str: ...
def expr_col(name: str) -> PyExpr: ...
def expr_lit(value: Union[bool, int, float, str]) -> PyExpr: ...
def expr_func(name: str, args: List[PyExpr]) -> PyExpr: ...
class Session: class Session:
def __init__( def __init__(
self, self,
@@ -225,7 +251,9 @@ class RecordBatchStream:
class Query: class Query:
def where(self, filter: str): ... def where(self, filter: str): ...
def select(self, columns: Tuple[str, str]): ... def where_expr(self, expr: PyExpr): ...
def select(self, columns: List[Tuple[str, str]]): ...
def select_expr(self, columns: List[Tuple[str, PyExpr]]): ...
def select_columns(self, columns: List[str]): ... def select_columns(self, columns: List[str]): ...
def limit(self, limit: int): ... def limit(self, limit: int): ...
def offset(self, offset: int): ... def offset(self, offset: int): ...
@@ -251,7 +279,9 @@ class TakeQuery:
class FTSQuery: class FTSQuery:
def where(self, filter: str): ... def where(self, filter: str): ...
def select(self, columns: List[str]): ... def where_expr(self, expr: PyExpr): ...
def select(self, columns: List[Tuple[str, str]]): ...
def select_expr(self, columns: List[Tuple[str, PyExpr]]): ...
def limit(self, limit: int): ... def limit(self, limit: int): ...
def offset(self, offset: int): ... def offset(self, offset: int): ...
def fast_search(self): ... def fast_search(self): ...
@@ -270,7 +300,9 @@ class VectorQuery:
async def output_schema(self) -> pa.Schema: ... async def output_schema(self) -> pa.Schema: ...
async def execute(self) -> RecordBatchStream: ... async def execute(self) -> RecordBatchStream: ...
def where(self, filter: str): ... def where(self, filter: str): ...
def select(self, columns: List[str]): ... def where_expr(self, expr: PyExpr): ...
def select(self, columns: List[Tuple[str, str]]): ...
def select_expr(self, columns: List[Tuple[str, PyExpr]]): ...
def select_with_projection(self, columns: Tuple[str, str]): ... def select_with_projection(self, columns: Tuple[str, str]): ...
def limit(self, limit: int): ... def limit(self, limit: int): ...
def offset(self, offset: int): ... def offset(self, offset: int): ...
@@ -287,7 +319,9 @@ class VectorQuery:
class HybridQuery: class HybridQuery:
def where(self, filter: str): ... def where(self, filter: str): ...
def select(self, columns: List[str]): ... def where_expr(self, expr: PyExpr): ...
def select(self, columns: List[Tuple[str, str]]): ...
def select_expr(self, columns: List[Tuple[str, PyExpr]]): ...
def limit(self, limit: int): ... def limit(self, limit: int): ...
def offset(self, offset: int): ... def offset(self, offset: int): ...
def fast_search(self): ... def fast_search(self): ...

View File

@@ -10,6 +10,7 @@ import sys
import threading import threading
import time import time
import urllib.error import urllib.error
import urllib.request
import weakref import weakref
import logging import logging
from functools import wraps from functools import wraps

View File

@@ -0,0 +1,298 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
"""Type-safe expression builder for filters and projections.
Instead of writing raw SQL strings you can build expressions with Python
operators::
from lancedb.expr import col, lit
# filter: age > 18 AND status = 'active'
filt = (col("age") > lit(18)) & (col("status") == lit("active"))
# projection: compute a derived column
proj = {"score": col("raw_score") * lit(1.5)}
table.search().where(filt).select(proj).to_list()
"""
from __future__ import annotations
from typing import Union
import pyarrow as pa
from lancedb._lancedb import PyExpr, expr_col, expr_lit, expr_func
__all__ = ["Expr", "col", "lit", "func"]
_STR_TO_PA_TYPE: dict = {
"bool": pa.bool_(),
"boolean": pa.bool_(),
"int8": pa.int8(),
"int16": pa.int16(),
"int32": pa.int32(),
"int64": pa.int64(),
"uint8": pa.uint8(),
"uint16": pa.uint16(),
"uint32": pa.uint32(),
"uint64": pa.uint64(),
"float16": pa.float16(),
"float32": pa.float32(),
"float": pa.float32(),
"float64": pa.float64(),
"double": pa.float64(),
"string": pa.string(),
"utf8": pa.string(),
"str": pa.string(),
"large_string": pa.large_utf8(),
"large_utf8": pa.large_utf8(),
"date32": pa.date32(),
"date": pa.date32(),
"date64": pa.date64(),
}
def _coerce(value: "ExprLike") -> "Expr":
"""Return *value* as an :class:`Expr`, wrapping plain Python values via
:func:`lit` if needed."""
if isinstance(value, Expr):
return value
return lit(value)
# Type alias used in annotations.
ExprLike = Union["Expr", bool, int, float, str]
class Expr:
"""A type-safe expression node.
Construct instances with :func:`col` and :func:`lit`, then combine them
using Python operators or the named methods below.
Examples
--------
>>> from lancedb.expr import col, lit
>>> filt = (col("age") > lit(18)) & (col("name").lower() == lit("alice"))
>>> proj = {"double": col("x") * lit(2)}
"""
# Make Expr unhashable so that == returns an Expr rather than being used
# for dict keys / set membership.
__hash__ = None # type: ignore[assignment]
def __init__(self, inner: PyExpr) -> None:
self._inner = inner
# ── comparisons ──────────────────────────────────────────────────────────
def __eq__(self, other: ExprLike) -> "Expr": # type: ignore[override]
"""Equal to (``col("x") == 1``)."""
return Expr(self._inner.eq(_coerce(other)._inner))
def __ne__(self, other: ExprLike) -> "Expr": # type: ignore[override]
"""Not equal to (``col("x") != 1``)."""
return Expr(self._inner.ne(_coerce(other)._inner))
def __lt__(self, other: ExprLike) -> "Expr":
"""Less than (``col("x") < 1``)."""
return Expr(self._inner.lt(_coerce(other)._inner))
def __le__(self, other: ExprLike) -> "Expr":
"""Less than or equal to (``col("x") <= 1``)."""
return Expr(self._inner.lte(_coerce(other)._inner))
def __gt__(self, other: ExprLike) -> "Expr":
"""Greater than (``col("x") > 1``)."""
return Expr(self._inner.gt(_coerce(other)._inner))
def __ge__(self, other: ExprLike) -> "Expr":
"""Greater than or equal to (``col("x") >= 1``)."""
return Expr(self._inner.gte(_coerce(other)._inner))
# ── logical ──────────────────────────────────────────────────────────────
def __and__(self, other: "Expr") -> "Expr":
"""Logical AND (``expr_a & expr_b``)."""
return Expr(self._inner.and_(_coerce(other)._inner))
def __or__(self, other: "Expr") -> "Expr":
"""Logical OR (``expr_a | expr_b``)."""
return Expr(self._inner.or_(_coerce(other)._inner))
def __invert__(self) -> "Expr":
"""Logical NOT (``~expr``)."""
return Expr(self._inner.not_())
# ── arithmetic ───────────────────────────────────────────────────────────
def __add__(self, other: ExprLike) -> "Expr":
"""Add (``col("x") + 1``)."""
return Expr(self._inner.add(_coerce(other)._inner))
def __radd__(self, other: ExprLike) -> "Expr":
"""Right-hand add (``1 + col("x")``)."""
return Expr(_coerce(other)._inner.add(self._inner))
def __sub__(self, other: ExprLike) -> "Expr":
"""Subtract (``col("x") - 1``)."""
return Expr(self._inner.sub(_coerce(other)._inner))
def __rsub__(self, other: ExprLike) -> "Expr":
"""Right-hand subtract (``1 - col("x")``)."""
return Expr(_coerce(other)._inner.sub(self._inner))
def __mul__(self, other: ExprLike) -> "Expr":
"""Multiply (``col("x") * 2``)."""
return Expr(self._inner.mul(_coerce(other)._inner))
def __rmul__(self, other: ExprLike) -> "Expr":
"""Right-hand multiply (``2 * col("x")``)."""
return Expr(_coerce(other)._inner.mul(self._inner))
def __truediv__(self, other: ExprLike) -> "Expr":
"""Divide (``col("x") / 2``)."""
return Expr(self._inner.div(_coerce(other)._inner))
def __rtruediv__(self, other: ExprLike) -> "Expr":
"""Right-hand divide (``1 / col("x")``)."""
return Expr(_coerce(other)._inner.div(self._inner))
# ── string methods ───────────────────────────────────────────────────────
def lower(self) -> "Expr":
"""Convert string column values to lowercase."""
return Expr(self._inner.lower())
def upper(self) -> "Expr":
"""Convert string column values to uppercase."""
return Expr(self._inner.upper())
def contains(self, substr: "ExprLike") -> "Expr":
"""Return True where the string contains *substr*."""
return Expr(self._inner.contains(_coerce(substr)._inner))
# ── type cast ────────────────────────────────────────────────────────────
def cast(self, data_type: Union[str, "pa.DataType"]) -> "Expr":
"""Cast values to *data_type*.
Parameters
----------
data_type:
A PyArrow ``DataType`` (e.g. ``pa.int32()``) or one of the type
name strings: ``"bool"``, ``"int8"``, ``"int16"``, ``"int32"``,
``"int64"``, ``"uint8"````"uint64"``, ``"float32"``,
``"float64"``, ``"string"``, ``"date32"``, ``"date64"``.
"""
if isinstance(data_type, str):
try:
data_type = _STR_TO_PA_TYPE[data_type]
except KeyError:
raise ValueError(
f"unsupported data type: '{data_type}'. Supported: "
f"{', '.join(_STR_TO_PA_TYPE)}"
)
return Expr(self._inner.cast(data_type))
# ── named comparison helpers (alternative to operators) ──────────────────
def eq(self, other: ExprLike) -> "Expr":
"""Equal to."""
return self.__eq__(other)
def ne(self, other: ExprLike) -> "Expr":
"""Not equal to."""
return self.__ne__(other)
def lt(self, other: ExprLike) -> "Expr":
"""Less than."""
return self.__lt__(other)
def lte(self, other: ExprLike) -> "Expr":
"""Less than or equal to."""
return self.__le__(other)
def gt(self, other: ExprLike) -> "Expr":
"""Greater than."""
return self.__gt__(other)
def gte(self, other: ExprLike) -> "Expr":
"""Greater than or equal to."""
return self.__ge__(other)
def and_(self, other: "Expr") -> "Expr":
"""Logical AND."""
return self.__and__(other)
def or_(self, other: "Expr") -> "Expr":
"""Logical OR."""
return self.__or__(other)
# ── utilities ────────────────────────────────────────────────────────────
def to_sql(self) -> str:
"""Render the expression as a SQL string (useful for debugging)."""
return self._inner.to_sql()
def __repr__(self) -> str:
return f"Expr({self._inner.to_sql()})"
# ── free functions ────────────────────────────────────────────────────────────
def col(name: str) -> Expr:
"""Reference a table column by name.
Parameters
----------
name:
The column name.
Examples
--------
>>> from lancedb.expr import col, lit
>>> col("age") > lit(18)
Expr((age > 18))
"""
return Expr(expr_col(name))
def lit(value: Union[bool, int, float, str]) -> Expr:
"""Create a literal (constant) value expression.
Parameters
----------
value:
A Python ``bool``, ``int``, ``float``, or ``str``.
Examples
--------
>>> from lancedb.expr import col, lit
>>> col("price") * lit(1.1)
Expr((price * 1.1))
"""
return Expr(expr_lit(value))
def func(name: str, *args: ExprLike) -> Expr:
"""Call an arbitrary SQL function by name.
Parameters
----------
name:
The SQL function name (e.g. ``"lower"``, ``"upper"``).
*args:
The function arguments as :class:`Expr` or plain Python literals.
Examples
--------
>>> from lancedb.expr import col, func
>>> func("lower", col("name"))
Expr(lower(name))
"""
inner_args = [_coerce(a)._inner for a in args]
return Expr(expr_func(name, inner_args))

View File

@@ -38,6 +38,7 @@ from .rerankers.base import Reranker
from .rerankers.rrf import RRFReranker from .rerankers.rrf import RRFReranker
from .rerankers.util import check_reranker_result from .rerankers.util import check_reranker_result
from .util import flatten_columns from .util import flatten_columns
from .expr import Expr
from lancedb._lancedb import fts_query_to_json from lancedb._lancedb import fts_query_to_json
from typing_extensions import Annotated from typing_extensions import Annotated
@@ -70,7 +71,7 @@ def ensure_vector_query(
) -> Union[List[float], List[List[float]], pa.Array, List[pa.Array]]: ) -> Union[List[float], List[List[float]], pa.Array, List[pa.Array]]:
if isinstance(val, list): if isinstance(val, list):
if len(val) == 0: if len(val) == 0:
return ValueError("Vector query must be a non-empty list") raise ValueError("Vector query must be a non-empty list")
sample = val[0] sample = val[0]
else: else:
if isinstance(val, float): if isinstance(val, float):
@@ -83,7 +84,7 @@ def ensure_vector_query(
return val return val
if isinstance(sample, list): if isinstance(sample, list):
if len(sample) == 0: if len(sample) == 0:
return ValueError("Vector query must be a non-empty list") raise ValueError("Vector query must be a non-empty list")
if isinstance(sample[0], float): if isinstance(sample[0], float):
# val is list of list of floats # val is list of list of floats
return val return val
@@ -449,8 +450,8 @@ class Query(pydantic.BaseModel):
ensure_vector_query, ensure_vector_query,
] = None ] = None
# sql filter to refine the query with # sql filter or type-safe Expr to refine the query with
filter: Optional[str] = None filter: Optional[Union[str, Expr]] = None
# if True then apply the filter after vector search # if True then apply the filter after vector search
postfilter: Optional[bool] = None postfilter: Optional[bool] = None
@@ -464,8 +465,8 @@ class Query(pydantic.BaseModel):
# distance type to use for vector search # distance type to use for vector search
distance_type: Optional[str] = None distance_type: Optional[str] = None
# which columns to return in the results # which columns to return in the results (dict values may be str or Expr)
columns: Optional[Union[List[str], Dict[str, str]]] = None columns: Optional[Union[List[str], Dict[str, Union[str, Expr]]]] = None
# minimum number of IVF partitions to search # minimum number of IVF partitions to search
# #
@@ -856,14 +857,15 @@ class LanceQueryBuilder(ABC):
self._offset = offset self._offset = offset
return self return self
def select(self, columns: Union[list[str], dict[str, str]]) -> Self: def select(self, columns: Union[list[str], dict[str, Union[str, Expr]]]) -> Self:
"""Set the columns to return. """Set the columns to return.
Parameters Parameters
---------- ----------
columns: list of str, or dict of str to str default None columns: list of str, or dict of str to str or Expr
List of column names to be fetched. List of column names to be fetched.
Or a dictionary of column names to SQL expressions. Or a dictionary of column names to SQL expressions or
:class:`~lancedb.expr.Expr` objects.
All columns are fetched if None or unspecified. All columns are fetched if None or unspecified.
Returns Returns
@@ -877,15 +879,15 @@ class LanceQueryBuilder(ABC):
raise ValueError("columns must be a list or a dictionary") raise ValueError("columns must be a list or a dictionary")
return self return self
def where(self, where: str, prefilter: bool = True) -> Self: def where(self, where: Union[str, Expr], prefilter: bool = True) -> Self:
"""Set the where clause. """Set the where clause.
Parameters Parameters
---------- ----------
where: str where: str or :class:`~lancedb.expr.Expr`
The where clause which is a valid SQL where clause. See The filter condition. Can be a SQL string or a type-safe
`Lance filter pushdown <https://lance.org/guide/read_and_write#filter-push-down>`_ :class:`~lancedb.expr.Expr` built with :func:`~lancedb.expr.col`
for valid SQL expressions. and :func:`~lancedb.expr.lit`.
prefilter: bool, default True prefilter: bool, default True
If True, apply the filter before vector search, otherwise the If True, apply the filter before vector search, otherwise the
filter is applied on the result of vector search. filter is applied on the result of vector search.
@@ -1355,15 +1357,17 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
return result_set return result_set
def where(self, where: str, prefilter: bool = None) -> LanceVectorQueryBuilder: def where(
self, where: Union[str, Expr], prefilter: bool = None
) -> LanceVectorQueryBuilder:
"""Set the where clause. """Set the where clause.
Parameters Parameters
---------- ----------
where: str where: str or :class:`~lancedb.expr.Expr`
The where clause which is a valid SQL where clause. See The filter condition. Can be a SQL string or a type-safe
`Lance filter pushdown <https://lance.org/guide/read_and_write#filter-push-down>`_ :class:`~lancedb.expr.Expr` built with :func:`~lancedb.expr.col`
for valid SQL expressions. and :func:`~lancedb.expr.lit`.
prefilter: bool, default True prefilter: bool, default True
If True, apply the filter before vector search, otherwise the If True, apply the filter before vector search, otherwise the
filter is applied on the result of vector search. filter is applied on the result of vector search.
@@ -2286,10 +2290,20 @@ class AsyncQueryBase(object):
""" """
if isinstance(columns, list) and all(isinstance(c, str) for c in columns): if isinstance(columns, list) and all(isinstance(c, str) for c in columns):
self._inner.select_columns(columns) self._inner.select_columns(columns)
elif isinstance(columns, dict) and all( elif isinstance(columns, dict) and all(isinstance(k, str) for k in columns):
isinstance(k, str) and isinstance(v, str) for k, v in columns.items() if any(isinstance(v, Expr) for v in columns.values()):
): # At least one value is an Expr — use the type-safe path.
self._inner.select(list(columns.items())) from .expr import _coerce
pairs = [(k, _coerce(v)._inner) for k, v in columns.items()]
self._inner.select_expr(pairs)
elif all(isinstance(v, str) for v in columns.values()):
self._inner.select(list(columns.items()))
else:
raise TypeError(
"dict values must be str or Expr, got "
+ str({k: type(v) for k, v in columns.items()})
)
else: else:
raise TypeError("columns must be a list of column names or a dict") raise TypeError("columns must be a list of column names or a dict")
return self return self
@@ -2529,11 +2543,13 @@ class AsyncStandardQuery(AsyncQueryBase):
""" """
super().__init__(inner) super().__init__(inner)
def where(self, predicate: str) -> Self: def where(self, predicate: Union[str, Expr]) -> Self:
""" """
Only return rows matching the given predicate Only return rows matching the given predicate
The predicate should be supplied as an SQL query string. The predicate can be a SQL string or a type-safe
:class:`~lancedb.expr.Expr` built with :func:`~lancedb.expr.col`
and :func:`~lancedb.expr.lit`.
Examples Examples
-------- --------
@@ -2545,7 +2561,10 @@ class AsyncStandardQuery(AsyncQueryBase):
Filtering performance can often be improved by creating a scalar index Filtering performance can often be improved by creating a scalar index
on the filter column(s). on the filter column(s).
""" """
self._inner.where(predicate) if isinstance(predicate, Expr):
self._inner.where_expr(predicate._inner)
else:
self._inner.where(predicate)
return self return self
def limit(self, limit: int) -> Self: def limit(self, limit: int) -> Self:

View File

@@ -568,4 +568,4 @@ class RemoteDBConnection(DBConnection):
async def close(self): async def close(self):
"""Close the connection to the database.""" """Close the connection to the database."""
self._client.close() self._conn.close()

View File

@@ -278,7 +278,7 @@ def _sanitize_data(
if metadata: if metadata:
new_metadata = target_schema.metadata or {} new_metadata = target_schema.metadata or {}
new_metadata = new_metadata.update(metadata) new_metadata.update(metadata)
target_schema = target_schema.with_metadata(new_metadata) target_schema = target_schema.with_metadata(new_metadata)
_validate_schema(target_schema) _validate_schema(target_schema)
@@ -3857,7 +3857,13 @@ class AsyncTable:
# _santitize_data is an old code path, but we will use it until the # _santitize_data is an old code path, but we will use it until the
# new code path is ready. # new code path is ready.
if on_bad_vectors != "error" or ( if mode == "overwrite":
# For overwrite, apply the same preprocessing as create_table
# so vector columns are inferred as FixedSizeList.
data, _ = sanitize_create_table(
data, None, on_bad_vectors=on_bad_vectors, fill_value=fill_value
)
elif on_bad_vectors != "error" or (
schema.metadata is not None and b"embedding_functions" in schema.metadata schema.metadata is not None and b"embedding_functions" in schema.metadata
): ):
data = _sanitize_data( data = _sanitize_data(
@@ -4205,7 +4211,7 @@ class AsyncTable:
async_query = async_query.offset(query.offset) async_query = async_query.offset(query.offset)
if query.columns: if query.columns:
async_query = async_query.select(query.columns) async_query = async_query.select(query.columns)
if query.filter: if query.filter is not None:
async_query = async_query.where(query.filter) async_query = async_query.where(query.filter)
if query.fast_search: if query.fast_search:
async_query = async_query.fast_search() async_query = async_query.fast_search()

View File

@@ -546,3 +546,24 @@ def test_openai_no_retry_on_401(mock_sleep):
assert mock_func.call_count == 1 assert mock_func.call_count == 1
# Verify that sleep was never called (no retries) # Verify that sleep was never called (no retries)
assert mock_sleep.call_count == 0 assert mock_sleep.call_count == 0
def test_url_retrieve_downloads_image():
"""
Embedding functions like open-clip, siglip, and jinaai use url_retrieve()
to download images from HTTP URLs. For example, open_clip._to_pil() calls:
PIL_Image.open(io.BytesIO(url_retrieve(image)))
Verify that url_retrieve() can download an image and open it as PIL Image,
matching the real usage pattern in embedding functions.
"""
import io
Image = pytest.importorskip("PIL.Image")
from lancedb.embeddings.utils import url_retrieve
image_url = "http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg"
image_bytes = url_retrieve(image_url)
img = Image.open(io.BytesIO(image_bytes))
assert img.size[0] > 0 and img.size[1] > 0

View File

@@ -8,6 +8,7 @@ import shutil
import pytest import pytest
import pyarrow as pa import pyarrow as pa
import lancedb import lancedb
from lance_namespace.errors import NamespaceNotEmptyError, TableNotFoundError
class TestNamespaceConnection: class TestNamespaceConnection:
@@ -130,7 +131,7 @@ class TestNamespaceConnection:
assert len(list(db.table_names(namespace=["test_ns"]))) == 0 assert len(list(db.table_names(namespace=["test_ns"]))) == 0
# Should not be able to open dropped table # Should not be able to open dropped table
with pytest.raises(RuntimeError): with pytest.raises(TableNotFoundError):
db.open_table("table1", namespace=["test_ns"]) db.open_table("table1", namespace=["test_ns"])
def test_create_table_with_schema(self): def test_create_table_with_schema(self):
@@ -340,7 +341,7 @@ class TestNamespaceConnection:
db.create_table("test_table", schema=schema, namespace=["test_namespace"]) db.create_table("test_table", schema=schema, namespace=["test_namespace"])
# Try to drop namespace with tables - should fail # Try to drop namespace with tables - should fail
with pytest.raises(RuntimeError, match="is not empty"): with pytest.raises(NamespaceNotEmptyError):
db.drop_namespace(["test_namespace"]) db.drop_namespace(["test_namespace"])
# Drop table first # Drop table first

View File

@@ -30,6 +30,7 @@ from lancedb.query import (
PhraseQuery, PhraseQuery,
Query, Query,
FullTextSearchQuery, FullTextSearchQuery,
ensure_vector_query,
) )
from lancedb.rerankers.cross_encoder import CrossEncoderReranker from lancedb.rerankers.cross_encoder import CrossEncoderReranker
from lancedb.table import AsyncTable, LanceTable from lancedb.table import AsyncTable, LanceTable
@@ -1501,6 +1502,18 @@ def test_search_empty_table(mem_db):
assert results == [] assert results == []
def test_ensure_vector_query_empty_list():
"""Regression: ensure_vector_query used to return instead of raise ValueError."""
with pytest.raises(ValueError, match="non-empty"):
ensure_vector_query([])
def test_ensure_vector_query_nested_empty_list():
"""Regression: ensure_vector_query used to return instead of raise ValueError."""
with pytest.raises(ValueError, match="non-empty"):
ensure_vector_query([[]])
def test_fast_search(tmp_path): def test_fast_search(tmp_path):
db = lancedb.connect(tmp_path) db = lancedb.connect(tmp_path)

View File

@@ -1201,6 +1201,18 @@ async def test_header_provider_overrides_static_headers():
await db.table_names() await db.table_names()
def test_close():
"""Test that close() works without AttributeError."""
import asyncio
def handler(req):
req.send_response(200)
req.end_headers()
with mock_lancedb_connection(handler) as db:
asyncio.run(db.close())
@pytest.mark.parametrize("exception", [KeyboardInterrupt, SystemExit, GeneratorExit]) @pytest.mark.parametrize("exception", [KeyboardInterrupt, SystemExit, GeneratorExit])
def test_background_loop_cancellation(exception): def test_background_loop_cancellation(exception):
"""Test that BackgroundEventLoop.run() cancels the future on interrupt.""" """Test that BackgroundEventLoop.run() cancels the future on interrupt."""

View File

@@ -527,6 +527,36 @@ async def test_add_async(mem_db_async: AsyncConnection):
assert await table.count_rows() == 3 assert await table.count_rows() == 3
def test_add_overwrite_infers_vector_schema(mem_db: DBConnection):
"""Overwrite should infer vector columns the same way create_table does.
Regression test for https://github.com/lancedb/lancedb/issues/3183
"""
table = mem_db.create_table(
"test_overwrite_vec",
data=[
{"vector": [1.0, 2.0, 3.0, 4.0], "item": "foo"},
{"vector": [5.0, 6.0, 7.0, 8.0], "item": "bar"},
],
)
# create_table infers vector as fixed_size_list<float32, 4>
original_type = table.schema.field("vector").type
assert pa.types.is_fixed_size_list(original_type)
# overwrite with plain Python lists (PyArrow infers list<double>)
table.add(
[
{"vector": [10.0, 20.0, 30.0, 40.0], "item": "baz"},
],
mode="overwrite",
)
# overwrite should infer vector column the same way as create_table
new_type = table.schema.field("vector").type
assert pa.types.is_fixed_size_list(new_type), (
f"Expected fixed_size_list after overwrite, got {new_type}"
)
def test_add_progress_callback(mem_db: DBConnection): def test_add_progress_callback(mem_db: DBConnection):
table = mem_db.create_table( table = mem_db.create_table(
"test", "test",
@@ -2143,3 +2173,33 @@ def test_table_uri(tmp_path):
db = lancedb.connect(tmp_path) db = lancedb.connect(tmp_path)
table = db.create_table("my_table", data=[{"x": 0}]) table = db.create_table("my_table", data=[{"x": 0}])
assert table.uri == str(tmp_path / "my_table.lance") assert table.uri == str(tmp_path / "my_table.lance")
def test_sanitize_data_metadata_not_stripped():
"""Regression test: dict.update() returns None, so assigning its result
would silently replace metadata with None, causing with_metadata(None)
to strip all schema metadata from the target schema."""
from lancedb.table import _sanitize_data
schema = pa.schema(
[pa.field("x", pa.int64())],
metadata={b"existing_key": b"existing_value"},
)
batch = pa.record_batch([pa.array([1, 2, 3])], schema=schema)
# Use a different field type so the reader and target schemas differ,
# forcing _cast_to_target_schema to rebuild the schema with the
# target's metadata (instead of taking the fast-path).
target_schema = pa.schema(
[pa.field("x", pa.int32())],
metadata={b"existing_key": b"existing_value"},
)
reader = pa.RecordBatchReader.from_batches(schema, [batch])
metadata = {b"new_key": b"new_value"}
result = _sanitize_data(reader, target_schema=target_schema, metadata=metadata)
result_schema = result.schema
assert result_schema.metadata is not None
assert result_schema.metadata[b"existing_key"] == b"existing_value"
assert result_schema.metadata[b"new_key"] == b"new_value"

175
python/src/expr.rs Normal file
View File

@@ -0,0 +1,175 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! PyO3 bindings for the LanceDB expression builder API.
//!
//! This module exposes [`PyExpr`] and helper free functions so Python can
//! build type-safe filter / projection expressions that map directly to
//! DataFusion [`Expr`] nodes, bypassing SQL string parsing.
use arrow::{datatypes::DataType, pyarrow::PyArrowType};
use lancedb::expr::{DfExpr, col as ldb_col, contains, expr_cast, lit as df_lit, lower, upper};
use pyo3::{Bound, PyAny, PyResult, exceptions::PyValueError, prelude::*, pyfunction};
/// A type-safe DataFusion expression.
///
/// Instances are constructed via the free functions [`expr_col`] and
/// [`expr_lit`] and combined with the methods on this struct. On the Python
/// side a thin wrapper class (`lancedb.expr.Expr`) delegates to these methods
/// and adds Python operator overloads.
#[pyclass(name = "PyExpr")]
#[derive(Clone)]
pub struct PyExpr(pub DfExpr);
#[pymethods]
impl PyExpr {
// ── comparisons ──────────────────────────────────────────────────────────
fn eq(&self, other: &Self) -> Self {
Self(self.0.clone().eq(other.0.clone()))
}
fn ne(&self, other: &Self) -> Self {
Self(self.0.clone().not_eq(other.0.clone()))
}
fn lt(&self, other: &Self) -> Self {
Self(self.0.clone().lt(other.0.clone()))
}
fn lte(&self, other: &Self) -> Self {
Self(self.0.clone().lt_eq(other.0.clone()))
}
fn gt(&self, other: &Self) -> Self {
Self(self.0.clone().gt(other.0.clone()))
}
fn gte(&self, other: &Self) -> Self {
Self(self.0.clone().gt_eq(other.0.clone()))
}
// ── logical ──────────────────────────────────────────────────────────────
fn and_(&self, other: &Self) -> Self {
Self(self.0.clone().and(other.0.clone()))
}
fn or_(&self, other: &Self) -> Self {
Self(self.0.clone().or(other.0.clone()))
}
fn not_(&self) -> Self {
use std::ops::Not;
Self(self.0.clone().not())
}
// ── arithmetic ───────────────────────────────────────────────────────────
fn add(&self, other: &Self) -> Self {
use std::ops::Add;
Self(self.0.clone().add(other.0.clone()))
}
fn sub(&self, other: &Self) -> Self {
use std::ops::Sub;
Self(self.0.clone().sub(other.0.clone()))
}
fn mul(&self, other: &Self) -> Self {
use std::ops::Mul;
Self(self.0.clone().mul(other.0.clone()))
}
fn div(&self, other: &Self) -> Self {
use std::ops::Div;
Self(self.0.clone().div(other.0.clone()))
}
// ── string functions ─────────────────────────────────────────────────────
/// Convert string column to lowercase.
fn lower(&self) -> Self {
Self(lower(self.0.clone()))
}
/// Convert string column to uppercase.
fn upper(&self) -> Self {
Self(upper(self.0.clone()))
}
/// Test whether the string contains `substr`.
fn contains(&self, substr: &Self) -> Self {
Self(contains(self.0.clone(), substr.0.clone()))
}
// ── type cast ────────────────────────────────────────────────────────────
/// Cast the expression to `data_type`.
///
/// `data_type` must be a PyArrow `DataType` (e.g. `pa.int32()`).
/// On the Python side, `lancedb.expr.Expr.cast` also accepts type name
/// strings via `pa.lib.ensure_type` before forwarding here.
fn cast(&self, data_type: PyArrowType<DataType>) -> Self {
Self(expr_cast(self.0.clone(), data_type.0))
}
// ── utilities ────────────────────────────────────────────────────────────
/// Render the expression as a SQL string (useful for debugging).
fn to_sql(&self) -> PyResult<String> {
lancedb::expr::expr_to_sql_string(&self.0).map_err(|e| PyValueError::new_err(e.to_string()))
}
fn __repr__(&self) -> PyResult<String> {
let sql =
lancedb::expr::expr_to_sql_string(&self.0).unwrap_or_else(|_| "<expr>".to_string());
Ok(format!("PyExpr({})", sql))
}
}
// ── free functions ────────────────────────────────────────────────────────────
/// Create a column reference expression.
///
/// The column name is preserved exactly as given (case-sensitive), so
/// `col("firstName")` correctly references a field named `firstName`.
#[pyfunction]
pub fn expr_col(name: &str) -> PyExpr {
PyExpr(ldb_col(name))
}
/// Create a literal value expression.
///
/// Supported Python types: `bool`, `int`, `float`, `str`.
#[pyfunction]
pub fn expr_lit(value: Bound<'_, PyAny>) -> PyResult<PyExpr> {
// bool must be checked before int because bool is a subclass of int in Python
if let Ok(b) = value.extract::<bool>() {
return Ok(PyExpr(df_lit(b)));
}
if let Ok(i) = value.extract::<i64>() {
return Ok(PyExpr(df_lit(i)));
}
if let Ok(f) = value.extract::<f64>() {
return Ok(PyExpr(df_lit(f)));
}
if let Ok(s) = value.extract::<String>() {
return Ok(PyExpr(df_lit(s)));
}
Err(PyValueError::new_err(format!(
"unsupported literal type: {}. Supported: bool, int, float, str",
value.get_type().name()?
)))
}
/// Call an arbitrary registered SQL function by name.
///
/// See `lancedb::expr::func` for the list of supported function names.
#[pyfunction]
pub fn expr_func(name: &str, args: Vec<PyExpr>) -> PyResult<PyExpr> {
let df_args: Vec<DfExpr> = args.into_iter().map(|e| e.0).collect();
lancedb::expr::func(name, df_args)
.map(PyExpr)
.map_err(|e| PyValueError::new_err(e.to_string()))
}

View File

@@ -4,6 +4,7 @@
use arrow::RecordBatchStream; use arrow::RecordBatchStream;
use connection::{Connection, connect}; use connection::{Connection, connect};
use env_logger::Env; use env_logger::Env;
use expr::{PyExpr, expr_col, expr_func, expr_lit};
use index::IndexConfig; use index::IndexConfig;
use permutation::{PyAsyncPermutationBuilder, PyPermutationReader}; use permutation::{PyAsyncPermutationBuilder, PyPermutationReader};
use pyo3::{ use pyo3::{
@@ -21,6 +22,7 @@ use table::{
pub mod arrow; pub mod arrow;
pub mod connection; pub mod connection;
pub mod error; pub mod error;
pub mod expr;
pub mod header; pub mod header;
pub mod index; pub mod index;
pub mod namespace; pub mod namespace;
@@ -55,10 +57,14 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<UpdateResult>()?; m.add_class::<UpdateResult>()?;
m.add_class::<PyAsyncPermutationBuilder>()?; m.add_class::<PyAsyncPermutationBuilder>()?;
m.add_class::<PyPermutationReader>()?; m.add_class::<PyPermutationReader>()?;
m.add_class::<PyExpr>()?;
m.add_function(wrap_pyfunction!(connect, m)?)?; m.add_function(wrap_pyfunction!(connect, m)?)?;
m.add_function(wrap_pyfunction!(permutation::async_permutation_builder, m)?)?; m.add_function(wrap_pyfunction!(permutation::async_permutation_builder, m)?)?;
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?; m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;
m.add_function(wrap_pyfunction!(query::fts_query_to_json, m)?)?; m.add_function(wrap_pyfunction!(query::fts_query_to_json, m)?)?;
m.add_function(wrap_pyfunction!(expr_col, m)?)?;
m.add_function(wrap_pyfunction!(expr_lit, m)?)?;
m.add_function(wrap_pyfunction!(expr_func, m)?)?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?; m.add("__version__", env!("CARGO_PKG_VERSION"))?;
Ok(()) Ok(())
} }

View File

@@ -35,12 +35,10 @@ use pyo3::types::PyList;
use pyo3::types::{PyDict, PyString}; use pyo3::types::{PyDict, PyString};
use pyo3::{FromPyObject, exceptions::PyRuntimeError}; use pyo3::{FromPyObject, exceptions::PyRuntimeError};
use pyo3::{PyErr, pyclass}; use pyo3::{PyErr, pyclass};
use pyo3::{ use pyo3::{exceptions::PyValueError, intern};
exceptions::{PyNotImplementedError, PyValueError},
intern,
};
use pyo3_async_runtimes::tokio::future_into_py; use pyo3_async_runtimes::tokio::future_into_py;
use crate::expr::PyExpr;
use crate::util::parse_distance_type; use crate::util::parse_distance_type;
use crate::{arrow::RecordBatchStream, util::PyLanceDB}; use crate::{arrow::RecordBatchStream, util::PyLanceDB};
use crate::{error::PythonErrorExt, index::class_name}; use crate::{error::PythonErrorExt, index::class_name};
@@ -344,9 +342,13 @@ impl<'py> IntoPyObject<'py> for PyQueryFilter {
fn into_pyobject(self, py: pyo3::Python<'py>) -> PyResult<Self::Output> { fn into_pyobject(self, py: pyo3::Python<'py>) -> PyResult<Self::Output> {
match self.0 { match self.0 {
QueryFilter::Datafusion(_) => Err(PyNotImplementedError::new_err( QueryFilter::Datafusion(expr) => {
"Datafusion filter has no conversion to Python", // Serialize the DataFusion expression to a SQL string so that
)), // callers (e.g. remote tables) see the same format as Sql.
let sql = lancedb::expr::expr_to_sql_string(&expr)
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
Ok(sql.into_pyobject(py)?.into_any())
}
QueryFilter::Sql(sql) => Ok(sql.into_pyobject(py)?.into_any()), QueryFilter::Sql(sql) => Ok(sql.into_pyobject(py)?.into_any()),
QueryFilter::Substrait(substrait) => Ok(substrait.into_pyobject(py)?.into_any()), QueryFilter::Substrait(substrait) => Ok(substrait.into_pyobject(py)?.into_any()),
} }
@@ -370,10 +372,20 @@ impl Query {
self.inner = self.inner.clone().only_if(predicate); self.inner = self.inner.clone().only_if(predicate);
} }
pub fn where_expr(&mut self, expr: PyExpr) {
self.inner = self.inner.clone().only_if_expr(expr.0);
}
pub fn select(&mut self, columns: Vec<(String, String)>) { pub fn select(&mut self, columns: Vec<(String, String)>) {
self.inner = self.inner.clone().select(Select::dynamic(&columns)); self.inner = self.inner.clone().select(Select::dynamic(&columns));
} }
pub fn select_expr(&mut self, columns: Vec<(String, PyExpr)>) {
let pairs: Vec<(String, lancedb::expr::DfExpr)> =
columns.into_iter().map(|(name, e)| (name, e.0)).collect();
self.inner = self.inner.clone().select(Select::Expr(pairs));
}
pub fn select_columns(&mut self, columns: Vec<String>) { pub fn select_columns(&mut self, columns: Vec<String>) {
self.inner = self.inner.clone().select(Select::columns(&columns)); self.inner = self.inner.clone().select(Select::columns(&columns));
} }
@@ -607,10 +619,20 @@ impl FTSQuery {
self.inner = self.inner.clone().only_if(predicate); self.inner = self.inner.clone().only_if(predicate);
} }
pub fn where_expr(&mut self, expr: PyExpr) {
self.inner = self.inner.clone().only_if_expr(expr.0);
}
pub fn select(&mut self, columns: Vec<(String, String)>) { pub fn select(&mut self, columns: Vec<(String, String)>) {
self.inner = self.inner.clone().select(Select::dynamic(&columns)); self.inner = self.inner.clone().select(Select::dynamic(&columns));
} }
pub fn select_expr(&mut self, columns: Vec<(String, PyExpr)>) {
let pairs: Vec<(String, lancedb::expr::DfExpr)> =
columns.into_iter().map(|(name, e)| (name, e.0)).collect();
self.inner = self.inner.clone().select(Select::Expr(pairs));
}
pub fn select_columns(&mut self, columns: Vec<String>) { pub fn select_columns(&mut self, columns: Vec<String>) {
self.inner = self.inner.clone().select(Select::columns(&columns)); self.inner = self.inner.clone().select(Select::columns(&columns));
} }
@@ -725,6 +747,10 @@ impl VectorQuery {
self.inner = self.inner.clone().only_if(predicate); self.inner = self.inner.clone().only_if(predicate);
} }
pub fn where_expr(&mut self, expr: PyExpr) {
self.inner = self.inner.clone().only_if_expr(expr.0);
}
pub fn add_query_vector(&mut self, vector: Bound<'_, PyAny>) -> PyResult<()> { pub fn add_query_vector(&mut self, vector: Bound<'_, PyAny>) -> PyResult<()> {
let data: ArrayData = ArrayData::from_pyarrow_bound(&vector)?; let data: ArrayData = ArrayData::from_pyarrow_bound(&vector)?;
let array = make_array(data); let array = make_array(data);
@@ -736,6 +762,12 @@ impl VectorQuery {
self.inner = self.inner.clone().select(Select::dynamic(&columns)); self.inner = self.inner.clone().select(Select::dynamic(&columns));
} }
pub fn select_expr(&mut self, columns: Vec<(String, PyExpr)>) {
let pairs: Vec<(String, lancedb::expr::DfExpr)> =
columns.into_iter().map(|(name, e)| (name, e.0)).collect();
self.inner = self.inner.clone().select(Select::Expr(pairs));
}
pub fn select_columns(&mut self, columns: Vec<String>) { pub fn select_columns(&mut self, columns: Vec<String>) {
self.inner = self.inner.clone().select(Select::columns(&columns)); self.inner = self.inner.clone().select(Select::columns(&columns));
} }
@@ -890,11 +922,21 @@ impl HybridQuery {
self.inner_fts.r#where(predicate); self.inner_fts.r#where(predicate);
} }
pub fn where_expr(&mut self, expr: PyExpr) {
self.inner_vec.where_expr(expr.clone());
self.inner_fts.where_expr(expr);
}
pub fn select(&mut self, columns: Vec<(String, String)>) { pub fn select(&mut self, columns: Vec<(String, String)>) {
self.inner_vec.select(columns.clone()); self.inner_vec.select(columns.clone());
self.inner_fts.select(columns); self.inner_fts.select(columns);
} }
pub fn select_expr(&mut self, columns: Vec<(String, PyExpr)>) {
self.inner_vec.select_expr(columns.clone());
self.inner_fts.select_expr(columns);
}
pub fn select_columns(&mut self, columns: Vec<String>) { pub fn select_columns(&mut self, columns: Vec<String>) {
self.inner_vec.select_columns(columns.clone()); self.inner_vec.select_columns(columns.clone());
self.inner_fts.select_columns(columns); self.inner_fts.select_columns(columns);

387
python/tests/test_expr.py Normal file
View File

@@ -0,0 +1,387 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
"""Tests for the type-safe expression builder API."""
import pytest
import pyarrow as pa
import lancedb
from lancedb.expr import Expr, col, lit, func
# ── unit tests for Expr construction ─────────────────────────────────────────
class TestExprConstruction:
def test_col_returns_expr(self):
e = col("age")
assert isinstance(e, Expr)
def test_lit_int(self):
e = lit(42)
assert isinstance(e, Expr)
def test_lit_float(self):
e = lit(3.14)
assert isinstance(e, Expr)
def test_lit_str(self):
e = lit("hello")
assert isinstance(e, Expr)
def test_lit_bool(self):
e = lit(True)
assert isinstance(e, Expr)
def test_lit_unsupported_type_raises(self):
with pytest.raises(Exception):
lit([1, 2, 3])
def test_func(self):
e = func("lower", col("name"))
assert isinstance(e, Expr)
assert e.to_sql() == "lower(name)"
def test_func_unknown_raises(self):
with pytest.raises(Exception):
func("not_a_real_function", col("x"))
class TestExprOperators:
def test_eq_operator(self):
e = col("x") == lit(1)
assert isinstance(e, Expr)
assert e.to_sql() == "(x = 1)"
def test_ne_operator(self):
e = col("x") != lit(1)
assert isinstance(e, Expr)
assert e.to_sql() == "(x <> 1)"
def test_lt_operator(self):
e = col("age") < lit(18)
assert isinstance(e, Expr)
assert e.to_sql() == "(age < 18)"
def test_le_operator(self):
e = col("age") <= lit(18)
assert isinstance(e, Expr)
assert e.to_sql() == "(age <= 18)"
def test_gt_operator(self):
e = col("age") > lit(18)
assert isinstance(e, Expr)
assert e.to_sql() == "(age > 18)"
def test_ge_operator(self):
e = col("age") >= lit(18)
assert isinstance(e, Expr)
assert e.to_sql() == "(age >= 18)"
def test_and_operator(self):
e = (col("age") > lit(18)) & (col("status") == lit("active"))
assert isinstance(e, Expr)
assert e.to_sql() == "((age > 18) AND (status = 'active'))"
def test_or_operator(self):
e = (col("a") == lit(1)) | (col("b") == lit(2))
assert isinstance(e, Expr)
assert e.to_sql() == "((a = 1) OR (b = 2))"
def test_invert_operator(self):
e = ~(col("active") == lit(True))
assert isinstance(e, Expr)
assert e.to_sql() == "NOT (active = true)"
def test_add_operator(self):
e = col("x") + lit(1)
assert isinstance(e, Expr)
assert e.to_sql() == "(x + 1)"
def test_sub_operator(self):
e = col("x") - lit(1)
assert isinstance(e, Expr)
assert e.to_sql() == "(x - 1)"
def test_mul_operator(self):
e = col("price") * lit(1.1)
assert isinstance(e, Expr)
assert e.to_sql() == "(price * 1.1)"
def test_div_operator(self):
e = col("total") / lit(2)
assert isinstance(e, Expr)
assert e.to_sql() == "(total / 2)"
def test_radd(self):
e = lit(1) + col("x")
assert isinstance(e, Expr)
assert e.to_sql() == "(1 + x)"
def test_rmul(self):
e = lit(2) * col("x")
assert isinstance(e, Expr)
assert e.to_sql() == "(2 * x)"
def test_coerce_plain_int(self):
# Operators should auto-wrap plain Python values via lit()
e = col("age") > 18
assert isinstance(e, Expr)
assert e.to_sql() == "(age > 18)"
def test_coerce_plain_str(self):
e = col("name") == "alice"
assert isinstance(e, Expr)
assert e.to_sql() == "(name = 'alice')"
class TestExprStringMethods:
def test_lower(self):
e = col("name").lower()
assert isinstance(e, Expr)
assert e.to_sql() == "lower(name)"
def test_upper(self):
e = col("name").upper()
assert isinstance(e, Expr)
assert e.to_sql() == "upper(name)"
def test_contains(self):
e = col("text").contains(lit("hello"))
assert isinstance(e, Expr)
assert e.to_sql() == "contains(text, 'hello')"
def test_contains_with_str_coerce(self):
e = col("text").contains("hello")
assert isinstance(e, Expr)
assert e.to_sql() == "contains(text, 'hello')"
def test_chained_lower_eq(self):
e = col("name").lower() == lit("alice")
assert isinstance(e, Expr)
assert e.to_sql() == "(lower(name) = 'alice')"
class TestExprCast:
def test_cast_string(self):
e = col("id").cast("string")
assert isinstance(e, Expr)
assert e.to_sql() == "CAST(id AS VARCHAR)"
def test_cast_int32(self):
e = col("score").cast("int32")
assert isinstance(e, Expr)
assert e.to_sql() == "CAST(score AS INTEGER)"
def test_cast_float64(self):
e = col("val").cast("float64")
assert isinstance(e, Expr)
assert e.to_sql() == "CAST(val AS DOUBLE)"
def test_cast_pyarrow_type(self):
e = col("score").cast(pa.int32())
assert isinstance(e, Expr)
assert e.to_sql() == "CAST(score AS INTEGER)"
def test_cast_pyarrow_float64(self):
e = col("val").cast(pa.float64())
assert isinstance(e, Expr)
assert e.to_sql() == "CAST(val AS DOUBLE)"
def test_cast_pyarrow_string(self):
e = col("id").cast(pa.string())
assert isinstance(e, Expr)
assert e.to_sql() == "CAST(id AS VARCHAR)"
def test_cast_pyarrow_and_string_equivalent(self):
# pa.int32() and "int32" should produce equivalent SQL
sql_str = col("x").cast("int32").to_sql()
sql_pa = col("x").cast(pa.int32()).to_sql()
assert sql_str == sql_pa
class TestExprNamedMethods:
def test_eq_method(self):
e = col("x").eq(lit(1))
assert isinstance(e, Expr)
assert e.to_sql() == "(x = 1)"
def test_gt_method(self):
e = col("x").gt(lit(0))
assert isinstance(e, Expr)
assert e.to_sql() == "(x > 0)"
def test_and_method(self):
e = col("x").gt(lit(0)).and_(col("y").lt(lit(10)))
assert isinstance(e, Expr)
assert e.to_sql() == "((x > 0) AND (y < 10))"
def test_or_method(self):
e = col("x").eq(lit(1)).or_(col("x").eq(lit(2)))
assert isinstance(e, Expr)
assert e.to_sql() == "((x = 1) OR (x = 2))"
class TestExprRepr:
def test_repr(self):
e = col("age") > lit(18)
assert repr(e) == "Expr((age > 18))"
def test_to_sql(self):
e = col("age") > 18
assert e.to_sql() == "(age > 18)"
def test_unhashable(self):
e = col("x")
with pytest.raises(TypeError):
{e: 1}
# ── integration tests: end-to-end query against a real table ─────────────────
@pytest.fixture
def simple_table(tmp_path):
db = lancedb.connect(str(tmp_path))
data = pa.table(
{
"id": [1, 2, 3, 4, 5],
"name": ["Alice", "Bob", "Charlie", "alice", "BOB"],
"age": [25, 17, 30, 22, 15],
"score": [1.5, 2.0, 3.5, 4.0, 0.5],
}
)
return db.create_table("test", data)
class TestExprFilter:
def test_simple_gt_filter(self, simple_table):
result = simple_table.search().where(col("age") > lit(20)).to_arrow()
assert result.num_rows == 3 # ages 25, 30, 22
def test_compound_and_filter(self, simple_table):
result = (
simple_table.search()
.where((col("age") > lit(18)) & (col("score") > lit(2.0)))
.to_arrow()
)
assert result.num_rows == 2 # (30, 3.5) and (22, 4.0)
def test_string_equality_filter(self, simple_table):
result = simple_table.search().where(col("name") == lit("Bob")).to_arrow()
assert result.num_rows == 1
def test_or_filter(self, simple_table):
result = (
simple_table.search()
.where((col("age") < lit(18)) | (col("age") > lit(28)))
.to_arrow()
)
assert result.num_rows == 3 # ages 17, 30, 15
def test_coercion_no_lit(self, simple_table):
# Python values should be auto-coerced
result = simple_table.search().where(col("age") > 20).to_arrow()
assert result.num_rows == 3
def test_string_sql_still_works(self, simple_table):
# Backwards compatibility: plain strings still accepted
result = simple_table.search().where("age > 20").to_arrow()
assert result.num_rows == 3
class TestExprProjection:
def test_select_with_expr(self, simple_table):
result = (
simple_table.search()
.select({"double_score": col("score") * lit(2)})
.to_arrow()
)
assert "double_score" in result.schema.names
def test_select_mixed_str_and_expr(self, simple_table):
result = (
simple_table.search()
.select({"id": "id", "double_score": col("score") * lit(2)})
.to_arrow()
)
assert "id" in result.schema.names
assert "double_score" in result.schema.names
def test_select_list_of_columns(self, simple_table):
# Plain list of str still works
result = simple_table.search().select(["id", "name"]).to_arrow()
assert result.schema.names == ["id", "name"]
# ── column name edge cases ────────────────────────────────────────────────────
class TestColNaming:
"""Unit tests verifying that col() preserves identifiers exactly.
Identifiers that need quoting (camelCase, spaces, leading digits, unicode)
are wrapped in backticks to match the lance SQL parser's dialect.
"""
def test_camel_case_preserved_in_sql(self):
# camelCase is quoted with backticks so the case round-trips correctly.
assert col("firstName").to_sql() == "`firstName`"
def test_camel_case_in_expression(self):
assert (col("firstName") > lit(18)).to_sql() == "(`firstName` > 18)"
def test_space_in_name_quoted(self):
assert col("first name").to_sql() == "`first name`"
def test_space_in_expression(self):
assert (col("first name") == lit("A")).to_sql() == "(`first name` = 'A')"
def test_leading_digit_quoted(self):
assert col("2fast").to_sql() == "`2fast`"
def test_unicode_quoted(self):
assert col("名前").to_sql() == "`名前`"
def test_snake_case_unquoted(self):
# Plain snake_case needs no quoting.
assert col("first_name").to_sql() == "first_name"
@pytest.fixture
def special_col_table(tmp_path):
db = lancedb.connect(str(tmp_path))
data = pa.table(
{
"firstName": ["Alice", "Bob", "Charlie"],
"first name": ["A", "B", "C"],
"score": [10, 20, 30],
}
)
return db.create_table("special", data)
class TestColNamingIntegration:
def test_camel_case_filter(self, special_col_table):
result = (
special_col_table.search()
.where(col("firstName") == lit("Alice"))
.to_arrow()
)
assert result.num_rows == 1
assert result["firstName"][0].as_py() == "Alice"
def test_space_in_col_filter(self, special_col_table):
result = (
special_col_table.search().where(col("first name") == lit("B")).to_arrow()
)
assert result.num_rows == 1
def test_camel_case_projection(self, special_col_table):
result = (
special_col_table.search()
.select({"upper_name": col("firstName").upper()})
.to_arrow()
)
assert "upper_name" in result.schema.names
assert sorted(result["upper_name"].to_pylist()) == ["ALICE", "BOB", "CHARLIE"]

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "lancedb" name = "lancedb"
version = "0.27.2-beta.0" version = "0.27.2-beta.1"
edition.workspace = true edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications" description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true license.workspace = true

View File

@@ -240,7 +240,7 @@ impl Shuffler {
.await?; .await?;
// Need to read the entire file in a single batch for in-memory shuffling // Need to read the entire file in a single batch for in-memory shuffling
let batch = reader.read_record_batch(0, reader.num_rows()).await?; let batch = reader.read_record_batch(0, reader.num_rows()).await?;
let mut rng = rng.lock().unwrap(); let mut rng = rng.lock().unwrap_or_else(|e| e.into_inner());
Self::shuffle_batch(&batch, &mut rng, clump_size) Self::shuffle_batch(&batch, &mut rng, clump_size)
} }
}) })

View File

@@ -27,7 +27,17 @@ use arrow_schema::DataType;
use datafusion_expr::{Expr, ScalarUDF, expr_fn::cast}; use datafusion_expr::{Expr, ScalarUDF, expr_fn::cast};
use datafusion_functions::string::expr_fn as string_expr_fn; use datafusion_functions::string::expr_fn as string_expr_fn;
pub use datafusion_expr::{col, lit}; pub use datafusion_expr::lit;
/// Create a column reference expression, preserving the name exactly as given.
///
/// Unlike DataFusion's built-in [`col`][datafusion_expr::col], this function
/// does **not** normalise the identifier to lower-case, so
/// `col("firstName")` correctly references a field named `firstName`.
pub fn col(name: impl Into<String>) -> DfExpr {
use datafusion_common::Column;
DfExpr::Column(Column::new_unqualified(name))
}
pub use datafusion_expr::Expr as DfExpr; pub use datafusion_expr::Expr as DfExpr;

View File

@@ -2,11 +2,37 @@
// SPDX-FileCopyrightText: Copyright The LanceDB Authors // SPDX-FileCopyrightText: Copyright The LanceDB Authors
use datafusion_expr::Expr; use datafusion_expr::Expr;
use datafusion_sql::unparser; use datafusion_sql::unparser::{self, dialect::Dialect};
/// Unparser dialect that matches the quoting style expected by the Lance SQL
/// parser. Lance uses backtick (`` ` ``) as the only delimited-identifier
/// quote character, so we must produce `` `firstName` `` rather than
/// `"firstName"` for identifiers that require quoting.
///
/// We quote an identifier when it:
/// * is a SQL reserved word, OR
/// * contains characters outside `[a-zA-Z0-9_]`, OR
/// * starts with a digit, OR
/// * contains upper-case letters (unquoted identifiers are normalised to
/// lower-case by the SQL parser, which would break case-sensitive schemas).
struct LanceSqlDialect;
impl Dialect for LanceSqlDialect {
fn identifier_quote_style(&self, identifier: &str) -> Option<char> {
let needs_quote = identifier.chars().any(|c| c.is_ascii_uppercase())
|| !identifier
.chars()
.enumerate()
.all(|(i, c)| c == '_' || c.is_ascii_alphabetic() || (i > 0 && c.is_ascii_digit()));
if needs_quote { Some('`') } else { None }
}
}
pub fn expr_to_sql_string(expr: &Expr) -> crate::Result<String> { pub fn expr_to_sql_string(expr: &Expr) -> crate::Result<String> {
let ast = unparser::expr_to_sql(expr).map_err(|e| crate::Error::InvalidInput { let ast = unparser::Unparser::new(&LanceSqlDialect)
message: format!("failed to serialize expression to SQL: {}", e), .expr_to_sql(expr)
})?; .map_err(|e| crate::Error::InvalidInput {
message: format!("failed to serialize expression to SQL: {}", e),
})?;
Ok(ast.to_string()) Ok(ast.to_string())
} }

View File

@@ -66,13 +66,13 @@ impl IoTrackingStore {
} }
fn record_read(&self, num_bytes: u64) { fn record_read(&self, num_bytes: u64) {
let mut stats = self.stats.lock().unwrap(); let mut stats = self.stats.lock().unwrap_or_else(|e| e.into_inner());
stats.read_iops += 1; stats.read_iops += 1;
stats.read_bytes += num_bytes; stats.read_bytes += num_bytes;
} }
fn record_write(&self, num_bytes: u64) { fn record_write(&self, num_bytes: u64) {
let mut stats = self.stats.lock().unwrap(); let mut stats = self.stats.lock().unwrap_or_else(|e| e.into_inner());
stats.write_iops += 1; stats.write_iops += 1;
stats.write_bytes += num_bytes; stats.write_bytes += num_bytes;
} }
@@ -229,10 +229,63 @@ impl MultipartUpload for IoTrackingMultipartUpload {
fn put_part(&mut self, payload: PutPayload) -> UploadPart { fn put_part(&mut self, payload: PutPayload) -> UploadPart {
{ {
let mut stats = self.stats.lock().unwrap(); let mut stats = self.stats.lock().unwrap_or_else(|e| e.into_inner());
stats.write_iops += 1; stats.write_iops += 1;
stats.write_bytes += payload.content_length() as u64; stats.write_bytes += payload.content_length() as u64;
} }
self.target.put_part(payload) self.target.put_part(payload)
} }
} }
#[cfg(test)]
mod tests {
use super::*;
/// Helper: poison a Mutex<IoStats> by panicking while holding the lock.
fn poison_stats(stats: &Arc<Mutex<IoStats>>) {
let stats_clone = stats.clone();
let handle = std::thread::spawn(move || {
let _guard = stats_clone.lock().unwrap();
panic!("intentional panic to poison stats mutex");
});
let _ = handle.join();
assert!(stats.lock().is_err(), "mutex should be poisoned");
}
#[test]
fn test_record_read_recovers_from_poisoned_lock() {
let stats = Arc::new(Mutex::new(IoStats::default()));
let store = IoTrackingStore {
target: Arc::new(object_store::memory::InMemory::new()),
stats: stats.clone(),
};
poison_stats(&stats);
// record_read should not panic
store.record_read(1024);
// Verify the stats were updated despite poisoning
let s = stats.lock().unwrap_or_else(|e| e.into_inner());
assert_eq!(s.read_iops, 1);
assert_eq!(s.read_bytes, 1024);
}
#[test]
fn test_record_write_recovers_from_poisoned_lock() {
let stats = Arc::new(Mutex::new(IoStats::default()));
let store = IoTrackingStore {
target: Arc::new(object_store::memory::InMemory::new()),
stats: stats.clone(),
};
poison_stats(&stats);
// record_write should not panic
store.record_write(2048);
let s = stats.lock().unwrap_or_else(|e| e.into_inner());
assert_eq!(s.write_iops, 1);
assert_eq!(s.write_bytes, 2048);
}
}

View File

@@ -5,7 +5,7 @@ use std::sync::Arc;
use std::{future::Future, time::Duration}; use std::{future::Future, time::Duration};
use arrow::compute::concat_batches; use arrow::compute::concat_batches;
use arrow_array::{Array, Float16Array, Float32Array, Float64Array, make_array}; use arrow_array::{Array, Float16Array, Float32Array, Float64Array, RecordBatch, make_array};
use arrow_schema::{DataType, SchemaRef}; use arrow_schema::{DataType, SchemaRef};
use datafusion_expr::Expr; use datafusion_expr::Expr;
use datafusion_physical_plan::ExecutionPlan; use datafusion_physical_plan::ExecutionPlan;
@@ -17,15 +17,17 @@ use lance_datafusion::exec::execute_plan;
use lance_index::scalar::FullTextSearchQuery; use lance_index::scalar::FullTextSearchQuery;
use lance_index::scalar::inverted::SCORE_COL; use lance_index::scalar::inverted::SCORE_COL;
use lance_index::vector::DIST_COL; use lance_index::vector::DIST_COL;
use lance_io::stream::RecordBatchStreamAdapter;
use crate::DistanceType; use crate::DistanceType;
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::rerankers::rrf::RRFReranker; use crate::rerankers::rrf::RRFReranker;
use crate::rerankers::{NormalizeMethod, Reranker, check_reranker_result}; use crate::rerankers::{NormalizeMethod, Reranker, check_reranker_result};
use crate::table::BaseTable; use crate::table::BaseTable;
use crate::utils::TimeoutStream; use crate::utils::{MaxBatchLengthStream, TimeoutStream};
use crate::{arrow::SendableRecordBatchStream, table::AnyQuery}; use crate::{
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
table::AnyQuery,
};
mod hybrid; mod hybrid;
@@ -604,6 +606,14 @@ impl Default for QueryExecutionOptions {
} }
} }
impl QueryExecutionOptions {
fn without_output_batch_length_limit(&self) -> Self {
let mut options = self.clone();
options.max_batch_length = 0;
options
}
}
/// A trait for a query object that can be executed to get results /// A trait for a query object that can be executed to get results
/// ///
/// There are various kinds of queries but they all return results /// There are various kinds of queries but they all return results
@@ -1180,6 +1190,8 @@ impl VectorQuery {
&self, &self,
options: QueryExecutionOptions, options: QueryExecutionOptions,
) -> Result<SendableRecordBatchStream> { ) -> Result<SendableRecordBatchStream> {
let max_batch_length = options.max_batch_length as usize;
let internal_options = options.without_output_batch_length_limit();
// clone query and specify we want to include row IDs, which can be needed for reranking // clone query and specify we want to include row IDs, which can be needed for reranking
let mut fts_query = Query::new(self.parent.clone()); let mut fts_query = Query::new(self.parent.clone());
fts_query.request = self.request.base.clone(); fts_query.request = self.request.base.clone();
@@ -1189,8 +1201,8 @@ impl VectorQuery {
vector_query.request.base.full_text_search = None; vector_query.request.base.full_text_search = None;
let (fts_results, vec_results) = try_join!( let (fts_results, vec_results) = try_join!(
fts_query.execute_with_options(options.clone()), fts_query.execute_with_options(internal_options.clone()),
vector_query.inner_execute_with_options(options) vector_query.inner_execute_with_options(internal_options)
)?; )?;
let (fts_results, vec_results) = try_join!( let (fts_results, vec_results) = try_join!(
@@ -1245,9 +1257,7 @@ impl VectorQuery {
results = results.drop_column(ROW_ID)?; results = results.drop_column(ROW_ID)?;
} }
Ok(SendableRecordBatchStream::from( Ok(single_batch_stream(results, max_batch_length))
RecordBatchStreamAdapter::new(results.schema(), stream::iter([Ok(results)])),
))
} }
async fn inner_execute_with_options( async fn inner_execute_with_options(
@@ -1256,6 +1266,7 @@ impl VectorQuery {
) -> Result<SendableRecordBatchStream> { ) -> Result<SendableRecordBatchStream> {
let plan = self.create_plan(options.clone()).await?; let plan = self.create_plan(options.clone()).await?;
let inner = execute_plan(plan, Default::default())?; let inner = execute_plan(plan, Default::default())?;
let inner = MaxBatchLengthStream::new_boxed(inner, options.max_batch_length as usize);
let inner = if let Some(timeout) = options.timeout { let inner = if let Some(timeout) = options.timeout {
TimeoutStream::new_boxed(inner, timeout) TimeoutStream::new_boxed(inner, timeout)
} else { } else {
@@ -1265,6 +1276,25 @@ impl VectorQuery {
} }
} }
fn single_batch_stream(batch: RecordBatch, max_batch_length: usize) -> SendableRecordBatchStream {
let schema = batch.schema();
if max_batch_length == 0 || batch.num_rows() <= max_batch_length {
return Box::pin(SimpleRecordBatchStream::new(
stream::iter([Ok(batch)]),
schema,
));
}
let mut batches = Vec::with_capacity(batch.num_rows().div_ceil(max_batch_length));
let mut offset = 0;
while offset < batch.num_rows() {
let length = (batch.num_rows() - offset).min(max_batch_length);
batches.push(Ok(batch.slice(offset, length)));
offset += length;
}
Box::pin(SimpleRecordBatchStream::new(stream::iter(batches), schema))
}
impl ExecutableQuery for VectorQuery { impl ExecutableQuery for VectorQuery {
async fn create_plan(&self, options: QueryExecutionOptions) -> Result<Arc<dyn ExecutionPlan>> { async fn create_plan(&self, options: QueryExecutionOptions) -> Result<Arc<dyn ExecutionPlan>> {
let query = AnyQuery::VectorQuery(self.request.clone()); let query = AnyQuery::VectorQuery(self.request.clone());
@@ -1753,6 +1783,50 @@ mod tests {
.unwrap() .unwrap()
} }
async fn make_large_vector_table(tmp_dir: &tempfile::TempDir, rows: usize) -> Table {
let dataset_path = tmp_dir.path().join("large_test.lance");
let uri = dataset_path.to_str().unwrap();
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("id", DataType::Utf8, false),
ArrowField::new(
"vector",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float32, true)),
4,
),
false,
),
]));
let ids = StringArray::from_iter_values((0..rows).map(|i| format!("row-{i}")));
let vectors = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
(0..rows).map(|i| Some(vec![Some(i as f32), Some(1.0), Some(2.0), Some(3.0)])),
4,
);
let batch =
RecordBatch::try_new(schema.clone(), vec![Arc::new(ids), Arc::new(vectors)]).unwrap();
let conn = connect(uri).execute().await.unwrap();
conn.create_table("my_table", vec![batch])
.execute()
.await
.unwrap()
}
async fn assert_stream_batches_at_most(
mut results: SendableRecordBatchStream,
max_batch_length: usize,
) {
let mut saw_batch = false;
while let Some(batch) = results.next().await {
let batch = batch.unwrap();
saw_batch = true;
assert!(batch.num_rows() <= max_batch_length);
}
assert!(saw_batch);
}
#[tokio::test] #[tokio::test]
async fn test_execute_with_options() { async fn test_execute_with_options() {
let tmp_dir = tempdir().unwrap(); let tmp_dir = tempdir().unwrap();
@@ -1772,6 +1846,83 @@ mod tests {
} }
} }
#[tokio::test]
async fn test_vector_query_execute_with_options_respects_max_batch_length() {
let tmp_dir = tempdir().unwrap();
let table = make_large_vector_table(&tmp_dir, 10_000).await;
let results = table
.query()
.nearest_to(vec![0.0, 1.0, 2.0, 3.0])
.unwrap()
.limit(10_000)
.execute_with_options(QueryExecutionOptions {
max_batch_length: 100,
..Default::default()
})
.await
.unwrap();
assert_stream_batches_at_most(results, 100).await;
}
#[tokio::test]
async fn test_hybrid_query_execute_with_options_respects_max_batch_length() {
let tmp_dir = tempdir().unwrap();
let dataset_path = tmp_dir.path();
let conn = connect(dataset_path.to_str().unwrap())
.execute()
.await
.unwrap();
let dims = 2;
let rows = 512;
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("text", DataType::Utf8, false),
ArrowField::new(
"vector",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float32, true)),
dims,
),
false,
),
]));
let text = StringArray::from_iter_values((0..rows).map(|_| "match"));
let vectors = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
(0..rows).map(|i| Some(vec![Some(i as f32), Some(0.0)])),
dims,
);
let record_batch =
RecordBatch::try_new(schema.clone(), vec![Arc::new(text), Arc::new(vectors)]).unwrap();
let table = conn
.create_table("my_table", record_batch)
.execute()
.await
.unwrap();
table
.create_index(&["text"], crate::index::Index::FTS(Default::default()))
.replace(true)
.execute()
.await
.unwrap();
let results = table
.query()
.full_text_search(FullTextSearchQuery::new("match".to_string()))
.limit(rows)
.nearest_to(&[0.0, 0.0])
.unwrap()
.execute_with_options(QueryExecutionOptions {
max_batch_length: 100,
..Default::default()
})
.await
.unwrap();
assert_stream_batches_at_most(results, 100).await;
}
#[tokio::test] #[tokio::test]
async fn test_analyze_plan() { async fn test_analyze_plan() {
let tmp_dir = tempdir().unwrap(); let tmp_dir = tempdir().unwrap();

View File

@@ -130,7 +130,10 @@ impl<S: HttpSend + 'static> RemoteInsertExec<S> {
// TODO: this will be used when we wire this up to Table::add(). // TODO: this will be used when we wire this up to Table::add().
#[allow(dead_code)] #[allow(dead_code)]
pub fn add_result(&self) -> Option<AddResult> { pub fn add_result(&self) -> Option<AddResult> {
self.add_result.lock().unwrap().clone() self.add_result
.lock()
.unwrap_or_else(|e| e.into_inner())
.clone()
} }
/// Stream the input into an HTTP body as an Arrow IPC stream, capturing any /// Stream the input into an HTTP body as an Arrow IPC stream, capturing any

View File

@@ -204,7 +204,9 @@ impl ExecutionPlan for InsertExec {
let to_commit = { let to_commit = {
// Don't hold the lock over an await point. // Don't hold the lock over an await point.
let mut txns = partial_transactions.lock().unwrap(); let mut txns = partial_transactions
.lock()
.unwrap_or_else(|e| e.into_inner());
txns.push(transaction); txns.push(transaction);
if txns.len() == total_partitions { if txns.len() == total_partitions {
Some(std::mem::take(&mut *txns)) Some(std::mem::take(&mut *txns))

View File

@@ -82,7 +82,7 @@ impl DatasetConsistencyWrapper {
/// pinned dataset regardless of consistency mode. /// pinned dataset regardless of consistency mode.
pub async fn get(&self) -> Result<Arc<Dataset>> { pub async fn get(&self) -> Result<Arc<Dataset>> {
{ {
let state = self.state.lock().unwrap(); let state = self.state.lock()?;
if state.pinned_version.is_some() { if state.pinned_version.is_some() {
return Ok(state.dataset.clone()); return Ok(state.dataset.clone());
} }
@@ -101,7 +101,7 @@ impl DatasetConsistencyWrapper {
} }
ConsistencyMode::Strong => refresh_latest(self.state.clone()).await, ConsistencyMode::Strong => refresh_latest(self.state.clone()).await,
ConsistencyMode::Lazy => { ConsistencyMode::Lazy => {
let state = self.state.lock().unwrap(); let state = self.state.lock()?;
Ok(state.dataset.clone()) Ok(state.dataset.clone())
} }
} }
@@ -116,7 +116,7 @@ impl DatasetConsistencyWrapper {
/// concurrent [`as_time_travel`](Self::as_time_travel) call), the update /// concurrent [`as_time_travel`](Self::as_time_travel) call), the update
/// is silently ignored — the write already committed to storage. /// is silently ignored — the write already committed to storage.
pub fn update(&self, dataset: Dataset) { pub fn update(&self, dataset: Dataset) {
let mut state = self.state.lock().unwrap(); let mut state = self.state.lock().unwrap_or_else(|e| e.into_inner());
if state.pinned_version.is_some() { if state.pinned_version.is_some() {
// A concurrent as_time_travel() beat us here. The write succeeded // A concurrent as_time_travel() beat us here. The write succeeded
// in storage, but since we're now pinned we don't advance the // in storage, but since we're now pinned we don't advance the
@@ -139,7 +139,7 @@ impl DatasetConsistencyWrapper {
/// Check that the dataset is in a mutable mode (Latest). /// Check that the dataset is in a mutable mode (Latest).
pub fn ensure_mutable(&self) -> Result<()> { pub fn ensure_mutable(&self) -> Result<()> {
let state = self.state.lock().unwrap(); let state = self.state.lock()?;
if state.pinned_version.is_some() { if state.pinned_version.is_some() {
Err(crate::Error::InvalidInput { Err(crate::Error::InvalidInput {
message: "table cannot be modified when a specific version is checked out" message: "table cannot be modified when a specific version is checked out"
@@ -152,13 +152,16 @@ impl DatasetConsistencyWrapper {
/// Returns the version, if in time travel mode, or None otherwise. /// Returns the version, if in time travel mode, or None otherwise.
pub fn time_travel_version(&self) -> Option<u64> { pub fn time_travel_version(&self) -> Option<u64> {
self.state.lock().unwrap().pinned_version self.state
.lock()
.unwrap_or_else(|e| e.into_inner())
.pinned_version
} }
/// Convert into a wrapper in latest version mode. /// Convert into a wrapper in latest version mode.
pub async fn as_latest(&self) -> Result<()> { pub async fn as_latest(&self) -> Result<()> {
let dataset = { let dataset = {
let state = self.state.lock().unwrap(); let state = self.state.lock()?;
if state.pinned_version.is_none() { if state.pinned_version.is_none() {
return Ok(()); return Ok(());
} }
@@ -168,7 +171,7 @@ impl DatasetConsistencyWrapper {
let latest_version = dataset.latest_version_id().await?; let latest_version = dataset.latest_version_id().await?;
let new_dataset = dataset.checkout_version(latest_version).await?; let new_dataset = dataset.checkout_version(latest_version).await?;
let mut state = self.state.lock().unwrap(); let mut state = self.state.lock()?;
if state.pinned_version.is_some() { if state.pinned_version.is_some() {
state.dataset = Arc::new(new_dataset); state.dataset = Arc::new(new_dataset);
state.pinned_version = None; state.pinned_version = None;
@@ -184,7 +187,7 @@ impl DatasetConsistencyWrapper {
let target_ref = target_version.into(); let target_ref = target_version.into();
let (should_checkout, dataset) = { let (should_checkout, dataset) = {
let state = self.state.lock().unwrap(); let state = self.state.lock()?;
let should = match state.pinned_version { let should = match state.pinned_version {
None => true, None => true,
Some(version) => match &target_ref { Some(version) => match &target_ref {
@@ -204,7 +207,7 @@ impl DatasetConsistencyWrapper {
let new_dataset = dataset.checkout_version(target_ref).await?; let new_dataset = dataset.checkout_version(target_ref).await?;
let version_value = new_dataset.version().version; let version_value = new_dataset.version().version;
let mut state = self.state.lock().unwrap(); let mut state = self.state.lock()?;
state.dataset = Arc::new(new_dataset); state.dataset = Arc::new(new_dataset);
state.pinned_version = Some(version_value); state.pinned_version = Some(version_value);
Ok(()) Ok(())
@@ -212,7 +215,7 @@ impl DatasetConsistencyWrapper {
pub async fn reload(&self) -> Result<()> { pub async fn reload(&self) -> Result<()> {
let (dataset, pinned_version) = { let (dataset, pinned_version) = {
let state = self.state.lock().unwrap(); let state = self.state.lock()?;
(state.dataset.clone(), state.pinned_version) (state.dataset.clone(), state.pinned_version)
}; };
@@ -230,7 +233,7 @@ impl DatasetConsistencyWrapper {
let new_dataset = dataset.checkout_version(version).await?; let new_dataset = dataset.checkout_version(version).await?;
let mut state = self.state.lock().unwrap(); let mut state = self.state.lock()?;
if state.pinned_version == Some(version) { if state.pinned_version == Some(version) {
state.dataset = Arc::new(new_dataset); state.dataset = Arc::new(new_dataset);
} }
@@ -242,14 +245,14 @@ impl DatasetConsistencyWrapper {
} }
async fn refresh_latest(state: Arc<Mutex<DatasetState>>) -> Result<Arc<Dataset>> { async fn refresh_latest(state: Arc<Mutex<DatasetState>>) -> Result<Arc<Dataset>> {
let dataset = { state.lock().unwrap().dataset.clone() }; let dataset = { state.lock()?.dataset.clone() };
let mut ds = (*dataset).clone(); let mut ds = (*dataset).clone();
ds.checkout_latest().await?; ds.checkout_latest().await?;
let new_arc = Arc::new(ds); let new_arc = Arc::new(ds);
{ {
let mut state = state.lock().unwrap(); let mut state = state.lock()?;
if state.pinned_version.is_none() if state.pinned_version.is_none()
&& new_arc.manifest().version >= state.dataset.manifest().version && new_arc.manifest().version >= state.dataset.manifest().version
{ {
@@ -612,4 +615,108 @@ mod tests {
let s = io_stats.incremental_stats(); let s = io_stats.incremental_stats();
assert_eq!(s.read_iops, 0, "step 5, elapsed={:?}", start.elapsed()); assert_eq!(s.read_iops, 0, "step 5, elapsed={:?}", start.elapsed());
} }
/// Helper: poison the mutex inside a DatasetConsistencyWrapper.
fn poison_state(wrapper: &DatasetConsistencyWrapper) {
let state = wrapper.state.clone();
let handle = std::thread::spawn(move || {
let _guard = state.lock().unwrap();
panic!("intentional panic to poison mutex");
});
let _ = handle.join(); // join collects the panic
assert!(wrapper.state.lock().is_err(), "mutex should be poisoned");
}
#[tokio::test]
async fn test_get_returns_error_on_poisoned_lock() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let ds = create_test_dataset(uri).await;
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
poison_state(&wrapper);
// get() should return Err, not panic
let result = wrapper.get().await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_ensure_mutable_returns_error_on_poisoned_lock() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let ds = create_test_dataset(uri).await;
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
poison_state(&wrapper);
let result = wrapper.ensure_mutable();
assert!(result.is_err());
}
#[tokio::test]
async fn test_update_recovers_from_poisoned_lock() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let ds = create_test_dataset(uri).await;
let ds_v2 = append_to_dataset(uri).await;
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
poison_state(&wrapper);
// update() returns (), should not panic
wrapper.update(ds_v2);
}
#[tokio::test]
async fn test_time_travel_version_recovers_from_poisoned_lock() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let ds = create_test_dataset(uri).await;
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
poison_state(&wrapper);
// Should not panic, returns whatever was in the mutex
let _version = wrapper.time_travel_version();
}
#[tokio::test]
async fn test_as_latest_returns_error_on_poisoned_lock() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let ds = create_test_dataset(uri).await;
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
poison_state(&wrapper);
let result = wrapper.as_latest().await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_as_time_travel_returns_error_on_poisoned_lock() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let ds = create_test_dataset(uri).await;
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
poison_state(&wrapper);
let result = wrapper.as_time_travel(1u64).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_reload_returns_error_on_poisoned_lock() {
let dir = tempfile::tempdir().unwrap();
let uri = dir.path().to_str().unwrap();
let ds = create_test_dataset(uri).await;
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
poison_state(&wrapper);
let result = wrapper.reload().await;
assert!(result.is_err());
}
} }

View File

@@ -9,7 +9,7 @@ use crate::expr::expr_to_sql_string;
use crate::query::{ use crate::query::{
DEFAULT_TOP_K, QueryExecutionOptions, QueryFilter, QueryRequest, Select, VectorQueryRequest, DEFAULT_TOP_K, QueryExecutionOptions, QueryFilter, QueryRequest, Select, VectorQueryRequest,
}; };
use crate::utils::{TimeoutStream, default_vector_column}; use crate::utils::{MaxBatchLengthStream, TimeoutStream, default_vector_column};
use arrow::array::{AsArray, FixedSizeListBuilder, Float32Builder}; use arrow::array::{AsArray, FixedSizeListBuilder, Float32Builder};
use arrow::datatypes::{Float32Type, UInt8Type}; use arrow::datatypes::{Float32Type, UInt8Type};
use arrow_array::Array; use arrow_array::Array;
@@ -66,6 +66,7 @@ async fn execute_generic_query(
) -> Result<DatasetRecordBatchStream> { ) -> Result<DatasetRecordBatchStream> {
let plan = create_plan(table, query, options.clone()).await?; let plan = create_plan(table, query, options.clone()).await?;
let inner = execute_plan(plan, Default::default())?; let inner = execute_plan(plan, Default::default())?;
let inner = MaxBatchLengthStream::new_boxed(inner, options.max_batch_length as usize);
let inner = if let Some(timeout) = options.timeout { let inner = if let Some(timeout) = options.timeout {
TimeoutStream::new_boxed(inner, timeout) TimeoutStream::new_boxed(inner, timeout)
} else { } else {
@@ -200,7 +201,9 @@ pub async fn create_plan(
scanner.with_row_id(); scanner.with_row_id();
} }
scanner.batch_size(options.max_batch_length as usize); if options.max_batch_length > 0 {
scanner.batch_size(options.max_batch_length as usize);
}
if query.base.fast_search { if query.base.fast_search {
scanner.fast_search(); scanner.fast_search();

View File

@@ -130,8 +130,11 @@ impl WriteProgressTracker {
pub fn record_batch(&self, rows: usize, bytes: usize) { pub fn record_batch(&self, rows: usize, bytes: usize) {
// Lock order: callback first, then rows_and_bytes. This is the only // Lock order: callback first, then rows_and_bytes. This is the only
// order used anywhere, so deadlocks cannot occur. // order used anywhere, so deadlocks cannot occur.
let mut cb = self.callback.lock().unwrap(); let mut cb = self.callback.lock().unwrap_or_else(|e| e.into_inner());
let mut guard = self.rows_and_bytes.lock().unwrap(); let mut guard = self
.rows_and_bytes
.lock()
.unwrap_or_else(|e| e.into_inner());
guard.0 += rows; guard.0 += rows;
guard.1 += bytes; guard.1 += bytes;
let progress = self.snapshot(guard.0, guard.1, false); let progress = self.snapshot(guard.0, guard.1, false);
@@ -151,8 +154,11 @@ impl WriteProgressTracker {
/// `total_rows` is always `Some` on the final callback: it uses the known /// `total_rows` is always `Some` on the final callback: it uses the known
/// total if available, or falls back to the number of rows actually written. /// total if available, or falls back to the number of rows actually written.
pub fn finish(&self) { pub fn finish(&self) {
let mut cb = self.callback.lock().unwrap(); let mut cb = self.callback.lock().unwrap_or_else(|e| e.into_inner());
let guard = self.rows_and_bytes.lock().unwrap(); let guard = self
.rows_and_bytes
.lock()
.unwrap_or_else(|e| e.into_inner());
let mut snap = self.snapshot(guard.0, guard.1, true); let mut snap = self.snapshot(guard.0, guard.1, true);
snap.total_rows = Some(self.total_rows.unwrap_or(guard.0)); snap.total_rows = Some(self.total_rows.unwrap_or(guard.0));
drop(guard); drop(guard);
@@ -376,4 +382,50 @@ mod tests {
} }
} }
} }
#[test]
fn test_record_batch_recovers_from_poisoned_callback_lock() {
use super::{ProgressCallback, WriteProgressTracker};
use std::sync::Mutex;
let callback: ProgressCallback = Arc::new(Mutex::new(|_: &super::WriteProgress| {}));
// Poison the callback mutex
let cb_clone = callback.clone();
let handle = std::thread::spawn(move || {
let _guard = cb_clone.lock().unwrap();
panic!("intentional panic to poison callback mutex");
});
let _ = handle.join();
assert!(
callback.lock().is_err(),
"callback mutex should be poisoned"
);
let tracker = WriteProgressTracker::new(callback, Some(100));
// record_batch should not panic
tracker.record_batch(10, 1024);
}
#[test]
fn test_finish_recovers_from_poisoned_callback_lock() {
use super::{ProgressCallback, WriteProgressTracker};
use std::sync::Mutex;
let callback: ProgressCallback = Arc::new(Mutex::new(|_: &super::WriteProgress| {}));
// Poison the callback mutex
let cb_clone = callback.clone();
let handle = std::thread::spawn(move || {
let _guard = cb_clone.lock().unwrap();
panic!("intentional panic to poison callback mutex");
});
let _ = handle.join();
let tracker = WriteProgressTracker::new(callback, Some(100));
// finish should not panic
tracker.finish();
}
} }

View File

@@ -122,7 +122,7 @@ where
/// This is a cheap synchronous check useful as a fast path before /// This is a cheap synchronous check useful as a fast path before
/// constructing a fetch closure for [`get()`](Self::get). /// constructing a fetch closure for [`get()`](Self::get).
pub fn try_get(&self) -> Option<V> { pub fn try_get(&self) -> Option<V> {
let cache = self.inner.lock().unwrap(); let cache = self.inner.lock().unwrap_or_else(|e| e.into_inner());
cache.state.fresh_value(self.ttl, self.refresh_window) cache.state.fresh_value(self.ttl, self.refresh_window)
} }
@@ -138,7 +138,7 @@ where
{ {
// Fast path: check if cache is fresh // Fast path: check if cache is fresh
{ {
let cache = self.inner.lock().unwrap(); let cache = self.inner.lock().unwrap_or_else(|e| e.into_inner());
if let Some(value) = cache.state.fresh_value(self.ttl, self.refresh_window) { if let Some(value) = cache.state.fresh_value(self.ttl, self.refresh_window) {
return Ok(value); return Ok(value);
} }
@@ -147,7 +147,7 @@ where
// Slow path // Slow path
let mut fetch = Some(fetch); let mut fetch = Some(fetch);
let action = { let action = {
let mut cache = self.inner.lock().unwrap(); let mut cache = self.inner.lock().unwrap_or_else(|e| e.into_inner());
self.determine_action(&mut cache, &mut fetch) self.determine_action(&mut cache, &mut fetch)
}; };
@@ -161,7 +161,7 @@ where
/// ///
/// This avoids a blocking fetch on the first [`get()`](Self::get) call. /// This avoids a blocking fetch on the first [`get()`](Self::get) call.
pub fn seed(&self, value: V) { pub fn seed(&self, value: V) {
let mut cache = self.inner.lock().unwrap(); let mut cache = self.inner.lock().unwrap_or_else(|e| e.into_inner());
cache.state = State::Current(value, clock::now()); cache.state = State::Current(value, clock::now());
} }
@@ -170,7 +170,7 @@ where
/// Any in-flight background fetch from before this call will not update the /// Any in-flight background fetch from before this call will not update the
/// cache (the generation counter prevents stale writes). /// cache (the generation counter prevents stale writes).
pub fn invalidate(&self) { pub fn invalidate(&self) {
let mut cache = self.inner.lock().unwrap(); let mut cache = self.inner.lock().unwrap_or_else(|e| e.into_inner());
cache.state = State::Empty; cache.state = State::Empty;
cache.generation += 1; cache.generation += 1;
} }
@@ -267,7 +267,7 @@ where
let fut_for_spawn = shared.clone(); let fut_for_spawn = shared.clone();
tokio::spawn(async move { tokio::spawn(async move {
let result = fut_for_spawn.await; let result = fut_for_spawn.await;
let mut cache = inner.lock().unwrap(); let mut cache = inner.lock().unwrap_or_else(|e| e.into_inner());
// Only update if no invalidation has happened since we started // Only update if no invalidation has happened since we started
if cache.generation != generation { if cache.generation != generation {
return; return;
@@ -590,4 +590,67 @@ mod tests {
let v = cache.get(ok_fetcher(count.clone(), "fresh")).await.unwrap(); let v = cache.get(ok_fetcher(count.clone(), "fresh")).await.unwrap();
assert_eq!(v, "fresh"); assert_eq!(v, "fresh");
} }
/// Helper: poison the inner mutex of a BackgroundCache.
fn poison_cache(cache: &BackgroundCache<String, TestError>) {
let inner = cache.inner.clone();
let handle = std::thread::spawn(move || {
let _guard = inner.lock().unwrap();
panic!("intentional panic to poison mutex");
});
let _ = handle.join();
assert!(cache.inner.lock().is_err(), "mutex should be poisoned");
}
#[tokio::test]
async fn test_try_get_recovers_from_poisoned_lock() {
let cache = new_cache();
let count = Arc::new(AtomicUsize::new(0));
// Seed a value first
cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap();
cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap(); // peek
poison_cache(&cache);
// try_get() should not panic — it recovers via unwrap_or_else
let result = cache.try_get();
// The value may or may not be fresh depending on timing, but it must not panic
let _ = result;
}
#[tokio::test]
async fn test_get_recovers_from_poisoned_lock() {
let cache = new_cache();
let count = Arc::new(AtomicUsize::new(0));
poison_cache(&cache);
// get() should not panic — it recovers and can still fetch
let result = cache.get(ok_fetcher(count.clone(), "recovered")).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "recovered");
}
#[tokio::test]
async fn test_seed_recovers_from_poisoned_lock() {
let cache = new_cache();
poison_cache(&cache);
// seed() should not panic
cache.seed("seeded".to_string());
}
#[tokio::test]
async fn test_invalidate_recovers_from_poisoned_lock() {
let cache = new_cache();
let count = Arc::new(AtomicUsize::new(0));
cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap();
poison_cache(&cache);
// invalidate() should not panic
cache.invalidate();
}
} }

View File

@@ -335,6 +335,85 @@ impl Stream for TimeoutStream {
} }
} }
/// A `Stream` wrapper that slices oversized batches to enforce a maximum batch length.
pub struct MaxBatchLengthStream {
inner: SendableRecordBatchStream,
max_batch_length: Option<usize>,
buffered_batch: Option<RecordBatch>,
buffered_offset: usize,
}
impl MaxBatchLengthStream {
pub fn new(inner: SendableRecordBatchStream, max_batch_length: usize) -> Self {
Self {
inner,
max_batch_length: (max_batch_length > 0).then_some(max_batch_length),
buffered_batch: None,
buffered_offset: 0,
}
}
pub fn new_boxed(
inner: SendableRecordBatchStream,
max_batch_length: usize,
) -> SendableRecordBatchStream {
if max_batch_length == 0 {
inner
} else {
Box::pin(Self::new(inner, max_batch_length))
}
}
}
impl RecordBatchStream for MaxBatchLengthStream {
fn schema(&self) -> SchemaRef {
self.inner.schema()
}
}
impl Stream for MaxBatchLengthStream {
type Item = DataFusionResult<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
loop {
let Some(max_batch_length) = self.max_batch_length else {
return Pin::new(&mut self.inner).poll_next(cx);
};
if let Some(batch) = self.buffered_batch.clone() {
if self.buffered_offset < batch.num_rows() {
let remaining = batch.num_rows() - self.buffered_offset;
let length = remaining.min(max_batch_length);
let sliced = batch.slice(self.buffered_offset, length);
self.buffered_offset += length;
if self.buffered_offset >= batch.num_rows() {
self.buffered_batch = None;
self.buffered_offset = 0;
}
return std::task::Poll::Ready(Some(Ok(sliced)));
}
self.buffered_batch = None;
self.buffered_offset = 0;
}
match Pin::new(&mut self.inner).poll_next(cx) {
std::task::Poll::Ready(Some(Ok(batch))) => {
if batch.num_rows() <= max_batch_length {
return std::task::Poll::Ready(Some(Ok(batch)));
}
self.buffered_batch = Some(batch);
self.buffered_offset = 0;
}
other => return other,
}
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use arrow_array::Int32Array; use arrow_array::Int32Array;
@@ -470,7 +549,7 @@ mod tests {
assert_eq!(string_to_datatype(string), Some(expected)); assert_eq!(string_to_datatype(string), Some(expected));
} }
fn sample_batch() -> RecordBatch { fn sample_batch(num_rows: i32) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new( let schema = Arc::new(Schema::new(vec![Field::new(
"col1", "col1",
DataType::Int32, DataType::Int32,
@@ -478,14 +557,14 @@ mod tests {
)])); )]));
RecordBatch::try_new( RecordBatch::try_new(
schema.clone(), schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], vec![Arc::new(Int32Array::from_iter_values(0..num_rows))],
) )
.unwrap() .unwrap()
} }
#[tokio::test] #[tokio::test]
async fn test_timeout_stream() { async fn test_timeout_stream() {
let batch = sample_batch(); let batch = sample_batch(3);
let schema = batch.schema(); let schema = batch.schema();
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]); let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
@@ -515,7 +594,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_timeout_stream_zero_duration() { async fn test_timeout_stream_zero_duration() {
let batch = sample_batch(); let batch = sample_batch(3);
let schema = batch.schema(); let schema = batch.schema();
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]); let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
@@ -534,7 +613,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_timeout_stream_completes_normally() { async fn test_timeout_stream_completes_normally() {
let batch = sample_batch(); let batch = sample_batch(3);
let schema = batch.schema(); let schema = batch.schema();
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]); let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
@@ -552,4 +631,35 @@ mod tests {
// Stream should be empty now // Stream should be empty now
assert!(timeout_stream.next().await.is_none()); assert!(timeout_stream.next().await.is_none());
} }
async fn collect_batch_sizes(
stream: SendableRecordBatchStream,
max_batch_length: usize,
) -> Vec<usize> {
let mut sliced_stream = MaxBatchLengthStream::new(stream, max_batch_length);
sliced_stream
.by_ref()
.map(|batch| batch.unwrap().num_rows())
.collect::<Vec<_>>()
.await
}
#[tokio::test]
async fn test_max_batch_length_stream_behaviors() {
let schema = sample_batch(7).schema();
let mock_stream = stream::iter(vec![Ok(sample_batch(2)), Ok(sample_batch(7))]);
let sendable_stream: SendableRecordBatchStream =
Box::pin(RecordBatchStreamAdapter::new(schema.clone(), mock_stream));
assert_eq!(
collect_batch_sizes(sendable_stream, 3).await,
vec![2, 3, 3, 1]
);
let sendable_stream: SendableRecordBatchStream = Box::pin(RecordBatchStreamAdapter::new(
schema,
stream::iter(vec![Ok(sample_batch(2)), Ok(sample_batch(7))]),
));
assert_eq!(collect_batch_sizes(sendable_stream, 0).await, vec![2, 7]);
}
} }