mirror of
https://github.com/lancedb/lancedb.git
synced 2026-03-28 11:30:39 +00:00
Compare commits
2 Commits
python-v0.
...
codex/upda
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b2653e5524 | ||
|
|
7324bcec84 |
66
Cargo.lock
generated
66
Cargo.lock
generated
@@ -3078,8 +3078,8 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
|
||||
|
||||
[[package]]
|
||||
name = "fsst"
|
||||
version = "3.1.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.2#ae3b1f413cc49d783f51abe62c8261c106c9b6cd"
|
||||
version = "3.1.0-beta.1"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.1#c36a4d9071b92c81c1d5e699ffeef8598ac48d78"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"rand 0.9.2",
|
||||
@@ -4226,8 +4226,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance"
|
||||
version = "3.1.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.2#ae3b1f413cc49d783f51abe62c8261c106c9b6cd"
|
||||
version = "3.1.0-beta.1"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.1#c36a4d9071b92c81c1d5e699ffeef8598ac48d78"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-arith",
|
||||
@@ -4293,8 +4293,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-arrow"
|
||||
version = "3.1.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.2#ae3b1f413cc49d783f51abe62c8261c106c9b6cd"
|
||||
version = "3.1.0-beta.1"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.1#c36a4d9071b92c81c1d5e699ffeef8598ac48d78"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4313,8 +4313,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-bitpacking"
|
||||
version = "3.1.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.2#ae3b1f413cc49d783f51abe62c8261c106c9b6cd"
|
||||
version = "3.1.0-beta.1"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.1#c36a4d9071b92c81c1d5e699ffeef8598ac48d78"
|
||||
dependencies = [
|
||||
"arrayref",
|
||||
"paste",
|
||||
@@ -4323,8 +4323,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-core"
|
||||
version = "3.1.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.2#ae3b1f413cc49d783f51abe62c8261c106c9b6cd"
|
||||
version = "3.1.0-beta.1"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.1#c36a4d9071b92c81c1d5e699ffeef8598ac48d78"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4361,8 +4361,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-datafusion"
|
||||
version = "3.1.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.2#ae3b1f413cc49d783f51abe62c8261c106c9b6cd"
|
||||
version = "3.1.0-beta.1"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.1#c36a4d9071b92c81c1d5e699ffeef8598ac48d78"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -4392,8 +4392,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-datagen"
|
||||
version = "3.1.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.2#ae3b1f413cc49d783f51abe62c8261c106c9b6cd"
|
||||
version = "3.1.0-beta.1"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.1#c36a4d9071b92c81c1d5e699ffeef8598ac48d78"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -4411,8 +4411,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-encoding"
|
||||
version = "3.1.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.2#ae3b1f413cc49d783f51abe62c8261c106c9b6cd"
|
||||
version = "3.1.0-beta.1"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.1#c36a4d9071b92c81c1d5e699ffeef8598ac48d78"
|
||||
dependencies = [
|
||||
"arrow-arith",
|
||||
"arrow-array",
|
||||
@@ -4449,8 +4449,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-file"
|
||||
version = "3.1.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.2#ae3b1f413cc49d783f51abe62c8261c106c9b6cd"
|
||||
version = "3.1.0-beta.1"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.1#c36a4d9071b92c81c1d5e699ffeef8598ac48d78"
|
||||
dependencies = [
|
||||
"arrow-arith",
|
||||
"arrow-array",
|
||||
@@ -4482,8 +4482,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-index"
|
||||
version = "3.1.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.2#ae3b1f413cc49d783f51abe62c8261c106c9b6cd"
|
||||
version = "3.1.0-beta.1"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.1#c36a4d9071b92c81c1d5e699ffeef8598ac48d78"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-arith",
|
||||
@@ -4546,8 +4546,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-io"
|
||||
version = "3.1.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.2#ae3b1f413cc49d783f51abe62c8261c106c9b6cd"
|
||||
version = "3.1.0-beta.1"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.1#c36a4d9071b92c81c1d5e699ffeef8598ac48d78"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-arith",
|
||||
@@ -4587,8 +4587,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-linalg"
|
||||
version = "3.1.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.2#ae3b1f413cc49d783f51abe62c8261c106c9b6cd"
|
||||
version = "3.1.0-beta.1"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.1#c36a4d9071b92c81c1d5e699ffeef8598ac48d78"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4604,8 +4604,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-namespace"
|
||||
version = "3.1.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.2#ae3b1f413cc49d783f51abe62c8261c106c9b6cd"
|
||||
version = "3.1.0-beta.1"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.1#c36a4d9071b92c81c1d5e699ffeef8598ac48d78"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"async-trait",
|
||||
@@ -4617,8 +4617,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-namespace-impls"
|
||||
version = "3.1.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.2#ae3b1f413cc49d783f51abe62c8261c106c9b6cd"
|
||||
version = "3.1.0-beta.1"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.1#c36a4d9071b92c81c1d5e699ffeef8598ac48d78"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-ipc",
|
||||
@@ -4662,8 +4662,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-table"
|
||||
version = "3.1.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.2#ae3b1f413cc49d783f51abe62c8261c106c9b6cd"
|
||||
version = "3.1.0-beta.1"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.1#c36a4d9071b92c81c1d5e699ffeef8598ac48d78"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -4702,8 +4702,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-testing"
|
||||
version = "3.1.0-beta.2"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.2#ae3b1f413cc49d783f51abe62c8261c106c9b6cd"
|
||||
version = "3.1.0-beta.1"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v3.1.0-beta.1#c36a4d9071b92c81c1d5e699ffeef8598ac48d78"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-schema",
|
||||
@@ -4744,10 +4744,8 @@ dependencies = [
|
||||
"datafusion-common",
|
||||
"datafusion-execution",
|
||||
"datafusion-expr",
|
||||
"datafusion-functions",
|
||||
"datafusion-physical-expr",
|
||||
"datafusion-physical-plan",
|
||||
"datafusion-sql",
|
||||
"futures",
|
||||
"half",
|
||||
"hf-hub",
|
||||
|
||||
28
Cargo.toml
28
Cargo.toml
@@ -15,20 +15,20 @@ categories = ["database-implementations"]
|
||||
rust-version = "1.91.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=3.1.0-beta.2", default-features = false, "tag" = "v3.1.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-core = { "version" = "=3.1.0-beta.2", "tag" = "v3.1.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datagen = { "version" = "=3.1.0-beta.2", "tag" = "v3.1.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-file = { "version" = "=3.1.0-beta.2", "tag" = "v3.1.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-io = { "version" = "=3.1.0-beta.2", default-features = false, "tag" = "v3.1.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-index = { "version" = "=3.1.0-beta.2", "tag" = "v3.1.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-linalg = { "version" = "=3.1.0-beta.2", "tag" = "v3.1.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace = { "version" = "=3.1.0-beta.2", "tag" = "v3.1.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace-impls = { "version" = "=3.1.0-beta.2", default-features = false, "tag" = "v3.1.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-table = { "version" = "=3.1.0-beta.2", "tag" = "v3.1.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-testing = { "version" = "=3.1.0-beta.2", "tag" = "v3.1.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datafusion = { "version" = "=3.1.0-beta.2", "tag" = "v3.1.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-encoding = { "version" = "=3.1.0-beta.2", "tag" = "v3.1.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-arrow = { "version" = "=3.1.0-beta.2", "tag" = "v3.1.0-beta.2", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance = { "version" = "=3.1.0-beta.1", default-features = false, "tag" = "v3.1.0-beta.1", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-core = { "version" = "=3.1.0-beta.1", "tag" = "v3.1.0-beta.1", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datagen = { "version" = "=3.1.0-beta.1", "tag" = "v3.1.0-beta.1", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-file = { "version" = "=3.1.0-beta.1", "tag" = "v3.1.0-beta.1", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-io = { "version" = "=3.1.0-beta.1", default-features = false, "tag" = "v3.1.0-beta.1", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-index = { "version" = "=3.1.0-beta.1", "tag" = "v3.1.0-beta.1", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-linalg = { "version" = "=3.1.0-beta.1", "tag" = "v3.1.0-beta.1", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace = { "version" = "=3.1.0-beta.1", "tag" = "v3.1.0-beta.1", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace-impls = { "version" = "=3.1.0-beta.1", default-features = false, "tag" = "v3.1.0-beta.1", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-table = { "version" = "=3.1.0-beta.1", "tag" = "v3.1.0-beta.1", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-testing = { "version" = "=3.1.0-beta.1", "tag" = "v3.1.0-beta.1", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datafusion = { "version" = "=3.1.0-beta.1", "tag" = "v3.1.0-beta.1", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-encoding = { "version" = "=3.1.0-beta.1", "tag" = "v3.1.0-beta.1", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-arrow = { "version" = "=3.1.0-beta.1", "tag" = "v3.1.0-beta.1", "git" = "https://github.com/lance-format/lance.git" }
|
||||
ahash = "0.8"
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "57.2", optional = false }
|
||||
|
||||
@@ -28,7 +28,7 @@
|
||||
<properties>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<arrow.version>15.0.0</arrow.version>
|
||||
<lance-core.version>3.1.0-beta.2</lance-core.version>
|
||||
<lance-core.version>3.1.0-beta.1</lance-core.version>
|
||||
<spotless.skip>false</spotless.skip>
|
||||
<spotless.version>2.30.0</spotless.version>
|
||||
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.30.0-beta.2"
|
||||
current_version = "0.30.0-beta.1"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.30.0-beta.2"
|
||||
version = "0.30.0-beta.1"
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
license.workspace = true
|
||||
|
||||
@@ -45,7 +45,7 @@ repository = "https://github.com/lancedb/lancedb"
|
||||
|
||||
[project.optional-dependencies]
|
||||
pylance = [
|
||||
"pylance>=1.0.0b14",
|
||||
"pylance>=3.1.0b1",
|
||||
]
|
||||
tests = [
|
||||
"aiohttp",
|
||||
@@ -59,9 +59,9 @@ tests = [
|
||||
"polars>=0.19, <=1.3.0",
|
||||
"tantivy",
|
||||
"pyarrow-stubs",
|
||||
"pylance>=1.0.0b14",
|
||||
"pylance>=3.1.0b1",
|
||||
"requests",
|
||||
"datafusion<52",
|
||||
"datafusion>=51,<52", # Must match pylance's DataFusion version
|
||||
]
|
||||
dev = [
|
||||
"ruff",
|
||||
|
||||
@@ -25,9 +25,7 @@ datafusion-catalog.workspace = true
|
||||
datafusion-common.workspace = true
|
||||
datafusion-execution.workspace = true
|
||||
datafusion-expr.workspace = true
|
||||
datafusion-functions = "51.0"
|
||||
datafusion-physical-expr.workspace = true
|
||||
datafusion-sql = "51.0"
|
||||
datafusion-physical-plan.workspace = true
|
||||
datafusion.workspace = true
|
||||
object_store = { workspace = true }
|
||||
|
||||
@@ -9,6 +9,13 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::{ArrayRef, RecordBatch, RecordBatchIterator, RecordBatchReader};
|
||||
use arrow_schema::{ArrowError, SchemaRef};
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::once;
|
||||
use futures::StreamExt;
|
||||
use lance_datafusion::utils::StreamingWriteSource;
|
||||
|
||||
use crate::arrow::{
|
||||
SendableRecordBatchStream, SendableRecordBatchStreamExt, SimpleRecordBatchStream,
|
||||
};
|
||||
@@ -18,12 +25,6 @@ use crate::embeddings::{
|
||||
};
|
||||
use crate::table::{ColumnDefinition, ColumnKind, TableDefinition};
|
||||
use crate::{Error, Result};
|
||||
use arrow_array::{ArrayRef, RecordBatch, RecordBatchIterator, RecordBatchReader};
|
||||
use arrow_schema::{ArrowError, SchemaRef};
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::once;
|
||||
use futures::StreamExt;
|
||||
use lance_datafusion::utils::StreamingWriteSource;
|
||||
|
||||
pub trait Scannable: Send {
|
||||
/// Returns the schema of the data.
|
||||
@@ -348,133 +349,6 @@ pub fn scannable_with_embeddings(
|
||||
Ok(inner)
|
||||
}
|
||||
|
||||
/// A wrapper that buffers the first RecordBatch from a Scannable so we can
|
||||
/// inspect it (e.g. to estimate data size) without losing it.
|
||||
pub(crate) struct PeekedScannable {
|
||||
inner: Box<dyn Scannable>,
|
||||
peeked: Option<RecordBatch>,
|
||||
/// The first item from the stream, if it was an error. Stored so we can
|
||||
/// re-emit it from `scan_as_stream` instead of silently dropping it.
|
||||
first_error: Option<crate::Error>,
|
||||
stream: Option<SendableRecordBatchStream>,
|
||||
}
|
||||
|
||||
impl PeekedScannable {
|
||||
pub fn new(inner: Box<dyn Scannable>) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
peeked: None,
|
||||
first_error: None,
|
||||
stream: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Reads and buffers the first batch from the inner scannable.
|
||||
/// Returns a clone of it. Subsequent calls return the same batch.
|
||||
///
|
||||
/// Returns `None` if the stream is empty or the first item is an error.
|
||||
/// Errors are preserved and re-emitted by `scan_as_stream`.
|
||||
pub async fn peek(&mut self) -> Option<RecordBatch> {
|
||||
if self.peeked.is_some() {
|
||||
return self.peeked.clone();
|
||||
}
|
||||
// Already peeked and got an error or empty stream.
|
||||
if self.stream.is_some() || self.first_error.is_some() {
|
||||
return None;
|
||||
}
|
||||
let mut stream = self.inner.scan_as_stream();
|
||||
match stream.next().await {
|
||||
Some(Ok(batch)) => {
|
||||
self.peeked = Some(batch.clone());
|
||||
self.stream = Some(stream);
|
||||
Some(batch)
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
self.first_error = Some(e);
|
||||
self.stream = Some(stream);
|
||||
None
|
||||
}
|
||||
None => {
|
||||
self.stream = Some(stream);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Scannable for PeekedScannable {
|
||||
fn schema(&self) -> SchemaRef {
|
||||
self.inner.schema()
|
||||
}
|
||||
|
||||
fn num_rows(&self) -> Option<usize> {
|
||||
self.inner.num_rows()
|
||||
}
|
||||
|
||||
fn rescannable(&self) -> bool {
|
||||
self.inner.rescannable()
|
||||
}
|
||||
|
||||
fn scan_as_stream(&mut self) -> SendableRecordBatchStream {
|
||||
let schema = self.inner.schema();
|
||||
|
||||
// If peek() hit an error, prepend it so downstream sees the error.
|
||||
let error_item = self.first_error.take().map(Err);
|
||||
|
||||
match (self.peeked.take(), self.stream.take()) {
|
||||
(Some(batch), Some(rest)) => {
|
||||
let prepend = futures::stream::once(std::future::ready(Ok(batch)));
|
||||
Box::pin(SimpleRecordBatchStream {
|
||||
schema,
|
||||
stream: prepend.chain(rest),
|
||||
})
|
||||
}
|
||||
(Some(batch), None) => Box::pin(SimpleRecordBatchStream {
|
||||
schema,
|
||||
stream: futures::stream::once(std::future::ready(Ok(batch))),
|
||||
}),
|
||||
(None, Some(rest)) => {
|
||||
if let Some(err) = error_item {
|
||||
let stream = futures::stream::once(std::future::ready(err));
|
||||
Box::pin(SimpleRecordBatchStream { schema, stream })
|
||||
} else {
|
||||
rest
|
||||
}
|
||||
}
|
||||
(None, None) => {
|
||||
// peek() was never called — just delegate
|
||||
self.inner.scan_as_stream()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the number of write partitions based on data size estimates.
|
||||
///
|
||||
/// `sample_bytes` and `sample_rows` come from a representative batch and are
|
||||
/// used to estimate per-row size. `total_rows_hint` is the total row count
|
||||
/// when known; otherwise `sample_rows` row count is used as a lower bound
|
||||
/// estimate.
|
||||
///
|
||||
/// Targets roughly 1 million rows or 2 GB per partition, capped at
|
||||
/// `max_partitions` (typically the number of available CPU cores).
|
||||
pub(crate) fn estimate_write_partitions(
|
||||
sample_bytes: usize,
|
||||
sample_rows: usize,
|
||||
total_rows_hint: Option<usize>,
|
||||
max_partitions: usize,
|
||||
) -> usize {
|
||||
if sample_rows == 0 {
|
||||
return 1;
|
||||
}
|
||||
let bytes_per_row = sample_bytes / sample_rows;
|
||||
let total_rows = total_rows_hint.unwrap_or(sample_rows);
|
||||
let total_bytes = total_rows * bytes_per_row;
|
||||
let by_rows = total_rows.div_ceil(1_000_000);
|
||||
let by_bytes = total_bytes.div_ceil(2 * 1024 * 1024 * 1024);
|
||||
by_rows.max(by_bytes).max(1).min(max_partitions)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -571,231 +445,6 @@ mod tests {
|
||||
assert!(result2.unwrap().is_err());
|
||||
}
|
||||
|
||||
mod peeked_scannable_tests {
|
||||
use crate::test_utils::TestCustomError;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_peek_returns_first_batch() {
|
||||
let batch = record_batch!(("id", Int64, [1, 2, 3])).unwrap();
|
||||
let mut peeked = PeekedScannable::new(Box::new(batch.clone()));
|
||||
|
||||
let first = peeked.peek().await.unwrap();
|
||||
assert_eq!(first, batch);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_peek_is_idempotent() {
|
||||
let batch = record_batch!(("id", Int64, [1, 2, 3])).unwrap();
|
||||
let mut peeked = PeekedScannable::new(Box::new(batch.clone()));
|
||||
|
||||
let first = peeked.peek().await.unwrap();
|
||||
let second = peeked.peek().await.unwrap();
|
||||
assert_eq!(first, second);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_scan_after_peek_returns_all_data() {
|
||||
let batches = vec![
|
||||
record_batch!(("id", Int64, [1, 2])).unwrap(),
|
||||
record_batch!(("id", Int64, [3, 4, 5])).unwrap(),
|
||||
];
|
||||
let mut peeked = PeekedScannable::new(Box::new(batches.clone()));
|
||||
|
||||
let first = peeked.peek().await.unwrap();
|
||||
assert_eq!(first, batches[0]);
|
||||
|
||||
let result: Vec<RecordBatch> = peeked.scan_as_stream().try_collect().await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0], batches[0]);
|
||||
assert_eq!(result[1], batches[1]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_scan_without_peek_passes_through() {
|
||||
let batch = record_batch!(("id", Int64, [1, 2, 3])).unwrap();
|
||||
let mut peeked = PeekedScannable::new(Box::new(batch.clone()));
|
||||
|
||||
let result: Vec<RecordBatch> = peeked.scan_as_stream().try_collect().await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0], batch);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delegates_num_rows() {
|
||||
let batches = vec![
|
||||
record_batch!(("id", Int64, [1, 2])).unwrap(),
|
||||
record_batch!(("id", Int64, [3])).unwrap(),
|
||||
];
|
||||
let peeked = PeekedScannable::new(Box::new(batches));
|
||||
assert_eq!(peeked.num_rows(), Some(3));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_non_rescannable_stream_data_preserved() {
|
||||
let batches = vec![
|
||||
record_batch!(("id", Int64, [1, 2])).unwrap(),
|
||||
record_batch!(("id", Int64, [3])).unwrap(),
|
||||
];
|
||||
let schema = batches[0].schema();
|
||||
let inner = futures::stream::iter(batches.clone().into_iter().map(Ok));
|
||||
let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream {
|
||||
schema,
|
||||
stream: inner,
|
||||
});
|
||||
|
||||
let mut peeked = PeekedScannable::new(Box::new(stream));
|
||||
assert!(!peeked.rescannable());
|
||||
assert_eq!(peeked.num_rows(), None);
|
||||
|
||||
let first = peeked.peek().await.unwrap();
|
||||
assert_eq!(first, batches[0]);
|
||||
|
||||
// All data is still available via scan_as_stream
|
||||
let result: Vec<RecordBatch> = peeked.scan_as_stream().try_collect().await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0], batches[0]);
|
||||
assert_eq!(result[1], batches[1]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_error_in_first_batch_propagates() {
|
||||
let schema = Arc::new(arrow_schema::Schema::new(vec![arrow_schema::Field::new(
|
||||
"id",
|
||||
arrow_schema::DataType::Int64,
|
||||
false,
|
||||
)]));
|
||||
let inner = futures::stream::iter(vec![Err(Error::External {
|
||||
source: Box::new(TestCustomError),
|
||||
})]);
|
||||
let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream {
|
||||
schema,
|
||||
stream: inner,
|
||||
});
|
||||
|
||||
let mut peeked = PeekedScannable::new(Box::new(stream));
|
||||
|
||||
// peek returns None for errors
|
||||
assert!(peeked.peek().await.is_none());
|
||||
|
||||
// But the error should come through when scanning
|
||||
let mut stream = peeked.scan_as_stream();
|
||||
let first = stream.next().await.unwrap();
|
||||
assert!(first.is_err());
|
||||
let err = first.unwrap_err();
|
||||
assert!(
|
||||
matches!(&err, Error::External { source } if source.downcast_ref::<TestCustomError>().is_some()),
|
||||
"Expected TestCustomError to be preserved, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_error_in_later_batch_propagates() {
|
||||
let good_batch = record_batch!(("id", Int64, [1, 2])).unwrap();
|
||||
let schema = good_batch.schema();
|
||||
let inner = futures::stream::iter(vec![
|
||||
Ok(good_batch.clone()),
|
||||
Err(Error::External {
|
||||
source: Box::new(TestCustomError),
|
||||
}),
|
||||
]);
|
||||
let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream {
|
||||
schema,
|
||||
stream: inner,
|
||||
});
|
||||
|
||||
let mut peeked = PeekedScannable::new(Box::new(stream));
|
||||
|
||||
// peek succeeds with the first batch
|
||||
let first = peeked.peek().await.unwrap();
|
||||
assert_eq!(first, good_batch);
|
||||
|
||||
// scan_as_stream should yield the first batch, then the error
|
||||
let mut stream = peeked.scan_as_stream();
|
||||
let batch1 = stream.next().await.unwrap().unwrap();
|
||||
assert_eq!(batch1, good_batch);
|
||||
|
||||
let batch2 = stream.next().await.unwrap();
|
||||
assert!(batch2.is_err());
|
||||
let err = batch2.unwrap_err();
|
||||
assert!(
|
||||
matches!(&err, Error::External { source } if source.downcast_ref::<TestCustomError>().is_some()),
|
||||
"Expected TestCustomError to be preserved, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_empty_stream_returns_none() {
|
||||
let schema = Arc::new(arrow_schema::Schema::new(vec![arrow_schema::Field::new(
|
||||
"id",
|
||||
arrow_schema::DataType::Int64,
|
||||
false,
|
||||
)]));
|
||||
let inner = futures::stream::empty();
|
||||
let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream {
|
||||
schema,
|
||||
stream: inner,
|
||||
});
|
||||
|
||||
let mut peeked = PeekedScannable::new(Box::new(stream));
|
||||
assert!(peeked.peek().await.is_none());
|
||||
|
||||
// Scanning an empty (post-peek) stream should yield nothing
|
||||
let result: Vec<RecordBatch> = peeked.scan_as_stream().try_collect().await.unwrap();
|
||||
assert!(result.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
mod estimate_write_partitions_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_small_data_single_partition() {
|
||||
// 100 rows * 24 bytes/row = 2400 bytes — well under both thresholds
|
||||
assert_eq!(estimate_write_partitions(2400, 100, Some(100), 8), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scales_by_row_count() {
|
||||
// 2.5M rows at 24 bytes/row — row threshold dominates
|
||||
// ceil(2_500_000 / 1_000_000) = 3
|
||||
assert_eq!(estimate_write_partitions(72, 3, Some(2_500_000), 8), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scales_by_byte_size() {
|
||||
// 100k rows at 40KB/row = ~4GB total → ceil(4GB / 2GB) = 2
|
||||
let sample_bytes = 40_000 * 10;
|
||||
assert_eq!(
|
||||
estimate_write_partitions(sample_bytes, 10, Some(100_000), 8),
|
||||
2
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_capped_at_max_partitions() {
|
||||
// 10M rows would want 10 partitions, but capped at 4
|
||||
assert_eq!(estimate_write_partitions(72, 3, Some(10_000_000), 4), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zero_sample_rows_returns_one() {
|
||||
assert_eq!(estimate_write_partitions(0, 0, Some(1_000_000), 8), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_row_hint_uses_sample_size() {
|
||||
// Without a hint, uses sample_rows (3), which is small
|
||||
assert_eq!(estimate_write_partitions(72, 3, None, 8), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_always_at_least_one() {
|
||||
assert_eq!(estimate_write_partitions(24, 1, Some(1), 8), 1);
|
||||
}
|
||||
}
|
||||
|
||||
mod embedding_tests {
|
||||
use super::*;
|
||||
use crate::embeddings::MemoryRegistry;
|
||||
|
||||
@@ -426,7 +426,6 @@ impl PermutationReader {
|
||||
row_ids_query = row_ids_query.limit(limit as usize);
|
||||
}
|
||||
let mut row_ids = row_ids_query.execute().await?;
|
||||
let mut idx_offset = 0;
|
||||
while let Some(batch) = row_ids.try_next().await? {
|
||||
let row_ids = batch
|
||||
.column(0)
|
||||
@@ -434,9 +433,8 @@ impl PermutationReader {
|
||||
.values()
|
||||
.to_vec();
|
||||
for (i, row_id) in row_ids.iter().enumerate() {
|
||||
offset_map.insert(i as u64 + idx_offset, *row_id);
|
||||
offset_map.insert(i as u64, *row_id);
|
||||
}
|
||||
idx_offset += batch.num_rows() as u64;
|
||||
}
|
||||
let offset_map = Arc::new(offset_map);
|
||||
*offset_map_ref = Some(offset_map.clone());
|
||||
@@ -847,106 +845,4 @@ mod tests {
|
||||
.to_vec();
|
||||
assert_eq!(idx_values, vec![row_ids[2] as i32]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_filtered_permutation_full_iteration() {
|
||||
use crate::dataloader::permutation::builder::PermutationBuilder;
|
||||
|
||||
// Create a base table with 10000 rows where idx goes 0..10000.
|
||||
// Filter to even values only, giving 5000 rows in the permutation.
|
||||
let base_table = lance_datagen::gen_batch()
|
||||
.col("idx", lance_datagen::array::step::<Int32Type>())
|
||||
.into_mem_table("tbl", RowCount::from(10000), BatchCount::from(1))
|
||||
.await;
|
||||
|
||||
let permutation_table = PermutationBuilder::new(base_table.clone())
|
||||
.with_filter("idx % 2 = 0".to_string())
|
||||
.build()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(permutation_table.count_rows(None).await.unwrap(), 5000);
|
||||
|
||||
let reader = PermutationReader::try_from_tables(
|
||||
base_table.base_table().clone(),
|
||||
permutation_table.base_table().clone(),
|
||||
0,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(reader.count_rows(), 5000);
|
||||
|
||||
// Iterate through all batches using a batch size that doesn't evenly divide
|
||||
// the row count (5000 / 128 = 39 full batches + 1 batch of 8 rows).
|
||||
let batch_size = 128;
|
||||
let mut stream = reader
|
||||
.read(
|
||||
Select::All,
|
||||
QueryExecutionOptions {
|
||||
max_batch_length: batch_size,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut total_rows = 0u64;
|
||||
let mut all_idx_values = Vec::new();
|
||||
while let Some(batch) = stream.try_next().await.unwrap() {
|
||||
assert!(batch.num_rows() <= batch_size as usize);
|
||||
total_rows += batch.num_rows() as u64;
|
||||
let idx_col = batch.column(0).as_primitive::<Int32Type>().values();
|
||||
all_idx_values.extend(idx_col.iter().copied());
|
||||
}
|
||||
|
||||
assert_eq!(total_rows, 5000);
|
||||
assert_eq!(all_idx_values.len(), 5000);
|
||||
|
||||
// Every value should be even (from the filter)
|
||||
assert!(all_idx_values.iter().all(|v| v % 2 == 0));
|
||||
|
||||
// Should have 5000 unique values
|
||||
let unique: std::collections::HashSet<i32> = all_idx_values.iter().copied().collect();
|
||||
assert_eq!(unique.len(), 5000);
|
||||
|
||||
// Use take_offsets to fetch rows from the beginning, middle, and end
|
||||
// of the permutation. The values should match what we saw during iteration.
|
||||
|
||||
// Beginning
|
||||
let batch = reader.take_offsets(&[0, 1, 2], Select::All).await.unwrap();
|
||||
assert_eq!(batch.num_rows(), 3);
|
||||
let idx_values = batch
|
||||
.column(0)
|
||||
.as_primitive::<Int32Type>()
|
||||
.values()
|
||||
.to_vec();
|
||||
assert_eq!(idx_values, &all_idx_values[0..3]);
|
||||
|
||||
// Middle
|
||||
let batch = reader
|
||||
.take_offsets(&[2499, 2500, 2501], Select::All)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(batch.num_rows(), 3);
|
||||
let idx_values = batch
|
||||
.column(0)
|
||||
.as_primitive::<Int32Type>()
|
||||
.values()
|
||||
.to_vec();
|
||||
assert_eq!(idx_values, &all_idx_values[2499..2502]);
|
||||
|
||||
// End (last 3 rows)
|
||||
let batch = reader
|
||||
.take_offsets(&[4997, 4998, 4999], Select::All)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(batch.num_rows(), 3);
|
||||
let idx_values = batch
|
||||
.column(0)
|
||||
.as_primitive::<Int32Type>()
|
||||
.values()
|
||||
.to_vec();
|
||||
assert_eq!(idx_values, &all_idx_values[4997..5000]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,7 +97,10 @@ pub type Result<T> = std::result::Result<T, Error>;
|
||||
impl From<ArrowError> for Error {
|
||||
fn from(source: ArrowError) -> Self {
|
||||
match source {
|
||||
ArrowError::ExternalError(source) => Self::from_box_error(source),
|
||||
ArrowError::ExternalError(source) => match source.downcast::<Self>() {
|
||||
Ok(e) => *e,
|
||||
Err(source) => Self::External { source },
|
||||
},
|
||||
_ => Self::Arrow { source },
|
||||
}
|
||||
}
|
||||
@@ -107,7 +110,15 @@ impl From<DataFusionError> for Error {
|
||||
fn from(source: DataFusionError) -> Self {
|
||||
match source {
|
||||
DataFusionError::ArrowError(source, _) => (*source).into(),
|
||||
DataFusionError::External(source) => Self::from_box_error(source),
|
||||
DataFusionError::External(source) => match source.downcast::<Self>() {
|
||||
Ok(e) => *e,
|
||||
Err(source) => match source.downcast::<ArrowError>() {
|
||||
Ok(arrow_error) => Self::Arrow {
|
||||
source: *arrow_error,
|
||||
},
|
||||
Err(source) => Self::External { source },
|
||||
},
|
||||
},
|
||||
other => Self::External {
|
||||
source: Box::new(other),
|
||||
},
|
||||
@@ -119,52 +130,15 @@ impl From<lance::Error> for Error {
|
||||
fn from(source: lance::Error) -> Self {
|
||||
// Try to unwrap external errors that were wrapped by lance
|
||||
match source {
|
||||
lance::Error::Wrapped { error, .. } => Self::from_box_error(error),
|
||||
lance::Error::External { source } => Self::from_box_error(source),
|
||||
lance::Error::Wrapped { error, .. } => match error.downcast::<Self>() {
|
||||
Ok(e) => *e,
|
||||
Err(source) => Self::External { source },
|
||||
},
|
||||
_ => Self::Lance { source },
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Error {
|
||||
fn from_box_error(mut source: Box<dyn std::error::Error + Send + Sync>) -> Self {
|
||||
source = match source.downcast::<Self>() {
|
||||
Ok(e) => match *e {
|
||||
Self::External { source } => return Self::from_box_error(source),
|
||||
other => return other,
|
||||
},
|
||||
Err(source) => source,
|
||||
};
|
||||
|
||||
source = match source.downcast::<lance::Error>() {
|
||||
Ok(e) => match *e {
|
||||
lance::Error::Wrapped { error, .. } => return Self::from_box_error(error),
|
||||
other => return other.into(),
|
||||
},
|
||||
Err(source) => source,
|
||||
};
|
||||
|
||||
source = match source.downcast::<ArrowError>() {
|
||||
Ok(e) => match *e {
|
||||
ArrowError::ExternalError(source) => return Self::from_box_error(source),
|
||||
other => return other.into(),
|
||||
},
|
||||
Err(source) => source,
|
||||
};
|
||||
|
||||
source = match source.downcast::<DataFusionError>() {
|
||||
Ok(e) => match *e {
|
||||
DataFusionError::ArrowError(source, _) => return (*source).into(),
|
||||
DataFusionError::External(source) => return Self::from_box_error(source),
|
||||
other => return other.into(),
|
||||
},
|
||||
Err(source) => source,
|
||||
};
|
||||
|
||||
Self::External { source }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<object_store::Error> for Error {
|
||||
fn from(source: object_store::Error) -> Self {
|
||||
Self::ObjectStore { source }
|
||||
|
||||
@@ -1,131 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! Expression builder API for type-safe query construction
|
||||
//!
|
||||
//! This module provides a fluent API for building expressions that can be used
|
||||
//! in filters and projections. It wraps DataFusion's expression system.
|
||||
//!
|
||||
//! # Examples
|
||||
//!
|
||||
//! ```rust
|
||||
//! use std::ops::Mul;
|
||||
//! use lancedb::expr::{col, lit};
|
||||
//!
|
||||
//! let expr = col("age").gt(lit(18));
|
||||
//! let expr = col("age").gt(lit(18)).and(col("status").eq(lit("active")));
|
||||
//! let expr = col("price") * lit(1.1);
|
||||
//! ```
|
||||
|
||||
mod sql;
|
||||
|
||||
pub use sql::expr_to_sql_string;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_schema::DataType;
|
||||
use datafusion_expr::{expr_fn::cast, Expr, ScalarUDF};
|
||||
use datafusion_functions::string::expr_fn as string_expr_fn;
|
||||
|
||||
pub use datafusion_expr::{col, lit};
|
||||
|
||||
pub use datafusion_expr::Expr as DfExpr;
|
||||
|
||||
pub fn lower(expr: Expr) -> Expr {
|
||||
string_expr_fn::lower(expr)
|
||||
}
|
||||
|
||||
pub fn upper(expr: Expr) -> Expr {
|
||||
string_expr_fn::upper(expr)
|
||||
}
|
||||
|
||||
pub fn contains(expr: Expr, search: Expr) -> Expr {
|
||||
string_expr_fn::contains(expr, search)
|
||||
}
|
||||
|
||||
pub fn expr_cast(expr: Expr, data_type: DataType) -> Expr {
|
||||
cast(expr, data_type)
|
||||
}
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref FUNC_REGISTRY: std::sync::RwLock<std::collections::HashMap<String, Arc<ScalarUDF>>> = {
|
||||
let mut m = std::collections::HashMap::new();
|
||||
m.insert("lower".to_string(), datafusion_functions::string::lower());
|
||||
m.insert("upper".to_string(), datafusion_functions::string::upper());
|
||||
m.insert("contains".to_string(), datafusion_functions::string::contains());
|
||||
m.insert("btrim".to_string(), datafusion_functions::string::btrim());
|
||||
m.insert("ltrim".to_string(), datafusion_functions::string::ltrim());
|
||||
m.insert("rtrim".to_string(), datafusion_functions::string::rtrim());
|
||||
m.insert("concat".to_string(), datafusion_functions::string::concat());
|
||||
m.insert("octet_length".to_string(), datafusion_functions::string::octet_length());
|
||||
std::sync::RwLock::new(m)
|
||||
};
|
||||
}
|
||||
|
||||
pub fn func(name: impl AsRef<str>, args: Vec<Expr>) -> crate::Result<Expr> {
|
||||
let name = name.as_ref();
|
||||
let registry = FUNC_REGISTRY
|
||||
.read()
|
||||
.map_err(|e| crate::Error::InvalidInput {
|
||||
message: format!("lock poisoned: {}", e),
|
||||
})?;
|
||||
let udf = registry
|
||||
.get(name)
|
||||
.ok_or_else(|| crate::Error::InvalidInput {
|
||||
message: format!("unknown function: {}", name),
|
||||
})?;
|
||||
Ok(Expr::ScalarFunction(
|
||||
datafusion_expr::expr::ScalarFunction::new_udf(udf.clone(), args),
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_col_lit_comparisons() {
|
||||
let expr = col("age").gt(lit(18));
|
||||
let sql = expr_to_sql_string(&expr).unwrap();
|
||||
assert!(sql.contains("age") && sql.contains("18"));
|
||||
|
||||
let expr = col("name").eq(lit("Alice"));
|
||||
let sql = expr_to_sql_string(&expr).unwrap();
|
||||
assert!(sql.contains("name") && sql.contains("Alice"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compound_expression() {
|
||||
let expr = col("age").gt(lit(18)).and(col("status").eq(lit("active")));
|
||||
let sql = expr_to_sql_string(&expr).unwrap();
|
||||
assert!(sql.contains("age") && sql.contains("status"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_string_functions() {
|
||||
let expr = lower(col("name"));
|
||||
let sql = expr_to_sql_string(&expr).unwrap();
|
||||
assert!(sql.to_lowercase().contains("lower"));
|
||||
|
||||
let expr = contains(col("text"), lit("search"));
|
||||
let sql = expr_to_sql_string(&expr).unwrap();
|
||||
assert!(sql.to_lowercase().contains("contains"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_func() {
|
||||
let expr = func("lower", vec![col("x")]).unwrap();
|
||||
let sql = expr_to_sql_string(&expr).unwrap();
|
||||
assert!(sql.to_lowercase().contains("lower"));
|
||||
|
||||
let result = func("unknown_func", vec![col("x")]);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arithmetic() {
|
||||
let expr = col("price") * lit(1.1);
|
||||
let sql = expr_to_sql_string(&expr).unwrap();
|
||||
assert!(sql.contains("price"));
|
||||
}
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use datafusion_expr::Expr;
|
||||
use datafusion_sql::unparser;
|
||||
|
||||
pub fn expr_to_sql_string(expr: &Expr) -> crate::Result<String> {
|
||||
let ast = unparser::expr_to_sql(expr).map_err(|e| crate::Error::InvalidInput {
|
||||
message: format!("failed to serialize expression to SQL: {}", e),
|
||||
})?;
|
||||
Ok(ast.to_string())
|
||||
}
|
||||
@@ -169,7 +169,6 @@ pub mod database;
|
||||
pub mod dataloader;
|
||||
pub mod embeddings;
|
||||
pub mod error;
|
||||
pub mod expr;
|
||||
pub mod index;
|
||||
pub mod io;
|
||||
pub mod ipc;
|
||||
|
||||
@@ -359,28 +359,6 @@ pub trait QueryBase {
|
||||
/// on the filter column(s).
|
||||
fn only_if(self, filter: impl AsRef<str>) -> Self;
|
||||
|
||||
/// Only return rows which match the filter, using an expression builder.
|
||||
///
|
||||
/// Use [`crate::expr`] for building type-safe expressions:
|
||||
///
|
||||
/// ```
|
||||
/// use lancedb::expr::{col, lit};
|
||||
/// use lancedb::query::{QueryBase, ExecutableQuery};
|
||||
///
|
||||
/// # use lancedb::Table;
|
||||
/// # async fn query(table: &Table) -> Result<(), Box<dyn std::error::Error>> {
|
||||
/// let results = table.query()
|
||||
/// .only_if_expr(col("age").gt(lit(18)).and(col("status").eq(lit("active"))))
|
||||
/// .execute()
|
||||
/// .await?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// Note: Expression filters are not supported for remote/server-side queries.
|
||||
/// Use [`QueryBase::only_if`] with SQL strings for remote tables.
|
||||
fn only_if_expr(self, filter: datafusion_expr::Expr) -> Self;
|
||||
|
||||
/// Perform a full text search on the table.
|
||||
///
|
||||
/// The results will be returned in order of BM25 scores.
|
||||
@@ -490,11 +468,6 @@ impl<T: HasQuery> QueryBase for T {
|
||||
self
|
||||
}
|
||||
|
||||
fn only_if_expr(mut self, filter: datafusion_expr::Expr) -> Self {
|
||||
self.mut_query().filter = Some(QueryFilter::Datafusion(filter));
|
||||
self
|
||||
}
|
||||
|
||||
fn full_text_search(mut self, query: FullTextSearchQuery) -> Self {
|
||||
if self.mut_query().limit.is_none() {
|
||||
self.mut_query().limit = Some(DEFAULT_TOP_K);
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
pub mod insert;
|
||||
|
||||
use self::insert::RemoteInsertExec;
|
||||
use crate::expr::expr_to_sql_string;
|
||||
|
||||
use super::client::RequestResultExt;
|
||||
use super::client::{HttpSend, RestfulLanceDbClient, Sender};
|
||||
@@ -202,6 +201,7 @@ impl<S: HttpSend + 'static> Tags for RemoteTags<'_, S> {
|
||||
}
|
||||
|
||||
pub struct RemoteTable<S: HttpSend = Sender> {
|
||||
#[allow(dead_code)]
|
||||
client: RestfulLanceDbClient<S>,
|
||||
name: String,
|
||||
namespace: Vec<String>,
|
||||
@@ -447,17 +447,13 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
body["k"] = serde_json::Value::Number(serde_json::Number::from(limit));
|
||||
|
||||
if let Some(filter) = ¶ms.filter {
|
||||
let filter_sql = match filter {
|
||||
QueryFilter::Sql(sql) => sql.clone(),
|
||||
QueryFilter::Datafusion(expr) => expr_to_sql_string(expr)?,
|
||||
QueryFilter::Substrait(_) => {
|
||||
return Err(Error::NotSupported {
|
||||
message: "Substrait filters are not supported for remote queries"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
};
|
||||
body["filter"] = serde_json::Value::String(filter_sql);
|
||||
if let QueryFilter::Sql(filter) = filter {
|
||||
body["filter"] = serde_json::Value::String(filter.clone());
|
||||
} else {
|
||||
return Err(Error::NotSupported {
|
||||
message: "querying a remote table with a non-sql filter".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
match ¶ms.select {
|
||||
@@ -945,12 +941,12 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
let version = self.current_version().await;
|
||||
|
||||
if let Some(filter) = filter {
|
||||
let filter_sql = match filter {
|
||||
Filter::Sql(sql) => sql.clone(),
|
||||
Filter::Datafusion(expr) => expr_to_sql_string(&expr)?,
|
||||
let Filter::Sql(filter) = filter else {
|
||||
return Err(Error::NotSupported {
|
||||
message: "querying a remote table with a datafusion filter".to_string(),
|
||||
});
|
||||
};
|
||||
request =
|
||||
request.json(&serde_json::json!({ "predicate": filter_sql, "version": version }));
|
||||
request = request.json(&serde_json::json!({ "predicate": filter, "version": version }));
|
||||
} else {
|
||||
let body = serde_json::json!({ "version": version });
|
||||
request = request.json(&body);
|
||||
@@ -4639,60 +4635,4 @@ mod tests {
|
||||
assert_eq!(result.version, 3);
|
||||
assert_eq!(attempt.load(Ordering::SeqCst), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_query_with_datafusion_filter() {
|
||||
use datafusion_expr::{col, lit};
|
||||
|
||||
let expected_data = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap();
|
||||
let expected_data_ref = expected_data.clone();
|
||||
|
||||
let table = Table::new_with_handler("my_table", move |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/query/");
|
||||
|
||||
let body = request.body().unwrap().as_bytes().unwrap();
|
||||
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
|
||||
|
||||
// The Datafusion expression should be serialized to SQL
|
||||
let filter = body.get("filter").expect("filter should be present");
|
||||
let filter_str = filter.as_str().expect("filter should be a string");
|
||||
// col("x") > lit(10) AND col("status") = lit("active")
|
||||
assert!(
|
||||
filter_str.contains("x") && filter_str.contains("10"),
|
||||
"Filter should contain 'x' and '10', got: {}",
|
||||
filter_str
|
||||
);
|
||||
assert!(
|
||||
filter_str.contains("status") && filter_str.contains("active"),
|
||||
"Filter should contain 'status' and 'active', got: {}",
|
||||
filter_str
|
||||
);
|
||||
|
||||
let response_body = write_ipc_file(&expected_data_ref);
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
|
||||
.body(response_body)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
// Use only_if_expr with a Datafusion expression
|
||||
let expr = col("x").gt(lit(10)).and(col("status").eq(lit("active")));
|
||||
let data = table
|
||||
.query()
|
||||
.only_if_expr(expr)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
.collect::<Vec<_>>()
|
||||
.await;
|
||||
|
||||
assert_eq!(data.len(), 1);
|
||||
assert_eq!(data[0].as_ref().unwrap(), &expected_data);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,12 +6,11 @@
|
||||
use arrow_array::{RecordBatch, RecordBatchReader};
|
||||
use arrow_schema::{DataType, Field, Schema, SchemaRef};
|
||||
use async_trait::async_trait;
|
||||
use datafusion_execution::TaskContext;
|
||||
use datafusion_expr::Expr;
|
||||
use datafusion_physical_plan::display::DisplayableExecutionPlan;
|
||||
use datafusion_physical_plan::ExecutionPlan;
|
||||
use futures::stream::FuturesUnordered;
|
||||
use futures::StreamExt;
|
||||
use futures::TryStreamExt;
|
||||
use lance::dataset::builder::DatasetBuilder;
|
||||
pub use lance::dataset::ColumnAlteration;
|
||||
pub use lance::dataset::NewColumnTransform;
|
||||
@@ -22,6 +21,7 @@ use lance::dataset::{InsertBuilder, WriteParams};
|
||||
use lance::index::vector::utils::infer_vector_dim;
|
||||
use lance::index::vector::VectorIndexParams;
|
||||
use lance::io::{ObjectStoreParams, WrappingObjectStore};
|
||||
use lance_datafusion::exec::execute_plan;
|
||||
use lance_datafusion::utils::StreamingWriteSource;
|
||||
use lance_index::scalar::{BuiltinIndexType, ScalarIndexParams};
|
||||
use lance_index::vector::bq::RQBuildParams;
|
||||
@@ -43,7 +43,7 @@ use std::format;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::data::scannable::{estimate_write_partitions, PeekedScannable, Scannable};
|
||||
use crate::data::scannable::Scannable;
|
||||
use crate::database::Database;
|
||||
use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MemoryRegistry};
|
||||
use crate::error::{Error, Result};
|
||||
@@ -2113,7 +2113,7 @@ impl BaseTable for NativeTable {
|
||||
}
|
||||
}
|
||||
|
||||
async fn add(&self, mut add: AddDataBuilder) -> Result<AddResult> {
|
||||
async fn add(&self, add: AddDataBuilder) -> Result<AddResult> {
|
||||
let table_def = self.table_definition().await?;
|
||||
|
||||
self.dataset.ensure_mutable()?;
|
||||
@@ -2122,22 +2122,6 @@ impl BaseTable for NativeTable {
|
||||
|
||||
let table_schema = Schema::from(&ds.schema().clone());
|
||||
|
||||
// Peek at the first batch to estimate a good partition count for
|
||||
// write parallelism.
|
||||
let mut peeked = PeekedScannable::new(add.data);
|
||||
let num_partitions = if let Some(first_batch) = peeked.peek().await {
|
||||
let max_partitions = lance_core::utils::tokio::get_num_compute_intensive_cpus();
|
||||
estimate_write_partitions(
|
||||
first_batch.get_array_memory_size(),
|
||||
first_batch.num_rows(),
|
||||
peeked.num_rows(),
|
||||
max_partitions,
|
||||
)
|
||||
} else {
|
||||
1
|
||||
};
|
||||
add.data = Box::new(peeked);
|
||||
|
||||
let output = add.into_plan(&table_schema, &table_def)?;
|
||||
|
||||
let lance_params = output
|
||||
@@ -2151,41 +2135,18 @@ impl BaseTable for NativeTable {
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
// Repartition for write parallelism if beneficial.
|
||||
let plan = if num_partitions > 1 {
|
||||
Arc::new(
|
||||
datafusion_physical_plan::repartition::RepartitionExec::try_new(
|
||||
output.plan,
|
||||
datafusion_physical_plan::Partitioning::RoundRobinBatch(num_partitions),
|
||||
)?,
|
||||
) as Arc<dyn ExecutionPlan>
|
||||
} else {
|
||||
output.plan
|
||||
};
|
||||
let plan = Arc::new(InsertExec::new(
|
||||
ds_wrapper.clone(),
|
||||
ds,
|
||||
output.plan,
|
||||
lance_params,
|
||||
));
|
||||
|
||||
let insert_exec = Arc::new(InsertExec::new(ds_wrapper.clone(), ds, plan, lance_params));
|
||||
|
||||
// Execute all partitions in parallel.
|
||||
let task_ctx = Arc::new(TaskContext::default());
|
||||
let handles = FuturesUnordered::new();
|
||||
for partition in 0..num_partitions {
|
||||
let exec = insert_exec.clone();
|
||||
let ctx = task_ctx.clone();
|
||||
handles.push(tokio::spawn(async move {
|
||||
let mut stream = exec
|
||||
.execute(partition, ctx)
|
||||
.map_err(|e| -> Error { e.into() })?;
|
||||
while let Some(batch) = stream.next().await {
|
||||
batch.map_err(|e| -> Error { e.into() })?;
|
||||
}
|
||||
Ok::<_, Error>(())
|
||||
}));
|
||||
}
|
||||
for handle in handles {
|
||||
handle.await.map_err(|e| Error::Runtime {
|
||||
message: format!("Insert task panicked: {}", e),
|
||||
})??;
|
||||
}
|
||||
let stream = execute_plan(plan, Default::default())?;
|
||||
stream
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.map_err(crate::Error::from)?;
|
||||
|
||||
let version = ds_wrapper.get().await?.manifest().version;
|
||||
Ok(AddResult { version })
|
||||
|
||||
@@ -219,7 +219,6 @@ mod tests {
|
||||
use crate::table::add_data::NaNVectorBehavior;
|
||||
use crate::table::{ColumnDefinition, ColumnKind, Table, TableDefinition, WriteOptions};
|
||||
use crate::test_utils::embeddings::MockEmbed;
|
||||
use crate::test_utils::TestCustomError;
|
||||
use crate::Error;
|
||||
|
||||
use super::AddDataMode;
|
||||
@@ -284,20 +283,17 @@ mod tests {
|
||||
test_add_with_data(stream).await;
|
||||
}
|
||||
|
||||
fn assert_preserves_external_error(err: &Error) {
|
||||
assert!(
|
||||
matches!(err, Error::External { source } if source.downcast_ref::<TestCustomError>().is_some()),
|
||||
"Expected Error::External, got: {err:?}"
|
||||
);
|
||||
// The original TestCustomError message should be preserved through the
|
||||
// error chain, even if the error gets wrapped multiple times by
|
||||
// lance's insert pipeline.
|
||||
assert!(
|
||||
err.to_string().contains("TestCustomError occurred"),
|
||||
"Expected original error message to be preserved, got: {err}"
|
||||
);
|
||||
#[derive(Debug)]
|
||||
struct MyError;
|
||||
|
||||
impl std::fmt::Display for MyError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "MyError occurred")
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for MyError {}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_add_preserves_reader_error() {
|
||||
let table = create_test_table().await;
|
||||
@@ -305,7 +301,7 @@ mod tests {
|
||||
let schema = first_batch.schema();
|
||||
let iterator = vec![
|
||||
Ok(first_batch),
|
||||
Err(ArrowError::ExternalError(Box::new(TestCustomError))),
|
||||
Err(ArrowError::ExternalError(Box::new(MyError))),
|
||||
];
|
||||
let reader: Box<dyn arrow_array::RecordBatchReader + Send> = Box::new(
|
||||
RecordBatchIterator::new(iterator.into_iter(), schema.clone()),
|
||||
@@ -313,7 +309,7 @@ mod tests {
|
||||
|
||||
let result = table.add(reader).execute().await;
|
||||
|
||||
assert_preserves_external_error(&result.unwrap_err());
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -324,7 +320,7 @@ mod tests {
|
||||
let iterator = vec![
|
||||
Ok(first_batch),
|
||||
Err(Error::External {
|
||||
source: Box::new(TestCustomError),
|
||||
source: Box::new(MyError),
|
||||
}),
|
||||
];
|
||||
let stream = futures::stream::iter(iterator);
|
||||
@@ -335,7 +331,7 @@ mod tests {
|
||||
|
||||
let result = table.add(stream).execute().await;
|
||||
|
||||
assert_preserves_external_error(&result.unwrap_err());
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -5,7 +5,6 @@ use std::sync::Arc;
|
||||
|
||||
use super::NativeTable;
|
||||
use crate::error::{Error, Result};
|
||||
use crate::expr::expr_to_sql_string;
|
||||
use crate::query::{
|
||||
QueryExecutionOptions, QueryFilter, QueryRequest, Select, VectorQueryRequest, DEFAULT_TOP_K,
|
||||
};
|
||||
@@ -453,12 +452,14 @@ fn convert_to_namespace_query(query: &AnyQuery) -> Result<NsQueryTableRequest> {
|
||||
|
||||
fn filter_to_sql(filter: &QueryFilter) -> Result<String> {
|
||||
match filter {
|
||||
QueryFilter::Sql(sql) => Ok(sql.clone()),
|
||||
QueryFilter::Substrait(_) => Err(Error::NotSupported {
|
||||
message: "Substrait filters are not supported for server-side queries".to_string(),
|
||||
}),
|
||||
QueryFilter::Datafusion(expr) => expr_to_sql_string(expr),
|
||||
}
|
||||
QueryFilter::Sql(sql) => Ok(sql.clone()),
|
||||
QueryFilter::Substrait(_) => Err(Error::NotSupported {
|
||||
message: "Substrait filters are not supported for server-side queries".to_string(),
|
||||
}),
|
||||
QueryFilter::Datafusion(_) => Err(Error::NotSupported {
|
||||
message: "Datafusion expression filters are not supported for server-side queries. Use SQL filter instead.".to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract query vector(s) from Arrow arrays into the namespace format.
|
||||
|
||||
@@ -4,14 +4,3 @@
|
||||
pub mod connection;
|
||||
pub mod datagen;
|
||||
pub mod embeddings;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TestCustomError;
|
||||
|
||||
impl std::fmt::Display for TestCustomError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "TestCustomError occurred")
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for TestCustomError {}
|
||||
|
||||
Reference in New Issue
Block a user