Compare commits

..

3 Commits

Author SHA1 Message Date
Will Jones
c7f189f27b chore: upgrade lance to stable 4.0.0 (#3207)
Bumps all lance-* workspace dependencies from `4.0.0-rc.3` (git source)
to the stable `4.0.0` release on crates.io, removing the `git`/`tag`
overrides.

No code changes were required — compiles and passes clippy cleanly.

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-30 17:05:45 -07:00
yaommen
a0a2942ad5 fix: respect max_batch_length for Rust vector and hybrid queries (#3172)
Fixes #1540

I could not reproduce this on current `main` from Python, but I could
still reproduce it from the Rust SDK.

Python no longer reproduces because the current Python vector/hybrid
query paths re-chunk results into a `pyarrow.Table` before returning
batches. Rust still reproduced because `max_batch_length` was passed
into planning/scanning, but vector search could still emit larger
`RecordBatch`es later in execution (for example after KNN / TopK), so it
was not enforced on the final Rust output stream.

This PR enforces `max_batch_length` on the final Rust query output
stream and adds Rust regression coverage.

Before the fix, the Rust repro produced:
`num_batches=2, max_batch=8192, min_batch=1808, all_le_100=false`

After the fix, the same repro produces batches `<= 100`.

## Runnable Rust repro

Before this fix, current `main` could still return batches like `[8192,
1808]` here even with `max_batch_length = 100`:

```rust
use std::sync::Arc;

use arrow_array::{
    types::Float32Type, FixedSizeListArray, RecordBatch, RecordBatchReader, StringArray,
};
use arrow_schema::{DataType, Field, Schema};
use futures::TryStreamExt;
use lancedb::query::{ExecutableQuery, QueryBase, QueryExecutionOptions};

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    let tmp = tempfile::tempdir()?;
    let uri = tmp.path().to_str().unwrap();

    let rows = 10_000;
    let schema = Arc::new(Schema::new(vec![
        Field::new("id", DataType::Utf8, false),
        Field::new(
            "vector",
            DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4),
            false,
        ),
    ]));

    let ids = StringArray::from_iter_values((0..rows).map(|i| format!("row-{i}")));
    let vectors = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
        (0..rows).map(|i| Some(vec![Some(i as f32), Some(1.0), Some(2.0), Some(3.0)])),
        4,
    );
    let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(ids), Arc::new(vectors)])?;
    let reader: Box<dyn RecordBatchReader + Send> = Box::new(
        arrow_array::RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema),
    );

    let db = lancedb::connect(uri).execute().await?;
    let table = db.create_table("test", reader).execute().await?;

    let mut opts = QueryExecutionOptions::default();
    opts.max_batch_length = 100;

    let mut stream = table
        .query()
        .nearest_to(vec![0.0, 1.0, 2.0, 3.0])?
        .limit(rows)
        .execute_with_options(opts)
        .await?;

    let mut sizes = Vec::new();
    while let Some(batch) = stream.try_next().await? {
        sizes.push(batch.num_rows());
    }

    println!("{sizes:?}");
    Ok(())
}
```

Signed-off-by: yaommen <myanstu@163.com>
2026-03-30 15:43:58 -07:00
Will Jones
e3d53dd185 fix(python): skip test_url_retrieve_downloads_image when PIL not installed (#3208)
The test added in #3190 unconditionally imports `PIL`, which is an
optional dependency. This causes CI failures in environments where
Pillow isn't installed (`ModuleNotFoundError: No module named 'PIL'`).

Use `pytest.importorskip` to skip gracefully when Pillow is unavailable.

Fixes CI failure on main.

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-30 14:48:49 -07:00
9 changed files with 365 additions and 86 deletions

118
Cargo.lock generated
View File

@@ -108,7 +108,7 @@ version = "1.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc"
dependencies = [
"windows-sys 0.60.2",
"windows-sys 0.61.2",
]
[[package]]
@@ -119,7 +119,7 @@ checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d"
dependencies = [
"anstyle",
"once_cell_polyfill",
"windows-sys 0.60.2",
"windows-sys 0.61.2",
]
[[package]]
@@ -2682,7 +2682,7 @@ dependencies = [
"libc",
"option-ext",
"redox_users",
"windows-sys 0.60.2",
"windows-sys 0.61.2",
]
[[package]]
@@ -2876,7 +2876,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [
"libc",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -3072,8 +3072,9 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]]
name = "fsst"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2195cc7f87e84bd695586137de99605e7e9579b26ec5e01b82960ddb4d0922f2"
dependencies = [
"arrow-array",
"rand 0.9.2",
@@ -3736,7 +3737,7 @@ dependencies = [
"libc",
"percent-encoding",
"pin-project-lite",
"socket2 0.5.10",
"socket2 0.6.3",
"system-configuration",
"tokio",
"tower-service",
@@ -4037,7 +4038,7 @@ dependencies = [
"portable-atomic",
"portable-atomic-util",
"serde_core",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -4123,8 +4124,9 @@ dependencies = [
[[package]]
name = "lance"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "efe6c3ddd79cdfd2b7e1c23cafae52806906bc40fbd97de9e8cf2f8c7a75fc04"
dependencies = [
"arrow",
"arrow-arith",
@@ -4190,8 +4192,9 @@ dependencies = [
[[package]]
name = "lance-arrow"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d9f5d95bdda2a2b790f1fb8028b5b6dcf661abeb3133a8bca0f3d24b054af87"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -4211,8 +4214,9 @@ dependencies = [
[[package]]
name = "lance-bitpacking"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f827d6ab9f8f337a9509d5ad66a12f3314db8713868260521c344ef6135eb4e4"
dependencies = [
"arrayref",
"paste",
@@ -4221,8 +4225,9 @@ dependencies = [
[[package]]
name = "lance-core"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f1e25df6a79bf72ee6bcde0851f19b1cd36c5848c1b7db83340882d3c9fdecb"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -4259,8 +4264,9 @@ dependencies = [
[[package]]
name = "lance-datafusion"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93146de8ae720cb90edef81c2f2d0a1b065fc2f23ecff2419546f389b0fa70a4"
dependencies = [
"arrow",
"arrow-array",
@@ -4290,8 +4296,9 @@ dependencies = [
[[package]]
name = "lance-datagen"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ccec8ce4d8e0a87a99c431dab2364398029f2ffb649c1a693c60c79e05ed30dd"
dependencies = [
"arrow",
"arrow-array",
@@ -4309,8 +4316,9 @@ dependencies = [
[[package]]
name = "lance-encoding"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c1aec0bbbac6bce829bc10f1ba066258126100596c375fb71908ecf11c2c2a5"
dependencies = [
"arrow-arith",
"arrow-array",
@@ -4347,8 +4355,9 @@ dependencies = [
[[package]]
name = "lance-file"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14a8c548804f5b17486dc2d3282356ed1957095a852780283bc401fdd69e9075"
dependencies = [
"arrow-arith",
"arrow-array",
@@ -4380,8 +4389,9 @@ dependencies = [
[[package]]
name = "lance-index"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2da212f0090ea59f79ac3686660f596520c167fe1cb5f408900cf71d215f0e03"
dependencies = [
"arrow",
"arrow-arith",
@@ -4445,8 +4455,9 @@ dependencies = [
[[package]]
name = "lance-io"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41d958eb4b56f03bbe0f5f85eb2b4e9657882812297b6f711f201ffc995f259f"
dependencies = [
"arrow",
"arrow-arith",
@@ -4487,8 +4498,9 @@ dependencies = [
[[package]]
name = "lance-linalg"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0285b70da35def7ed95e150fae1d5308089554e1290470403ed3c50cb235bc5e"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -4504,8 +4516,9 @@ dependencies = [
[[package]]
name = "lance-namespace"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f78e2a828b654e062a495462c6e3eb4fcf0e7e907d761b8f217fc09ccd3ceac"
dependencies = [
"arrow",
"async-trait",
@@ -4518,8 +4531,9 @@ dependencies = [
[[package]]
name = "lance-namespace-impls"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2392314f3da38f00d166295e44244208a65ccfc256e274fa8631849fc3f4d94"
dependencies = [
"arrow",
"arrow-ipc",
@@ -4533,7 +4547,6 @@ dependencies = [
"lance-core",
"lance-index",
"lance-io",
"lance-linalg",
"lance-namespace",
"lance-table",
"log",
@@ -4564,8 +4577,9 @@ dependencies = [
[[package]]
name = "lance-table"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3df9c4adca3eb2074b3850432a9fb34248a3d90c3d6427d158b13ff9355664ee"
dependencies = [
"arrow",
"arrow-array",
@@ -4604,8 +4618,9 @@ dependencies = [
[[package]]
name = "lance-testing"
version = "5.0.0-beta.2"
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.2#34e311c7632f62d8e4ff3a6e8bd124f84f0b70dc"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ed7119bdd6983718387b4ac44af873a165262ca94f181b104cd6f97912eb3bf"
dependencies = [
"arrow-array",
"arrow-schema",
@@ -5308,7 +5323,7 @@ version = "0.50.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
dependencies = [
"windows-sys 0.60.2",
"windows-sys 0.61.2",
]
[[package]]
@@ -6287,7 +6302,7 @@ version = "0.14.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7"
dependencies = [
"heck 0.4.1",
"heck 0.5.0",
"itertools 0.14.0",
"log",
"multimap",
@@ -6474,7 +6489,7 @@ dependencies = [
"quinn-udp",
"rustc-hash",
"rustls 0.23.37",
"socket2 0.5.10",
"socket2 0.6.3",
"thiserror 2.0.18",
"tokio",
"tracing",
@@ -6511,9 +6526,9 @@ dependencies = [
"cfg_aliases",
"libc",
"once_cell",
"socket2 0.5.10",
"socket2 0.6.3",
"tracing",
"windows-sys 0.52.0",
"windows-sys 0.60.2",
]
[[package]]
@@ -7042,7 +7057,7 @@ dependencies = [
"errno",
"libc",
"linux-raw-sys 0.4.15",
"windows-sys 0.52.0",
"windows-sys 0.59.0",
]
[[package]]
@@ -7055,7 +7070,7 @@ dependencies = [
"errno",
"libc",
"linux-raw-sys 0.12.1",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -7575,7 +7590,7 @@ version = "0.8.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1c97747dbf44bb1ca44a561ece23508e99cb592e862f22222dcf42f51d1e451"
dependencies = [
"heck 0.4.1",
"heck 0.5.0",
"proc-macro2",
"quote",
"syn 2.0.117",
@@ -7587,7 +7602,7 @@ version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54254b8531cafa275c5e096f62d48c81435d1015405a91198ddb11e967301d40"
dependencies = [
"heck 0.4.1",
"heck 0.5.0",
"proc-macro2",
"quote",
"syn 2.0.117",
@@ -7610,7 +7625,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e"
dependencies = [
"libc",
"windows-sys 0.60.2",
"windows-sys 0.61.2",
]
[[package]]
@@ -7714,7 +7729,6 @@ dependencies = [
"cfg-if",
"libc",
"psm",
"windows-sys 0.52.0",
"windows-sys 0.59.0",
]
@@ -8075,7 +8089,7 @@ dependencies = [
"getrandom 0.4.2",
"once_cell",
"rustix 1.1.4",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -8880,7 +8894,7 @@ version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
dependencies = [
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]

View File

@@ -15,20 +15,20 @@ categories = ["database-implementations"]
rust-version = "1.91.0"
[workspace.dependencies]
lance = { "version" = "=5.0.0-beta.2", default-features = false, "tag" = "v5.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-core = { "version" = "=5.0.0-beta.2", "tag" = "v5.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-datagen = { "version" = "=5.0.0-beta.2", "tag" = "v5.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-file = { "version" = "=5.0.0-beta.2", "tag" = "v5.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-io = { "version" = "=5.0.0-beta.2", default-features = false, "tag" = "v5.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-index = { "version" = "=5.0.0-beta.2", "tag" = "v5.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-linalg = { "version" = "=5.0.0-beta.2", "tag" = "v5.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace = { "version" = "=5.0.0-beta.2", "tag" = "v5.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace-impls = { "version" = "=5.0.0-beta.2", default-features = false, "tag" = "v5.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-table = { "version" = "=5.0.0-beta.2", "tag" = "v5.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-testing = { "version" = "=5.0.0-beta.2", "tag" = "v5.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-datafusion = { "version" = "=5.0.0-beta.2", "tag" = "v5.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-encoding = { "version" = "=5.0.0-beta.2", "tag" = "v5.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance-arrow = { "version" = "=5.0.0-beta.2", "tag" = "v5.0.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
lance = { version = "=4.0.0", default-features = false }
lance-core = { version = "=4.0.0" }
lance-datagen = { version = "=4.0.0" }
lance-file = { version = "=4.0.0" }
lance-io = { version = "=4.0.0", default-features = false }
lance-index = { version = "=4.0.0" }
lance-linalg = { version = "=4.0.0" }
lance-namespace = { version = "=4.0.0" }
lance-namespace-impls = { version = "=4.0.0", default-features = false }
lance-table = { version = "=4.0.0" }
lance-testing = { version = "=4.0.0" }
lance-datafusion = { version = "=4.0.0" }
lance-encoding = { version = "=4.0.0" }
lance-arrow = { version = "=4.0.0" }
ahash = "0.8"
# Note that this one does not include pyarrow
arrow = { version = "57.2", optional = false }

View File

@@ -28,7 +28,7 @@
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<arrow.version>15.0.0</arrow.version>
<lance-core.version>5.0.0-beta.2</lance-core.version>
<lance-core.version>3.0.1</lance-core.version>
<spotless.skip>false</spotless.skip>
<spotless.version>2.30.0</spotless.version>
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>

View File

@@ -559,7 +559,8 @@ def test_url_retrieve_downloads_image():
matching the real usage pattern in embedding functions.
"""
import io
from PIL import Image
Image = pytest.importorskip("PIL.Image")
from lancedb.embeddings.utils import url_retrieve
image_url = "http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg"

View File

@@ -5,7 +5,7 @@ use std::sync::Arc;
use std::{future::Future, time::Duration};
use arrow::compute::concat_batches;
use arrow_array::{Array, Float16Array, Float32Array, Float64Array, make_array};
use arrow_array::{Array, Float16Array, Float32Array, Float64Array, RecordBatch, make_array};
use arrow_schema::{DataType, SchemaRef};
use datafusion_expr::Expr;
use datafusion_physical_plan::ExecutionPlan;
@@ -17,15 +17,17 @@ use lance_datafusion::exec::execute_plan;
use lance_index::scalar::FullTextSearchQuery;
use lance_index::scalar::inverted::SCORE_COL;
use lance_index::vector::DIST_COL;
use lance_io::stream::RecordBatchStreamAdapter;
use crate::DistanceType;
use crate::error::{Error, Result};
use crate::rerankers::rrf::RRFReranker;
use crate::rerankers::{NormalizeMethod, Reranker, check_reranker_result};
use crate::table::BaseTable;
use crate::utils::TimeoutStream;
use crate::{arrow::SendableRecordBatchStream, table::AnyQuery};
use crate::utils::{MaxBatchLengthStream, TimeoutStream};
use crate::{
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
table::AnyQuery,
};
mod hybrid;
@@ -604,6 +606,14 @@ impl Default for QueryExecutionOptions {
}
}
impl QueryExecutionOptions {
fn without_output_batch_length_limit(&self) -> Self {
let mut options = self.clone();
options.max_batch_length = 0;
options
}
}
/// A trait for a query object that can be executed to get results
///
/// There are various kinds of queries but they all return results
@@ -1180,6 +1190,8 @@ impl VectorQuery {
&self,
options: QueryExecutionOptions,
) -> Result<SendableRecordBatchStream> {
let max_batch_length = options.max_batch_length as usize;
let internal_options = options.without_output_batch_length_limit();
// clone query and specify we want to include row IDs, which can be needed for reranking
let mut fts_query = Query::new(self.parent.clone());
fts_query.request = self.request.base.clone();
@@ -1189,8 +1201,8 @@ impl VectorQuery {
vector_query.request.base.full_text_search = None;
let (fts_results, vec_results) = try_join!(
fts_query.execute_with_options(options.clone()),
vector_query.inner_execute_with_options(options)
fts_query.execute_with_options(internal_options.clone()),
vector_query.inner_execute_with_options(internal_options)
)?;
let (fts_results, vec_results) = try_join!(
@@ -1245,9 +1257,7 @@ impl VectorQuery {
results = results.drop_column(ROW_ID)?;
}
Ok(SendableRecordBatchStream::from(
RecordBatchStreamAdapter::new(results.schema(), stream::iter([Ok(results)])),
))
Ok(single_batch_stream(results, max_batch_length))
}
async fn inner_execute_with_options(
@@ -1256,6 +1266,7 @@ impl VectorQuery {
) -> Result<SendableRecordBatchStream> {
let plan = self.create_plan(options.clone()).await?;
let inner = execute_plan(plan, Default::default())?;
let inner = MaxBatchLengthStream::new_boxed(inner, options.max_batch_length as usize);
let inner = if let Some(timeout) = options.timeout {
TimeoutStream::new_boxed(inner, timeout)
} else {
@@ -1265,6 +1276,25 @@ impl VectorQuery {
}
}
fn single_batch_stream(batch: RecordBatch, max_batch_length: usize) -> SendableRecordBatchStream {
let schema = batch.schema();
if max_batch_length == 0 || batch.num_rows() <= max_batch_length {
return Box::pin(SimpleRecordBatchStream::new(
stream::iter([Ok(batch)]),
schema,
));
}
let mut batches = Vec::with_capacity(batch.num_rows().div_ceil(max_batch_length));
let mut offset = 0;
while offset < batch.num_rows() {
let length = (batch.num_rows() - offset).min(max_batch_length);
batches.push(Ok(batch.slice(offset, length)));
offset += length;
}
Box::pin(SimpleRecordBatchStream::new(stream::iter(batches), schema))
}
impl ExecutableQuery for VectorQuery {
async fn create_plan(&self, options: QueryExecutionOptions) -> Result<Arc<dyn ExecutionPlan>> {
let query = AnyQuery::VectorQuery(self.request.clone());
@@ -1753,6 +1783,50 @@ mod tests {
.unwrap()
}
async fn make_large_vector_table(tmp_dir: &tempfile::TempDir, rows: usize) -> Table {
let dataset_path = tmp_dir.path().join("large_test.lance");
let uri = dataset_path.to_str().unwrap();
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("id", DataType::Utf8, false),
ArrowField::new(
"vector",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float32, true)),
4,
),
false,
),
]));
let ids = StringArray::from_iter_values((0..rows).map(|i| format!("row-{i}")));
let vectors = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
(0..rows).map(|i| Some(vec![Some(i as f32), Some(1.0), Some(2.0), Some(3.0)])),
4,
);
let batch =
RecordBatch::try_new(schema.clone(), vec![Arc::new(ids), Arc::new(vectors)]).unwrap();
let conn = connect(uri).execute().await.unwrap();
conn.create_table("my_table", vec![batch])
.execute()
.await
.unwrap()
}
async fn assert_stream_batches_at_most(
mut results: SendableRecordBatchStream,
max_batch_length: usize,
) {
let mut saw_batch = false;
while let Some(batch) = results.next().await {
let batch = batch.unwrap();
saw_batch = true;
assert!(batch.num_rows() <= max_batch_length);
}
assert!(saw_batch);
}
#[tokio::test]
async fn test_execute_with_options() {
let tmp_dir = tempdir().unwrap();
@@ -1772,6 +1846,83 @@ mod tests {
}
}
#[tokio::test]
async fn test_vector_query_execute_with_options_respects_max_batch_length() {
let tmp_dir = tempdir().unwrap();
let table = make_large_vector_table(&tmp_dir, 10_000).await;
let results = table
.query()
.nearest_to(vec![0.0, 1.0, 2.0, 3.0])
.unwrap()
.limit(10_000)
.execute_with_options(QueryExecutionOptions {
max_batch_length: 100,
..Default::default()
})
.await
.unwrap();
assert_stream_batches_at_most(results, 100).await;
}
#[tokio::test]
async fn test_hybrid_query_execute_with_options_respects_max_batch_length() {
let tmp_dir = tempdir().unwrap();
let dataset_path = tmp_dir.path();
let conn = connect(dataset_path.to_str().unwrap())
.execute()
.await
.unwrap();
let dims = 2;
let rows = 512;
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("text", DataType::Utf8, false),
ArrowField::new(
"vector",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float32, true)),
dims,
),
false,
),
]));
let text = StringArray::from_iter_values((0..rows).map(|_| "match"));
let vectors = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
(0..rows).map(|i| Some(vec![Some(i as f32), Some(0.0)])),
dims,
);
let record_batch =
RecordBatch::try_new(schema.clone(), vec![Arc::new(text), Arc::new(vectors)]).unwrap();
let table = conn
.create_table("my_table", record_batch)
.execute()
.await
.unwrap();
table
.create_index(&["text"], crate::index::Index::FTS(Default::default()))
.replace(true)
.execute()
.await
.unwrap();
let results = table
.query()
.full_text_search(FullTextSearchQuery::new("match".to_string()))
.limit(rows)
.nearest_to(&[0.0, 0.0])
.unwrap()
.execute_with_options(QueryExecutionOptions {
max_batch_length: 100,
..Default::default()
})
.await
.unwrap();
assert_stream_batches_at_most(results, 100).await;
}
#[tokio::test]
async fn test_analyze_plan() {
let tmp_dir = tempdir().unwrap();

View File

@@ -19,11 +19,11 @@ pub use lance::dataset::Version;
use lance::dataset::WriteMode;
use lance::dataset::builder::DatasetBuilder;
use lance::dataset::{InsertBuilder, WriteParams};
use lance::index::DatasetIndexExt;
use lance::index::vector::VectorIndexParams;
use lance::index::vector::utils::infer_vector_dim;
use lance::io::{ObjectStoreParams, WrappingObjectStore};
use lance_datafusion::utils::StreamingWriteSource;
use lance_index::DatasetIndexExt;
use lance_index::IndexType;
use lance_index::scalar::{BuiltinIndexType, ScalarIndexParams};
use lance_index::vector::bq::RQBuildParams;

View File

@@ -10,7 +10,7 @@ use std::sync::Arc;
use lance::dataset::cleanup::RemovalStats;
use lance::dataset::optimize::{CompactionMetrics, IndexRemapperOptions, compact_files};
use lance::index::DatasetIndexExt;
use lance_index::DatasetIndexExt;
use lance_index::optimize::OptimizeOptions;
use log::info;

View File

@@ -9,7 +9,7 @@ use crate::expr::expr_to_sql_string;
use crate::query::{
DEFAULT_TOP_K, QueryExecutionOptions, QueryFilter, QueryRequest, Select, VectorQueryRequest,
};
use crate::utils::{TimeoutStream, default_vector_column};
use crate::utils::{MaxBatchLengthStream, TimeoutStream, default_vector_column};
use arrow::array::{AsArray, FixedSizeListBuilder, Float32Builder};
use arrow::datatypes::{Float32Type, UInt8Type};
use arrow_array::Array;
@@ -66,6 +66,7 @@ async fn execute_generic_query(
) -> Result<DatasetRecordBatchStream> {
let plan = create_plan(table, query, options.clone()).await?;
let inner = execute_plan(plan, Default::default())?;
let inner = MaxBatchLengthStream::new_boxed(inner, options.max_batch_length as usize);
let inner = if let Some(timeout) = options.timeout {
TimeoutStream::new_boxed(inner, timeout)
} else {
@@ -200,7 +201,9 @@ pub async fn create_plan(
scanner.with_row_id();
}
scanner.batch_size(options.max_batch_length as usize);
if options.max_batch_length > 0 {
scanner.batch_size(options.max_batch_length as usize);
}
if query.base.fast_search {
scanner.fast_search();

View File

@@ -335,6 +335,85 @@ impl Stream for TimeoutStream {
}
}
/// A `Stream` wrapper that slices oversized batches to enforce a maximum batch length.
pub struct MaxBatchLengthStream {
inner: SendableRecordBatchStream,
max_batch_length: Option<usize>,
buffered_batch: Option<RecordBatch>,
buffered_offset: usize,
}
impl MaxBatchLengthStream {
pub fn new(inner: SendableRecordBatchStream, max_batch_length: usize) -> Self {
Self {
inner,
max_batch_length: (max_batch_length > 0).then_some(max_batch_length),
buffered_batch: None,
buffered_offset: 0,
}
}
pub fn new_boxed(
inner: SendableRecordBatchStream,
max_batch_length: usize,
) -> SendableRecordBatchStream {
if max_batch_length == 0 {
inner
} else {
Box::pin(Self::new(inner, max_batch_length))
}
}
}
impl RecordBatchStream for MaxBatchLengthStream {
fn schema(&self) -> SchemaRef {
self.inner.schema()
}
}
impl Stream for MaxBatchLengthStream {
type Item = DataFusionResult<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
loop {
let Some(max_batch_length) = self.max_batch_length else {
return Pin::new(&mut self.inner).poll_next(cx);
};
if let Some(batch) = self.buffered_batch.clone() {
if self.buffered_offset < batch.num_rows() {
let remaining = batch.num_rows() - self.buffered_offset;
let length = remaining.min(max_batch_length);
let sliced = batch.slice(self.buffered_offset, length);
self.buffered_offset += length;
if self.buffered_offset >= batch.num_rows() {
self.buffered_batch = None;
self.buffered_offset = 0;
}
return std::task::Poll::Ready(Some(Ok(sliced)));
}
self.buffered_batch = None;
self.buffered_offset = 0;
}
match Pin::new(&mut self.inner).poll_next(cx) {
std::task::Poll::Ready(Some(Ok(batch))) => {
if batch.num_rows() <= max_batch_length {
return std::task::Poll::Ready(Some(Ok(batch)));
}
self.buffered_batch = Some(batch);
self.buffered_offset = 0;
}
other => return other,
}
}
}
}
#[cfg(test)]
mod tests {
use arrow_array::Int32Array;
@@ -470,7 +549,7 @@ mod tests {
assert_eq!(string_to_datatype(string), Some(expected));
}
fn sample_batch() -> RecordBatch {
fn sample_batch(num_rows: i32) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new(
"col1",
DataType::Int32,
@@ -478,14 +557,14 @@ mod tests {
)]));
RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
vec![Arc::new(Int32Array::from_iter_values(0..num_rows))],
)
.unwrap()
}
#[tokio::test]
async fn test_timeout_stream() {
let batch = sample_batch();
let batch = sample_batch(3);
let schema = batch.schema();
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
@@ -515,7 +594,7 @@ mod tests {
#[tokio::test]
async fn test_timeout_stream_zero_duration() {
let batch = sample_batch();
let batch = sample_batch(3);
let schema = batch.schema();
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
@@ -534,7 +613,7 @@ mod tests {
#[tokio::test]
async fn test_timeout_stream_completes_normally() {
let batch = sample_batch();
let batch = sample_batch(3);
let schema = batch.schema();
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
@@ -552,4 +631,35 @@ mod tests {
// Stream should be empty now
assert!(timeout_stream.next().await.is_none());
}
async fn collect_batch_sizes(
stream: SendableRecordBatchStream,
max_batch_length: usize,
) -> Vec<usize> {
let mut sliced_stream = MaxBatchLengthStream::new(stream, max_batch_length);
sliced_stream
.by_ref()
.map(|batch| batch.unwrap().num_rows())
.collect::<Vec<_>>()
.await
}
#[tokio::test]
async fn test_max_batch_length_stream_behaviors() {
let schema = sample_batch(7).schema();
let mock_stream = stream::iter(vec![Ok(sample_batch(2)), Ok(sample_batch(7))]);
let sendable_stream: SendableRecordBatchStream =
Box::pin(RecordBatchStreamAdapter::new(schema.clone(), mock_stream));
assert_eq!(
collect_batch_sizes(sendable_stream, 3).await,
vec![2, 3, 3, 1]
);
let sendable_stream: SendableRecordBatchStream = Box::pin(RecordBatchStreamAdapter::new(
schema,
stream::iter(vec![Ok(sample_batch(2)), Ok(sample_batch(7))]),
));
assert_eq!(collect_batch_sizes(sendable_stream, 0).await, vec![2, 7]);
}
}