Compare commits

..

7 Commits

Author SHA1 Message Date
Lance Release
5d550124bd Bump version: 0.30.2-beta.2 → 0.30.2 2026-03-31 21:25:04 +00:00
Lance Release
c57cb310a2 Bump version: 0.30.2-beta.1 → 0.30.2-beta.2 2026-03-31 21:25:02 +00:00
Dan Tasse
97754f5123 fix: change _client reference to _conn (#3188)
This code previously referenced `self._client`, which does not exist.
This change makes it correctly call `self._conn.close()`
2026-03-31 13:29:17 -07:00
Pratik Dey
7b1c063848 feat(python): add type-safe expression builder API (#3150)
Introduces col(), lit(), func(), and Expr class as alternatives to raw
SQL strings in .where() and .select(). Expressions are backed by
DataFusion's Expr AST and serialized to SQL for remote table compat.

Resolves: 
- https://github.com/lancedb/lancedb/issues/3044 (python api's)
- https://github.com/lancedb/lancedb/issues/3043 (support for filter)
- https://github.com/lancedb/lancedb/issues/3045 (support for
projection)

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-31 11:32:49 -07:00
Will Jones
c7f189f27b chore: upgrade lance to stable 4.0.0 (#3207)
Bumps all lance-* workspace dependencies from `4.0.0-rc.3` (git source)
to the stable `4.0.0` release on crates.io, removing the `git`/`tag`
overrides.

No code changes were required — compiles and passes clippy cleanly.

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-30 17:05:45 -07:00
yaommen
a0a2942ad5 fix: respect max_batch_length for Rust vector and hybrid queries (#3172)
Fixes #1540

I could not reproduce this on current `main` from Python, but I could
still reproduce it from the Rust SDK.

Python no longer reproduces because the current Python vector/hybrid
query paths re-chunk results into a `pyarrow.Table` before returning
batches. Rust still reproduced because `max_batch_length` was passed
into planning/scanning, but vector search could still emit larger
`RecordBatch`es later in execution (for example after KNN / TopK), so it
was not enforced on the final Rust output stream.

This PR enforces `max_batch_length` on the final Rust query output
stream and adds Rust regression coverage.

Before the fix, the Rust repro produced:
`num_batches=2, max_batch=8192, min_batch=1808, all_le_100=false`

After the fix, the same repro produces batches `<= 100`.

## Runnable Rust repro

Before this fix, current `main` could still return batches like `[8192,
1808]` here even with `max_batch_length = 100`:

```rust
use std::sync::Arc;

use arrow_array::{
    types::Float32Type, FixedSizeListArray, RecordBatch, RecordBatchReader, StringArray,
};
use arrow_schema::{DataType, Field, Schema};
use futures::TryStreamExt;
use lancedb::query::{ExecutableQuery, QueryBase, QueryExecutionOptions};

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    let tmp = tempfile::tempdir()?;
    let uri = tmp.path().to_str().unwrap();

    let rows = 10_000;
    let schema = Arc::new(Schema::new(vec![
        Field::new("id", DataType::Utf8, false),
        Field::new(
            "vector",
            DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4),
            false,
        ),
    ]));

    let ids = StringArray::from_iter_values((0..rows).map(|i| format!("row-{i}")));
    let vectors = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
        (0..rows).map(|i| Some(vec![Some(i as f32), Some(1.0), Some(2.0), Some(3.0)])),
        4,
    );
    let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(ids), Arc::new(vectors)])?;
    let reader: Box<dyn RecordBatchReader + Send> = Box::new(
        arrow_array::RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema),
    );

    let db = lancedb::connect(uri).execute().await?;
    let table = db.create_table("test", reader).execute().await?;

    let mut opts = QueryExecutionOptions::default();
    opts.max_batch_length = 100;

    let mut stream = table
        .query()
        .nearest_to(vec![0.0, 1.0, 2.0, 3.0])?
        .limit(rows)
        .execute_with_options(opts)
        .await?;

    let mut sizes = Vec::new();
    while let Some(batch) = stream.try_next().await? {
        sizes.push(batch.num_rows());
    }

    println!("{sizes:?}");
    Ok(())
}
```

Signed-off-by: yaommen <myanstu@163.com>
2026-03-30 15:43:58 -07:00
Will Jones
e3d53dd185 fix(python): skip test_url_retrieve_downloads_image when PIL not installed (#3208)
The test added in #3190 unconditionally imports `PIL`, which is an
optional dependency. This causes CI failures in environments where
Pillow isn't installed (`ModuleNotFoundError: No module named 'PIL'`).

Use `pytest.importorskip` to skip gracefully when Pillow is unavailable.

Fixes CI failure on main.

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-30 14:48:49 -07:00
26 changed files with 1439 additions and 130 deletions

118
Cargo.lock generated
View File

@@ -108,7 +108,7 @@ version = "1.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc"
dependencies = [
"windows-sys 0.60.2",
"windows-sys 0.61.2",
]
[[package]]
@@ -119,7 +119,7 @@ checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d"
dependencies = [
"anstyle",
"once_cell_polyfill",
"windows-sys 0.60.2",
"windows-sys 0.61.2",
]
[[package]]
@@ -2682,7 +2682,7 @@ dependencies = [
"libc",
"option-ext",
"redox_users",
"windows-sys 0.60.2",
"windows-sys 0.61.2",
]
[[package]]
@@ -2876,7 +2876,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [
"libc",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -3072,8 +3072,9 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]]
name = "fsst"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2195cc7f87e84bd695586137de99605e7e9579b26ec5e01b82960ddb4d0922f2"
dependencies = [
"arrow-array",
"rand 0.9.2",
@@ -3736,7 +3737,7 @@ dependencies = [
"libc",
"percent-encoding",
"pin-project-lite",
"socket2 0.5.10",
"socket2 0.6.3",
"system-configuration",
"tokio",
"tower-service",
@@ -4037,7 +4038,7 @@ dependencies = [
"portable-atomic",
"portable-atomic-util",
"serde_core",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -4123,8 +4124,9 @@ dependencies = [
[[package]]
name = "lance"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "efe6c3ddd79cdfd2b7e1c23cafae52806906bc40fbd97de9e8cf2f8c7a75fc04"
dependencies = [
"arrow",
"arrow-arith",
@@ -4190,8 +4192,9 @@ dependencies = [
[[package]]
name = "lance-arrow"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d9f5d95bdda2a2b790f1fb8028b5b6dcf661abeb3133a8bca0f3d24b054af87"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -4211,8 +4214,9 @@ dependencies = [
[[package]]
name = "lance-bitpacking"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f827d6ab9f8f337a9509d5ad66a12f3314db8713868260521c344ef6135eb4e4"
dependencies = [
"arrayref",
"paste",
@@ -4221,8 +4225,9 @@ dependencies = [
[[package]]
name = "lance-core"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f1e25df6a79bf72ee6bcde0851f19b1cd36c5848c1b7db83340882d3c9fdecb"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -4259,8 +4264,9 @@ dependencies = [
[[package]]
name = "lance-datafusion"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93146de8ae720cb90edef81c2f2d0a1b065fc2f23ecff2419546f389b0fa70a4"
dependencies = [
"arrow",
"arrow-array",
@@ -4290,8 +4296,9 @@ dependencies = [
[[package]]
name = "lance-datagen"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ccec8ce4d8e0a87a99c431dab2364398029f2ffb649c1a693c60c79e05ed30dd"
dependencies = [
"arrow",
"arrow-array",
@@ -4309,8 +4316,9 @@ dependencies = [
[[package]]
name = "lance-encoding"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c1aec0bbbac6bce829bc10f1ba066258126100596c375fb71908ecf11c2c2a5"
dependencies = [
"arrow-arith",
"arrow-array",
@@ -4347,8 +4355,9 @@ dependencies = [
[[package]]
name = "lance-file"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14a8c548804f5b17486dc2d3282356ed1957095a852780283bc401fdd69e9075"
dependencies = [
"arrow-arith",
"arrow-array",
@@ -4380,8 +4389,9 @@ dependencies = [
[[package]]
name = "lance-index"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2da212f0090ea59f79ac3686660f596520c167fe1cb5f408900cf71d215f0e03"
dependencies = [
"arrow",
"arrow-arith",
@@ -4445,8 +4455,9 @@ dependencies = [
[[package]]
name = "lance-io"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41d958eb4b56f03bbe0f5f85eb2b4e9657882812297b6f711f201ffc995f259f"
dependencies = [
"arrow",
"arrow-arith",
@@ -4487,8 +4498,9 @@ dependencies = [
[[package]]
name = "lance-linalg"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0285b70da35def7ed95e150fae1d5308089554e1290470403ed3c50cb235bc5e"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -4504,8 +4516,9 @@ dependencies = [
[[package]]
name = "lance-namespace"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f78e2a828b654e062a495462c6e3eb4fcf0e7e907d761b8f217fc09ccd3ceac"
dependencies = [
"arrow",
"async-trait",
@@ -4518,8 +4531,9 @@ dependencies = [
[[package]]
name = "lance-namespace-impls"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2392314f3da38f00d166295e44244208a65ccfc256e274fa8631849fc3f4d94"
dependencies = [
"arrow",
"arrow-ipc",
@@ -4533,7 +4547,6 @@ dependencies = [
"lance-core",
"lance-index",
"lance-io",
"lance-linalg",
"lance-namespace",
"lance-table",
"log",
@@ -4564,8 +4577,9 @@ dependencies = [
[[package]]
name = "lance-table"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3df9c4adca3eb2074b3850432a9fb34248a3d90c3d6427d158b13ff9355664ee"
dependencies = [
"arrow",
"arrow-array",
@@ -4604,8 +4618,9 @@ dependencies = [
[[package]]
name = "lance-testing"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ed7119bdd6983718387b4ac44af873a165262ca94f181b104cd6f97912eb3bf"
dependencies = [
"arrow-array",
"arrow-schema",
@@ -5308,7 +5323,7 @@ version = "0.50.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
dependencies = [
"windows-sys 0.60.2",
"windows-sys 0.61.2",
]
[[package]]
@@ -6287,7 +6302,7 @@ version = "0.14.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7"
dependencies = [
"heck 0.4.1",
"heck 0.5.0",
"itertools 0.14.0",
"log",
"multimap",
@@ -6474,7 +6489,7 @@ dependencies = [
"quinn-udp",
"rustc-hash",
"rustls 0.23.37",
"socket2 0.5.10",
"socket2 0.6.3",
"thiserror 2.0.18",
"tokio",
"tracing",
@@ -6511,9 +6526,9 @@ dependencies = [
"cfg_aliases",
"libc",
"once_cell",
"socket2 0.5.10",
"socket2 0.6.3",
"tracing",
"windows-sys 0.52.0",
"windows-sys 0.60.2",
]
[[package]]
@@ -7042,7 +7057,7 @@ dependencies = [
"errno",
"libc",
"linux-raw-sys 0.4.15",
"windows-sys 0.52.0",
"windows-sys 0.59.0",
]
[[package]]
@@ -7055,7 +7070,7 @@ dependencies = [
"errno",
"libc",
"linux-raw-sys 0.12.1",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -7575,7 +7590,7 @@ version = "0.8.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1c97747dbf44bb1ca44a561ece23508e99cb592e862f22222dcf42f51d1e451"
dependencies = [
"heck 0.4.1",
"heck 0.5.0",
"proc-macro2",
"quote",
"syn 2.0.117",
@@ -7587,7 +7602,7 @@ version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54254b8531cafa275c5e096f62d48c81435d1015405a91198ddb11e967301d40"
dependencies = [
"heck 0.4.1",
"heck 0.5.0",
"proc-macro2",
"quote",
"syn 2.0.117",
@@ -7610,7 +7625,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e"
dependencies = [
"libc",
"windows-sys 0.60.2",
"windows-sys 0.61.2",
]
[[package]]
@@ -7714,7 +7729,6 @@ dependencies = [
"cfg-if",
"libc",
"psm",
"windows-sys 0.52.0",
"windows-sys 0.59.0",
]
@@ -8075,7 +8089,7 @@ dependencies = [
"getrandom 0.4.2",
"once_cell",
"rustix 1.1.4",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -8880,7 +8894,7 @@ version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
dependencies = [
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]

View File

@@ -15,20 +15,20 @@ categories = ["database-implementations"]
rust-version = "1.91.0"
[workspace.dependencies]
lance = { "version" = "=5.0.0-beta.2", default-features = false, "tag" = "v5.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-core = { "version" = "=5.0.0-beta.2", "tag" = "v5.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-datagen = { "version" = "=5.0.0-beta.2", "tag" = "v5.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-file = { "version" = "=5.0.0-beta.2", "tag" = "v5.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-io = { "version" = "=5.0.0-beta.2", default-features = false, "tag" = "v5.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-index = { "version" = "=5.0.0-beta.2", "tag" = "v5.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-linalg = { "version" = "=5.0.0-beta.2", "tag" = "v5.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace = { "version" = "=5.0.0-beta.2", "tag" = "v5.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace-impls = { "version" = "=5.0.0-beta.2", default-features = false, "tag" = "v5.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-table = { "version" = "=5.0.0-beta.2", "tag" = "v5.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-testing = { "version" = "=5.0.0-beta.2", "tag" = "v5.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-datafusion = { "version" = "=5.0.0-beta.2", "tag" = "v5.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-encoding = { "version" = "=5.0.0-beta.2", "tag" = "v5.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-arrow = { "version" = "=5.0.0-beta.2", "tag" = "v5.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance = { version = "=4.0.0", default-features = false }
lance-core = { version = "=4.0.0" }
lance-datagen = { version = "=4.0.0" }
lance-file = { version = "=4.0.0" }
lance-io = { version = "=4.0.0", default-features = false }
lance-index = { version = "=4.0.0" }
lance-linalg = { version = "=4.0.0" }
lance-namespace = { version = "=4.0.0" }
lance-namespace-impls = { version = "=4.0.0", default-features = false }
lance-table = { version = "=4.0.0" }
lance-testing = { version = "=4.0.0" }
lance-datafusion = { version = "=4.0.0" }
lance-encoding = { version = "=4.0.0" }
lance-arrow = { version = "=4.0.0" }
ahash = "0.8"
# Note that this one does not include pyarrow
arrow = { version = "57.2", optional = false }

View File

@@ -36,6 +36,20 @@ is also an [asynchronous API client](#connections-asynchronous).
::: lancedb.table.Tags
## Expressions
Type-safe expression builder for filters and projections. Use these instead
of raw SQL strings with [where][lancedb.query.LanceQueryBuilder.where] and
[select][lancedb.query.LanceQueryBuilder.select].
::: lancedb.expr.Expr
::: lancedb.expr.col
::: lancedb.expr.lit
::: lancedb.expr.func
## Querying (Synchronous)
::: lancedb.query.Query

View File

@@ -28,7 +28,7 @@
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<arrow.version>15.0.0</arrow.version>
<lance-core.version>5.0.0-beta.2</lance-core.version>
<lance-core.version>3.0.1</lance-core.version>
<spotless.skip>false</spotless.skip>
<spotless.version>2.30.0</spotless.version>
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.30.2-beta.1"
current_version = "0.30.2"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

2
python/.gitignore vendored
View File

@@ -1,3 +1,5 @@
# Test data created by some example tests
data/
_lancedb.pyd
# macOS debug symbols bundle generated during build
*.dSYM/

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-python"
version = "0.30.2-beta.1"
version = "0.30.2"
edition.workspace = true
description = "Python bindings for LanceDB"
license.workspace = true

View File

@@ -18,6 +18,7 @@ from .db import AsyncConnection, DBConnection, LanceDBConnection
from .io import StorageOptionsProvider
from .remote import ClientConfig
from .remote.db import RemoteDBConnection
from .expr import Expr, col, lit, func
from .schema import vector
from .table import AsyncTable, Table
from ._lancedb import Session
@@ -271,6 +272,10 @@ __all__ = [
"AsyncConnection",
"AsyncLanceNamespaceDBConnection",
"AsyncTable",
"col",
"Expr",
"func",
"lit",
"URI",
"sanitize_uri",
"vector",

View File

@@ -27,6 +27,32 @@ from .remote import ClientConfig
IvfHnswPq: type[HnswPq] = HnswPq
IvfHnswSq: type[HnswSq] = HnswSq
class PyExpr:
"""A type-safe DataFusion expression node (Rust-side handle)."""
def eq(self, other: "PyExpr") -> "PyExpr": ...
def ne(self, other: "PyExpr") -> "PyExpr": ...
def lt(self, other: "PyExpr") -> "PyExpr": ...
def lte(self, other: "PyExpr") -> "PyExpr": ...
def gt(self, other: "PyExpr") -> "PyExpr": ...
def gte(self, other: "PyExpr") -> "PyExpr": ...
def and_(self, other: "PyExpr") -> "PyExpr": ...
def or_(self, other: "PyExpr") -> "PyExpr": ...
def not_(self) -> "PyExpr": ...
def add(self, other: "PyExpr") -> "PyExpr": ...
def sub(self, other: "PyExpr") -> "PyExpr": ...
def mul(self, other: "PyExpr") -> "PyExpr": ...
def div(self, other: "PyExpr") -> "PyExpr": ...
def lower(self) -> "PyExpr": ...
def upper(self) -> "PyExpr": ...
def contains(self, substr: "PyExpr") -> "PyExpr": ...
def cast(self, data_type: pa.DataType) -> "PyExpr": ...
def to_sql(self) -> str: ...
def expr_col(name: str) -> PyExpr: ...
def expr_lit(value: Union[bool, int, float, str]) -> PyExpr: ...
def expr_func(name: str, args: List[PyExpr]) -> PyExpr: ...
class Session:
def __init__(
self,
@@ -225,7 +251,9 @@ class RecordBatchStream:
class Query:
def where(self, filter: str): ...
def select(self, columns: Tuple[str, str]): ...
def where_expr(self, expr: PyExpr): ...
def select(self, columns: List[Tuple[str, str]]): ...
def select_expr(self, columns: List[Tuple[str, PyExpr]]): ...
def select_columns(self, columns: List[str]): ...
def limit(self, limit: int): ...
def offset(self, offset: int): ...
@@ -251,7 +279,9 @@ class TakeQuery:
class FTSQuery:
def where(self, filter: str): ...
def select(self, columns: List[str]): ...
def where_expr(self, expr: PyExpr): ...
def select(self, columns: List[Tuple[str, str]]): ...
def select_expr(self, columns: List[Tuple[str, PyExpr]]): ...
def limit(self, limit: int): ...
def offset(self, offset: int): ...
def fast_search(self): ...
@@ -270,7 +300,9 @@ class VectorQuery:
async def output_schema(self) -> pa.Schema: ...
async def execute(self) -> RecordBatchStream: ...
def where(self, filter: str): ...
def select(self, columns: List[str]): ...
def where_expr(self, expr: PyExpr): ...
def select(self, columns: List[Tuple[str, str]]): ...
def select_expr(self, columns: List[Tuple[str, PyExpr]]): ...
def select_with_projection(self, columns: Tuple[str, str]): ...
def limit(self, limit: int): ...
def offset(self, offset: int): ...
@@ -287,7 +319,9 @@ class VectorQuery:
class HybridQuery:
def where(self, filter: str): ...
def select(self, columns: List[str]): ...
def where_expr(self, expr: PyExpr): ...
def select(self, columns: List[Tuple[str, str]]): ...
def select_expr(self, columns: List[Tuple[str, PyExpr]]): ...
def limit(self, limit: int): ...
def offset(self, offset: int): ...
def fast_search(self): ...

View File

@@ -0,0 +1,298 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
"""Type-safe expression builder for filters and projections.
Instead of writing raw SQL strings you can build expressions with Python
operators::
from lancedb.expr import col, lit
# filter: age > 18 AND status = 'active'
filt = (col("age") > lit(18)) & (col("status") == lit("active"))
# projection: compute a derived column
proj = {"score": col("raw_score") * lit(1.5)}
table.search().where(filt).select(proj).to_list()
"""
from __future__ import annotations
from typing import Union
import pyarrow as pa
from lancedb._lancedb import PyExpr, expr_col, expr_lit, expr_func
__all__ = ["Expr", "col", "lit", "func"]
_STR_TO_PA_TYPE: dict = {
"bool": pa.bool_(),
"boolean": pa.bool_(),
"int8": pa.int8(),
"int16": pa.int16(),
"int32": pa.int32(),
"int64": pa.int64(),
"uint8": pa.uint8(),
"uint16": pa.uint16(),
"uint32": pa.uint32(),
"uint64": pa.uint64(),
"float16": pa.float16(),
"float32": pa.float32(),
"float": pa.float32(),
"float64": pa.float64(),
"double": pa.float64(),
"string": pa.string(),
"utf8": pa.string(),
"str": pa.string(),
"large_string": pa.large_utf8(),
"large_utf8": pa.large_utf8(),
"date32": pa.date32(),
"date": pa.date32(),
"date64": pa.date64(),
}
def _coerce(value: "ExprLike") -> "Expr":
"""Return *value* as an :class:`Expr`, wrapping plain Python values via
:func:`lit` if needed."""
if isinstance(value, Expr):
return value
return lit(value)
# Type alias used in annotations.
ExprLike = Union["Expr", bool, int, float, str]
class Expr:
"""A type-safe expression node.
Construct instances with :func:`col` and :func:`lit`, then combine them
using Python operators or the named methods below.
Examples
--------
>>> from lancedb.expr import col, lit
>>> filt = (col("age") > lit(18)) & (col("name").lower() == lit("alice"))
>>> proj = {"double": col("x") * lit(2)}
"""
# Make Expr unhashable so that == returns an Expr rather than being used
# for dict keys / set membership.
__hash__ = None # type: ignore[assignment]
def __init__(self, inner: PyExpr) -> None:
self._inner = inner
# ── comparisons ──────────────────────────────────────────────────────────
def __eq__(self, other: ExprLike) -> "Expr": # type: ignore[override]
"""Equal to (``col("x") == 1``)."""
return Expr(self._inner.eq(_coerce(other)._inner))
def __ne__(self, other: ExprLike) -> "Expr": # type: ignore[override]
"""Not equal to (``col("x") != 1``)."""
return Expr(self._inner.ne(_coerce(other)._inner))
def __lt__(self, other: ExprLike) -> "Expr":
"""Less than (``col("x") < 1``)."""
return Expr(self._inner.lt(_coerce(other)._inner))
def __le__(self, other: ExprLike) -> "Expr":
"""Less than or equal to (``col("x") <= 1``)."""
return Expr(self._inner.lte(_coerce(other)._inner))
def __gt__(self, other: ExprLike) -> "Expr":
"""Greater than (``col("x") > 1``)."""
return Expr(self._inner.gt(_coerce(other)._inner))
def __ge__(self, other: ExprLike) -> "Expr":
"""Greater than or equal to (``col("x") >= 1``)."""
return Expr(self._inner.gte(_coerce(other)._inner))
# ── logical ──────────────────────────────────────────────────────────────
def __and__(self, other: "Expr") -> "Expr":
"""Logical AND (``expr_a & expr_b``)."""
return Expr(self._inner.and_(_coerce(other)._inner))
def __or__(self, other: "Expr") -> "Expr":
"""Logical OR (``expr_a | expr_b``)."""
return Expr(self._inner.or_(_coerce(other)._inner))
def __invert__(self) -> "Expr":
"""Logical NOT (``~expr``)."""
return Expr(self._inner.not_())
# ── arithmetic ───────────────────────────────────────────────────────────
def __add__(self, other: ExprLike) -> "Expr":
"""Add (``col("x") + 1``)."""
return Expr(self._inner.add(_coerce(other)._inner))
def __radd__(self, other: ExprLike) -> "Expr":
"""Right-hand add (``1 + col("x")``)."""
return Expr(_coerce(other)._inner.add(self._inner))
def __sub__(self, other: ExprLike) -> "Expr":
"""Subtract (``col("x") - 1``)."""
return Expr(self._inner.sub(_coerce(other)._inner))
def __rsub__(self, other: ExprLike) -> "Expr":
"""Right-hand subtract (``1 - col("x")``)."""
return Expr(_coerce(other)._inner.sub(self._inner))
def __mul__(self, other: ExprLike) -> "Expr":
"""Multiply (``col("x") * 2``)."""
return Expr(self._inner.mul(_coerce(other)._inner))
def __rmul__(self, other: ExprLike) -> "Expr":
"""Right-hand multiply (``2 * col("x")``)."""
return Expr(_coerce(other)._inner.mul(self._inner))
def __truediv__(self, other: ExprLike) -> "Expr":
"""Divide (``col("x") / 2``)."""
return Expr(self._inner.div(_coerce(other)._inner))
def __rtruediv__(self, other: ExprLike) -> "Expr":
"""Right-hand divide (``1 / col("x")``)."""
return Expr(_coerce(other)._inner.div(self._inner))
# ── string methods ───────────────────────────────────────────────────────
def lower(self) -> "Expr":
"""Convert string column values to lowercase."""
return Expr(self._inner.lower())
def upper(self) -> "Expr":
"""Convert string column values to uppercase."""
return Expr(self._inner.upper())
def contains(self, substr: "ExprLike") -> "Expr":
"""Return True where the string contains *substr*."""
return Expr(self._inner.contains(_coerce(substr)._inner))
# ── type cast ────────────────────────────────────────────────────────────
def cast(self, data_type: Union[str, "pa.DataType"]) -> "Expr":
"""Cast values to *data_type*.
Parameters
----------
data_type:
A PyArrow ``DataType`` (e.g. ``pa.int32()``) or one of the type
name strings: ``"bool"``, ``"int8"``, ``"int16"``, ``"int32"``,
``"int64"``, ``"uint8"````"uint64"``, ``"float32"``,
``"float64"``, ``"string"``, ``"date32"``, ``"date64"``.
"""
if isinstance(data_type, str):
try:
data_type = _STR_TO_PA_TYPE[data_type]
except KeyError:
raise ValueError(
f"unsupported data type: '{data_type}'. Supported: "
f"{', '.join(_STR_TO_PA_TYPE)}"
)
return Expr(self._inner.cast(data_type))
# ── named comparison helpers (alternative to operators) ──────────────────
def eq(self, other: ExprLike) -> "Expr":
"""Equal to."""
return self.__eq__(other)
def ne(self, other: ExprLike) -> "Expr":
"""Not equal to."""
return self.__ne__(other)
def lt(self, other: ExprLike) -> "Expr":
"""Less than."""
return self.__lt__(other)
def lte(self, other: ExprLike) -> "Expr":
"""Less than or equal to."""
return self.__le__(other)
def gt(self, other: ExprLike) -> "Expr":
"""Greater than."""
return self.__gt__(other)
def gte(self, other: ExprLike) -> "Expr":
"""Greater than or equal to."""
return self.__ge__(other)
def and_(self, other: "Expr") -> "Expr":
"""Logical AND."""
return self.__and__(other)
def or_(self, other: "Expr") -> "Expr":
"""Logical OR."""
return self.__or__(other)
# ── utilities ────────────────────────────────────────────────────────────
def to_sql(self) -> str:
"""Render the expression as a SQL string (useful for debugging)."""
return self._inner.to_sql()
def __repr__(self) -> str:
return f"Expr({self._inner.to_sql()})"
# ── free functions ────────────────────────────────────────────────────────────
def col(name: str) -> Expr:
"""Reference a table column by name.
Parameters
----------
name:
The column name.
Examples
--------
>>> from lancedb.expr import col, lit
>>> col("age") > lit(18)
Expr((age > 18))
"""
return Expr(expr_col(name))
def lit(value: Union[bool, int, float, str]) -> Expr:
"""Create a literal (constant) value expression.
Parameters
----------
value:
A Python ``bool``, ``int``, ``float``, or ``str``.
Examples
--------
>>> from lancedb.expr import col, lit
>>> col("price") * lit(1.1)
Expr((price * 1.1))
"""
return Expr(expr_lit(value))
def func(name: str, *args: ExprLike) -> Expr:
"""Call an arbitrary SQL function by name.
Parameters
----------
name:
The SQL function name (e.g. ``"lower"``, ``"upper"``).
*args:
The function arguments as :class:`Expr` or plain Python literals.
Examples
--------
>>> from lancedb.expr import col, func
>>> func("lower", col("name"))
Expr(lower(name))
"""
inner_args = [_coerce(a)._inner for a in args]
return Expr(expr_func(name, inner_args))

View File

@@ -38,6 +38,7 @@ from .rerankers.base import Reranker
from .rerankers.rrf import RRFReranker
from .rerankers.util import check_reranker_result
from .util import flatten_columns
from .expr import Expr
from lancedb._lancedb import fts_query_to_json
from typing_extensions import Annotated
@@ -449,8 +450,8 @@ class Query(pydantic.BaseModel):
ensure_vector_query,
] = None
# sql filter to refine the query with
filter: Optional[str] = None
# sql filter or type-safe Expr to refine the query with
filter: Optional[Union[str, Expr]] = None
# if True then apply the filter after vector search
postfilter: Optional[bool] = None
@@ -464,8 +465,8 @@ class Query(pydantic.BaseModel):
# distance type to use for vector search
distance_type: Optional[str] = None
# which columns to return in the results
columns: Optional[Union[List[str], Dict[str, str]]] = None
# which columns to return in the results (dict values may be str or Expr)
columns: Optional[Union[List[str], Dict[str, Union[str, Expr]]]] = None
# minimum number of IVF partitions to search
#
@@ -856,14 +857,15 @@ class LanceQueryBuilder(ABC):
self._offset = offset
return self
def select(self, columns: Union[list[str], dict[str, str]]) -> Self:
def select(self, columns: Union[list[str], dict[str, Union[str, Expr]]]) -> Self:
"""Set the columns to return.
Parameters
----------
columns: list of str, or dict of str to str default None
columns: list of str, or dict of str to str or Expr
List of column names to be fetched.
Or a dictionary of column names to SQL expressions.
Or a dictionary of column names to SQL expressions or
:class:`~lancedb.expr.Expr` objects.
All columns are fetched if None or unspecified.
Returns
@@ -877,15 +879,15 @@ class LanceQueryBuilder(ABC):
raise ValueError("columns must be a list or a dictionary")
return self
def where(self, where: str, prefilter: bool = True) -> Self:
def where(self, where: Union[str, Expr], prefilter: bool = True) -> Self:
"""Set the where clause.
Parameters
----------
where: str
The where clause which is a valid SQL where clause. See
`Lance filter pushdown <https://lance.org/guide/read_and_write#filter-push-down>`_
for valid SQL expressions.
where: str or :class:`~lancedb.expr.Expr`
The filter condition. Can be a SQL string or a type-safe
:class:`~lancedb.expr.Expr` built with :func:`~lancedb.expr.col`
and :func:`~lancedb.expr.lit`.
prefilter: bool, default True
If True, apply the filter before vector search, otherwise the
filter is applied on the result of vector search.
@@ -1355,15 +1357,17 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
return result_set
def where(self, where: str, prefilter: bool = None) -> LanceVectorQueryBuilder:
def where(
self, where: Union[str, Expr], prefilter: bool = None
) -> LanceVectorQueryBuilder:
"""Set the where clause.
Parameters
----------
where: str
The where clause which is a valid SQL where clause. See
`Lance filter pushdown <https://lance.org/guide/read_and_write#filter-push-down>`_
for valid SQL expressions.
where: str or :class:`~lancedb.expr.Expr`
The filter condition. Can be a SQL string or a type-safe
:class:`~lancedb.expr.Expr` built with :func:`~lancedb.expr.col`
and :func:`~lancedb.expr.lit`.
prefilter: bool, default True
If True, apply the filter before vector search, otherwise the
filter is applied on the result of vector search.
@@ -2286,10 +2290,20 @@ class AsyncQueryBase(object):
"""
if isinstance(columns, list) and all(isinstance(c, str) for c in columns):
self._inner.select_columns(columns)
elif isinstance(columns, dict) and all(
isinstance(k, str) and isinstance(v, str) for k, v in columns.items()
):
self._inner.select(list(columns.items()))
elif isinstance(columns, dict) and all(isinstance(k, str) for k in columns):
if any(isinstance(v, Expr) for v in columns.values()):
# At least one value is an Expr — use the type-safe path.
from .expr import _coerce
pairs = [(k, _coerce(v)._inner) for k, v in columns.items()]
self._inner.select_expr(pairs)
elif all(isinstance(v, str) for v in columns.values()):
self._inner.select(list(columns.items()))
else:
raise TypeError(
"dict values must be str or Expr, got "
+ str({k: type(v) for k, v in columns.items()})
)
else:
raise TypeError("columns must be a list of column names or a dict")
return self
@@ -2529,11 +2543,13 @@ class AsyncStandardQuery(AsyncQueryBase):
"""
super().__init__(inner)
def where(self, predicate: str) -> Self:
def where(self, predicate: Union[str, Expr]) -> Self:
"""
Only return rows matching the given predicate
The predicate should be supplied as an SQL query string.
The predicate can be a SQL string or a type-safe
:class:`~lancedb.expr.Expr` built with :func:`~lancedb.expr.col`
and :func:`~lancedb.expr.lit`.
Examples
--------
@@ -2545,7 +2561,10 @@ class AsyncStandardQuery(AsyncQueryBase):
Filtering performance can often be improved by creating a scalar index
on the filter column(s).
"""
self._inner.where(predicate)
if isinstance(predicate, Expr):
self._inner.where_expr(predicate._inner)
else:
self._inner.where(predicate)
return self
def limit(self, limit: int) -> Self:

View File

@@ -568,4 +568,4 @@ class RemoteDBConnection(DBConnection):
async def close(self):
"""Close the connection to the database."""
self._client.close()
self._conn.close()

View File

@@ -4211,7 +4211,7 @@ class AsyncTable:
async_query = async_query.offset(query.offset)
if query.columns:
async_query = async_query.select(query.columns)
if query.filter:
if query.filter is not None:
async_query = async_query.where(query.filter)
if query.fast_search:
async_query = async_query.fast_search()

View File

@@ -559,7 +559,8 @@ def test_url_retrieve_downloads_image():
matching the real usage pattern in embedding functions.
"""
import io
from PIL import Image
Image = pytest.importorskip("PIL.Image")
from lancedb.embeddings.utils import url_retrieve
image_url = "http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg"

View File

@@ -1201,6 +1201,18 @@ async def test_header_provider_overrides_static_headers():
await db.table_names()
def test_close():
"""Test that close() works without AttributeError."""
import asyncio
def handler(req):
req.send_response(200)
req.end_headers()
with mock_lancedb_connection(handler) as db:
asyncio.run(db.close())
@pytest.mark.parametrize("exception", [KeyboardInterrupt, SystemExit, GeneratorExit])
def test_background_loop_cancellation(exception):
"""Test that BackgroundEventLoop.run() cancels the future on interrupt."""

175
python/src/expr.rs Normal file
View File

@@ -0,0 +1,175 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! PyO3 bindings for the LanceDB expression builder API.
//!
//! This module exposes [`PyExpr`] and helper free functions so Python can
//! build type-safe filter / projection expressions that map directly to
//! DataFusion [`Expr`] nodes, bypassing SQL string parsing.
use arrow::{datatypes::DataType, pyarrow::PyArrowType};
use lancedb::expr::{DfExpr, col as ldb_col, contains, expr_cast, lit as df_lit, lower, upper};
use pyo3::{Bound, PyAny, PyResult, exceptions::PyValueError, prelude::*, pyfunction};
/// A type-safe DataFusion expression.
///
/// Instances are constructed via the free functions [`expr_col`] and
/// [`expr_lit`] and combined with the methods on this struct. On the Python
/// side a thin wrapper class (`lancedb.expr.Expr`) delegates to these methods
/// and adds Python operator overloads.
#[pyclass(name = "PyExpr")]
#[derive(Clone)]
pub struct PyExpr(pub DfExpr);
#[pymethods]
impl PyExpr {
// ── comparisons ──────────────────────────────────────────────────────────
fn eq(&self, other: &Self) -> Self {
Self(self.0.clone().eq(other.0.clone()))
}
fn ne(&self, other: &Self) -> Self {
Self(self.0.clone().not_eq(other.0.clone()))
}
fn lt(&self, other: &Self) -> Self {
Self(self.0.clone().lt(other.0.clone()))
}
fn lte(&self, other: &Self) -> Self {
Self(self.0.clone().lt_eq(other.0.clone()))
}
fn gt(&self, other: &Self) -> Self {
Self(self.0.clone().gt(other.0.clone()))
}
fn gte(&self, other: &Self) -> Self {
Self(self.0.clone().gt_eq(other.0.clone()))
}
// ── logical ──────────────────────────────────────────────────────────────
fn and_(&self, other: &Self) -> Self {
Self(self.0.clone().and(other.0.clone()))
}
fn or_(&self, other: &Self) -> Self {
Self(self.0.clone().or(other.0.clone()))
}
fn not_(&self) -> Self {
use std::ops::Not;
Self(self.0.clone().not())
}
// ── arithmetic ───────────────────────────────────────────────────────────
fn add(&self, other: &Self) -> Self {
use std::ops::Add;
Self(self.0.clone().add(other.0.clone()))
}
fn sub(&self, other: &Self) -> Self {
use std::ops::Sub;
Self(self.0.clone().sub(other.0.clone()))
}
fn mul(&self, other: &Self) -> Self {
use std::ops::Mul;
Self(self.0.clone().mul(other.0.clone()))
}
fn div(&self, other: &Self) -> Self {
use std::ops::Div;
Self(self.0.clone().div(other.0.clone()))
}
// ── string functions ─────────────────────────────────────────────────────
/// Convert string column to lowercase.
fn lower(&self) -> Self {
Self(lower(self.0.clone()))
}
/// Convert string column to uppercase.
fn upper(&self) -> Self {
Self(upper(self.0.clone()))
}
/// Test whether the string contains `substr`.
fn contains(&self, substr: &Self) -> Self {
Self(contains(self.0.clone(), substr.0.clone()))
}
// ── type cast ────────────────────────────────────────────────────────────
/// Cast the expression to `data_type`.
///
/// `data_type` must be a PyArrow `DataType` (e.g. `pa.int32()`).
/// On the Python side, `lancedb.expr.Expr.cast` also accepts type name
/// strings via `pa.lib.ensure_type` before forwarding here.
fn cast(&self, data_type: PyArrowType<DataType>) -> Self {
Self(expr_cast(self.0.clone(), data_type.0))
}
// ── utilities ────────────────────────────────────────────────────────────
/// Render the expression as a SQL string (useful for debugging).
fn to_sql(&self) -> PyResult<String> {
lancedb::expr::expr_to_sql_string(&self.0).map_err(|e| PyValueError::new_err(e.to_string()))
}
fn __repr__(&self) -> PyResult<String> {
let sql =
lancedb::expr::expr_to_sql_string(&self.0).unwrap_or_else(|_| "<expr>".to_string());
Ok(format!("PyExpr({})", sql))
}
}
// ── free functions ────────────────────────────────────────────────────────────
/// Create a column reference expression.
///
/// The column name is preserved exactly as given (case-sensitive), so
/// `col("firstName")` correctly references a field named `firstName`.
#[pyfunction]
pub fn expr_col(name: &str) -> PyExpr {
PyExpr(ldb_col(name))
}
/// Create a literal value expression.
///
/// Supported Python types: `bool`, `int`, `float`, `str`.
#[pyfunction]
pub fn expr_lit(value: Bound<'_, PyAny>) -> PyResult<PyExpr> {
// bool must be checked before int because bool is a subclass of int in Python
if let Ok(b) = value.extract::<bool>() {
return Ok(PyExpr(df_lit(b)));
}
if let Ok(i) = value.extract::<i64>() {
return Ok(PyExpr(df_lit(i)));
}
if let Ok(f) = value.extract::<f64>() {
return Ok(PyExpr(df_lit(f)));
}
if let Ok(s) = value.extract::<String>() {
return Ok(PyExpr(df_lit(s)));
}
Err(PyValueError::new_err(format!(
"unsupported literal type: {}. Supported: bool, int, float, str",
value.get_type().name()?
)))
}
/// Call an arbitrary registered SQL function by name.
///
/// See `lancedb::expr::func` for the list of supported function names.
#[pyfunction]
pub fn expr_func(name: &str, args: Vec<PyExpr>) -> PyResult<PyExpr> {
let df_args: Vec<DfExpr> = args.into_iter().map(|e| e.0).collect();
lancedb::expr::func(name, df_args)
.map(PyExpr)
.map_err(|e| PyValueError::new_err(e.to_string()))
}

View File

@@ -4,6 +4,7 @@
use arrow::RecordBatchStream;
use connection::{Connection, connect};
use env_logger::Env;
use expr::{PyExpr, expr_col, expr_func, expr_lit};
use index::IndexConfig;
use permutation::{PyAsyncPermutationBuilder, PyPermutationReader};
use pyo3::{
@@ -21,6 +22,7 @@ use table::{
pub mod arrow;
pub mod connection;
pub mod error;
pub mod expr;
pub mod header;
pub mod index;
pub mod namespace;
@@ -55,10 +57,14 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<UpdateResult>()?;
m.add_class::<PyAsyncPermutationBuilder>()?;
m.add_class::<PyPermutationReader>()?;
m.add_class::<PyExpr>()?;
m.add_function(wrap_pyfunction!(connect, m)?)?;
m.add_function(wrap_pyfunction!(permutation::async_permutation_builder, m)?)?;
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;
m.add_function(wrap_pyfunction!(query::fts_query_to_json, m)?)?;
m.add_function(wrap_pyfunction!(expr_col, m)?)?;
m.add_function(wrap_pyfunction!(expr_lit, m)?)?;
m.add_function(wrap_pyfunction!(expr_func, m)?)?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
Ok(())
}

View File

@@ -35,12 +35,10 @@ use pyo3::types::PyList;
use pyo3::types::{PyDict, PyString};
use pyo3::{FromPyObject, exceptions::PyRuntimeError};
use pyo3::{PyErr, pyclass};
use pyo3::{
exceptions::{PyNotImplementedError, PyValueError},
intern,
};
use pyo3::{exceptions::PyValueError, intern};
use pyo3_async_runtimes::tokio::future_into_py;
use crate::expr::PyExpr;
use crate::util::parse_distance_type;
use crate::{arrow::RecordBatchStream, util::PyLanceDB};
use crate::{error::PythonErrorExt, index::class_name};
@@ -344,9 +342,13 @@ impl<'py> IntoPyObject<'py> for PyQueryFilter {
fn into_pyobject(self, py: pyo3::Python<'py>) -> PyResult<Self::Output> {
match self.0 {
QueryFilter::Datafusion(_) => Err(PyNotImplementedError::new_err(
"Datafusion filter has no conversion to Python",
)),
QueryFilter::Datafusion(expr) => {
// Serialize the DataFusion expression to a SQL string so that
// callers (e.g. remote tables) see the same format as Sql.
let sql = lancedb::expr::expr_to_sql_string(&expr)
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
Ok(sql.into_pyobject(py)?.into_any())
}
QueryFilter::Sql(sql) => Ok(sql.into_pyobject(py)?.into_any()),
QueryFilter::Substrait(substrait) => Ok(substrait.into_pyobject(py)?.into_any()),
}
@@ -370,10 +372,20 @@ impl Query {
self.inner = self.inner.clone().only_if(predicate);
}
pub fn where_expr(&mut self, expr: PyExpr) {
self.inner = self.inner.clone().only_if_expr(expr.0);
}
pub fn select(&mut self, columns: Vec<(String, String)>) {
self.inner = self.inner.clone().select(Select::dynamic(&columns));
}
pub fn select_expr(&mut self, columns: Vec<(String, PyExpr)>) {
let pairs: Vec<(String, lancedb::expr::DfExpr)> =
columns.into_iter().map(|(name, e)| (name, e.0)).collect();
self.inner = self.inner.clone().select(Select::Expr(pairs));
}
pub fn select_columns(&mut self, columns: Vec<String>) {
self.inner = self.inner.clone().select(Select::columns(&columns));
}
@@ -607,10 +619,20 @@ impl FTSQuery {
self.inner = self.inner.clone().only_if(predicate);
}
pub fn where_expr(&mut self, expr: PyExpr) {
self.inner = self.inner.clone().only_if_expr(expr.0);
}
pub fn select(&mut self, columns: Vec<(String, String)>) {
self.inner = self.inner.clone().select(Select::dynamic(&columns));
}
pub fn select_expr(&mut self, columns: Vec<(String, PyExpr)>) {
let pairs: Vec<(String, lancedb::expr::DfExpr)> =
columns.into_iter().map(|(name, e)| (name, e.0)).collect();
self.inner = self.inner.clone().select(Select::Expr(pairs));
}
pub fn select_columns(&mut self, columns: Vec<String>) {
self.inner = self.inner.clone().select(Select::columns(&columns));
}
@@ -725,6 +747,10 @@ impl VectorQuery {
self.inner = self.inner.clone().only_if(predicate);
}
pub fn where_expr(&mut self, expr: PyExpr) {
self.inner = self.inner.clone().only_if_expr(expr.0);
}
pub fn add_query_vector(&mut self, vector: Bound<'_, PyAny>) -> PyResult<()> {
let data: ArrayData = ArrayData::from_pyarrow_bound(&vector)?;
let array = make_array(data);
@@ -736,6 +762,12 @@ impl VectorQuery {
self.inner = self.inner.clone().select(Select::dynamic(&columns));
}
pub fn select_expr(&mut self, columns: Vec<(String, PyExpr)>) {
let pairs: Vec<(String, lancedb::expr::DfExpr)> =
columns.into_iter().map(|(name, e)| (name, e.0)).collect();
self.inner = self.inner.clone().select(Select::Expr(pairs));
}
pub fn select_columns(&mut self, columns: Vec<String>) {
self.inner = self.inner.clone().select(Select::columns(&columns));
}
@@ -890,11 +922,21 @@ impl HybridQuery {
self.inner_fts.r#where(predicate);
}
pub fn where_expr(&mut self, expr: PyExpr) {
self.inner_vec.where_expr(expr.clone());
self.inner_fts.where_expr(expr);
}
pub fn select(&mut self, columns: Vec<(String, String)>) {
self.inner_vec.select(columns.clone());
self.inner_fts.select(columns);
}
pub fn select_expr(&mut self, columns: Vec<(String, PyExpr)>) {
self.inner_vec.select_expr(columns.clone());
self.inner_fts.select_expr(columns);
}
pub fn select_columns(&mut self, columns: Vec<String>) {
self.inner_vec.select_columns(columns.clone());
self.inner_fts.select_columns(columns);

387
python/tests/test_expr.py Normal file
View File

@@ -0,0 +1,387 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
"""Tests for the type-safe expression builder API."""
import pytest
import pyarrow as pa
import lancedb
from lancedb.expr import Expr, col, lit, func
# ── unit tests for Expr construction ─────────────────────────────────────────
class TestExprConstruction:
def test_col_returns_expr(self):
e = col("age")
assert isinstance(e, Expr)
def test_lit_int(self):
e = lit(42)
assert isinstance(e, Expr)
def test_lit_float(self):
e = lit(3.14)
assert isinstance(e, Expr)
def test_lit_str(self):
e = lit("hello")
assert isinstance(e, Expr)
def test_lit_bool(self):
e = lit(True)
assert isinstance(e, Expr)
def test_lit_unsupported_type_raises(self):
with pytest.raises(Exception):
lit([1, 2, 3])
def test_func(self):
e = func("lower", col("name"))
assert isinstance(e, Expr)
assert e.to_sql() == "lower(name)"
def test_func_unknown_raises(self):
with pytest.raises(Exception):
func("not_a_real_function", col("x"))
class TestExprOperators:
def test_eq_operator(self):
e = col("x") == lit(1)
assert isinstance(e, Expr)
assert e.to_sql() == "(x = 1)"
def test_ne_operator(self):
e = col("x") != lit(1)
assert isinstance(e, Expr)
assert e.to_sql() == "(x <> 1)"
def test_lt_operator(self):
e = col("age") < lit(18)
assert isinstance(e, Expr)
assert e.to_sql() == "(age < 18)"
def test_le_operator(self):
e = col("age") <= lit(18)
assert isinstance(e, Expr)
assert e.to_sql() == "(age <= 18)"
def test_gt_operator(self):
e = col("age") > lit(18)
assert isinstance(e, Expr)
assert e.to_sql() == "(age > 18)"
def test_ge_operator(self):
e = col("age") >= lit(18)
assert isinstance(e, Expr)
assert e.to_sql() == "(age >= 18)"
def test_and_operator(self):
e = (col("age") > lit(18)) & (col("status") == lit("active"))
assert isinstance(e, Expr)
assert e.to_sql() == "((age > 18) AND (status = 'active'))"
def test_or_operator(self):
e = (col("a") == lit(1)) | (col("b") == lit(2))
assert isinstance(e, Expr)
assert e.to_sql() == "((a = 1) OR (b = 2))"
def test_invert_operator(self):
e = ~(col("active") == lit(True))
assert isinstance(e, Expr)
assert e.to_sql() == "NOT (active = true)"
def test_add_operator(self):
e = col("x") + lit(1)
assert isinstance(e, Expr)
assert e.to_sql() == "(x + 1)"
def test_sub_operator(self):
e = col("x") - lit(1)
assert isinstance(e, Expr)
assert e.to_sql() == "(x - 1)"
def test_mul_operator(self):
e = col("price") * lit(1.1)
assert isinstance(e, Expr)
assert e.to_sql() == "(price * 1.1)"
def test_div_operator(self):
e = col("total") / lit(2)
assert isinstance(e, Expr)
assert e.to_sql() == "(total / 2)"
def test_radd(self):
e = lit(1) + col("x")
assert isinstance(e, Expr)
assert e.to_sql() == "(1 + x)"
def test_rmul(self):
e = lit(2) * col("x")
assert isinstance(e, Expr)
assert e.to_sql() == "(2 * x)"
def test_coerce_plain_int(self):
# Operators should auto-wrap plain Python values via lit()
e = col("age") > 18
assert isinstance(e, Expr)
assert e.to_sql() == "(age > 18)"
def test_coerce_plain_str(self):
e = col("name") == "alice"
assert isinstance(e, Expr)
assert e.to_sql() == "(name = 'alice')"
class TestExprStringMethods:
def test_lower(self):
e = col("name").lower()
assert isinstance(e, Expr)
assert e.to_sql() == "lower(name)"
def test_upper(self):
e = col("name").upper()
assert isinstance(e, Expr)
assert e.to_sql() == "upper(name)"
def test_contains(self):
e = col("text").contains(lit("hello"))
assert isinstance(e, Expr)
assert e.to_sql() == "contains(text, 'hello')"
def test_contains_with_str_coerce(self):
e = col("text").contains("hello")
assert isinstance(e, Expr)
assert e.to_sql() == "contains(text, 'hello')"
def test_chained_lower_eq(self):
e = col("name").lower() == lit("alice")
assert isinstance(e, Expr)
assert e.to_sql() == "(lower(name) = 'alice')"
class TestExprCast:
def test_cast_string(self):
e = col("id").cast("string")
assert isinstance(e, Expr)
assert e.to_sql() == "CAST(id AS VARCHAR)"
def test_cast_int32(self):
e = col("score").cast("int32")
assert isinstance(e, Expr)
assert e.to_sql() == "CAST(score AS INTEGER)"
def test_cast_float64(self):
e = col("val").cast("float64")
assert isinstance(e, Expr)
assert e.to_sql() == "CAST(val AS DOUBLE)"
def test_cast_pyarrow_type(self):
e = col("score").cast(pa.int32())
assert isinstance(e, Expr)
assert e.to_sql() == "CAST(score AS INTEGER)"
def test_cast_pyarrow_float64(self):
e = col("val").cast(pa.float64())
assert isinstance(e, Expr)
assert e.to_sql() == "CAST(val AS DOUBLE)"
def test_cast_pyarrow_string(self):
e = col("id").cast(pa.string())
assert isinstance(e, Expr)
assert e.to_sql() == "CAST(id AS VARCHAR)"
def test_cast_pyarrow_and_string_equivalent(self):
# pa.int32() and "int32" should produce equivalent SQL
sql_str = col("x").cast("int32").to_sql()
sql_pa = col("x").cast(pa.int32()).to_sql()
assert sql_str == sql_pa
class TestExprNamedMethods:
def test_eq_method(self):
e = col("x").eq(lit(1))
assert isinstance(e, Expr)
assert e.to_sql() == "(x = 1)"
def test_gt_method(self):
e = col("x").gt(lit(0))
assert isinstance(e, Expr)
assert e.to_sql() == "(x > 0)"
def test_and_method(self):
e = col("x").gt(lit(0)).and_(col("y").lt(lit(10)))
assert isinstance(e, Expr)
assert e.to_sql() == "((x > 0) AND (y < 10))"
def test_or_method(self):
e = col("x").eq(lit(1)).or_(col("x").eq(lit(2)))
assert isinstance(e, Expr)
assert e.to_sql() == "((x = 1) OR (x = 2))"
class TestExprRepr:
def test_repr(self):
e = col("age") > lit(18)
assert repr(e) == "Expr((age > 18))"
def test_to_sql(self):
e = col("age") > 18
assert e.to_sql() == "(age > 18)"
def test_unhashable(self):
e = col("x")
with pytest.raises(TypeError):
{e: 1}
# ── integration tests: end-to-end query against a real table ─────────────────
@pytest.fixture
def simple_table(tmp_path):
db = lancedb.connect(str(tmp_path))
data = pa.table(
{
"id": [1, 2, 3, 4, 5],
"name": ["Alice", "Bob", "Charlie", "alice", "BOB"],
"age": [25, 17, 30, 22, 15],
"score": [1.5, 2.0, 3.5, 4.0, 0.5],
}
)
return db.create_table("test", data)
class TestExprFilter:
def test_simple_gt_filter(self, simple_table):
result = simple_table.search().where(col("age") > lit(20)).to_arrow()
assert result.num_rows == 3 # ages 25, 30, 22
def test_compound_and_filter(self, simple_table):
result = (
simple_table.search()
.where((col("age") > lit(18)) & (col("score") > lit(2.0)))
.to_arrow()
)
assert result.num_rows == 2 # (30, 3.5) and (22, 4.0)
def test_string_equality_filter(self, simple_table):
result = simple_table.search().where(col("name") == lit("Bob")).to_arrow()
assert result.num_rows == 1
def test_or_filter(self, simple_table):
result = (
simple_table.search()
.where((col("age") < lit(18)) | (col("age") > lit(28)))
.to_arrow()
)
assert result.num_rows == 3 # ages 17, 30, 15
def test_coercion_no_lit(self, simple_table):
# Python values should be auto-coerced
result = simple_table.search().where(col("age") > 20).to_arrow()
assert result.num_rows == 3
def test_string_sql_still_works(self, simple_table):
# Backwards compatibility: plain strings still accepted
result = simple_table.search().where("age > 20").to_arrow()
assert result.num_rows == 3
class TestExprProjection:
def test_select_with_expr(self, simple_table):
result = (
simple_table.search()
.select({"double_score": col("score") * lit(2)})
.to_arrow()
)
assert "double_score" in result.schema.names
def test_select_mixed_str_and_expr(self, simple_table):
result = (
simple_table.search()
.select({"id": "id", "double_score": col("score") * lit(2)})
.to_arrow()
)
assert "id" in result.schema.names
assert "double_score" in result.schema.names
def test_select_list_of_columns(self, simple_table):
# Plain list of str still works
result = simple_table.search().select(["id", "name"]).to_arrow()
assert result.schema.names == ["id", "name"]
# ── column name edge cases ────────────────────────────────────────────────────
class TestColNaming:
"""Unit tests verifying that col() preserves identifiers exactly.
Identifiers that need quoting (camelCase, spaces, leading digits, unicode)
are wrapped in backticks to match the lance SQL parser's dialect.
"""
def test_camel_case_preserved_in_sql(self):
# camelCase is quoted with backticks so the case round-trips correctly.
assert col("firstName").to_sql() == "`firstName`"
def test_camel_case_in_expression(self):
assert (col("firstName") > lit(18)).to_sql() == "(`firstName` > 18)"
def test_space_in_name_quoted(self):
assert col("first name").to_sql() == "`first name`"
def test_space_in_expression(self):
assert (col("first name") == lit("A")).to_sql() == "(`first name` = 'A')"
def test_leading_digit_quoted(self):
assert col("2fast").to_sql() == "`2fast`"
def test_unicode_quoted(self):
assert col("名前").to_sql() == "`名前`"
def test_snake_case_unquoted(self):
# Plain snake_case needs no quoting.
assert col("first_name").to_sql() == "first_name"
@pytest.fixture
def special_col_table(tmp_path):
db = lancedb.connect(str(tmp_path))
data = pa.table(
{
"firstName": ["Alice", "Bob", "Charlie"],
"first name": ["A", "B", "C"],
"score": [10, 20, 30],
}
)
return db.create_table("special", data)
class TestColNamingIntegration:
def test_camel_case_filter(self, special_col_table):
result = (
special_col_table.search()
.where(col("firstName") == lit("Alice"))
.to_arrow()
)
assert result.num_rows == 1
assert result["firstName"][0].as_py() == "Alice"
def test_space_in_col_filter(self, special_col_table):
result = (
special_col_table.search().where(col("first name") == lit("B")).to_arrow()
)
assert result.num_rows == 1
def test_camel_case_projection(self, special_col_table):
result = (
special_col_table.search()
.select({"upper_name": col("firstName").upper()})
.to_arrow()
)
assert "upper_name" in result.schema.names
assert sorted(result["upper_name"].to_pylist()) == ["ALICE", "BOB", "CHARLIE"]

View File

@@ -27,7 +27,17 @@ use arrow_schema::DataType;
use datafusion_expr::{Expr, ScalarUDF, expr_fn::cast};
use datafusion_functions::string::expr_fn as string_expr_fn;
pub use datafusion_expr::{col, lit};
pub use datafusion_expr::lit;
/// Create a column reference expression, preserving the name exactly as given.
///
/// Unlike DataFusion's built-in [`col`][datafusion_expr::col], this function
/// does **not** normalise the identifier to lower-case, so
/// `col("firstName")` correctly references a field named `firstName`.
pub fn col(name: impl Into<String>) -> DfExpr {
use datafusion_common::Column;
DfExpr::Column(Column::new_unqualified(name))
}
pub use datafusion_expr::Expr as DfExpr;

View File

@@ -2,11 +2,37 @@
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use datafusion_expr::Expr;
use datafusion_sql::unparser;
use datafusion_sql::unparser::{self, dialect::Dialect};
/// Unparser dialect that matches the quoting style expected by the Lance SQL
/// parser. Lance uses backtick (`` ` ``) as the only delimited-identifier
/// quote character, so we must produce `` `firstName` `` rather than
/// `"firstName"` for identifiers that require quoting.
///
/// We quote an identifier when it:
/// * is a SQL reserved word, OR
/// * contains characters outside `[a-zA-Z0-9_]`, OR
/// * starts with a digit, OR
/// * contains upper-case letters (unquoted identifiers are normalised to
/// lower-case by the SQL parser, which would break case-sensitive schemas).
struct LanceSqlDialect;
impl Dialect for LanceSqlDialect {
fn identifier_quote_style(&self, identifier: &str) -> Option<char> {
let needs_quote = identifier.chars().any(|c| c.is_ascii_uppercase())
|| !identifier
.chars()
.enumerate()
.all(|(i, c)| c == '_' || c.is_ascii_alphabetic() || (i > 0 && c.is_ascii_digit()));
if needs_quote { Some('`') } else { None }
}
}
pub fn expr_to_sql_string(expr: &Expr) -> crate::Result<String> {
let ast = unparser::expr_to_sql(expr).map_err(|e| crate::Error::InvalidInput {
message: format!("failed to serialize expression to SQL: {}", e),
})?;
let ast = unparser::Unparser::new(&LanceSqlDialect)
.expr_to_sql(expr)
.map_err(|e| crate::Error::InvalidInput {
message: format!("failed to serialize expression to SQL: {}", e),
})?;
Ok(ast.to_string())
}

View File

@@ -5,7 +5,7 @@ use std::sync::Arc;
use std::{future::Future, time::Duration};
use arrow::compute::concat_batches;
use arrow_array::{Array, Float16Array, Float32Array, Float64Array, make_array};
use arrow_array::{Array, Float16Array, Float32Array, Float64Array, RecordBatch, make_array};
use arrow_schema::{DataType, SchemaRef};
use datafusion_expr::Expr;
use datafusion_physical_plan::ExecutionPlan;
@@ -17,15 +17,17 @@ use lance_datafusion::exec::execute_plan;
use lance_index::scalar::FullTextSearchQuery;
use lance_index::scalar::inverted::SCORE_COL;
use lance_index::vector::DIST_COL;
use lance_io::stream::RecordBatchStreamAdapter;
use crate::DistanceType;
use crate::error::{Error, Result};
use crate::rerankers::rrf::RRFReranker;
use crate::rerankers::{NormalizeMethod, Reranker, check_reranker_result};
use crate::table::BaseTable;
use crate::utils::TimeoutStream;
use crate::{arrow::SendableRecordBatchStream, table::AnyQuery};
use crate::utils::{MaxBatchLengthStream, TimeoutStream};
use crate::{
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
table::AnyQuery,
};
mod hybrid;
@@ -604,6 +606,14 @@ impl Default for QueryExecutionOptions {
}
}
impl QueryExecutionOptions {
fn without_output_batch_length_limit(&self) -> Self {
let mut options = self.clone();
options.max_batch_length = 0;
options
}
}
/// A trait for a query object that can be executed to get results
///
/// There are various kinds of queries but they all return results
@@ -1180,6 +1190,8 @@ impl VectorQuery {
&self,
options: QueryExecutionOptions,
) -> Result<SendableRecordBatchStream> {
let max_batch_length = options.max_batch_length as usize;
let internal_options = options.without_output_batch_length_limit();
// clone query and specify we want to include row IDs, which can be needed for reranking
let mut fts_query = Query::new(self.parent.clone());
fts_query.request = self.request.base.clone();
@@ -1189,8 +1201,8 @@ impl VectorQuery {
vector_query.request.base.full_text_search = None;
let (fts_results, vec_results) = try_join!(
fts_query.execute_with_options(options.clone()),
vector_query.inner_execute_with_options(options)
fts_query.execute_with_options(internal_options.clone()),
vector_query.inner_execute_with_options(internal_options)
)?;
let (fts_results, vec_results) = try_join!(
@@ -1245,9 +1257,7 @@ impl VectorQuery {
results = results.drop_column(ROW_ID)?;
}
Ok(SendableRecordBatchStream::from(
RecordBatchStreamAdapter::new(results.schema(), stream::iter([Ok(results)])),
))
Ok(single_batch_stream(results, max_batch_length))
}
async fn inner_execute_with_options(
@@ -1256,6 +1266,7 @@ impl VectorQuery {
) -> Result<SendableRecordBatchStream> {
let plan = self.create_plan(options.clone()).await?;
let inner = execute_plan(plan, Default::default())?;
let inner = MaxBatchLengthStream::new_boxed(inner, options.max_batch_length as usize);
let inner = if let Some(timeout) = options.timeout {
TimeoutStream::new_boxed(inner, timeout)
} else {
@@ -1265,6 +1276,25 @@ impl VectorQuery {
}
}
fn single_batch_stream(batch: RecordBatch, max_batch_length: usize) -> SendableRecordBatchStream {
let schema = batch.schema();
if max_batch_length == 0 || batch.num_rows() <= max_batch_length {
return Box::pin(SimpleRecordBatchStream::new(
stream::iter([Ok(batch)]),
schema,
));
}
let mut batches = Vec::with_capacity(batch.num_rows().div_ceil(max_batch_length));
let mut offset = 0;
while offset < batch.num_rows() {
let length = (batch.num_rows() - offset).min(max_batch_length);
batches.push(Ok(batch.slice(offset, length)));
offset += length;
}
Box::pin(SimpleRecordBatchStream::new(stream::iter(batches), schema))
}
impl ExecutableQuery for VectorQuery {
async fn create_plan(&self, options: QueryExecutionOptions) -> Result<Arc<dyn ExecutionPlan>> {
let query = AnyQuery::VectorQuery(self.request.clone());
@@ -1753,6 +1783,50 @@ mod tests {
.unwrap()
}
async fn make_large_vector_table(tmp_dir: &tempfile::TempDir, rows: usize) -> Table {
let dataset_path = tmp_dir.path().join("large_test.lance");
let uri = dataset_path.to_str().unwrap();
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("id", DataType::Utf8, false),
ArrowField::new(
"vector",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float32, true)),
4,
),
false,
),
]));
let ids = StringArray::from_iter_values((0..rows).map(|i| format!("row-{i}")));
let vectors = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
(0..rows).map(|i| Some(vec![Some(i as f32), Some(1.0), Some(2.0), Some(3.0)])),
4,
);
let batch =
RecordBatch::try_new(schema.clone(), vec![Arc::new(ids), Arc::new(vectors)]).unwrap();
let conn = connect(uri).execute().await.unwrap();
conn.create_table("my_table", vec![batch])
.execute()
.await
.unwrap()
}
async fn assert_stream_batches_at_most(
mut results: SendableRecordBatchStream,
max_batch_length: usize,
) {
let mut saw_batch = false;
while let Some(batch) = results.next().await {
let batch = batch.unwrap();
saw_batch = true;
assert!(batch.num_rows() <= max_batch_length);
}
assert!(saw_batch);
}
#[tokio::test]
async fn test_execute_with_options() {
let tmp_dir = tempdir().unwrap();
@@ -1772,6 +1846,83 @@ mod tests {
}
}
#[tokio::test]
async fn test_vector_query_execute_with_options_respects_max_batch_length() {
let tmp_dir = tempdir().unwrap();
let table = make_large_vector_table(&tmp_dir, 10_000).await;
let results = table
.query()
.nearest_to(vec![0.0, 1.0, 2.0, 3.0])
.unwrap()
.limit(10_000)
.execute_with_options(QueryExecutionOptions {
max_batch_length: 100,
..Default::default()
})
.await
.unwrap();
assert_stream_batches_at_most(results, 100).await;
}
#[tokio::test]
async fn test_hybrid_query_execute_with_options_respects_max_batch_length() {
let tmp_dir = tempdir().unwrap();
let dataset_path = tmp_dir.path();
let conn = connect(dataset_path.to_str().unwrap())
.execute()
.await
.unwrap();
let dims = 2;
let rows = 512;
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("text", DataType::Utf8, false),
ArrowField::new(
"vector",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float32, true)),
dims,
),
false,
),
]));
let text = StringArray::from_iter_values((0..rows).map(|_| "match"));
let vectors = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
(0..rows).map(|i| Some(vec![Some(i as f32), Some(0.0)])),
dims,
);
let record_batch =
RecordBatch::try_new(schema.clone(), vec![Arc::new(text), Arc::new(vectors)]).unwrap();
let table = conn
.create_table("my_table", record_batch)
.execute()
.await
.unwrap();
table
.create_index(&["text"], crate::index::Index::FTS(Default::default()))
.replace(true)
.execute()
.await
.unwrap();
let results = table
.query()
.full_text_search(FullTextSearchQuery::new("match".to_string()))
.limit(rows)
.nearest_to(&[0.0, 0.0])
.unwrap()
.execute_with_options(QueryExecutionOptions {
max_batch_length: 100,
..Default::default()
})
.await
.unwrap();
assert_stream_batches_at_most(results, 100).await;
}
#[tokio::test]
async fn test_analyze_plan() {
let tmp_dir = tempdir().unwrap();

View File

@@ -19,11 +19,11 @@ pub use lance::dataset::Version;
use lance::dataset::WriteMode;
use lance::dataset::builder::DatasetBuilder;
use lance::dataset::{InsertBuilder, WriteParams};
use lance::index::DatasetIndexExt;
use lance::index::vector::VectorIndexParams;
use lance::index::vector::utils::infer_vector_dim;
use lance::io::{ObjectStoreParams, WrappingObjectStore};
use lance_datafusion::utils::StreamingWriteSource;
use lance_index::DatasetIndexExt;
use lance_index::IndexType;
use lance_index::scalar::{BuiltinIndexType, ScalarIndexParams};
use lance_index::vector::bq::RQBuildParams;

View File

@@ -10,7 +10,7 @@ use std::sync::Arc;
use lance::dataset::cleanup::RemovalStats;
use lance::dataset::optimize::{CompactionMetrics, IndexRemapperOptions, compact_files};
use lance::index::DatasetIndexExt;
use lance_index::DatasetIndexExt;
use lance_index::optimize::OptimizeOptions;
use log::info;

View File

@@ -9,7 +9,7 @@ use crate::expr::expr_to_sql_string;
use crate::query::{
DEFAULT_TOP_K, QueryExecutionOptions, QueryFilter, QueryRequest, Select, VectorQueryRequest,
};
use crate::utils::{TimeoutStream, default_vector_column};
use crate::utils::{MaxBatchLengthStream, TimeoutStream, default_vector_column};
use arrow::array::{AsArray, FixedSizeListBuilder, Float32Builder};
use arrow::datatypes::{Float32Type, UInt8Type};
use arrow_array::Array;
@@ -66,6 +66,7 @@ async fn execute_generic_query(
) -> Result<DatasetRecordBatchStream> {
let plan = create_plan(table, query, options.clone()).await?;
let inner = execute_plan(plan, Default::default())?;
let inner = MaxBatchLengthStream::new_boxed(inner, options.max_batch_length as usize);
let inner = if let Some(timeout) = options.timeout {
TimeoutStream::new_boxed(inner, timeout)
} else {
@@ -200,7 +201,9 @@ pub async fn create_plan(
scanner.with_row_id();
}
scanner.batch_size(options.max_batch_length as usize);
if options.max_batch_length > 0 {
scanner.batch_size(options.max_batch_length as usize);
}
if query.base.fast_search {
scanner.fast_search();

View File

@@ -335,6 +335,85 @@ impl Stream for TimeoutStream {
}
}
/// A `Stream` wrapper that slices oversized batches to enforce a maximum batch length.
pub struct MaxBatchLengthStream {
inner: SendableRecordBatchStream,
max_batch_length: Option<usize>,
buffered_batch: Option<RecordBatch>,
buffered_offset: usize,
}
impl MaxBatchLengthStream {
pub fn new(inner: SendableRecordBatchStream, max_batch_length: usize) -> Self {
Self {
inner,
max_batch_length: (max_batch_length > 0).then_some(max_batch_length),
buffered_batch: None,
buffered_offset: 0,
}
}
pub fn new_boxed(
inner: SendableRecordBatchStream,
max_batch_length: usize,
) -> SendableRecordBatchStream {
if max_batch_length == 0 {
inner
} else {
Box::pin(Self::new(inner, max_batch_length))
}
}
}
impl RecordBatchStream for MaxBatchLengthStream {
fn schema(&self) -> SchemaRef {
self.inner.schema()
}
}
impl Stream for MaxBatchLengthStream {
type Item = DataFusionResult<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
loop {
let Some(max_batch_length) = self.max_batch_length else {
return Pin::new(&mut self.inner).poll_next(cx);
};
if let Some(batch) = self.buffered_batch.clone() {
if self.buffered_offset < batch.num_rows() {
let remaining = batch.num_rows() - self.buffered_offset;
let length = remaining.min(max_batch_length);
let sliced = batch.slice(self.buffered_offset, length);
self.buffered_offset += length;
if self.buffered_offset >= batch.num_rows() {
self.buffered_batch = None;
self.buffered_offset = 0;
}
return std::task::Poll::Ready(Some(Ok(sliced)));
}
self.buffered_batch = None;
self.buffered_offset = 0;
}
match Pin::new(&mut self.inner).poll_next(cx) {
std::task::Poll::Ready(Some(Ok(batch))) => {
if batch.num_rows() <= max_batch_length {
return std::task::Poll::Ready(Some(Ok(batch)));
}
self.buffered_batch = Some(batch);
self.buffered_offset = 0;
}
other => return other,
}
}
}
}
#[cfg(test)]
mod tests {
use arrow_array::Int32Array;
@@ -470,7 +549,7 @@ mod tests {
assert_eq!(string_to_datatype(string), Some(expected));
}
fn sample_batch() -> RecordBatch {
fn sample_batch(num_rows: i32) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new(
"col1",
DataType::Int32,
@@ -478,14 +557,14 @@ mod tests {
)]));
RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
vec![Arc::new(Int32Array::from_iter_values(0..num_rows))],
)
.unwrap()
}
#[tokio::test]
async fn test_timeout_stream() {
let batch = sample_batch();
let batch = sample_batch(3);
let schema = batch.schema();
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
@@ -515,7 +594,7 @@ mod tests {
#[tokio::test]
async fn test_timeout_stream_zero_duration() {
let batch = sample_batch();
let batch = sample_batch(3);
let schema = batch.schema();
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
@@ -534,7 +613,7 @@ mod tests {
#[tokio::test]
async fn test_timeout_stream_completes_normally() {
let batch = sample_batch();
let batch = sample_batch(3);
let schema = batch.schema();
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
@@ -552,4 +631,35 @@ mod tests {
// Stream should be empty now
assert!(timeout_stream.next().await.is_none());
}
async fn collect_batch_sizes(
stream: SendableRecordBatchStream,
max_batch_length: usize,
) -> Vec<usize> {
let mut sliced_stream = MaxBatchLengthStream::new(stream, max_batch_length);
sliced_stream
.by_ref()
.map(|batch| batch.unwrap().num_rows())
.collect::<Vec<_>>()
.await
}
#[tokio::test]
async fn test_max_batch_length_stream_behaviors() {
let schema = sample_batch(7).schema();
let mock_stream = stream::iter(vec![Ok(sample_batch(2)), Ok(sample_batch(7))]);
let sendable_stream: SendableRecordBatchStream =
Box::pin(RecordBatchStreamAdapter::new(schema.clone(), mock_stream));
assert_eq!(
collect_batch_sizes(sendable_stream, 3).await,
vec![2, 3, 3, 1]
);
let sendable_stream: SendableRecordBatchStream = Box::pin(RecordBatchStreamAdapter::new(
schema,
stream::iter(vec![Ok(sample_batch(2)), Ok(sample_batch(7))]),
));
assert_eq!(collect_batch_sizes(sendable_stream, 0).await, vec![2, 7]);
}
}