Compare commits

..

2 Commits

Author SHA1 Message Date
Yang Cen
7ff72022dd feat(query): add approx mode to vector queries (#3549)
## Feature

### What is the new feature?

Adds Rust core API support for configuring vector query approximation
mode with `ApproxMode::{Fast, Normal, Accurate}`.

### Why do we need this feature?

Lance already exposes `lance_index::vector::ApproxMode` and scanner
support for controlling the speed/accuracy tradeoff for approximate
vector search. LanceDB Rust queries need to expose and pass this setting
through for local/native and remote vector searches.

### How does it work?

- Adds public `ApproxMode` in `rust/lancedb`, with lowercase serde,
`Default::Normal`, parse/display, and conversions to/from Lance's
`ApproxMode`.
- Adds `approx_mode: Option<ApproxMode>` to `VectorQueryRequest` and a
`VectorQuery::approx_mode(...)` builder.
- Applies the mode to native/local Lance scanners after `nearest(...)`
when explicitly set.
- Sends `approx_mode` in remote query JSON only when explicitly set;
default requests omit it.

## Validation

- `cargo fmt --all`
- `cargo test --quiet --features remote approx_mode`
- `cargo test --quiet --features remote
test_query_vector_default_values`
- `cargo check --quiet --features remote --tests --examples`
- `git diff --check`
2026-06-17 20:21:02 +08:00
lancedb automation
b2e0aa0588 chore: update lance dependency to v8.0.0-beta.16 2026-06-17 01:41:42 +00:00
82 changed files with 273 additions and 8974 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.31.0-beta.4"
current_version = "0.30.1-beta.2"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.
@@ -23,8 +23,6 @@ allow_dirty = true
commit = true
message = "Bump version: {current_version} → {new_version}"
commit_args = ""
# bump-my-version >=1.4.0 rejects pre_commit_hooks containing shell syntax unless opted in.
allow_shell_hooks = true
# Java maven files
pre_commit_hooks = [

190
Cargo.lock generated
View File

@@ -157,9 +157,9 @@ dependencies = [
[[package]]
name = "anyhow"
version = "1.0.103"
version = "1.0.102"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a4385e2e34eb35d6b3efe798b9eb88096925d87726c0798709bf56d9ed84af3"
checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c"
[[package]]
name = "approx"
@@ -1297,6 +1297,15 @@ version = "2.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4512299f36f043ab09a583e57bceb5a5aab7a73db1805848e8fef3c9e8c78b3"
[[package]]
name = "bitpacking"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96a7139abd3d9cebf8cd6f920a389cf3dc9576172e32f4563f188cae3c3eb019"
dependencies = [
"crunchy",
]
[[package]]
name = "bitvec"
version = "1.0.1"
@@ -1463,9 +1472,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]]
name = "bytes"
version = "1.12.0"
version = "1.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ae3f5d315924270530207e2a68396c3cc547f6dca3fbdca317cfb1a51edb593"
checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33"
[[package]]
name = "bytes-utils"
@@ -1750,7 +1759,7 @@ version = "3.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "faf9468729b8cbcea668e36183cb69d317348c2e08e994829fb56ebfdfbaac34"
dependencies = [
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -3014,7 +3023,7 @@ dependencies = [
"libc",
"option-ext",
"redox_users",
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -3177,9 +3186,9 @@ dependencies = [
[[package]]
name = "env_filter"
version = "2.0.0"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "900d271a03799a1ee8d1ca9b19893b48ca674a9284fefcfb85f05e74ed314217"
checksum = "32e90c2accc4b07a8456ea0debdc2e7587bdd890680d71173a15d4ae604f6eef"
dependencies = [
"log",
"regex",
@@ -3187,9 +3196,9 @@ dependencies = [
[[package]]
name = "env_logger"
version = "0.11.11"
version = "0.11.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de671bd27a75a797dc9ae289ba1e77276e75e2026408aab65185384e2d5cd3f6"
checksum = "0621c04f2196ac3f488dd583365b9c09be011a4ab8b9f37248ffcc8f6198b56a"
dependencies = [
"anstream",
"anstyle",
@@ -3231,7 +3240,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [
"libc",
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -3423,8 +3432,8 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]]
name = "fsst"
version = "9.0.0-beta.10"
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
version = "8.0.0-beta.16"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
dependencies = [
"arrow-array",
"rand 0.9.4",
@@ -4466,7 +4475,7 @@ checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46"
dependencies = [
"hermit-abi",
"libc",
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -4560,7 +4569,7 @@ dependencies = [
"portable-atomic-util",
"serde_core",
"wasm-bindgen",
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -4726,8 +4735,8 @@ checksum = "e037a2e1d8d5fdbd49b16a4ea09d5d6401c1f29eca5ff29d03d3824dba16256a"
[[package]]
name = "lance"
version = "9.0.0-beta.10"
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
version = "8.0.0-beta.16"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
dependencies = [
"arc-swap",
"arrow",
@@ -4745,6 +4754,7 @@ dependencies = [
"async_cell",
"aws-credential-types",
"aws-sdk-dynamodb",
"bitpacking",
"byteorder",
"bytes",
"chrono",
@@ -4761,9 +4771,8 @@ dependencies = [
"futures",
"half",
"humantime",
"itertools 0.14.0",
"itertools 0.13.0",
"lance-arrow",
"lance-bitpacking",
"lance-core",
"lance-datafusion",
"lance-encoding",
@@ -4801,8 +4810,8 @@ dependencies = [
[[package]]
name = "lance-arrow"
version = "9.0.0-beta.10"
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
version = "8.0.0-beta.16"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -4823,7 +4832,7 @@ dependencies = [
[[package]]
name = "lance-arrow-scalar"
version = "58.0.0"
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -4837,7 +4846,7 @@ dependencies = [
[[package]]
name = "lance-arrow-stats"
version = "58.0.0"
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
dependencies = [
"arrow-array",
"arrow-schema",
@@ -4846,19 +4855,18 @@ dependencies = [
[[package]]
name = "lance-bitpacking"
version = "9.0.0-beta.10"
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
version = "8.0.0-beta.16"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
dependencies = [
"arrayref",
"crunchy",
"paste",
"seq-macro",
]
[[package]]
name = "lance-core"
version = "9.0.0-beta.10"
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
version = "8.0.0-beta.16"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -4870,7 +4878,7 @@ dependencies = [
"datafusion-common",
"datafusion-sql",
"futures",
"itertools 0.14.0",
"itertools 0.13.0",
"lance-arrow",
"lance-derive",
"libc",
@@ -4896,8 +4904,8 @@ dependencies = [
[[package]]
name = "lance-datafusion"
version = "9.0.0-beta.10"
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
version = "8.0.0-beta.16"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
dependencies = [
"arrow",
"arrow-array",
@@ -4927,8 +4935,8 @@ dependencies = [
[[package]]
name = "lance-datagen"
version = "9.0.0-beta.10"
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
version = "8.0.0-beta.16"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
dependencies = [
"arrow",
"arrow-array",
@@ -4945,8 +4953,8 @@ dependencies = [
[[package]]
name = "lance-derive"
version = "9.0.0-beta.10"
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
version = "8.0.0-beta.16"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
dependencies = [
"proc-macro2",
"quote",
@@ -4955,8 +4963,8 @@ dependencies = [
[[package]]
name = "lance-encoding"
version = "9.0.0-beta.10"
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
version = "8.0.0-beta.16"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
dependencies = [
"arrow-arith",
"arrow-array",
@@ -4972,7 +4980,7 @@ dependencies = [
"futures",
"hex",
"hyperloglogplus",
"itertools 0.14.0",
"itertools 0.13.0",
"lance-arrow",
"lance-bitpacking",
"lance-core",
@@ -4991,8 +4999,8 @@ dependencies = [
[[package]]
name = "lance-file"
version = "9.0.0-beta.10"
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
version = "8.0.0-beta.16"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
dependencies = [
"arrow-arith",
"arrow-array",
@@ -5022,8 +5030,8 @@ dependencies = [
[[package]]
name = "lance-index"
version = "9.0.0-beta.10"
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
version = "8.0.0-beta.16"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
dependencies = [
"arc-swap",
"arrow",
@@ -5035,6 +5043,7 @@ dependencies = [
"async-channel",
"async-recursion",
"async-trait",
"bitpacking",
"bitvec",
"bytes",
"chrono",
@@ -5047,12 +5056,11 @@ dependencies = [
"fst",
"futures",
"half",
"itertools 0.14.0",
"itertools 0.13.0",
"jieba-rs",
"jsonb",
"lance-arrow",
"lance-arrow-stats",
"lance-bitpacking",
"lance-core",
"lance-datafusion",
"lance-datagen",
@@ -5088,8 +5096,8 @@ dependencies = [
[[package]]
name = "lance-io"
version = "9.0.0-beta.10"
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
version = "8.0.0-beta.16"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
dependencies = [
"arrow",
"arrow-arith",
@@ -5130,8 +5138,8 @@ dependencies = [
[[package]]
name = "lance-linalg"
version = "9.0.0-beta.10"
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
version = "8.0.0-beta.16"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -5142,13 +5150,12 @@ dependencies = [
"lance-core",
"num-traits",
"rand 0.9.4",
"rayon",
]
[[package]]
name = "lance-namespace"
version = "9.0.0-beta.10"
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
version = "8.0.0-beta.16"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
dependencies = [
"arrow",
"async-trait",
@@ -5160,8 +5167,8 @@ dependencies = [
[[package]]
name = "lance-namespace-impls"
version = "9.0.0-beta.10"
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
version = "8.0.0-beta.16"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
dependencies = [
"arrow",
"arrow-ipc",
@@ -5215,15 +5222,15 @@ dependencies = [
[[package]]
name = "lance-select"
version = "9.0.0-beta.10"
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
version = "8.0.0-beta.16"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
dependencies = [
"arrow-array",
"arrow-buffer",
"arrow-schema",
"byteorder",
"bytes",
"itertools 0.14.0",
"itertools 0.13.0",
"lance-core",
"roaring",
"tracing",
@@ -5231,8 +5238,8 @@ dependencies = [
[[package]]
name = "lance-table"
version = "9.0.0-beta.10"
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
version = "8.0.0-beta.16"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
dependencies = [
"arrow",
"arrow-array",
@@ -5271,8 +5278,8 @@ dependencies = [
[[package]]
name = "lance-testing"
version = "9.0.0-beta.10"
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
version = "8.0.0-beta.16"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
dependencies = [
"arrow-array",
"arrow-schema",
@@ -5285,21 +5292,20 @@ dependencies = [
[[package]]
name = "lance-tokenizer"
version = "9.0.0-beta.10"
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
version = "8.0.0-beta.16"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
dependencies = [
"icu_segmenter",
"jieba-rs",
"lindera",
"rust-stemmers",
"serde",
"stop-words",
"unicode-normalization",
]
[[package]]
name = "lancedb"
version = "0.31.0-beta.4"
version = "0.30.1-beta.2"
dependencies = [
"ahash",
"anyhow",
@@ -5376,14 +5382,13 @@ dependencies = [
"tokenizers",
"tokio",
"url",
"urlencoding",
"uuid",
"walkdir",
]
[[package]]
name = "lancedb-nodejs"
version = "0.31.0-beta.4"
version = "0.30.1-beta.2"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -5408,7 +5413,7 @@ dependencies = [
[[package]]
name = "lancedb-python"
version = "0.34.0-beta.4"
version = "0.33.1-beta.2"
dependencies = [
"arrow",
"async-trait",
@@ -5641,9 +5646,9 @@ dependencies = [
[[package]]
name = "log"
version = "0.4.33"
version = "0.4.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ceec5bc11778974d1bcb055b18002eba7f4b3518b6a0081b3af5f21666da9ad"
checksum = "953f07c43838f8e6f9758cab68bf5bed85465e7587ebe0b823f1bcd81978ad3a"
[[package]]
name = "loom"
@@ -5951,9 +5956,9 @@ dependencies = [
[[package]]
name = "napi"
version = "3.9.4"
version = "3.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b41bda2ac390efb5e8d22025d925ccc3f3807d8c1bea6d19b36127247c4b8f83"
checksum = "ad513ff22558f1830b595ea6eb4091da48145d09a222ce157e781896f78be0b9"
dependencies = [
"bitflags 2.11.1",
"chrono",
@@ -5976,9 +5981,9 @@ checksum = "c9c366d2c8c60b86fa632df75f745509b52f9128f91a6bad4c796e44abb505e1"
[[package]]
name = "napi-derive"
version = "3.5.7"
version = "3.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61d66f70256ad5aef58659966064471d0ad90e2897bc36a5a5e0389c85aabc1e"
checksum = "89b3f766e04667e6da0e181e2da4f85475d5a6513b7cf6a80bea184e224a5b42"
dependencies = [
"convert_case",
"ctor 1.0.5",
@@ -5990,9 +5995,9 @@ dependencies = [
[[package]]
name = "napi-derive-backend"
version = "5.0.5"
version = "5.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81b4b08f15eed7a2a20c3f4c6314013fc3ac890a3afa9892b594485299ebdb2d"
checksum = "0d5af30503edf933ce7377cf6d4c877a62b0f1107ea05585f1b5e430e88d5baf"
dependencies = [
"convert_case",
"proc-macro2",
@@ -6085,7 +6090,7 @@ version = "0.50.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
dependencies = [
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -7400,8 +7405,8 @@ version = "0.14.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7"
dependencies = [
"heck 0.4.1",
"itertools 0.11.0",
"heck 0.5.0",
"itertools 0.14.0",
"log",
"multimap",
"petgraph",
@@ -7420,7 +7425,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b"
dependencies = [
"anyhow",
"itertools 0.11.0",
"itertools 0.14.0",
"proc-macro2",
"quote",
"syn 2.0.117",
@@ -7654,7 +7659,7 @@ dependencies = [
"once_cell",
"socket2 0.6.3",
"tracing",
"windows-sys 0.59.0",
"windows-sys 0.60.2",
]
[[package]]
@@ -8394,7 +8399,7 @@ dependencies = [
"errno",
"libc",
"linux-raw-sys",
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -8465,7 +8470,7 @@ dependencies = [
"security-framework",
"security-framework-sys",
"webpki-root-certs",
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -9027,7 +9032,7 @@ version = "0.8.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1c97747dbf44bb1ca44a561ece23508e99cb592e862f22222dcf42f51d1e451"
dependencies = [
"heck 0.4.1",
"heck 0.5.0",
"proc-macro2",
"quote",
"syn 2.0.117",
@@ -9039,7 +9044,7 @@ version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54254b8531cafa275c5e096f62d48c81435d1015405a91198ddb11e967301d40"
dependencies = [
"heck 0.4.1",
"heck 0.5.0",
"proc-macro2",
"quote",
"syn 2.0.117",
@@ -9200,15 +9205,6 @@ version = "0.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e51f1e89f093f99e7432c491c382b88a6860a5adbe6bf02574bf0a08efff1978"
[[package]]
name = "stop-words"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d68df56303396bcfb639455b3c166804aeb7994005010aab5e9e8a1277b8871d"
dependencies = [
"serde_json",
]
[[package]]
name = "str_stack"
version = "0.1.1"
@@ -9472,7 +9468,7 @@ dependencies = [
"getrandom 0.4.2",
"once_cell",
"rustix",
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -10120,9 +10116,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "uuid"
version = "1.23.4"
version = "1.23.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf80a72845275afea99e7f2b434723d3bc7e38470fcd1c7ed39a599c73319a53"
checksum = "144d6b123cef80b301b8f72a9e2ca4370ddec21950d0a103dd22c437006d2db7"
dependencies = [
"getrandom 0.4.2",
"js-sys",
@@ -10407,7 +10403,7 @@ version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
dependencies = [
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]

View File

@@ -13,20 +13,20 @@ categories = ["database-implementations"]
rust-version = "1.91.0"
[workspace.dependencies]
lance = { "version" = "=9.0.0-beta.10", default-features = false, "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
lance-core = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
lance-datagen = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
lance-file = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
lance-io = { "version" = "=9.0.0-beta.10", default-features = false, "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
lance-index = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
lance-linalg = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
lance-namespace = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
lance-namespace-impls = { "version" = "=9.0.0-beta.10", default-features = false, "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
lance-table = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
lance-testing = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
lance-datafusion = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
lance-encoding = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
lance-arrow = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
lance = { "version" = "=8.0.0-beta.16", default-features = false, "tag" = "v8.0.0-beta.16", "git" = "https://github.com/lance-format/lance.git" }
lance-core = { "version" = "=8.0.0-beta.16", "tag" = "v8.0.0-beta.16", "git" = "https://github.com/lance-format/lance.git" }
lance-datagen = { "version" = "=8.0.0-beta.16", "tag" = "v8.0.0-beta.16", "git" = "https://github.com/lance-format/lance.git" }
lance-file = { "version" = "=8.0.0-beta.16", "tag" = "v8.0.0-beta.16", "git" = "https://github.com/lance-format/lance.git" }
lance-io = { "version" = "=8.0.0-beta.16", default-features = false, "tag" = "v8.0.0-beta.16", "git" = "https://github.com/lance-format/lance.git" }
lance-index = { "version" = "=8.0.0-beta.16", "tag" = "v8.0.0-beta.16", "git" = "https://github.com/lance-format/lance.git" }
lance-linalg = { "version" = "=8.0.0-beta.16", "tag" = "v8.0.0-beta.16", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace = { "version" = "=8.0.0-beta.16", "tag" = "v8.0.0-beta.16", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace-impls = { "version" = "=8.0.0-beta.16", default-features = false, "tag" = "v8.0.0-beta.16", "git" = "https://github.com/lance-format/lance.git" }
lance-table = { "version" = "=8.0.0-beta.16", "tag" = "v8.0.0-beta.16", "git" = "https://github.com/lance-format/lance.git" }
lance-testing = { "version" = "=8.0.0-beta.16", "tag" = "v8.0.0-beta.16", "git" = "https://github.com/lance-format/lance.git" }
lance-datafusion = { "version" = "=8.0.0-beta.16", "tag" = "v8.0.0-beta.16", "git" = "https://github.com/lance-format/lance.git" }
lance-encoding = { "version" = "=8.0.0-beta.16", "tag" = "v8.0.0-beta.16", "git" = "https://github.com/lance-format/lance.git" }
lance-arrow = { "version" = "=8.0.0-beta.16", "tag" = "v8.0.0-beta.16", "git" = "https://github.com/lance-format/lance.git" }
ahash = "0.8"
# Note that this one does not include pyarrow
arrow = { version = "58.0.0", optional = false }

View File

@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
<dependency>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-core</artifactId>
<version>0.31.0-beta.4</version>
<version>0.30.1-beta.2</version>
</dependency>
```

View File

@@ -1,29 +0,0 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / OAuthFlowType
# Enumeration: OAuthFlowType
OAuth authentication flow types.
## Enumeration Members
### AzureManagedIdentity
```ts
AzureManagedIdentity: "azure_managed_identity";
```
Azure Managed Identity via IMDS.
***
### ClientCredentials
```ts
ClientCredentials: "client_credentials";
```
Client Credentials grant (service-to-service / M2M).

View File

@@ -12,7 +12,6 @@
## Enumerations
- [FullTextQueryType](enumerations/FullTextQueryType.md)
- [OAuthFlowType](enumerations/OAuthFlowType.md)
- [Occur](enumerations/Occur.md)
- [Operator](enumerations/Operator.md)
@@ -86,8 +85,6 @@
- [ListNamespacesResponse](interfaces/ListNamespacesResponse.md)
- [LsmWriteSpec](interfaces/LsmWriteSpec.md)
- [MergeResult](interfaces/MergeResult.md)
- [NativeOAuthConfig](interfaces/NativeOAuthConfig.md)
- [OAuthConfig](interfaces/OAuthConfig.md)
- [OpenTableOptions](interfaces/OpenTableOptions.md)
- [OptimizeOptions](interfaces/OptimizeOptions.md)
- [OptimizeStats](interfaces/OptimizeStats.md)

View File

@@ -64,19 +64,6 @@ client used by manifest-enabled native connections.
***
### oauthConfig?
```ts
optional oauthConfig: NativeOAuthConfig;
```
(For LanceDB cloud only): OAuth configuration for IdP-based
authentication (e.g., Azure Entra ID). When set, token acquisition
and refresh are handled entirely in Rust. TypeScript users should pass
the public `OAuthConfig` type exported from `@lancedb/lancedb`.
***
### readConsistencyInterval?
```ts

View File

@@ -1,88 +0,0 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / NativeOAuthConfig
# Interface: NativeOAuthConfig
OAuth configuration for LanceDB authentication.
This is the generated napi-rs binding shape. TypeScript users should prefer
the public `OAuthConfig` type exported from `@lancedb/lancedb`.
All token acquisition and refresh is handled in the Rust layer.
## Properties
### clientId
```ts
clientId: string;
```
Application / Client ID.
***
### clientSecret?
```ts
optional clientSecret: string;
```
Client secret (required for client_credentials).
***
### flow?
```ts
optional flow: string;
```
Authentication flow: "client_credentials" or "azure_managed_identity"
***
### issuerUrl
```ts
issuerUrl: string;
```
OIDC issuer URL or OAuth authority URL.
For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
***
### managedIdentityClientId?
```ts
optional managedIdentityClientId: string;
```
Client ID for user-assigned managed identity (azure_managed_identity).
***
### refreshBufferSecs?
```ts
optional refreshBufferSecs: number;
```
Seconds before expiry to trigger proactive refresh (default: 300).
Keep this well below the token TTL; if it is greater than or equal to
the TTL, each request refreshes the token.
***
### scopes
```ts
scopes: string[];
```
OAuth scopes to request. For Azure managed identity, exactly one scope
or resource is required. For example: `["api://{app_id}/.default"]`

View File

@@ -1,111 +0,0 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / OAuthConfig
# Interface: OAuthConfig
OAuth configuration for LanceDB authentication.
This is the public TypeScript OAuth configuration type. The generated
`NativeOAuthConfig` type has the same runtime shape but is an implementation
detail of the napi-rs binding.
All token acquisition and refresh is handled in the Rust layer.
This config is passed through to Rust via napi-rs.
## Examples
```typescript
const config: OAuthConfig = {
issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
clientId: "app-id",
clientSecret: "secret",
scopes: ["api://lancedb-api/.default"],
};
```
```typescript
const config: OAuthConfig = {
issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
clientId: "app-id",
scopes: ["api://lancedb-api/.default"],
flow: OAuthFlowType.AzureManagedIdentity,
};
```
## Properties
### clientId
```ts
clientId: string;
```
Application / Client ID.
***
### clientSecret?
```ts
optional clientSecret: string;
```
Client secret (required for ClientCredentials).
***
### flow?
```ts
optional flow: OAuthFlowType;
```
Authentication flow (default: ClientCredentials).
***
### issuerUrl
```ts
issuerUrl: string;
```
OIDC issuer URL or OAuth authority URL.
For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
***
### managedIdentityClientId?
```ts
optional managedIdentityClientId: string;
```
Client ID for user-assigned managed identity (AzureManagedIdentity).
***
### refreshBufferSecs?
```ts
optional refreshBufferSecs: number;
```
Seconds before expiry to trigger proactive refresh (default: 300).
Keep this well below the token TTL; if it is greater than or equal to
the TTL, each request refreshes the token.
***
### scopes
```ts
scopes: string[];
```
OAuth scopes to request.
For Azure managed identity, exactly one scope or resource is required.
For example: `["api://{app_id}/.default"]`

View File

@@ -8,7 +8,7 @@
<parent>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.31.0-beta.4</version>
<version>0.30.1-beta.2</version>
<relativePath>../pom.xml</relativePath>
</parent>

View File

@@ -6,7 +6,7 @@
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.31.0-beta.4</version>
<version>0.30.1-beta.2</version>
<packaging>pom</packaging>
<name>${project.artifactId}</name>
<description>LanceDB Java SDK Parent POM</description>
@@ -28,7 +28,7 @@
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<arrow.version>15.0.0</arrow.version>
<lance-core.version>9.0.0-beta.10</lance-core.version>
<lance-core.version>8.0.0-beta.16</lance-core.version>
<spotless.skip>false</spotless.skip>
<spotless.version>2.30.0</spotless.version>
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>

View File

@@ -1,7 +1,7 @@
[package]
name = "lancedb-nodejs"
edition.workspace = true
version = "0.31.0-beta.4"
version = "0.30.1-beta.2"
publish = false
license.workspace = true
description.workspace = true

View File

@@ -52,7 +52,6 @@ export {
SplitHashOptions,
SplitSequentialOptions,
ShuffleOptions,
OAuthConfig as NativeOAuthConfig,
} from "./native.js";
export {
@@ -131,8 +130,6 @@ export {
TokenResponse,
} from "./header";
export { OAuthConfig, OAuthFlowType } from "./oauth";
export { MergeInsertBuilder, WriteExecutionOptions } from "./merge";
export * as embedding from "./embedding";

View File

@@ -1,76 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
/**
* OAuth authentication flow types.
*/
export enum OAuthFlowType {
/** Client Credentials grant (service-to-service / M2M). */
ClientCredentials = "client_credentials",
/** Azure Managed Identity via IMDS. */
AzureManagedIdentity = "azure_managed_identity",
}
/**
* OAuth configuration for LanceDB authentication.
*
* This is the public TypeScript OAuth configuration type. The generated
* `NativeOAuthConfig` type has the same runtime shape but is an implementation
* detail of the napi-rs binding.
*
* All token acquisition and refresh is handled in the Rust layer.
* This config is passed through to Rust via napi-rs.
*
* @example Client Credentials (service-to-service):
* ```typescript
* const config: OAuthConfig = {
* issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
* clientId: "app-id",
* clientSecret: "secret",
* scopes: ["api://lancedb-api/.default"],
* };
* ```
*
* @example Azure Managed Identity:
* ```typescript
* const config: OAuthConfig = {
* issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
* clientId: "app-id",
* scopes: ["api://lancedb-api/.default"],
* flow: OAuthFlowType.AzureManagedIdentity,
* };
* ```
*/
export interface OAuthConfig {
/**
* OIDC issuer URL or OAuth authority URL.
* For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
*/
issuerUrl: string;
/** Application / Client ID. */
clientId: string;
/**
* OAuth scopes to request.
* For Azure managed identity, exactly one scope or resource is required.
* For example: `["api://{app_id}/.default"]`
*/
scopes: string[];
/** Authentication flow (default: ClientCredentials). */
flow?: OAuthFlowType;
/** Client secret (required for ClientCredentials). */
clientSecret?: string;
/** Client ID for user-assigned managed identity (AzureManagedIdentity). */
managedIdentityClientId?: string;
/**
* Seconds before expiry to trigger proactive refresh (default: 300).
* Keep this well below the token TTL; if it is greater than or equal to
* the TTL, each request refreshes the token.
*/
refreshBufferSecs?: number;
}

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-arm64",
"version": "0.31.0-beta.4",
"version": "0.30.1-beta.2",
"os": ["darwin"],
"cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.31.0-beta.4",
"version": "0.30.1-beta.2",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-musl",
"version": "0.31.0-beta.4",
"version": "0.30.1-beta.2",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.31.0-beta.4",
"version": "0.30.1-beta.2",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-musl",
"version": "0.31.0-beta.4",
"version": "0.30.1-beta.2",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-arm64-msvc",
"version": "0.31.0-beta.4",
"version": "0.30.1-beta.2",
"os": [
"win32"
],

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.31.0-beta.4",
"version": "0.30.1-beta.2",
"os": ["win32"],
"cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node",

View File

@@ -1,12 +1,12 @@
{
"name": "@lancedb/lancedb",
"version": "0.31.0-beta.4",
"version": "0.30.1-beta.2",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "@lancedb/lancedb",
"version": "0.31.0-beta.4",
"version": "0.30.1-beta.2",
"cpu": [
"x64",
"arm64"

View File

@@ -11,7 +11,7 @@
"ann"
],
"private": false,
"version": "0.31.0-beta.4",
"version": "0.30.1-beta.2",
"main": "dist/index.js",
"exports": {
".": "./dist/index.js",

View File

@@ -112,12 +112,6 @@ impl Connection {
builder = builder.client_config(rust_config);
if let Some(oauth_config) = options.oauth_config {
let config: lancedb::remote::oauth::OAuthConfig =
oauth_config.try_into().default_error()?;
builder = builder.oauth_config(config);
}
if let Some(api_key) = options.api_key {
builder = builder.api_key(&api_key);
}

View File

@@ -65,11 +65,6 @@ pub struct ConnectionOptions {
/// (For LanceDB cloud only): the host to use for LanceDB cloud. Used
/// for testing purposes.
pub host_override: Option<String>,
/// (For LanceDB cloud only): OAuth configuration for IdP-based
/// authentication (e.g., Azure Entra ID). When set, token acquisition
/// and refresh are handled entirely in Rust. TypeScript users should pass
/// the public `OAuthConfig` type exported from `@lancedb/lancedb`.
pub oauth_config: Option<remote::OAuthConfig>,
}
#[napi(object)]

View File

@@ -3,7 +3,6 @@
use std::collections::HashMap;
use lancedb::error::Error;
use napi_derive::*;
/// Timeout configuration for remote HTTP client.
@@ -141,84 +140,6 @@ impl From<TlsConfig> for lancedb::remote::TlsConfig {
}
}
/// OAuth configuration for LanceDB authentication.
///
/// This is the generated napi-rs binding shape. TypeScript users should prefer
/// the public `OAuthConfig` type exported from `@lancedb/lancedb`.
///
/// All token acquisition and refresh is handled in the Rust layer.
#[napi(object)]
#[derive(Clone)]
pub struct OAuthConfig {
/// OIDC issuer URL or OAuth authority URL.
/// For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
pub issuer_url: String,
/// Application / Client ID.
pub client_id: String,
/// OAuth scopes to request. For Azure managed identity, exactly one scope
/// or resource is required. For example: `["api://{app_id}/.default"]`
pub scopes: Vec<String>,
/// Authentication flow: "client_credentials" or "azure_managed_identity"
pub flow: Option<String>,
/// Client secret (required for client_credentials).
pub client_secret: Option<String>,
/// Client ID for user-assigned managed identity (azure_managed_identity).
pub managed_identity_client_id: Option<String>,
/// Seconds before expiry to trigger proactive refresh (default: 300).
/// Keep this well below the token TTL; if it is greater than or equal to
/// the TTL, each request refreshes the token.
pub refresh_buffer_secs: Option<u32>,
}
impl std::fmt::Debug for OAuthConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OAuthConfig")
.field("issuer_url", &self.issuer_url)
.field("client_id", &self.client_id)
.field("scopes", &self.scopes)
.field("flow", &self.flow)
.field(
"client_secret",
&self.client_secret.as_deref().map(|_| "<redacted>"),
)
.field(
"managed_identity_client_id",
&self.managed_identity_client_id,
)
.field("refresh_buffer_secs", &self.refresh_buffer_secs)
.finish()
}
}
impl TryFrom<OAuthConfig> for lancedb::remote::oauth::OAuthConfig {
type Error = Error;
fn try_from(config: OAuthConfig) -> Result<Self, Self::Error> {
use lancedb::remote::oauth::OAuthFlow;
let flow = match config.flow.as_deref().unwrap_or("client_credentials") {
"client_credentials" => OAuthFlow::ClientCredentials,
"azure_managed_identity" => OAuthFlow::AzureManagedIdentity {
client_id: config.managed_identity_client_id,
},
other => {
return Err(Error::InvalidInput {
message: format!("Unknown OAuth flow type: {other}"),
});
}
};
Ok(Self {
issuer_url: config.issuer_url,
client_id: config.client_id,
client_secret: config.client_secret,
scopes: config.scopes,
flow,
refresh_buffer_secs: config.refresh_buffer_secs.map(|v| v as u64),
})
}
}
impl From<ClientConfig> for lancedb::remote::ClientConfig {
fn from(config: ClientConfig) -> Self {
Self {
@@ -235,45 +156,3 @@ impl From<ClientConfig> for lancedb::remote::ClientConfig {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_unknown_oauth_flow_returns_invalid_input() {
let config = OAuthConfig {
issuer_url: "https://issuer.example.com".to_string(),
client_id: "client-id".to_string(),
scopes: vec!["scope".to_string()],
flow: Some("typo".to_string()),
client_secret: None,
managed_identity_client_id: None,
refresh_buffer_secs: None,
};
let err = lancedb::remote::oauth::OAuthConfig::try_from(config).unwrap_err();
assert!(matches!(
err,
Error::InvalidInput { message }
if message == "Unknown OAuth flow type: typo"
));
}
#[test]
fn test_oauth_config_debug_redacts_client_secret() {
let config = OAuthConfig {
issuer_url: "https://issuer.example.com".to_string(),
client_id: "client-id".to_string(),
scopes: vec!["scope".to_string()],
flow: Some("client_credentials".to_string()),
client_secret: Some("super-secret".to_string()),
managed_identity_client_id: None,
refresh_buffer_secs: None,
};
let debug = format!("{config:?}");
assert!(!debug.contains("super-secret"));
assert!(debug.contains("client_secret: Some(\"<redacted>\")"));
}
}

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.34.0-beta.4"
current_version = "0.33.1-beta.2"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.
@@ -23,8 +23,6 @@ allow_dirty = true
commit = true
message = "Bump version: {current_version} → {new_version}"
commit_args = ""
# bump-my-version >=1.4.0 rejects pre_commit_hooks containing shell syntax unless opted in.
allow_shell_hooks = true
# Update Cargo.lock after version bump
pre_commit_hooks = [

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-python"
version = "0.34.0-beta.4"
version = "0.33.1-beta.2"
publish = false
edition.workspace = true
description = "Python bindings for LanceDB"

View File

@@ -17,17 +17,6 @@ from .db import AsyncConnection, DBConnection, LanceDBConnection
from .remote import ClientConfig
from .remote.db import RemoteDBConnection
from .expr import Expr, col, lit, func
from .udf import (
udf,
table_udf,
Udf,
JobHandle,
JobFailedError,
MaterializedView,
AsyncJobHandle,
AsyncMaterializedView,
)
from .lineage import Lineage, Node, Edge, FunctionRef
from .schema import vector
from .table import AsyncTable, Table
from ._lancedb import Session
@@ -100,8 +89,6 @@ def connect(
If presented, connect to LanceDB cloud.
Otherwise, connect to a database on file system or cloud storage.
Can be set via environment variable `LANCEDB_API_KEY`.
OAuth configuration is currently supported only by ``connect_async``;
synchronous LanceDB Cloud connections require an API key.
region: str, default "us-east-1"
The region to use for LanceDB Cloud.
host_override: str, optional
@@ -353,7 +340,6 @@ async def connect_async(
session: Optional[Session] = None,
manifest_enabled: bool = False,
namespace_client_properties: Optional[Dict[str, str]] = None,
oauth_config=None,
) -> AsyncConnection:
"""Connect to a LanceDB database.
@@ -403,10 +389,6 @@ async def connect_async(
namespace_client_properties : dict, optional
Additional directory namespace client properties to use with
``manifest_enabled=True``.
oauth_config : OAuthConfig, optional
OAuth configuration for LanceDB Cloud/Enterprise. This is supported by
``connect_async`` only; synchronous ``connect`` uses API key
authentication for ``db://`` URIs.
Examples
--------
@@ -453,24 +435,11 @@ async def connect_async(
session,
manifest_enabled,
namespace_client_properties,
oauth_config,
)
)
__all__ = [
"udf",
"table_udf",
"Udf",
"JobHandle",
"JobFailedError",
"MaterializedView",
"AsyncJobHandle",
"AsyncMaterializedView",
"Lineage",
"Node",
"Edge",
"FunctionRef",
"connect",
"connect_async",
"connect_namespace",

View File

@@ -280,7 +280,6 @@ async def connect(
session: Optional[Session],
manifest_enabled: bool = False,
namespace_client_properties: Optional[Dict[str, str]] = None,
oauth_config: Optional[Any] = None,
) -> Connection: ...
class RecordBatchStream:

View File

@@ -65,7 +65,6 @@ if TYPE_CHECKING:
from .common import DATA, URI
from .embeddings import EmbeddingFunctionConfig
from ._lancedb import Session
from .udf import MaterializedView, AsyncMaterializedView
from .namespace_utils import (
_normalize_create_namespace_mode,
@@ -563,259 +562,6 @@ class DBConnection(EnforceOverrides):
"""
raise NotImplementedError("serialize is not supported for this connection type")
# -- Derived compute: functions, materialized views, jobs -------------
# Server-backed features (LanceDB Enterprise / Cloud); local
# connections raise NotImplementedError for now.
def create_function(
self,
name,
language: str = "python",
return_type: Optional[str] = None,
body: Optional[str] = None,
options: Optional[Dict[str, str]] = None,
*,
replace: bool = False,
):
"""Register a UDF (CREATE FUNCTION).
Pass a ``@udf`` / ``@table_udf``-decorated function (preferred):
db.create_function(embed)
or the explicit fields:
Parameters
----------
name: str or Udf
A decorated UDF object, or the function name.
language: str
Implementation language (currently "python").
return_type: str
SQL return type, e.g. "FLOAT", "FLOAT[1536]",
"STRUCT(a FLOAT, b VARCHAR)", "TABLE(chunk VARCHAR, idx INT)".
body: str
Function body: source text, or base64 cloudpickle bytes when
options["body_format"] == "cloudpickle".
options: dict, optional
input_columns, pip, num_gpus, batch_size, timeout,
error_policy, docker_image, body_format, ...
replace: bool
Drop an existing function of the same name first.
"""
from .udf import Udf
if isinstance(name, Udf):
req = name.create_request()
name, language, return_type, body, options = (
req["name"],
req["language"],
req["return_type"],
req["body"],
req["options"],
)
if replace:
try:
self.drop_function(name)
except Exception:
pass
LOOP.run(self._conn.create_function(name, language, return_type, body, options))
def list_functions(self):
"""List registered functions (SHOW FUNCTIONS)."""
return LOOP.run(self._conn.list_functions())
def drop_function(self, name: str):
"""Drop a registered function (DROP FUNCTION)."""
LOOP.run(self._conn.drop_function(name))
def create_materialized_view(
self,
name: str,
source=None,
select=None,
*,
query: Optional[str] = None,
where: Optional[str] = None,
auto_refresh: bool = False,
with_no_data: bool = False,
replace: bool = False,
partition_by: Optional[str] = None,
) -> "MaterializedView":
"""Create a materialized view (CREATE MATERIALIZED VIEW); returns a
`MaterializedView` handle (``.wait()`` blocks until it is populated).
Two ways to specify the view body:
- ergonomic: pass ``source`` (a table name or table) and ``select``
items -- column names, expression strings ("embed(body)"),
(alias, expression) tuples, or ``@udf`` / ``@table_udf`` objects.
The SELECT is assembled and parsed server-side (one parser, shared
with SQL).
- raw: pass ``query=`` with a full SELECT, e.g.
"SELECT id, embed(body) AS vec FROM articles WHERE id > 1".
`partition_by` partitions the view's (single) table function on a source
column. If that column has an IVF vector index the server partitions by
its index clusters (image-dedup style); otherwise it groups by distinct
value. (Geneva's `partition_by` and `partition_by_indexed_column` unify
here -- the engine picks the strategy from the column.)
"""
from .udf import build_view_query, MaterializedView
if query is None:
if source is None or select is None:
raise ValueError(
"create_materialized_view needs either query= or both "
"source and select"
)
query = build_view_query(source, select)
if where:
query += f" WHERE {where}"
if replace:
self._drop_view_if_exists(name)
job_id = LOOP.run(
self._conn.create_materialized_view(
name,
query=query,
auto_refresh=auto_refresh,
with_no_data=with_no_data,
partition_by=partition_by,
)
)
return MaterializedView(self, name, job_id=job_id)
def _drop_view_if_exists(self, name: str) -> None:
# `replace=True` is "drop if present"; only a not-found error is
# benign here. Anything else (perms, server fault) must surface rather
# than be masked by a later create failure.
try:
self.drop_materialized_view(name)
except Exception as e:
msg = str(e).lower()
if "not found" not in msg and "does not exist" not in msg:
raise
def job(self, job_id: str):
"""A `JobHandle` for reconnecting to an inflight job by id -- e.g. an
id you stored, or one returned from the SQL / REST surface. Submit
methods (`refresh_column`, `MaterializedView.refresh`) already return a
handle directly, so you do not need this to wait on a fresh submission."""
from .udf import JobHandle
return JobHandle(self, job_id)
def lineage(
self,
table: str,
column: Optional[str] = None,
*,
direction: Optional[str] = None,
depth: Optional[int] = None,
):
"""Derived-compute lineage of a table/view, or one of its columns:
upstream sources, downstream dependents, and the function version +
location that produced each derived column (with a drift flag). Returns
a `Lineage`. `direction` is "upstream" | "downstream" | "both" (server
default both); `depth` limits column-hops (transitive when omitted)."""
# `self._conn` is the AsyncConnection; drive its async `lineage`
# (which parses the JSON) on the loop, mirroring create_materialized_view.
return LOOP.run(
self._conn.lineage(table, column, direction=direction, depth=depth)
)
def _refresh_materialized_view(
self,
name: str,
*,
full: bool = False,
src_version: Optional[int] = None,
num_workers: Optional[int] = None,
max_workers: Optional[int] = None,
) -> str:
"""Internal: submit a materialized-view refresh, return the job id.
The public surface is ``MaterializedView.refresh()`` (which returns a
`JobHandle`); this stays private so refresh is only reached through the
handle.
``full=True`` forces a full rebuild (recompute and replace every row)
instead of the default incremental refresh.
"""
return LOOP.run(
self._conn._refresh_materialized_view(
name,
full=full,
src_version=src_version,
num_workers=num_workers,
max_workers=max_workers,
)
)
def explain_refresh_materialized_view(
self,
name: str,
*,
full: bool = False,
src_version: Optional[int] = None,
):
"""Plan a refresh without running it (EXPLAIN REFRESH). Returns a
plan with .has_work / .source_version / .last_refreshed_version /
.full_refresh / .rebuild / .units_total. `full=True` plans a full
rebuild (incremental planning needs stable row IDs on the source)."""
return LOOP.run(
self._conn.explain_refresh_materialized_view(
name, full=full, src_version=src_version
)
)
def alter_materialized_view(self, name: str, *, auto_refresh: bool):
"""Update a materialized view's options (ALTER MATERIALIZED VIEW)."""
LOOP.run(self._conn.alter_materialized_view(name, auto_refresh=auto_refresh))
def drop_materialized_view(self, name: str):
"""Drop a materialized view definition (DROP MATERIALIZED VIEW)."""
LOOP.run(self._conn.drop_materialized_view(name))
def list_materialized_views(self):
"""List registered materialized view definitions."""
return LOOP.run(self._conn.list_materialized_views())
def list_jobs(self):
"""List inflight server-side jobs across the database's tables."""
return LOOP.run(self._conn.list_jobs())
def get_job(self, job_id: str, table: "str | None" = None):
"""Look up one server-side job by id (the wait()/status poll path).
Passing ``table`` (the job's table) lets the server answer with an O(1)
single-node read instead of scanning the database's active jobs.
Returns the job's status, or None if it's unknown or no longer active.
"""
return LOOP.run(self._conn.get_job(job_id, table))
def cancel_job(self, job_id: str) -> bool:
"""Cancel an inflight server-side job by id (CANCEL JOB).
Returns True if a matching inflight job was found and flagged for
cancellation, False if none was inflight (already finished or
unknown id) -- cancellation is best-effort.
"""
return LOOP.run(self._conn.cancel_job(job_id))
def job_history(self, job_id: "str | None" = None):
"""Durable history of completed server-side jobs (SHOW JOB HISTORY).
Pass ``job_id`` to narrow to a single job. Unlike :meth:`list_jobs`
(live, inflight) these are the terminal records.
"""
return LOOP.run(self._conn.job_history(job_id))
def errors(self, job_id: "str | None" = None, table: "str | None" = None):
"""Per-row UDF errors recorded by ``error_policy=skip`` (SHOW ERRORS),
optionally filtered by ``job_id`` and/or ``table``.
"""
return LOOP.run(self._conn.errors(job_id, table))
class LanceDBConnection(DBConnection):
"""
@@ -2041,200 +1787,6 @@ class AsyncConnection(object):
)
return AsyncTable(table)
# -- Derived compute: functions, materialized views, jobs -------------
# Server-backed features (LanceDB Enterprise / Cloud); local
# connections raise NotImplementedError for now.
async def create_function(
self,
name,
language: str = "python",
return_type: Optional[str] = None,
body: Optional[str] = None,
options: Optional[Dict[str, str]] = None,
*,
replace: bool = False,
):
"""Register a UDF (CREATE FUNCTION). Accepts a ``@udf``/``@table_udf``
object (preferred) or the explicit (name, language, return_type, body,
options)."""
from .udf import Udf
if isinstance(name, Udf):
req = name.create_request()
name, language, return_type, body, options = (
req["name"],
req["language"],
req["return_type"],
req["body"],
req["options"],
)
if replace:
try:
await self.drop_function(name)
except Exception:
pass
await self._inner.create_function(name, language, return_type, body, options)
async def list_functions(self):
"""List registered functions (SHOW FUNCTIONS)."""
return await self._inner.list_functions()
async def drop_function(self, name: str):
"""Drop a registered function (DROP FUNCTION)."""
await self._inner.drop_function(name)
async def create_materialized_view(
self,
name: str,
source=None,
select=None,
*,
query: Optional[str] = None,
where: Optional[str] = None,
auto_refresh: bool = False,
with_no_data: bool = False,
replace: bool = False,
partition_by: Optional[str] = None,
) -> "AsyncMaterializedView":
"""Create a materialized view; returns an `AsyncMaterializedView`
handle (``.wait()`` blocks until populated). Pass either ``query=`` (a
full SELECT) or ``source`` + ``select`` items; `partition_by`
partitions the view's table function on a source column (index-cluster
if the column is IVF-indexed, else distinct-value). See the sync
method for the select grammar."""
from .udf import build_view_query, AsyncMaterializedView
if query is None:
if source is None or select is None:
raise ValueError(
"create_materialized_view needs either query= or both "
"source and select"
)
query = build_view_query(source, select)
if where:
query += f" WHERE {where}"
if replace:
try:
await self.drop_materialized_view(name)
except Exception as e:
msg = str(e).lower()
if "not found" not in msg and "does not exist" not in msg:
raise
job_id = await self._inner.create_materialized_view(
name,
query,
auto_refresh=auto_refresh,
with_no_data=with_no_data,
partition_by=partition_by,
)
return AsyncMaterializedView(self, name, job_id=job_id)
def job(self, job_id: str):
"""An `AsyncJobHandle` for reconnecting to an inflight job by id (a
stored id, or one from the SQL / REST surface). Submit methods already
return a handle, so this is only needed to re-attach to an existing
job."""
from .udf import AsyncJobHandle
return AsyncJobHandle(self, job_id)
async def lineage(
self,
table: str,
column: Optional[str] = None,
*,
direction: Optional[str] = None,
depth: Optional[int] = None,
):
"""Derived-compute lineage of a table/view (or column). See the sync
`Connection.lineage`. Returns a `Lineage`."""
from .lineage import Lineage
raw = await self._inner.table_lineage(table, column, direction, depth)
return Lineage.from_json(raw)
async def _refresh_materialized_view(
self,
name: str,
*,
full: bool = False,
src_version: Optional[int] = None,
num_workers: Optional[int] = None,
max_workers: Optional[int] = None,
) -> str:
"""Internal: submit a refresh, return the job id. The public surface is
``AsyncMaterializedView.refresh()`` (returns an `AsyncJobHandle`).
``full=True`` forces a full rebuild (recompute and replace every row)
instead of the default incremental refresh.
"""
return await self._inner.refresh_materialized_view(
name,
full=full,
src_version=src_version,
num_workers=num_workers,
max_workers=max_workers,
)
async def explain_refresh_materialized_view(
self,
name: str,
*,
full: bool = False,
src_version: Optional[int] = None,
):
"""Plan a refresh without running it (EXPLAIN REFRESH)."""
return await self._inner.explain_refresh_materialized_view(
name, full=full, src_version=src_version
)
async def alter_materialized_view(self, name: str, *, auto_refresh: bool):
"""Update a materialized view's options."""
await self._inner.alter_materialized_view(name, auto_refresh)
async def drop_materialized_view(self, name: str):
"""Drop a materialized view definition."""
await self._inner.drop_materialized_view(name)
async def list_materialized_views(self):
"""List registered materialized view definitions."""
return await self._inner.list_materialized_views()
async def list_jobs(self):
"""List inflight server-side jobs across the database's tables."""
return await self._inner.list_jobs()
async def get_job(self, job_id: str, table: "str | None" = None):
"""Look up one server-side job by id (the wait()/status poll path).
``table`` (the job's table) enables an O(1) server-side lookup.
Returns the job's status, or None if unknown / no longer active."""
return await self._inner.get_job(job_id, table)
async def cancel_job(self, job_id: str) -> bool:
"""Cancel an inflight server-side job by id (CANCEL JOB).
Returns True if a matching inflight job was found and flagged for
cancellation, False otherwise (best-effort).
"""
return await self._inner.cancel_job(job_id)
async def job_history(self, job_id: "str | None" = None):
"""Durable history of completed server-side jobs (SHOW JOB HISTORY).
Reads each table's durable job-history store. Pass ``job_id`` to narrow
to a single job. Unlike :meth:`list_jobs` (live, inflight) these are the
terminal records, with created/updated/completed timestamps.
"""
return await self._inner.job_history(job_id)
async def errors(self, job_id: "str | None" = None, table: "str | None" = None):
"""Per-row UDF errors recorded by ``error_policy=skip`` (SHOW ERRORS).
Optionally filtered by ``job_id`` and/or ``table``.
"""
return await self._inner.errors(job_id, table)
async def rename_table(
self,
cur_name: str,

View File

@@ -81,7 +81,6 @@ class ColPaliEmbeddings(EmbeddingFunction):
warnings.warn(
"use_token_pooling is deprecated, use pooling_strategy=None instead",
DeprecationWarning,
stacklevel=2,
)
self.pooling_strategy = None

View File

@@ -1,177 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
"""Client-side model of derived-compute lineage.
`Connection.lineage()` / `Table.lineage()` / `MaterializedView.lineage()` return
a `Lineage`: the graph of what a column or materialized view derives from
(upstream), what derives from it (downstream), and -- for each derived column --
the function that produced it, the version it was produced with, and whether
that is stale relative to the function the registry now holds.
The server returns this as JSON (the wire contract); these classes deserialize
it. Nothing here talks to the server.
"""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from typing import List, Optional, Union
@dataclass
class FunctionRef:
"""The function that produced a derived column, with version + location."""
name: str
#: Version that produced the data (stamped at compute time), if known.
as_computed_version: Optional[str] = None
#: Version the registry currently holds for this function name.
current_version: Optional[str] = None
#: True when the column was produced by an older function than the registry
#: now holds -- i.e. silently stale; re-refresh to catch up.
stale_vs_current: bool = False
language: Optional[str] = None
docker_image: Optional[str] = None
env_digest: Optional[str] = None
code_uri: Optional[str] = None
@classmethod
def _from(cls, d: dict) -> "FunctionRef":
return cls(
name=d["name"],
as_computed_version=d.get("as_computed_version"),
current_version=d.get("current_version"),
stale_vs_current=d.get("stale_vs_current", False),
language=d.get("language"),
docker_image=d.get("docker_image"),
env_digest=d.get("env_digest"),
code_uri=d.get("code_uri"),
)
@dataclass
class Node:
"""A lineage node: a table, view, column, or function."""
kind: str # "table" | "view" | "column" | "function"
id: str # "table", "table.column", or "fn:name@version"
table: Optional[str] = None
function: Optional[FunctionRef] = None
@classmethod
def _from(cls, d: dict) -> "Node":
fn = d.get("function")
return cls(
kind=d["kind"],
id=d["id"],
table=d.get("table"),
function=FunctionRef._from(fn) if fn else None,
)
@dataclass
class Edge:
"""`downstream` depends on `upstream`, produced by `via` (a function name,
or None for a passthrough)."""
downstream: str
upstream: str
via: Optional[str] = None
@classmethod
def _from(cls, d: dict) -> "Edge":
return cls(downstream=d["downstream"], upstream=d["upstream"], via=d.get("via"))
@dataclass
class Lineage:
"""A derived-compute lineage graph (nodes + labeled edges)."""
target: str
nodes: List[Node] = field(default_factory=list)
edges: List[Edge] = field(default_factory=list)
@classmethod
def from_json(cls, raw: Union[str, bytes, dict]) -> "Lineage":
d = json.loads(raw) if isinstance(raw, (str, bytes)) else raw
return cls(
target=d.get("target", ""),
nodes=[Node._from(n) for n in d.get("nodes", [])],
edges=[Edge._from(e) for e in d.get("edges", [])],
)
def functions(self) -> List[FunctionRef]:
"""The function nodes in the graph."""
return [n.function for n in self.nodes if n.function is not None]
def stale(self) -> List[FunctionRef]:
"""Functions whose as-computed version is behind the current registry
version -- the columns they produced are silently out of date."""
return [f for f in self.functions() if f.stale_vs_current]
def to_dict(self) -> dict:
def prune(d: dict) -> dict:
return {k: v for k, v in d.items() if v is not None}
return {
"target": self.target,
"nodes": [
prune(
{
"kind": n.kind,
"id": n.id,
"table": n.table,
"function": prune(vars(n.function)) if n.function else None,
}
)
for n in self.nodes
],
"edges": [prune(vars(e)) for e in self.edges],
}
def to_graphviz(self) -> str:
"""Graphviz DOT for the lineage DAG: columns/tables as nodes, function
names on edges, drift edges dashed + red."""
stale_names = {f.name for f in self.stale()}
out = [
"digraph lineage {",
" rankdir=LR;",
' node [fontname="monospace"];',
]
for n in self.nodes:
if n.kind == "function":
continue
shape = "ellipse" if n.kind in ("table", "view") else "box"
out.append(f' "{n.id}" [shape={shape}];')
for e in self.edges:
attrs = ""
if e.via:
if e.via in stale_names:
attrs = f' [label="{e.via}" color=red style=dashed]'
else:
attrs = f' [label="{e.via}"]'
out.append(f' "{e.upstream}" -> "{e.downstream}"{attrs};')
out.append("}")
return "\n".join(out)
def _repr_html_(self) -> str:
warn = ""
drift = self.stale()
if drift:
names = ", ".join(sorted({f.name for f in drift}))
warn = (
f'<p style="color:#b00000"><b>stale vs current:</b> {names} '
"(re-refresh to catch up)</p>"
)
rows = "".join(
f"<tr><td><code>{e.downstream}</code></td>"
f"<td>&larr; {e.via or ''}</td>"
f"<td><code>{e.upstream}</code></td></tr>"
for e in self.edges
)
return (
f"<b>lineage: <code>{self.target}</code></b>{warn}"
"<table><tr><th>derived</th><th>via</th><th>from</th></tr>"
f"{rows}</table>"
)

View File

@@ -71,9 +71,6 @@ from lancedb.embeddings import EmbeddingFunctionConfig
from ._lancedb import Session
_MAX_QUERY_K = 2**31 - 1
def _query_to_namespace_request(
table_id: List[str],
query: "Query",
@@ -151,8 +148,7 @@ def _query_to_namespace_request(
if query.limit is not None:
k = query.limit
elif query.vector is None and query.full_text_query is None:
# limit k to max i32 value to avoid client overflows
k = _MAX_QUERY_K
k = sys.maxsize
else:
k = 10
@@ -373,19 +369,6 @@ def _convert_pyarrow_schema_to_json(schema: pa.Schema) -> JsonArrowSchema:
return JsonArrowSchema(fields=fields, metadata=meta)
def _builds_namespace_natively(
namespace_client_impl: Optional[str],
namespace_client_properties: Optional[Dict[str, str]],
) -> bool:
"""Whether ``connect_namespace_client`` builds the namespace client natively
in Rust (installing the read-freshness context provider) rather than wrapping
the pre-built Python client.
Must mirror Rust ``build_namespace_natively`` in ``python/src/connection.rs``.
"""
return namespace_client_impl == "rest" and bool(namespace_client_properties)
class LanceNamespaceDBConnection(DBConnection):
"""
A LanceDB connection that uses a namespace for table management.
@@ -445,13 +428,6 @@ class LanceNamespaceDBConnection(DBConnection):
)
self._namespace_client_impl = namespace_client_impl
self._namespace_client_properties = namespace_client_properties
# When the namespace client is built natively (see Rust
# ``build_namespace_natively``), the underlying Rust table performs
# QueryTable pushdown through the read-freshness context provider, which
# the pure-Python ``query_table`` path bypasses.
self._route_pushdown_to_rust = _builds_namespace_natively(
namespace_client_impl, namespace_client_properties
)
self._inner = AsyncConnection(
_connect_namespace_client(
namespace_client,
@@ -563,7 +539,6 @@ class LanceNamespaceDBConnection(DBConnection):
namespace_path=namespace_path,
namespace_client=self._namespace_client,
pushdown_operations=self._namespace_client_pushdown_operations,
route_pushdown_to_rust=self._route_pushdown_to_rust,
_async=async_table,
)
@@ -601,7 +576,6 @@ class LanceNamespaceDBConnection(DBConnection):
namespace_path=namespace_path,
namespace_client=self._namespace_client,
pushdown_operations=self._namespace_client_pushdown_operations,
route_pushdown_to_rust=self._route_pushdown_to_rust,
_async=async_table,
)
if branch is not None:
@@ -897,8 +871,6 @@ class AsyncLanceNamespaceDBConnection:
storage_options: Optional[Dict[str, str]] = None,
session: Optional[Session] = None,
namespace_client_pushdown_operations: Optional[List[str]] = None,
namespace_client_impl: Optional[str] = None,
namespace_client_properties: Optional[Dict[str, str]] = None,
):
"""
Initialize an async namespace-based LanceDB connection.
@@ -924,12 +896,6 @@ class AsyncLanceNamespaceDBConnection:
namespace.create_table() instead of using declare_table + local write.
Default is None (no pushdown, all operations run locally).
namespace_client_impl : Optional[str]
The namespace implementation name used to create this connection.
Required (with ``namespace_client_properties``) for the Rust client to
be built natively and install the read-freshness provider.
namespace_client_properties : Optional[Dict[str, str]]
The namespace properties used to create this connection.
"""
self._namespace_client = namespace_client
self.read_consistency_interval = read_consistency_interval
@@ -938,14 +904,6 @@ class AsyncLanceNamespaceDBConnection:
self._namespace_client_pushdown_operations = set(
namespace_client_pushdown_operations or []
)
self._namespace_client_impl = namespace_client_impl
self._namespace_client_properties = namespace_client_properties
# See LanceNamespaceDBConnection: when built natively the Rust table runs
# QueryTable pushdown through the read-freshness provider, so defer to it
# rather than the urllib3 client (which omits x-lancedb-min-timestamp).
self._route_pushdown_to_rust = _builds_namespace_natively(
namespace_client_impl, namespace_client_properties
)
self._inner = AsyncConnection(
_connect_namespace_client(
namespace_client,
@@ -959,8 +917,8 @@ class AsyncLanceNamespaceDBConnection:
namespace_client_pushdown_operations=(
list(self._namespace_client_pushdown_operations)
),
namespace_client_impl=namespace_client_impl,
namespace_client_properties=namespace_client_properties,
namespace_client_impl=None,
namespace_client_properties=None,
)
)
@@ -1030,7 +988,6 @@ class AsyncLanceNamespaceDBConnection:
namespace_path=namespace_path,
namespace_client=self._namespace_client,
pushdown_operations=self._namespace_client_pushdown_operations,
route_pushdown_to_rust=self._route_pushdown_to_rust,
)
async def open_table(
@@ -1068,7 +1025,6 @@ class AsyncLanceNamespaceDBConnection:
namespace_path=namespace_path,
namespace_client=self._namespace_client,
pushdown_operations=self._namespace_client_pushdown_operations,
route_pushdown_to_rust=self._route_pushdown_to_rust,
)
async def drop_table(self, name: str, namespace_path: Optional[List[str]] = None):
@@ -1427,6 +1383,4 @@ def connect_namespace_async(
storage_options=storage_options,
session=session,
namespace_client_pushdown_operations=namespace_client_pushdown_operations,
namespace_client_impl=namespace_client_impl,
namespace_client_properties=namespace_client_properties,
)

View File

@@ -48,14 +48,6 @@ class PermutationBuilder:
By default, the permutation builder will create a single split that contains all
rows in the same order as the base table.
"""
if not hasattr(table, "_inner"):
raise TypeError(
f"PermutationBuilder requires a local LanceTable, "
f"got {type(table).__name__}. "
"The permutation API is not supported on remote tables. "
"Remote tables connect to LanceDB Cloud or Enterprise and do not have "
"direct access to the underlying Lance dataset needed for permutations."
)
self._async = async_permutation_builder(table)
def split_random(

View File

@@ -275,18 +275,7 @@ def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
tz = get_extras(field, "tz")
return pa.timestamp("us", tz=tz)
elif getattr(py_type, "__origin__", None) in (list, tuple):
# A bare, unparameterised ``typing.List`` / ``typing.Tuple`` matches this
# branch (its ``__origin__`` is ``list`` / ``tuple``) but has no
# ``__args__``, so we cannot infer the element type. Raise a clear
# ``TypeError`` instead of crashing with an opaque ``AttributeError``.
args = getattr(py_type, "__args__", None)
if not args:
raise TypeError(
"Converting Pydantic type to Arrow Type: unsupported type "
f"{py_type}. Specify the element type, e.g. List[int] instead "
"of a bare List."
)
child = args[0]
child = py_type.__args__[0]
return _pydantic_list_child_to_arrow(child, field)
raise TypeError(
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}."

View File

@@ -9,7 +9,6 @@ from typing import List, Optional
from lancedb import __version__
from .header import HeaderProvider
from .oauth import OAuthConfig, OAuthFlowType
__all__ = [
"TimeoutConfig",
@@ -17,8 +16,6 @@ __all__ = [
"TlsConfig",
"ClientConfig",
"HeaderProvider",
"OAuthConfig",
"OAuthFlowType",
]

View File

@@ -124,7 +124,6 @@ class RemoteDBConnection(DBConnection):
"request_thread_pool is no longer used and will be removed in "
"a future release.",
DeprecationWarning,
stacklevel=2,
)
if connection_timeout is not None:
@@ -133,7 +132,6 @@ class RemoteDBConnection(DBConnection):
"release. Please use client_config.timeout_config.connect_timeout "
"instead.",
DeprecationWarning,
stacklevel=2,
)
client_config.timeout_config.connect_timeout = timedelta(
seconds=connection_timeout
@@ -144,7 +142,6 @@ class RemoteDBConnection(DBConnection):
"read_timeout is deprecated and will be removed in a future release. "
"Please use client_config.timeout_config.read_timeout instead.",
DeprecationWarning,
stacklevel=2,
)
client_config.timeout_config.read_timeout = timedelta(seconds=read_timeout)

View File

@@ -1,75 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional
class OAuthFlowType(str, Enum):
"""OAuth authentication flow types."""
CLIENT_CREDENTIALS = "client_credentials"
"""Client Credentials grant (service-to-service / M2M)."""
AZURE_MANAGED_IDENTITY = "azure_managed_identity"
"""Azure Managed Identity via IMDS."""
@dataclass
class OAuthConfig:
"""OAuth configuration for LanceDB authentication.
All token acquisition and refresh is handled in the Rust layer.
This config is passed through to Rust via PyO3.
Parameters
----------
issuer_url : str
OIDC issuer URL or OAuth authority URL.
For Azure: ``https://login.microsoftonline.com/{tenant_id}/v2.0``
client_id : str
Application / Client ID.
scopes : List[str]
OAuth scopes to request.
For Azure managed identity, exactly one scope or resource is required.
For example: ``["api://{app_id}/.default"]``
flow : OAuthFlowType
Authentication flow to use. Default: CLIENT_CREDENTIALS.
client_secret : Optional[str]
Client secret (required for CLIENT_CREDENTIALS).
managed_identity_client_id : Optional[str]
Client ID for user-assigned managed identity (AZURE_MANAGED_IDENTITY).
refresh_buffer_secs : Optional[int]
Seconds before expiry to trigger proactive refresh (default: 300).
Keep this well below the token TTL; if it is greater than or equal to
the TTL, each request refreshes the token.
Examples
--------
Client Credentials (service-to-service):
>>> config = OAuthConfig(
... issuer_url="https://login.microsoftonline.com/{tenant}/v2.0",
... client_id="app-id",
... client_secret="secret",
... scopes=["api://lancedb-api/.default"],
... )
Azure Managed Identity:
>>> config = OAuthConfig(
... issuer_url="https://login.microsoftonline.com/{tenant}/v2.0",
... client_id="app-id",
... scopes=["api://lancedb-api/.default"],
... flow=OAuthFlowType.AZURE_MANAGED_IDENTITY,
... )
"""
issuer_url: str
client_id: str
scopes: List[str]
flow: OAuthFlowType = OAuthFlowType.CLIENT_CREDENTIALS
client_secret: Optional[str] = field(default=None, repr=False)
managed_identity_client_id: Optional[str] = None
refresh_buffer_secs: Optional[int] = None

View File

@@ -13,14 +13,10 @@ from typing import (
Iterable,
List,
Optional,
TYPE_CHECKING,
Union,
Literal,
overload,
)
if TYPE_CHECKING:
from ..udf import JobHandle
import warnings
from lancedb import __version__
@@ -849,8 +845,7 @@ class RemoteTable(Table):
"""
warnings.warn(
"cleanup_old_versions() is a no-op on LanceDB Cloud. "
"Tables are automatically cleaned up and optimized.",
stacklevel=2,
"Tables are automatically cleaned up and optimized."
)
pass
@@ -862,8 +857,7 @@ class RemoteTable(Table):
"""
warnings.warn(
"compact_files() is a no-op on LanceDB Cloud. "
"Tables are automatically compacted and optimized.",
stacklevel=2,
"Tables are automatically compacted and optimized."
)
pass
@@ -880,150 +874,15 @@ class RemoteTable(Table):
"""
warnings.warn(
"optimize() is a no-op on LanceDB Cloud. "
"Indices are optimized automatically.",
stacklevel=2,
"Indices are optimized automatically."
)
pass
def count_rows(self, filter: Optional[str] = None) -> int:
return LOOP.run(self._table.count_rows(filter))
def add_columns(
self,
transforms: Optional[Dict[str, str]] = None,
*,
computed: Optional[Dict[str, tuple]] = None,
) -> Optional[AddColumnsResult]:
result = None
if transforms is not None:
result = LOOP.run(self._table.add_columns(transforms))
if computed:
LOOP.run(self._table.add_columns(computed=computed))
return result
def refresh_column(
self,
columns,
*,
where: Optional[str] = None,
num_workers: Optional[int] = None,
max_workers: Optional[int] = None,
batch_size: Optional[int] = None,
priority: Optional[str] = None,
) -> "JobHandle":
"""Trigger recompute of computed columns (REFRESH COLUMN).
The expression is resolved server-side from each column's stored
binding; columns bound to the same struct-returning function
refresh together. Returns a `JobHandle` to wait on, poll, or cancel
(``tbl.refresh_column("c").wait()``). Server-backed feature
(LanceDB Enterprise / Cloud).
num_workers / max_workers / batch_size / priority are per-refresh
scheduling knobs (how to run THIS refresh) and override any default
the function carries. `priority` is a Kueue tier
(training | interactive | backfill).
"""
from ..udf import JobHandle
if isinstance(columns, str):
columns = [columns]
job_id = LOOP.run(
self._table.refresh_column(
list(columns),
where=where,
num_workers=num_workers,
max_workers=max_workers,
batch_size=batch_size,
priority=priority,
)
)
return JobHandle(self._job_conn(), job_id)
def lineage(self, column=None, *, direction=None, depth=None):
"""Derived-compute lineage of this table, or one of its columns:
upstream sources, downstream dependents, and the function version +
location that produced each derived column (with a drift flag). Returns
a `Lineage`. See `Connection.lineage`."""
return self._job_conn().lineage(
self._name, column, direction=direction, depth=depth
)
def _job_conn(self):
"""A client connection for polling jobs this table spawns. Built lazily
from the table's serialized connection state and cached (not pickled --
a forked/unpickled table rebuilds it on next use)."""
from lancedb import deserialize_conn
conn = getattr(self, "_job_conn_cache", None)
if conn is None:
conn = deserialize_conn(self._serialized_connection_state())
self._job_conn_cache = conn
return conn
def load_columns(
self,
source: Union[str, Iterable[str]],
pk: str,
columns: Union[Iterable[str], Dict[str, str]],
*,
source_format: str = "parquet",
source_pk: Optional[str] = None,
on_missing: str = "carry",
source_storage_options: Optional[Dict[str, str]] = None,
num_workers: Optional[int] = None,
max_workers: Optional[int] = None,
batch_size: Optional[int] = None,
commit_granularity: Optional[int] = None,
priority: Optional[str] = None,
) -> str:
"""Fill existing columns from an external source by primary-key join.
The distributed-job equivalent of Geneva's ``Table.load_columns()``:
imports precomputed values (e.g. embeddings) from Parquet/Lance/IPC into
this table, matching on a primary key. Returns the load job id.
Server-backed feature (LanceDB Enterprise / Cloud).
Parameters
----------
source: str | list[str]
One source URI or a list of URIs.
pk: str
Destination primary-key column. Also the source key unless
``source_pk`` is given.
columns: list[str] | dict[str, str]
Value columns to load. A list loads same-named columns; a dict maps
``{target: source}``.
source_format: str
``"parquet"`` (default), ``"lance"``, or ``"ipc"``.
source_pk: str, optional
Source primary-key column when it differs from ``pk``.
on_missing: str
Behavior for destination rows with no source match:
``"carry"`` (default, keep existing), ``"null"``, or ``"error"``.
"""
if isinstance(source, str):
source = [source]
if isinstance(columns, dict):
mappings = [(target, src) for target, src in columns.items()]
else:
mappings = [(c, None) for c in columns]
return LOOP.run(
self._table.load_columns(
list(source),
source_format,
pk,
mappings,
source_key=source_pk,
source_storage_options=source_storage_options,
on_missing=on_missing,
num_workers=num_workers,
max_workers=max_workers,
batch_size=batch_size,
commit_granularity=commit_granularity,
priority=priority,
)
)
def add_columns(self, transforms: Dict[str, str]) -> AddColumnsResult:
return LOOP.run(self._table.add_columns(transforms))
def alter_columns(
self, *alterations: Iterable[Dict[str, str]]

View File

@@ -86,10 +86,7 @@ def _from_list(data: list) -> Scannable:
@to_scannable.register(dict)
def _from_dict(data: dict) -> Scannable:
raise ValueError(
"Cannot create or add rows from a single dictionary. "
"Use a list of dictionaries instead."
)
raise ValueError("Cannot add a single dictionary to a table. Use a list.")
@to_scannable.register(LanceModel)

View File

@@ -243,10 +243,7 @@ def _into_pyarrow_reader(
raise ValueError("Cannot add a single LanceModel to a table. Use a list.")
if isinstance(data, dict):
raise ValueError(
"Cannot create or add rows from a single dictionary. "
"Use a list of dictionaries instead."
)
raise ValueError("Cannot add a single dictionary to a table. Use a list.")
if isinstance(data, list):
# Handle empty list case
@@ -702,24 +699,6 @@ def _normalize_progress(progress):
return progress, False
def _computed_groups(computed):
"""Group computed columns by expression, preserving declaration order
(struct-returning functions need their columns adjacent so schema order
matches field order). Accepts the ergonomic forms -- `fn("col")` values
and tuple keys for struct fan-out -- via `_normalize_computed`."""
from .udf import _normalize_computed
groups = []
for name, (sql_type, expression) in _normalize_computed(computed).items():
for expr, cols in groups:
if expr == expression:
cols.append((name, sql_type))
break
else:
groups.append((expression, [(name, sql_type)]))
return groups
class Table(ABC):
"""
A Table is a collection of Records in a LanceDB Database.
@@ -825,59 +804,6 @@ class Table(ABC):
"""The number of rows in this Table"""
return self.count_rows(None)
def add_computed_column(
self,
columns,
fn,
args: Optional[List[str]] = None,
types=None,
) -> None:
"""Declare computed column(s) bound to a UDF -- no compute happens
here (the agent fills them lazily, or refresh_column() triggers a run).
.. deprecated::
A computed column is an expression over a registered function, so
bind it as one: ``add_columns(computed={"vec": embed("data")})``.
``embed("data")`` applies the function to the `data` column and
infers the type from the function's return signature -- the
function never couples to a particular column. Prefer that form.
"""
import warnings
warnings.warn(
"add_computed_column is deprecated; use add_columns(computed="
'{"vec": embed("data")}).',
DeprecationWarning,
stacklevel=2,
)
from .udf import Udf, struct_field_types
multi = isinstance(columns, (tuple, list))
if isinstance(fn, Udf):
expr = fn.expression(*(args or []))
if types is None:
if multi:
if not fn.returns.upper().startswith("STRUCT"):
raise ValueError(
"several columns need a STRUCT-returning function"
)
types = struct_field_types(fn.returns)
else:
types = fn.returns
else:
if types is None:
raise ValueError("pass types= when fn is a name string")
expr = f"{fn}({', '.join(args or [])})"
if multi:
if len(types) != len(columns):
raise ValueError(
f"{len(columns)} columns but {len(types)} output types"
)
computed = {c: (t, expr) for c, t in zip(columns, types)}
else:
computed = {columns: (types, expr)}
self.add_columns(computed=computed)
@property
@abstractmethod
def embedding_functions(self) -> Dict[str, EmbeddingFunctionConfig]:
@@ -2093,7 +2019,6 @@ class LanceTable(Table):
namespace_client: Optional[Any] = None,
managed_versioning: Optional[bool] = None,
pushdown_operations: Optional[set] = None,
route_pushdown_to_rust: bool = False,
_async: AsyncTable = None,
):
if namespace_path is None:
@@ -2103,14 +2028,6 @@ class LanceTable(Table):
self._location = location # Store location for use in _dataset_path
self._namespace_client = namespace_client
self._pushdown_operations = pushdown_operations or set()
# When the connection built the namespace client natively (e.g. an
# enterprise "rest" connection), the underlying Rust table already
# executes QueryTable pushdown itself -- and, unlike this Python urllib3
# path, it routes through the read-freshness context provider that emits
# the ``x-lancedb-min-timestamp`` header. So we must defer pushdown to
# Rust instead of calling the Python ``namespace_client.query_table``
# directly, or reads silently bypass read-freshness (stale results).
self._route_pushdown_to_rust = route_pushdown_to_rust
if _async is not None:
self._table = _async
else:
@@ -2321,7 +2238,6 @@ class LanceTable(Table):
namespace_path=self._namespace_path,
namespace_client=self._namespace_client,
pushdown_operations=self._pushdown_operations,
route_pushdown_to_rust=self._route_pushdown_to_rust,
location=self._location,
_async=async_table,
)
@@ -2472,11 +2388,8 @@ class LanceTable(Table):
Returns
-------
pa.Table"""
if (
_should_push_down_query_table(
self._namespace_client, self._pushdown_operations
)
and not self._route_pushdown_to_rust
if _should_push_down_query_table(
self._namespace_client, self._pushdown_operations
):
return self._execute_query(Query()).read_all()
@@ -3428,7 +3341,6 @@ class LanceTable(Table):
location: Optional[str] = None,
namespace_client: Optional[Any] = None,
pushdown_operations: Optional[set] = None,
route_pushdown_to_rust: bool = False,
):
"""
Create a new table.
@@ -3491,24 +3403,21 @@ class LanceTable(Table):
self._location = location
self._namespace_client = namespace_client
self._pushdown_operations = pushdown_operations or set()
self._route_pushdown_to_rust = route_pushdown_to_rust
if data_storage_version is not None:
warnings.warn(
"setting data_storage_version directly on create_table is deprecated. "
"setting data_storage_version directly on create_table is deprecated. ",
"Use database_options instead.",
DeprecationWarning,
stacklevel=2,
)
if storage_options is None:
storage_options = {}
storage_options["new_table_data_storage_version"] = data_storage_version
if enable_v2_manifest_paths is not None:
warnings.warn(
"setting enable_v2_manifest_paths directly on create_table is "
"setting enable_v2_manifest_paths directly on create_table is ",
"deprecated. Use database_options instead.",
DeprecationWarning,
stacklevel=2,
)
if storage_options is None:
storage_options = {}
@@ -3605,7 +3514,6 @@ class LanceTable(Table):
_should_push_down_query_table(
self._namespace_client, self._pushdown_operations
)
and not self._route_pushdown_to_rust
and self.current_branch() is None
):
from lancedb.namespace import _execute_server_side_query
@@ -3781,68 +3689,9 @@ class LanceTable(Table):
return LOOP.run(self._table.index_stats(index_name))
def add_columns(
self,
transforms: Dict[str, str]
| pa.field
| List[pa.field]
| pa.Schema
| None = None,
*,
computed: Optional[Dict] = None,
) -> Optional[AddColumnsResult]:
result = None
if transforms is not None:
result = LOOP.run(self._table.add_columns(transforms))
if computed:
# computed binds an expression over a registered function to a
# column: {col: fn("input_col")} -- fn("input_col") yields the
# expression and carries the inferred type; a tuple key fans a
# STRUCT return out to several columns. Declares the binding only;
# the server fills the values (server-backed). The legacy
# {col: (sql_type, expression)} tuple form is still accepted.
result_unused = LOOP.run(self._table.add_columns(computed=computed))
del result_unused
return result
def refresh_column(
self,
columns,
*,
where: Optional[str] = None,
num_workers: Optional[int] = None,
max_workers: Optional[int] = None,
batch_size: Optional[int] = None,
priority: Optional[str] = None,
) -> "JobHandle":
"""Trigger recompute of computed columns (REFRESH COLUMN).
The expression is resolved server-side from each column's stored
binding; columns bound to the same struct-returning function
refresh together. Returns a `JobHandle` to wait on, poll, or cancel
(``tbl.refresh_column("col").wait()``) -- mirrors
`MaterializedView.refresh()`. Server-backed feature (LanceDB
Enterprise / Cloud).
num_workers / max_workers / batch_size / priority are per-refresh
scheduling knobs (how to run THIS refresh) and override any default
the function carries. `priority` is a Kueue tier
(training | interactive | backfill).
"""
from .udf import JobHandle
if isinstance(columns, str):
columns = [columns]
job_id = LOOP.run(
self._table.refresh_column(
list(columns),
where=where,
num_workers=num_workers,
max_workers=max_workers,
batch_size=batch_size,
priority=priority,
)
)
return JobHandle(self._conn, job_id, table=self.name)
self, transforms: Dict[str, str] | pa.field | List[pa.field] | pa.Schema
) -> AddColumnsResult:
return LOOP.run(self._table.add_columns(transforms))
def alter_columns(
self, *alterations: Iterable[Dict[str, str]]
@@ -4406,7 +4255,6 @@ class AsyncTable:
namespace_path: Optional[List[str]] = None,
namespace_client: Optional[Any] = None,
pushdown_operations: Optional[set] = None,
route_pushdown_to_rust: bool = False,
):
"""Create a new AsyncTable object.
@@ -4419,9 +4267,6 @@ class AsyncTable:
self._namespace_path = namespace_path or []
self._namespace_client = namespace_client
self._pushdown_operations = pushdown_operations or set()
# See LanceTable.__init__: defer QueryTable pushdown to Rust (which emits
# the read-freshness header) for natively-built namespace clients.
self._route_pushdown_to_rust = route_pushdown_to_rust
def _set_namespace_context(
self,
@@ -4429,12 +4274,10 @@ class AsyncTable:
namespace_path: Optional[List[str]] = None,
namespace_client: Optional[Any] = None,
pushdown_operations: Optional[set] = None,
route_pushdown_to_rust: bool = False,
) -> "AsyncTable":
self._namespace_path = namespace_path or []
self._namespace_client = namespace_client
self._pushdown_operations = pushdown_operations or set()
self._route_pushdown_to_rust = route_pushdown_to_rust
return self
def __repr__(self):
@@ -4644,11 +4487,8 @@ class AsyncTable:
-------
pa.Table
"""
if (
_should_push_down_query_table(
self._namespace_client, self._pushdown_operations
)
and not self._route_pushdown_to_rust
if _should_push_down_query_table(
self._namespace_client, self._pushdown_operations
):
return (await self._execute_query(Query())).read_all()
@@ -5332,11 +5172,8 @@ class AsyncTable:
batch_size: Optional[int] = None,
timeout: Optional[timedelta] = None,
) -> pa.RecordBatchReader:
if (
_should_push_down_query_table(
self._namespace_client, self._pushdown_operations
)
and not self._route_pushdown_to_rust
if _should_push_down_query_table(
self._namespace_client, self._pushdown_operations
):
from lancedb.namespace import _execute_server_side_query
@@ -5520,44 +5357,9 @@ class AsyncTable:
return await self._inner.update(updates_sql, where)
async def refresh_column(
self,
columns,
*,
where: Optional[str] = None,
num_workers: Optional[int] = None,
max_workers: Optional[int] = None,
batch_size: Optional[int] = None,
priority: Optional[str] = None,
) -> str:
"""Trigger recompute of computed columns (REFRESH COLUMN).
Returns the refresh job id. Server-backed feature.
num_workers / max_workers / batch_size / priority are per-refresh
scheduling knobs (how to run THIS refresh); they override any default
the function carries. `priority` is a Kueue tier
(training | interactive | backfill)."""
if isinstance(columns, str):
columns = [columns]
return await self._inner.refresh_column(
list(columns),
where_clause=where,
num_workers=num_workers,
max_workers=max_workers,
batch_size=batch_size,
priority=priority,
)
async def add_columns(
self,
transforms: dict[str, str]
| pa.field
| List[pa.field]
| pa.Schema
| None = None,
*,
computed: Optional[Dict] = None,
) -> Optional[AddColumnsResult]:
self, transforms: dict[str, str] | pa.field | List[pa.field] | pa.Schema
) -> AddColumnsResult:
"""
Add new columns with defined values.
@@ -5576,7 +5378,6 @@ class AsyncTable:
version: the new version number of the table after adding columns.
"""
result = None
if isinstance(transforms, pa.Field):
transforms = [transforms]
if isinstance(transforms, list) and all(
@@ -5584,69 +5385,9 @@ class AsyncTable:
):
transforms = pa.schema(transforms)
if isinstance(transforms, pa.Schema):
result = await self._inner.add_columns_with_schema(transforms)
elif transforms is not None:
result = await self._inner.add_columns(list(transforms.items()))
if computed:
# computed binds an expression over a registered function to a
# column: {col: fn("input_col")} -- fn("input_col") yields the
# expression and carries the inferred type; a tuple key fans a
# STRUCT return out to several columns. Declares the binding only;
# the server fills the values (server-backed). The legacy
# {col: (sql_type, expression)} tuple form is still accepted.
for expression, cols in _computed_groups(computed):
await self._inner.add_computed_columns(cols, expression)
return result
async def add_computed_column(
self,
columns,
fn,
args: Optional[List[str]] = None,
types=None,
) -> None:
"""Declare computed column(s) bound to a UDF (async).
.. deprecated::
Use ``add_columns(computed={"col": fn("input_col")})`` -- a computed
column is an expression over a registered function, so bind it that
way instead of coupling the UDF to the column here.
"""
import warnings
warnings.warn(
"add_computed_column is deprecated; use add_columns(computed="
'{"col": fn("input_col")}).',
DeprecationWarning,
stacklevel=2,
)
from .udf import Udf, struct_field_types
multi = isinstance(columns, (tuple, list))
if isinstance(fn, Udf):
expr = fn.expression(*(args or []))
if types is None:
if multi:
if not fn.returns.upper().startswith("STRUCT"):
raise ValueError(
"several columns need a STRUCT-returning function"
)
types = struct_field_types(fn.returns)
else:
types = fn.returns
return await self._inner.add_columns_with_schema(transforms)
else:
if types is None:
raise ValueError("pass types= when fn is a name string")
expr = f"{fn}({', '.join(args or [])})"
if multi:
if len(types) != len(columns):
raise ValueError(
f"{len(columns)} columns but {len(types)} output types"
)
computed = {c: (t, expr) for c, t in zip(columns, types)}
else:
computed = {columns: (types, expr)}
await self.add_columns(computed=computed)
return await self._inner.add_columns(list(transforms.items()))
async def alter_columns(
self, *alterations: Iterable[dict[str, Any]]
@@ -5918,7 +5659,6 @@ class AsyncTable:
"The 'retrain' parameter is deprecated and will be removed in a "
"future version.",
DeprecationWarning,
stacklevel=2,
)
return await self._inner.optimize(

View File

@@ -1,753 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
"""UDF authoring for LanceDB derived compute (server-backed).
`@udf` / `@table_udf` turn a plain Python function into a registrable
server-side UDF: a cloudpickled (or source) body, a SQL signature inferred
from type hints, and the runtime options (pip deps, GPUs, batching, ...).
Register and use them through the existing connection/table API:
import lancedb
from lancedb import udf, table_udf
db = lancedb.connect("db://my_db", api_key="...", host_override="...")
@udf(pip=["torch>=2.0"], num_gpus=1)
def embed(text: str) -> list[float]:
return model.encode(text).tolist()
db.create_function(embed) # CREATE FUNCTION (once)
tbl = db.open_table("docs")
tbl.add_columns(computed={"vec": embed("text")}) # bind embed(text) -> vec
tbl.refresh_column("vec").wait() # materialize (returns a JobHandle)
view = db.create_materialized_view("chunks", tbl, ["id", chunk_fn])
`embed("text")` applies the registered function to the `text` column and yields
the expression `embed(text)`; the function itself stays decoupled from any
column, so the same `embed` works on any column or table.
These operations are server-backed (LanceDB Enterprise / Cloud); the
decorator itself works locally (define + call), only registration needs a
remote connection.
"""
from __future__ import annotations
import asyncio
import base64
import dataclasses
import functools
import inspect
import re
import sys
import textwrap
import time
import typing
# -- type hints -> SQL type strings -------------------------------------
_SCALARS = {
int: "BIGINT",
# Pragmatic default for ML workloads: python float maps to FLOAT
# (Float32). Use an explicit `returns=` for DOUBLE.
float: "FLOAT",
str: "VARCHAR",
bool: "BOOLEAN",
bytes: "BLOB",
}
class TypeInferenceError(TypeError):
pass
def sql_type(hint) -> str:
"""SQL type string for a python type hint."""
if hint in _SCALARS:
return _SCALARS[hint]
origin = typing.get_origin(hint)
if origin in (list, typing.List):
(item,) = typing.get_args(hint) or (None,)
if item in _SCALARS:
return f"{_SCALARS[item]}[]"
raise TypeInferenceError(
f"unsupported list item type {item!r}; use an explicit returns="
)
fields = _struct_fields(hint)
if fields is not None:
inner = ", ".join(f"{name} {sql_type(h)}" for name, h in fields)
return f"STRUCT({inner})"
raise TypeInferenceError(
f"cannot infer a SQL type for {hint!r}; pass an explicit type string"
)
def _struct_fields(hint):
"""(name, hint) pairs for a TypedDict or dataclass, else None."""
if dataclasses.is_dataclass(hint):
return [(f.name, f.type) for f in dataclasses.fields(hint)]
# TypedDict detection: a dict subclass with __annotations__.
if (
isinstance(hint, type)
and issubclass(hint, dict)
and typing.get_type_hints(hint)
):
return list(typing.get_type_hints(hint).items())
return None
def return_type(fn, override: "str | None", table: bool) -> str:
"""SQL return type for a function: explicit override wins, else the
return annotation. Table functions render as TABLE(...) and accept
struct-shaped hints (TypedDict/dataclass, optionally list-wrapped)."""
if override is not None:
s = override.strip()
if table and not s.upper().startswith("TABLE"):
if s.upper().startswith("STRUCT"):
return "TABLE" + s[len("STRUCT") :]
raise TypeInferenceError(
"a table function's returns= must be TABLE(...) or STRUCT(...)"
)
return s
hints = typing.get_type_hints(fn)
ret = hints.get("return")
if ret is None:
raise TypeInferenceError(
f"function {fn.__name__!r} needs a return annotation or returns="
)
if table:
# Accept list[Row] / Row where Row is a TypedDict or dataclass.
if typing.get_origin(ret) in (list, typing.List):
(ret,) = typing.get_args(ret)
fields = _struct_fields(ret)
if fields is None:
raise TypeInferenceError(
"a table function must return rows shaped as a TypedDict or "
"dataclass (optionally list-wrapped); or pass returns=..."
)
inner = ", ".join(f"{name} {sql_type(h)}" for name, h in fields)
return f"TABLE({inner})"
return sql_type(ret)
def param_types(fn) -> "list[tuple[str, str]]":
"""(name, sql type) per parameter, from annotations. Each UDF
parameter binds to a source column of the same name by default."""
hints = typing.get_type_hints(fn)
out = []
for name, p in inspect.signature(fn).parameters.items():
if p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD):
raise TypeInferenceError("*args/**kwargs are not supported in UDFs")
hint = hints.get(name)
if hint is None:
raise TypeInferenceError(
f"parameter {name!r} of {fn.__name__!r} needs a type annotation"
)
out.append((name, sql_type(hint)))
return out
# -- column expressions -------------------------------------------------
class ColumnExpr(str):
"""A computed-column expression produced by applying a registered
function to column names, e.g. ``embed("data") -> "embed(data)"``.
It IS the expression string everywhere a string is expected (views, SQL,
logging), and additionally carries the function's declared return type so
``add_columns(computed=...)`` can declare the column without a hand-written
type. ``field_types`` holds the per-field SQL types of a STRUCT return, for
fanning one expression out to several columns.
"""
data_type: "str | None"
field_types: "list[str] | None"
def __new__(cls, expr: str, data_type=None, field_types=None):
obj = super().__new__(cls, expr)
obj.data_type = data_type
obj.field_types = field_types
return obj
def _normalize_computed(computed: dict) -> dict:
"""Normalize the user-facing ``computed=`` mapping to the canonical
``{name: (sql_type, expression)}`` form.
Accepts, per entry:
- value is a `ColumnExpr` (from ``fn("col")``): the column's SQL type
comes from the function's return type -- no hand-written type needed. A
tuple key (``("chunk", "idx")``) fans a STRUCT return out to one
(type, expression) entry per field, in declared order.
- value is a legacy ``(sql_type, expression)`` tuple: passed through (the
escape hatch, e.g. bare-name function strings).
"""
out: dict = {}
for key, val in computed.items():
if isinstance(val, ColumnExpr):
expr = str(val)
if isinstance(key, (tuple, list)):
if not val.field_types:
raise ValueError(
f"columns {tuple(key)} need a STRUCT-returning function; "
f"{expr} returns a single value"
)
if len(val.field_types) != len(key):
raise ValueError(
f"{len(key)} columns but {len(val.field_types)} struct fields "
f"in {expr}"
)
for name, t in zip(key, val.field_types):
out[name] = (t, expr)
else:
if val.data_type is None:
raise ValueError(f"cannot infer a type for {expr}; pass types=")
out[key] = (val.data_type, expr)
else:
out[key] = val
return out
# -- the @udf / @table_udf decorators -----------------------------------
class Udf:
def __init__(
self,
fn,
*,
returns: "str | None" = None,
table: bool = False,
name: "str | None" = None,
pip: "list[str] | None" = None,
pip_index_url: "str | None" = None,
pip_extra_index_urls: "list[str] | None" = None,
find_links: "list[str] | None" = None,
requirements: "str | list[str] | None" = None,
conda: "list[str] | None" = None,
conda_channels: "list[str] | None" = None,
env: "dict[str, str] | list[str] | None" = None,
num_cpus: "int | None" = None,
num_gpus: "int | None" = None,
batch_size: "int | None" = None,
timeout: "float | None" = None,
error_policy: "str | None" = None,
max_skip_ratio: "float | None" = None,
retries: "int | None" = None,
docker_image: "str | None" = None,
description: "str | None" = None,
prefer_source: bool = False,
):
functools.update_wrapper(self, fn)
self.fn = fn
self.name = name or fn.__name__
self.table = table
self.params = param_types(fn)
self.returns = return_type(fn, returns, table)
self.prefer_source = prefer_source
self.options: "dict[str, str]" = {}
if conda and (pip or requirements):
raise ValueError("pass conda or pip/requirements, not both")
if conda_channels and not conda:
raise ValueError("conda_channels requires conda")
if pip:
self.options["pip"] = ",".join(pip)
if pip_extra_index_urls:
self.options["pip_extra_index_urls"] = ",".join(pip_extra_index_urls)
if find_links:
self.options["find_links"] = ",".join(find_links)
if requirements:
self.options["requirements"] = _format_requirements(requirements)
if conda:
self.options["conda"] = ",".join(conda)
if conda_channels:
self.options["conda_channels"] = ",".join(conda_channels)
if env:
self.options["env"] = _format_env(env)
for key, val in [
("pip_index_url", pip_index_url),
("num_cpus", num_cpus),
("num_gpus", num_gpus),
("batch_size", batch_size),
("timeout", timeout),
("error_policy", error_policy),
("max_skip_ratio", max_skip_ratio),
("retries", retries),
("docker_image", docker_image),
]:
if val is not None:
self.options[key] = str(val)
# Keep the source in the description (when available) so the
# catalog stays inspectable even for pickled bodies.
if description is not None:
self.options["description"] = description
else:
try:
self.options["description"] = textwrap.dedent(inspect.getsource(fn))
except (OSError, TypeError):
pass
def __call__(self, *args, **kwargs):
"""Call with real values to run locally; call with column-name
strings to build an expression for backfills and views, e.g.
``embed("data")`` -> the expression ``embed(data)`` (a `ColumnExpr`
carrying the function's return type for `add_columns(computed=...)`)."""
if args and all(isinstance(a, str) for a in args) and not kwargs:
return self.expression(*args)
return self.fn(*args, **kwargs)
def expression(self, *columns: str) -> ColumnExpr:
"""The expression applying this function to `columns` (default: the
function's own parameter names). Returns a `ColumnExpr` -- a string
that also carries the declared return type (and struct field types)."""
cols = columns or [p for p, _ in self.params]
expr = f"{self.name}({', '.join(cols)})"
field_types = None
if self.returns.upper().startswith("STRUCT"):
field_types = struct_field_types(self.returns)
return ColumnExpr(expr, data_type=self.returns, field_types=field_types)
def _body(self) -> "tuple[str, str]":
"""(body literal, body_format). Source when requested and
retrievable; cloudpickle otherwise (handles closures)."""
if self.prefer_source:
try:
src = textwrap.dedent(inspect.getsource(self.fn))
# Strip the decorator line(s) so the stored body is a
# plain function definition.
lines = src.splitlines(keepends=True)
while lines and lines[0].lstrip().startswith("@"):
lines.pop(0)
return "".join(lines), "source"
except (OSError, TypeError):
pass
import cloudpickle
raw = cloudpickle.dumps(self.fn)
return base64.b64encode(raw).decode("ascii"), "cloudpickle"
def _body_and_options(self) -> "tuple[str, dict[str, str]]":
"""The body literal plus the finalized options (body_format /
python_version / cloudpickle-pip bookkeeping for a non-source
body)."""
body, body_format = self._body()
options = dict(self.options)
if body_format != "source":
options["body_format"] = body_format
# Pickled code objects only load under the same interpreter
# minor version; record ours so the worker can fail with a
# clear message instead of a bytecode error.
options["python_version"] = self.pickle_environment()
# The worker deserializes the body with cloudpickle; make sure
# the job's pip environment provides it. Conda bakes inject
# cloudpickle server-side, so do not create an invalid pip+conda
# declaration here.
if "conda" not in options:
pip = [d for d in options.get("pip", "").split(",") if d]
if not any(d.startswith("cloudpickle") for d in pip):
pip.append("cloudpickle")
options["pip"] = ",".join(pip)
return body, options
def create_request(self) -> dict:
"""Keyword arguments for `connection.create_function`."""
body, options = self._body_and_options()
return {
"name": self.name,
"language": "python",
"return_type": self.returns,
"body": body,
"options": options,
}
def create_statement(self) -> str:
"""The equivalent `CREATE FUNCTION` SQL (for SQL-surface callers)."""
params = ", ".join(f"{n} {t}" for n, t in self.params)
body, options = self._body_and_options()
with_clause = ""
if options:
rendered = ", ".join(
f"{k} = '{_escape(v)}'" for k, v in sorted(options.items())
)
with_clause = f" WITH ({rendered})"
return (
f"CREATE FUNCTION {self.name}({params}) RETURNS {self.returns} "
f"LANGUAGE python AS '{_escape_body(body)}'{with_clause}"
)
def pickle_environment(self) -> str:
"""Python version the body pickles under -- workers should match
the minor version for cloudpickle compatibility."""
return f"{sys.version_info.major}.{sys.version_info.minor}"
def _escape(s: str) -> str:
return str(s).replace("'", "''")
def _format_requirements(requirements: "str | list[str]") -> str:
if isinstance(requirements, str):
return requirements
return "\n".join(str(req) for req in requirements)
def _format_env(env: "dict[str, str] | list[str]") -> str:
if isinstance(env, dict):
return "; ".join(f"{key}={value}" for key, value in env.items())
return "; ".join(str(entry) for entry in env)
def _escape_body(body: str) -> str:
# The server unescapes \n / \t in single-quoted bodies; encode real
# newlines accordingly and escape quotes.
return (
body.replace("\\", "\\\\")
.replace("'", "''")
.replace("\n", "\\n")
.replace("\t", "\\t")
)
def udf(fn=None, **kwargs):
"""Decorate a function as a scalar (or struct-returning) UDF.
@udf
def doubled(val: int) -> float: ...
@udf(pip=["torch>=2"], num_gpus=1)
def embed(body: str) -> list[float]: ...
"""
if fn is not None:
return Udf(fn, **kwargs)
return lambda f: Udf(f, **kwargs)
def table_udf(fn=None, **kwargs):
"""Decorate a table function (UDTF): each input row may emit zero or
more output rows. Only usable in materialized views.
class Chunk(TypedDict):
chunk: str
chunk_idx: int
@table_udf
def chunker(body: str) -> list[Chunk]: ...
"""
kwargs["table"] = True
if fn is not None:
return Udf(fn, **kwargs)
return lambda f: Udf(f, **kwargs)
# -- view / job handles (thin references over a connection) -------------
def struct_field_types(returns: str) -> "list[str]":
"""Field type strings of a STRUCT(...) SQL type, in declared order."""
inner = returns.strip()[len("STRUCT(") : -1]
fields, depth, start = [], 0, 0
for i, c in enumerate(inner):
if c in "([":
depth += 1
elif c in ")]":
depth -= 1
elif c == "," and depth == 0:
fields.append(inner[start:i].strip())
start = i + 1
fields.append(inner[start:].strip())
# Each field is "name TYPE"; drop the name.
return [f.split(None, 1)[1] for f in fields]
def build_view_query(source, select) -> str:
"""Assemble a view SELECT from a source (name or table) and select
items: a column name, an expression string, a (alias, expression)
tuple, or a @udf/@table_udf object."""
src = source.name if hasattr(source, "name") else source
items = []
for item in select:
if isinstance(item, Udf):
items.append(item.expression())
elif isinstance(item, tuple):
alias, expr = item
expr = expr.expression() if isinstance(expr, Udf) else expr
items.append(f"{expr} AS {alias}")
else:
items.append(item)
return f"SELECT {', '.join(items)} FROM {src}"
def _job_id_matches(handle_id: str, listed_id: str) -> bool:
# The refresh/backfill endpoints return the submission id (a uuid), but
# the agent names the manifest job "<table>-<type>-<first 8 of the
# submission id>" -- which is what list_jobs and cancel report. Match the
# canonical id directly, or by that submission prefix.
if listed_id == handle_id:
return True
prefix = handle_id[:8]
return len(prefix) >= 4 and prefix in listed_id
class MaterializedView:
"""A reference to a materialized view (name + connection). Operations are
server-backed connection calls bound to the name.
``create_materialized_view`` returns one of these; ``job_id`` is the
initial-population job (None when the view was created with no data), so
``db.create_materialized_view(...).wait()`` blocks until it is populated.
"""
def __init__(self, conn, name: str, job_id: "str | None" = None):
self.conn = conn
self.name = name
#: initial-population job id from create, or None (with_no_data).
self.job_id = job_id
def wait(self, timeout: float = 3600.0, poll: float = 2.0) -> str:
"""Block until the initial-population job (from create) finishes.
A no-op when the view was created with no data."""
if self.job_id is None:
return "finished"
return JobHandle(self.conn, self.job_id, table=self.name).wait(
timeout=timeout, poll=poll
)
def refresh(self, full: bool = False) -> "JobHandle":
"""Refresh the materialized view; returns a `JobHandle` to wait on,
poll, or cancel (``view.refresh().wait()``).
``full=True`` forces a full rebuild (recompute and replace every row)
instead of the default incremental refresh. A full rebuild preserves
the view's indexes -- they are reindexed by the distributed indexer.
"""
job_id = self.conn._refresh_materialized_view(self.name, full=full)
return JobHandle(self.conn, job_id, table=self.name)
def explain_refresh(self, full: bool = False):
"""Plan a refresh without running it (EXPLAIN REFRESH)."""
return self.conn.explain_refresh_materialized_view(self.name, full=full)
def alter(self, auto_refresh: bool) -> None:
self.conn.alter_materialized_view(self.name, auto_refresh=auto_refresh)
def drop(self) -> None:
self.conn.drop_materialized_view(self.name)
# A materialized view is a first-class table: it can be indexed and
# searched like any other. These open the materialized dataset by name and
# delegate. Indexes declared this way are recorded against the view, so the
# engine re-applies them after a full refresh rebuilds the dataset (a full
# refresh overwrites the dataset, which would otherwise drop its indices).
def _table(self):
return self.conn.open_table(self.name)
def create_index(self, *args, **kwargs):
"""Build an index on the materialized view (see Table.create_index)."""
return self._table().create_index(*args, **kwargs)
def create_scalar_index(self, *args, **kwargs):
"""Build a scalar index on the materialized view."""
return self._table().create_scalar_index(*args, **kwargs)
def create_fts_index(self, *args, **kwargs):
"""Build a full-text-search index on the materialized view."""
return self._table().create_fts_index(*args, **kwargs)
def search(self, *args, **kwargs):
"""Search the materialized view (vector / FTS / hybrid)."""
return self._table().search(*args, **kwargs)
def lineage(self, column=None, *, direction=None, depth=None):
"""Lineage of the materialized view (or one of its columns). Delegates
to the backing table; the server already includes the view's sources
and downstream dependents. Returns a `Lineage`."""
return self._table().lineage(column, direction=direction, depth=depth)
_PROGRESS = re.compile(r"(\d+)/(\d+)")
class JobFailedError(RuntimeError):
"""Raised by ``JobHandle.wait()`` when the server reports the job ``failed``.
Carries the server-side error so a doomed backfill (e.g. a multi-column
``REFRESH COLUMN`` of a scalar UDF) surfaces its real cause promptly,
instead of the caller blocking until ``wait()``'s timeout.
"""
def __init__(self, job_id: str, error: "str | None"):
self.job_id = job_id
self.error = error
super().__init__(f"job {job_id} failed: {error or 'unknown error'}")
class JobHandle:
"""A reference to an inflight server-side job, with polling helpers."""
#: How long an unseen job is treated as still materializing (submission
#: -> agent cycle -> manifest write is async).
GRACE_SECONDS = 20.0
def __init__(self, conn, job_id: str, table: "str | None" = None):
self.conn = conn
self.id = job_id
#: The job's table, when known (refresh_column / MV refresh). Lets the
#: server resolve this job with an O(1) single-node read; without it the
#: lookup scans the database's active jobs (still correct).
self.table = table
self._created = time.monotonic()
self._seen = False
def _job(self):
# Poll by id (one job), not list_jobs (every active job): the server
# matches the submission/manifest id and reads just this table's node.
return self.conn.get_job(self.id, self.table)
def status(self) -> str:
"""pending / running / cancelling / stale, or 'finished' once the
job has left the inflight listing."""
job = self._job()
if job is not None:
self._seen = True
return job.state
if not self._seen and time.monotonic() - self._created < self.GRACE_SECONDS:
return "pending"
return "finished"
def progress(self) -> "tuple[int, int] | None":
"""(units_done, units_total) while running, else None."""
job = self._job()
if job is not None and job.units_total is not None:
return job.units_done or 0, job.units_total
return None
def wait(self, timeout: float = 3600.0, poll: float = 2.0) -> str:
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
state = self.status()
if state in ("finished", "stale"):
return state
if state == "failed":
# Terminal failure -- surface the server error now, don't block
# until `timeout`. `finalize` wrote it to the job's status node.
job = self._job()
raise JobFailedError(self.id, job.error if job is not None else None)
if state == "pending":
time.sleep(min(poll, 0.5))
continue
job = self._job()
if job is not None and job.committed:
return "finished"
time.sleep(poll)
raise TimeoutError(f"job {self.id} still {self.status()} after {timeout}s")
def cancel(self) -> None:
# Cancel by the canonical manifest id (what cancel matches), found
# via the submission prefix; fall back to the raw id.
job = self._job()
self.conn.cancel_job(job.job_id if job is not None else self.id)
class AsyncMaterializedView:
"""Async reference to a materialized view (name + async connection)."""
def __init__(self, conn, name: str, job_id: "str | None" = None):
self.conn = conn
self.name = name
#: initial-population job id from create, or None (with_no_data).
self.job_id = job_id
async def wait(self, timeout: float = 3600.0, poll: float = 2.0) -> str:
"""Block until the initial-population job (from create) finishes.
A no-op when the view was created with no data."""
if self.job_id is None:
return "finished"
return await AsyncJobHandle(self.conn, self.job_id, table=self.name).wait(
timeout=timeout, poll=poll
)
async def refresh(self, full: bool = False) -> "AsyncJobHandle":
"""Refresh the materialized view; returns an `AsyncJobHandle` to wait
on, poll, or cancel.
``full=True`` forces a full rebuild instead of an incremental refresh
(indexes are preserved and reindexed by the distributed indexer).
"""
job_id = await self.conn._refresh_materialized_view(self.name, full=full)
return AsyncJobHandle(self.conn, job_id, table=self.name)
async def explain_refresh(self, full: bool = False):
return await self.conn.explain_refresh_materialized_view(self.name, full=full)
async def alter(self, auto_refresh: bool) -> None:
await self.conn.alter_materialized_view(self.name, auto_refresh=auto_refresh)
async def drop(self) -> None:
await self.conn.drop_materialized_view(self.name)
async def lineage(self, column=None, *, direction=None, depth=None):
"""Lineage of the materialized view (or column). Returns a `Lineage`."""
return await self.conn.lineage(
self.name, column, direction=direction, depth=depth
)
class AsyncJobHandle:
"""Async reference to an inflight server-side job, with polling helpers."""
GRACE_SECONDS = 20.0
def __init__(self, conn, job_id: str, table: "str | None" = None):
self.conn = conn
self.id = job_id
#: See JobHandle.table -- enables an O(1) by-id lookup when known.
self.table = table
self._created = time.monotonic()
self._seen = False
async def _job(self):
# Poll by id, not list_jobs (see JobHandle._job).
return await self.conn.get_job(self.id, self.table)
async def status(self) -> str:
job = await self._job()
if job is not None:
self._seen = True
return job.state
if not self._seen and time.monotonic() - self._created < self.GRACE_SECONDS:
return "pending"
return "finished"
async def progress(self) -> "tuple[int, int] | None":
job = await self._job()
if job is not None and job.units_total is not None:
return job.units_done or 0, job.units_total
return None
async def wait(self, timeout: float = 3600.0, poll: float = 2.0) -> str:
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
state = await self.status()
if state in ("finished", "stale"):
return state
if state == "failed":
# Terminal failure -- surface the server error now, don't block
# until `timeout`. `finalize` wrote it to the job's status node.
job = await self._job()
raise JobFailedError(self.id, job.error if job is not None else None)
if state == "pending":
await asyncio.sleep(min(poll, 0.5))
continue
job = await self._job()
if job is not None and job.committed:
return "finished"
await asyncio.sleep(poll)
raise TimeoutError(
f"job {self.id} still {await self.status()} after {timeout}s"
)
async def cancel(self) -> None:
job = await self._job()
await self.conn.cancel_job(job.job_id if job is not None else self.id)

View File

@@ -373,15 +373,9 @@ def _(value: list):
@value_to_sql.register(dict)
def _(value: dict):
# https://datafusion.apache.org/user-guide/sql/scalar_functions.html#named-struct
# Render the field name through value_to_sql(str(...)) as well so that keys
# containing characters meaningful in SQL (e.g. a single quote) are escaped
# the same way string values are. A bare f"'{k}'" would emit invalid SQL for
# a key like "it's".
return (
"named_struct("
+ ", ".join(
f"{value_to_sql(str(k))}, {value_to_sql(v)}" for k, v in value.items()
)
+ ", ".join(f"'{k}', {value_to_sql(v)}" for k, v in value.items())
+ ")"
)

View File

@@ -91,9 +91,7 @@ async def test_create_scalar_index(some_table: AsyncTable):
# Can recreate if replace=True
await some_table.create_index("id", replace=True)
indices = await some_table.list_indices()
assert str(indices).startswith(
'[IndexConfig(name="id_idx", index_type="BTree", columns=["id"]'
)
assert str(indices) == '[Index(BTree, columns=["id"], name="id_idx")]'
assert len(indices) == 1
assert indices[0].index_type == "BTree"
assert indices[0].columns == ["id"]
@@ -108,27 +106,6 @@ async def test_create_scalar_index(some_table: AsyncTable):
assert len(indices) == 0
@pytest.mark.asyncio
async def test_index_config_repr(db_async):
# Use >= 1000 rows so the thousands separator in the repr is exercised.
nrows = 1500
table = await db_async.create_table(
"repr_table", pa.Table.from_pydict({"id": list(range(nrows))})
)
await table.create_index("id", config=BTree())
indices = await table.list_indices()
assert len(indices) == 1
r = repr(indices[0])
assert r.startswith('IndexConfig(name="id_idx", index_type="BTree", columns=["id"]')
# Integer counts use `_` thousands separators (valid Python int syntax).
assert "num_indexed_rows=1_500" in r
assert "num_unindexed_rows=0" in r
# created_at renders as a datetime so the value round-trips.
assert "created_at=datetime.datetime(" in r
assert r.endswith(")")
@pytest.mark.asyncio
async def test_create_nested_scalar_index_lists_canonical_paths(db_async):
metadata_type = pa.struct(
@@ -221,9 +198,7 @@ async def test_create_nested_scalar_index_lists_canonical_paths(db_async):
async def test_create_fixed_size_binary_index(some_table: AsyncTable):
await some_table.create_index("fsb", config=BTree())
indices = await some_table.list_indices()
assert str(indices).startswith(
'[IndexConfig(name="fsb_idx", index_type="BTree", columns=["fsb"]'
)
assert str(indices) == '[Index(BTree, columns=["fsb"], name="fsb_idx")]'
assert len(indices) == 1
assert indices[0].index_type == "BTree"
assert indices[0].columns == ["fsb"]
@@ -272,9 +247,7 @@ async def test_create_bitmap_index(some_table: AsyncTable):
async def test_create_label_list_index(some_table: AsyncTable):
await some_table.create_index("tags", config=LabelList())
indices = await some_table.list_indices()
assert str(indices).startswith(
'[IndexConfig(name="tags_idx", index_type="LabelList", columns=["tags"]'
)
assert str(indices) == '[Index(LabelList, columns=["tags"], name="tags_idx")]'
plan = await some_table.query().where("array_has(tags, 'tag0')").explain_plan()
assert "ScalarIndexQuery" in plan
@@ -289,9 +262,7 @@ async def test_create_large_list_label_list_index(db_async):
await table.create_index("tags", config=LabelList())
indices = await table.list_indices()
assert str(indices).startswith(
'[IndexConfig(name="tags_idx", index_type="LabelList", columns=["tags"]'
)
assert str(indices) == '[Index(LabelList, columns=["tags"], name="tags_idx")]'
plan = await table.query().where("array_has(tags, 'shared')").explain_plan()
assert "ScalarIndexQuery" in plan
@@ -328,9 +299,7 @@ async def test_create_label_list_index_rejects_list_struct(db_async):
async def test_full_text_search_index(some_table: AsyncTable):
await some_table.create_index("tags", config=FTS(with_position=False))
indices = await some_table.list_indices()
assert str(indices).startswith(
'[IndexConfig(name="tags_idx", index_type="FTS", columns=["tags"]'
)
assert str(indices) == '[Index(FTS, columns=["tags"], name="tags_idx")]'
await some_table.prewarm_index("tags_idx")

View File

@@ -1,92 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
"""JobHandle.wait() terminal-state handling.
Regression coverage for the cluster backfill-failure hang: the server reports a
doomed job as ``state="failed"`` within seconds, but ``wait()`` used to ignore
``failed`` and block until its (default 3600s) timeout. These tests pin that a
``failed`` job raises ``JobFailedError`` promptly, carrying the server error.
"""
import asyncio
import time
import pytest
from lancedb.udf import JobHandle, AsyncJobHandle, JobFailedError
class FakeJobInfo:
"""Mirror of the pyo3 builtins.JobInfo fields wait()/status() read."""
def __init__(self, state, error=None, committed=False, units_total=None):
self.state = state
self.error = error
self.committed = committed
self.units_total = units_total
self.units_done = None
self.job_id = "job-1"
class FakeConn:
"""get_job() walks a scripted list of JobInfo (or None) snapshots, holding
the last one once exhausted, so wait() polls a deterministic timeline."""
def __init__(self, snapshots):
self._snaps = list(snapshots)
self.calls = 0
def get_job(self, job_id, table=None):
snap = self._snaps[min(self.calls, len(self._snaps) - 1)]
self.calls += 1
return snap
class AsyncFakeConn(FakeConn):
async def get_job(self, job_id, table=None):
return FakeConn.get_job(self, job_id, table)
def test_wait_raises_on_failed_promptly():
# pending -> failed: wait() must raise the server error, not TimeoutError.
conn = FakeConn(
[None, FakeJobInfo("failed", error="multi-column backfill needs a STRUCT")]
)
jh = JobHandle(conn, "job-1", table="t")
t0 = time.monotonic()
with pytest.raises(JobFailedError) as exc:
jh.wait(timeout=30, poll=0.01)
assert time.monotonic() - t0 < 5 # prompt, nowhere near the 30s timeout
assert "STRUCT" in str(exc.value)
assert exc.value.error == "multi-column backfill needs a STRUCT"
assert exc.value.job_id == "job-1"
def test_wait_returns_finished_on_success():
# running -> finished (job left the inflight listing) returns normally.
conn = FakeConn([FakeJobInfo("running", units_total=2), None])
jh = JobHandle(conn, "job-1", table="t")
jh._seen = True # already observed, so a None now means "finished" not grace
assert jh.wait(timeout=30, poll=0.01) == "finished"
def test_wait_returns_finished_on_committed():
# A committed job that is still listed resolves to finished.
conn = FakeConn([FakeJobInfo("running", committed=True, units_total=2)])
jh = JobHandle(conn, "job-1", table="t")
jh._seen = True
assert jh.wait(timeout=30, poll=0.01) == "finished"
def test_async_wait_raises_on_failed_promptly():
conn = AsyncFakeConn([None, FakeJobInfo("failed", error="boom")])
jh = AsyncJobHandle(conn, "job-1", table="t")
async def run():
t0 = time.monotonic()
with pytest.raises(JobFailedError) as exc:
await jh.wait(timeout=30, poll=0.01)
assert time.monotonic() - t0 < 5
assert exc.value.error == "boom"
asyncio.run(run())

View File

@@ -5,11 +5,11 @@
import tempfile
import shutil
import sys
import pytest
import pyarrow as pa
import lancedb
from lance_namespace.errors import NamespaceNotEmptyError, TableNotFoundError
from lancedb.namespace import _MAX_QUERY_K
from lancedb.table import AsyncTable, LanceTable
@@ -65,9 +65,6 @@ def _namespace_lance_table(namespace_client: _NamespaceClient) -> LanceTable:
table._namespace_path = ["geneva"]
table._namespace_client = namespace_client
table._pushdown_operations = {"QueryTable"}
# This test exercises the Python-side pushdown path (non-native client), so
# pushdown is not routed to Rust.
table._route_pushdown_to_rust = False
return table
@@ -808,37 +805,6 @@ class TestPushdownOperations:
db = lancedb.connect_namespace("dir", {"root": self.temp_dir})
assert len(db._namespace_client_pushdown_operations) == 0
def test_route_pushdown_to_rust_for_native_rest(self):
"""A natively-built rest connection must defer QueryTable pushdown to
Rust so reads carry the x-lancedb-min-timestamp read-freshness header."""
db = lancedb.connect_namespace(
"rest",
{"uri": "http://localhost:12345"},
namespace_client_pushdown_operations=["QueryTable"],
)
assert db._route_pushdown_to_rust is True
def test_route_pushdown_to_rust_false_for_dir(self):
"""A non-native (dir) connection keeps the Python pushdown path."""
db = lancedb.connect_namespace("dir", {"root": self.temp_dir})
assert db._route_pushdown_to_rust is False
def test_async_route_pushdown_to_rust_for_native_rest(self):
"""The async connection must not silently bypass the read-freshness fix:
a natively-built rest connection defers pushdown to Rust (regression test
for the async path omitting the freshness header)."""
db = lancedb.connect_namespace_async(
"rest",
{"uri": "http://localhost:12345"},
namespace_client_pushdown_operations=["QueryTable"],
)
assert db._route_pushdown_to_rust is True
def test_async_route_pushdown_to_rust_false_for_dir(self):
"""The async non-native (dir) connection keeps the Python pushdown path."""
db = lancedb.connect_namespace_async("dir", {"root": self.temp_dir})
assert db._route_pushdown_to_rust is False
def test_lance_table_to_arrow_uses_query_pushdown(self):
namespace_client = _NamespaceClient()
table = _namespace_lance_table(namespace_client)
@@ -850,13 +816,10 @@ class TestPushdownOperations:
["geneva", "hist"],
["geneva", "hist"],
]
# Unlimited reads cap k at i32::MAX (the namespace query_table `k`
# field is i32); sys.maxsize would overflow the Rust binding.
assert [request.k for request in namespace_client.requests] == [
_MAX_QUERY_K,
_MAX_QUERY_K,
sys.maxsize,
sys.maxsize,
]
assert all(r.k <= 2**31 - 1 for r in namespace_client.requests)
@pytest.mark.asyncio
@@ -911,13 +874,10 @@ class TestAsyncPushdownOperations:
["geneva", "hist"],
["geneva", "hist"],
]
# Unlimited reads cap k at i32::MAX (the namespace query_table `k`
# field is i32); sys.maxsize would overflow the Rust binding.
assert [request.k for request in namespace_client.requests] == [
_MAX_QUERY_K,
_MAX_QUERY_K,
sys.maxsize,
sys.maxsize,
]
assert all(r.k <= 2**31 - 1 for r in namespace_client.requests)
def test_local_table_to_arrow_and_to_pandas_are_unchanged(tmp_path):

View File

@@ -188,18 +188,6 @@ def test_nested_struct_list():
assert schema == expect_schema
def test_bare_generic_raises_type_error():
# A bare, unparameterised List/Tuple has no element type to map to Arrow.
# It should raise a clear TypeError, not crash with AttributeError: __args__.
for bare in (List, Tuple):
class TestModel(pydantic.BaseModel):
items: bare
with pytest.raises(TypeError, match="unsupported type"):
pydantic_to_schema(TestModel)
def test_nested_struct_list_optional():
class SplitInfo(pydantic.BaseModel):
start_frame: int

View File

@@ -301,16 +301,6 @@ def test_create_table(mem_db: DBConnection):
assert expected == tbl
def test_create_table_rejects_single_dictionary(mem_db: DBConnection):
data = {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}
with pytest.raises(ValueError) as excep_info:
mem_db.create_table("test", data=data)
assert (
str(excep_info.value) == "Cannot create or add rows from a single dictionary. "
"Use a list of dictionaries instead."
)
def test_empty_table(mem_db: DBConnection):
schema = pa.schema(
[
@@ -340,8 +330,8 @@ def test_add_dictionary(mem_db: DBConnection):
with pytest.raises(ValueError) as excep_info:
tbl.add(data=data)
assert (
str(excep_info.value) == "Cannot create or add rows from a single dictionary. "
"Use a list of dictionaries instead."
str(excep_info.value)
== "Cannot add a single dictionary to a table. Use a list."
)

View File

@@ -149,21 +149,6 @@ def test_value_to_sql_dict():
assert value_to_sql({}) == "named_struct()"
def test_value_to_sql_dict_key_escaping():
# Struct field names that contain a single quote must be escaped (doubled)
# the same way string values are, otherwise value_to_sql emits invalid SQL
# such as named_struct('it's', 1).
assert value_to_sql({"it's": 1}) == "named_struct('it''s', 1)"
assert (
value_to_sql({"o'brien": "d'angelo"}) == "named_struct('o''brien', 'd''angelo')"
)
# Escaping also applies to keys of nested structs.
assert (
value_to_sql({"outer": {"in'r": 1}})
== "named_struct('outer', named_struct('in''r', 1))"
)
def test_value_to_sql_numpy_scalars():
# numpy scalars (e.g. pulled from an ndarray or a pandas column) must
# convert the same way as their native Python counterparts. np.float64

View File

@@ -18,10 +18,7 @@ use lancedb::{
connection::Connection as LanceConnection,
connection::NamespaceClientPushdownOperation,
database::namespace::LanceNamespaceDatabase,
database::{
CreateFunctionRequest, CreateMaterializedViewRequest, CreateTableMode, Database,
ReadConsistency, RefreshMaterializedViewRequest, TableLineageRequest,
},
database::{CreateTableMode, Database, ReadConsistency},
};
use pyo3::{
Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
@@ -30,92 +27,6 @@ use pyo3::{
types::{PyDict, PyDictMethods},
};
/// A registered function, as returned by `list_functions`.
#[pyclass(get_all)]
#[derive(Clone)]
pub struct FunctionInfo {
pub name: String,
pub language: String,
pub return_type: String,
pub description: String,
}
/// A registered materialized view definition.
#[pyclass(get_all)]
#[derive(Clone)]
pub struct MaterializedViewInfo {
pub name: String,
pub source_table: String,
pub projection: Vec<String>,
pub udf_columns: Vec<String>,
pub filter: Option<String>,
pub auto_refresh: bool,
}
/// One inflight server-side job.
#[pyclass(get_all)]
#[derive(Clone)]
pub struct JobInfo {
pub table: String,
pub job_id: String,
pub job_type: String,
pub state: String,
pub column: Option<String>,
pub age_seconds: Option<i64>,
pub command: Option<String>,
pub units_done: Option<i64>,
pub units_total: Option<i64>,
pub committed: bool,
pub rows_skipped: u64,
pub error: Option<String>,
}
/// One durable, completed/terminal server-side job record (SHOW JOB HISTORY).
#[pyclass(get_all)]
#[derive(Clone)]
pub struct JobHistoryEntry {
pub table: String,
pub job_id: String,
pub job_type: String,
pub state: String,
pub column: Option<String>,
pub created_ms: i64,
pub updated_ms: i64,
pub completed_ms: Option<i64>,
pub rows_processed: Option<i64>,
pub rows_skipped: Option<i64>,
pub error: Option<String>,
pub events: Option<String>,
}
/// One per-row UDF error recorded by `error_policy=skip` (SHOW ERRORS).
#[pyclass(get_all)]
#[derive(Clone)]
pub struct JobErrorEntry {
pub job_id: String,
pub table: String,
pub column: String,
pub error_type: String,
pub error_message: String,
pub fragment_id: Option<i64>,
pub source_row_id: Option<i64>,
pub table_version: Option<i64>,
pub age_seconds: Option<i64>,
}
/// The plan a REFRESH MATERIALIZED VIEW would execute (EXPLAIN REFRESH).
#[pyclass(get_all)]
#[derive(Clone)]
pub struct MvRefreshPlan {
pub table_name: String,
pub has_work: bool,
pub source_version: u64,
pub last_refreshed_version: Option<u64>,
pub full_refresh: bool,
pub rebuild: bool,
pub units_total: u64,
}
#[pyclass]
pub struct Connection {
inner: Option<LanceConnection>,
@@ -399,308 +310,6 @@ impl Connection {
})
}
#[pyo3(signature = (name, language, return_type, body, options=None))]
pub fn create_function(
self_: PyRef<'_, Self>,
name: String,
language: String,
return_type: String,
body: String,
options: Option<HashMap<String, String>>,
) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.get_inner()?.clone();
future_into_py(self_.py(), async move {
inner
.create_function(CreateFunctionRequest {
name,
language,
return_type,
body,
options: options.unwrap_or_default(),
})
.await
.infer_error()
})
}
pub fn list_functions(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.get_inner()?.clone();
future_into_py(self_.py(), async move {
let functions = inner.list_functions().await.infer_error()?;
Ok(functions
.into_iter()
.map(|f| FunctionInfo {
name: f.name,
language: f.language,
return_type: f.return_type,
description: f.description,
})
.collect::<Vec<_>>())
})
}
pub fn drop_function(self_: PyRef<'_, Self>, name: String) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.get_inner()?.clone();
future_into_py(self_.py(), async move {
inner.drop_function(&name).await.infer_error()
})
}
#[pyo3(signature = (name, query, auto_refresh=false, with_no_data=false, partition_by=None))]
pub fn create_materialized_view(
self_: PyRef<'_, Self>,
name: String,
query: String,
auto_refresh: bool,
with_no_data: bool,
partition_by: Option<String>,
) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.get_inner()?.clone();
future_into_py(self_.py(), async move {
inner
.create_materialized_view(CreateMaterializedViewRequest {
name,
query,
auto_refresh,
with_no_data,
partition_by,
})
.await
.infer_error()
})
}
#[pyo3(signature = (name, full=false, src_version=None, num_workers=None, max_workers=None))]
pub fn refresh_materialized_view(
self_: PyRef<'_, Self>,
name: String,
full: bool,
src_version: Option<u64>,
num_workers: Option<u32>,
max_workers: Option<u32>,
) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.get_inner()?.clone();
future_into_py(self_.py(), async move {
inner
.refresh_materialized_view(RefreshMaterializedViewRequest {
name,
full,
src_version,
num_workers,
max_workers,
})
.await
.infer_error()
})
}
/// Derived-compute lineage of a table/view (or column), returned as the
/// server's lineage JSON string (the Python layer parses it).
pub fn table_lineage(
self_: PyRef<'_, Self>,
name: String,
column: Option<String>,
direction: Option<String>,
depth: Option<u32>,
) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.get_inner()?.clone();
future_into_py(self_.py(), async move {
inner
.table_lineage(TableLineageRequest {
name,
column,
direction,
depth,
})
.await
.infer_error()
})
}
#[pyo3(signature = (name, full=false, src_version=None))]
pub fn explain_refresh_materialized_view(
self_: PyRef<'_, Self>,
name: String,
full: bool,
src_version: Option<u64>,
) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.get_inner()?.clone();
future_into_py(self_.py(), async move {
let p = inner
.explain_refresh_materialized_view(&name, full, src_version)
.await
.infer_error()?;
Ok(MvRefreshPlan {
table_name: p.table_name,
has_work: p.has_work,
source_version: p.source_version,
last_refreshed_version: p.last_refreshed_version,
full_refresh: p.full_refresh,
rebuild: p.rebuild,
units_total: p.units_total,
})
})
}
pub fn alter_materialized_view(
self_: PyRef<'_, Self>,
name: String,
auto_refresh: bool,
) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.get_inner()?.clone();
future_into_py(self_.py(), async move {
inner
.alter_materialized_view(&name, auto_refresh)
.await
.infer_error()
})
}
pub fn drop_materialized_view(
self_: PyRef<'_, Self>,
name: String,
) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.get_inner()?.clone();
future_into_py(self_.py(), async move {
inner.drop_materialized_view(&name).await.infer_error()
})
}
pub fn list_materialized_views(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.get_inner()?.clone();
future_into_py(self_.py(), async move {
let views = inner.list_materialized_views().await.infer_error()?;
Ok(views
.into_iter()
.map(|v| MaterializedViewInfo {
name: v.name,
source_table: v.source_table,
projection: v.projection,
udf_columns: v.udf_columns,
filter: v.filter,
auto_refresh: v.auto_refresh,
})
.collect::<Vec<_>>())
})
}
pub fn list_jobs(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.get_inner()?.clone();
future_into_py(self_.py(), async move {
let jobs = inner.list_jobs().await.infer_error()?;
Ok(jobs
.into_iter()
.map(|j| JobInfo {
table: j.table,
job_id: j.job_id,
job_type: j.job_type,
state: j.state,
column: j.column,
age_seconds: j.age_seconds,
command: j.command,
units_done: j.units_done,
units_total: j.units_total,
committed: j.committed,
rows_skipped: j.rows_skipped,
error: j.error,
})
.collect::<Vec<_>>())
})
}
pub fn cancel_job(self_: PyRef<'_, Self>, job_id: String) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.get_inner()?.clone();
future_into_py(self_.py(), async move {
inner.cancel_job(&job_id).await.infer_error()
})
}
#[pyo3(signature = (job_id, table=None))]
pub fn get_job(
self_: PyRef<'_, Self>,
job_id: String,
table: Option<String>,
) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.get_inner()?.clone();
future_into_py(self_.py(), async move {
let job = inner
.get_job(&job_id, table.as_deref())
.await
.infer_error()?;
Ok(job.map(|j| JobInfo {
table: j.table,
job_id: j.job_id,
job_type: j.job_type,
state: j.state,
column: j.column,
age_seconds: j.age_seconds,
command: j.command,
units_done: j.units_done,
units_total: j.units_total,
committed: j.committed,
rows_skipped: j.rows_skipped,
error: j.error,
}))
})
}
#[pyo3(signature = (job_id=None))]
pub fn job_history(
self_: PyRef<'_, Self>,
job_id: Option<String>,
) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.get_inner()?.clone();
future_into_py(self_.py(), async move {
let rows = inner.job_history(job_id.as_deref()).await.infer_error()?;
Ok(rows
.into_iter()
.map(|r| JobHistoryEntry {
table: r.table,
job_id: r.job_id,
job_type: r.job_type,
state: r.state,
column: r.column,
created_ms: r.created_ms,
updated_ms: r.updated_ms,
completed_ms: r.completed_ms,
rows_processed: r.rows_processed,
rows_skipped: r.rows_skipped,
error: r.error,
events: r.events,
})
.collect::<Vec<_>>())
})
}
#[pyo3(signature = (job_id=None, table=None))]
pub fn errors(
self_: PyRef<'_, Self>,
job_id: Option<String>,
table: Option<String>,
) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.get_inner()?.clone();
future_into_py(self_.py(), async move {
let rows = inner
.errors(job_id.as_deref(), table.as_deref())
.await
.infer_error()?;
Ok(rows
.into_iter()
.map(|e| JobErrorEntry {
job_id: e.job_id,
table: e.table,
column: e.column,
error_type: e.error_type,
error_message: e.error_message,
fragment_id: e.fragment_id,
source_row_id: e.source_row_id,
table_version: e.table_version,
age_seconds: e.age_seconds,
})
.collect::<Vec<_>>())
})
}
#[pyo3(signature = (cur_name, new_name, cur_namespace_path=None, new_namespace_path=None))]
pub fn rename_table(
self_: PyRef<'_, Self>,
@@ -930,7 +539,7 @@ impl Connection {
}
#[pyfunction]
#[pyo3(signature = (uri, api_key=None, region=None, host_override=None, read_consistency_interval=None, client_config=None, storage_options=None, session=None, manifest_enabled=false, namespace_client_properties=None, oauth_config=None))]
#[pyo3(signature = (uri, api_key=None, region=None, host_override=None, read_consistency_interval=None, client_config=None, storage_options=None, session=None, manifest_enabled=false, namespace_client_properties=None))]
#[allow(clippy::too_many_arguments)]
pub fn connect(
py: Python<'_>,
@@ -944,7 +553,6 @@ pub fn connect(
session: Option<crate::session::Session>,
manifest_enabled: bool,
namespace_client_properties: Option<HashMap<String, String>>,
oauth_config: Option<crate::oauth::PyOAuthConfig>,
) -> PyResult<Bound<'_, PyAny>> {
future_into_py(py, async move {
let mut builder = lancedb::connect(&uri);
@@ -974,11 +582,6 @@ pub fn connect(
if let Some(client_config) = client_config {
builder = builder.client_config(client_config.into());
}
if let Some(oauth_config) = oauth_config {
let config: lancedb::remote::oauth::OAuthConfig =
oauth_config.try_into().infer_error()?;
builder = builder.oauth_config(config);
}
if let Some(session) = session {
builder = builder.session(session.inner.clone());
}
@@ -1007,38 +610,24 @@ pub fn connect_namespace_client(
namespace_client_impl: Option<String>,
namespace_client_properties: Option<HashMap<String, String>>,
) -> PyResult<Connection> {
let namespace_client = extract_namespace_arc(py, namespace_client)?;
let read_consistency_interval = read_consistency_interval.map(Duration::from_secs_f64);
let namespace_client_pushdown_operations =
parse_namespace_client_pushdown_operations(namespace_client_pushdown_operations)?;
let ns_impl = namespace_client_impl.unwrap_or_else(|| "python".to_string());
let ns_properties = namespace_client_properties.unwrap_or_default();
let storage_options = storage_options.unwrap_or_default();
let session = session.map(|s| s.inner.clone());
// Prefer building the namespace natively from (impl, properties) so the
// read-freshness provider installed
let database = if build_namespace_natively(namespace_client_impl.as_deref(), &ns_properties) {
let ns_impl = namespace_client_impl.expect("impl present per build_namespace_natively");
crate::runtime::block_on(LanceNamespaceDatabase::connect(
&ns_impl,
ns_properties,
storage_options,
read_consistency_interval,
session,
namespace_client_pushdown_operations,
))
.infer_error()?
} else {
let namespace_client = extract_namespace_arc(py, namespace_client)?;
LanceNamespaceDatabase::from_namespace_client(
namespace_client,
namespace_client_impl.unwrap_or_else(|| "python".to_string()),
ns_properties,
storage_options,
read_consistency_interval,
session,
namespace_client_pushdown_operations,
)
};
let database = LanceNamespaceDatabase::from_namespace_client(
namespace_client,
ns_impl,
ns_properties,
storage_options,
read_consistency_interval,
session,
namespace_client_pushdown_operations,
);
Ok(Connection::new(LanceConnection::new(
Arc::new(database),
@@ -1046,16 +635,6 @@ pub fn connect_namespace_client(
)))
}
/// Whether to build the namespace natively (from impl + properties) instead of
/// wrapping a pre-built client. Native construction is required for the
/// read-freshness provider to be installed
fn build_namespace_natively(
namespace_client_impl: Option<&str>,
namespace_client_properties: &HashMap<String, String>,
) -> bool {
matches!(namespace_client_impl, Some("rest")) && !namespace_client_properties.is_empty()
}
#[derive(FromPyObject)]
pub struct PyClientConfig {
user_agent: String,
@@ -1154,36 +733,3 @@ impl From<PyClientConfig> for lancedb::remote::ClientConfig {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn props(pairs: &[(&str, &str)]) -> HashMap<String, String> {
pairs
.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect()
}
#[test]
fn native_build_only_for_rest_with_properties() {
let rest = props(&[("uri", "http://localhost:10024")]);
// rest + non-empty properties -> build natively (installs the
// read-freshness provider so checkout_latest() busts the server cache).
assert!(build_namespace_natively(Some("rest"), &rest));
// dir is local (no server cache) -> wrap the pre-built client unchanged.
assert!(!build_namespace_natively(
Some("dir"),
&props(&[("root", "/tmp")])
));
// No impl: only a pre-built client was handed in -> wrap it as-is.
assert!(!build_namespace_natively(None, &rest));
// rest but no properties: nothing to build a connection from -> wrap.
assert!(!build_namespace_natively(Some("rest"), &HashMap::new()));
}
}

View File

@@ -319,53 +319,11 @@ pub struct IndexConfig {
#[pymethods]
impl IndexConfig {
pub fn __repr__(&self, py: Python<'_>) -> String {
let mut fields = vec![
format!("name={:?}", self.name),
format!("index_type={:?}", self.index_type),
format!("columns={:?}", self.columns),
];
if let Some(v) = &self.index_uuid {
fields.push(format!("index_uuid={:?}", v));
}
if let Some(v) = &self.type_url {
fields.push(format!("type_url={:?}", v));
}
if let Some(v) = self.created_at {
// Render the datetime's own Python repr so the value round-trips,
// falling back to RFC 3339 if the conversion ever fails.
let rendered = v
.into_pyobject(py)
.ok()
.and_then(|obj| obj.into_any().repr().ok())
.map(|r| r.to_string())
.unwrap_or_else(|| v.to_rfc3339());
fields.push(format!("created_at={}", rendered));
}
if let Some(v) = self.num_indexed_rows {
fields.push(format!("num_indexed_rows={}", fmt_thousands(v)));
}
if let Some(v) = self.num_unindexed_rows {
fields.push(format!("num_unindexed_rows={}", fmt_thousands(v)));
}
if let Some(v) = self.size_bytes {
fields.push(format!("size_bytes={}", fmt_thousands(v)));
}
if let Some(v) = self.num_segments {
fields.push(format!("num_segments={}", v));
}
if let Some(v) = self.index_version {
fields.push(format!("index_version={}", v));
}
if let Some(v) = &self.index_details {
let details = v
.bind(py)
.repr()
.map(|r| r.to_string())
.unwrap_or_else(|_| "<unavailable>".to_string());
fields.push(format!("index_details={}", details));
}
format!("IndexConfig({})", fields.join(", "))
pub fn __repr__(&self) -> String {
format!(
"Index({}, columns={:?}, name=\"{}\")",
self.index_type, self.columns, self.name
)
}
// For backwards-compatibility with the old sync SDK, we also support getting
@@ -394,23 +352,6 @@ impl IndexConfig {
}
}
/// Format an integer with `_` thousands separators, e.g. `24_500_213`.
///
/// Underscores are valid Python int-literal syntax, so the repr stays
/// copy-pasteable and machine-parseable while remaining readable.
fn fmt_thousands(n: u64) -> String {
let digits = n.to_string();
let bytes = digits.as_bytes();
let mut out = String::with_capacity(digits.len() + digits.len() / 3);
for (i, b) in bytes.iter().enumerate() {
if i > 0 && (bytes.len() - i).is_multiple_of(3) {
out.push('_');
}
out.push(*b as char);
}
out
}
fn parse_index_details(py: Python<'_>, s: String) -> Py<PyAny> {
let json = py.import("json").expect("json module is always available");
match json.call_method1("loads", (s.as_str(),)) {

View File

@@ -26,7 +26,6 @@ pub mod expr;
pub mod header;
pub mod index;
pub mod namespace;
pub mod oauth;
pub mod permutation;
pub mod query;
pub mod runtime;
@@ -41,11 +40,6 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
.write_style("LANCEDB_LOG_STYLE");
env_logger::init_from_env(env);
m.add_class::<Connection>()?;
m.add_class::<connection::FunctionInfo>()?;
m.add_class::<connection::MaterializedViewInfo>()?;
m.add_class::<connection::JobInfo>()?;
m.add_class::<connection::JobHistoryEntry>()?;
m.add_class::<connection::JobErrorEntry>()?;
m.add_class::<Session>()?;
m.add_class::<Table>()?;
m.add_class::<IndexConfig>()?;

View File

@@ -1,72 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use pyo3::FromPyObject;
use lancedb::error::Error;
use lancedb::remote::oauth::{OAuthConfig, OAuthFlow};
/// Python-side OAuth configuration, extracted via FromPyObject.
/// Maps to `lancedb.remote.oauth.OAuthConfig` Python dataclass.
#[derive(FromPyObject)]
pub struct PyOAuthConfig {
pub issuer_url: String,
pub client_id: String,
pub scopes: Vec<String>,
pub flow: String,
pub client_secret: Option<String>,
pub managed_identity_client_id: Option<String>,
pub refresh_buffer_secs: Option<u64>,
}
impl TryFrom<PyOAuthConfig> for OAuthConfig {
type Error = Error;
fn try_from(py: PyOAuthConfig) -> Result<Self, Self::Error> {
let flow = match py.flow.as_str() {
"client_credentials" => OAuthFlow::ClientCredentials,
"azure_managed_identity" => OAuthFlow::AzureManagedIdentity {
client_id: py.managed_identity_client_id,
},
other => {
return Err(Error::InvalidInput {
message: format!("Unknown OAuth flow type: {other}"),
});
}
};
Ok(Self {
issuer_url: py.issuer_url,
client_id: py.client_id,
client_secret: py.client_secret,
scopes: py.scopes,
flow,
refresh_buffer_secs: py.refresh_buffer_secs,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_unknown_oauth_flow_returns_invalid_input() {
let config = PyOAuthConfig {
issuer_url: "https://issuer.example.com".to_string(),
client_id: "client-id".to_string(),
scopes: vec!["scope".to_string()],
flow: "typo".to_string(),
client_secret: None,
managed_identity_client_id: None,
refresh_buffer_secs: None,
};
let err = OAuthConfig::try_from(config).unwrap_err();
assert!(matches!(
err,
Error::InvalidInput { message }
if message == "Unknown OAuth flow type: typo"
));
}
}

View File

@@ -56,15 +56,6 @@ fn get_runtime() -> &'static runtime::Runtime {
unsafe { &*new_ptr }
}
/// Block the current thread on a future using the shared runtime.
///
/// For sync `#[pyfunction]`s that need to drive an async operation (e.g.
/// building a namespace client). Must not be called from within the runtime's
/// own worker threads.
pub fn block_on<F: std::future::Future>(fut: F) -> F::Output {
get_runtime().block_on(fut)
}
/// Runs in async-signal context after `fork()` in the child. We can only
/// touch atomics here; we deliberately leak the previous runtime because
/// dropping a tokio `Runtime` would try to join its (now-dead) worker

View File

@@ -17,8 +17,8 @@ use arrow::{
pyarrow::{FromPyArrow, PyArrowType, ToPyArrow},
};
use lancedb::table::{
AddDataMode, ColumnAlteration, Duration, FieldMetadataUpdate, LoadColumnsRequest,
NewColumnTransform, OptimizeAction, OptimizeOptions, Ref, Table as LanceDbTable,
AddDataMode, ColumnAlteration, Duration, FieldMetadataUpdate, NewColumnTransform,
OptimizeAction, OptimizeOptions, Ref, Table as LanceDbTable,
};
use pyo3::{
Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
@@ -1060,83 +1060,6 @@ impl Table {
})
}
pub fn add_computed_columns(
self_: PyRef<'_, Self>,
columns: Vec<(String, String)>,
expression: String,
) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {
inner
.add_computed_columns(&columns, &expression)
.await
.infer_error()
})
}
#[pyo3(signature = (columns, where_clause=None, num_workers=None, max_workers=None, batch_size=None, priority=None))]
pub fn refresh_column(
self_: PyRef<'_, Self>,
columns: Vec<String>,
where_clause: Option<String>,
num_workers: Option<u32>,
max_workers: Option<u32>,
batch_size: Option<u32>,
priority: Option<String>,
) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {
inner
.refresh_column(
&columns,
where_clause,
num_workers,
max_workers,
batch_size,
priority,
)
.await
.infer_error()
})
}
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (source_uris, source_format, target_key, columns, source_key=None, source_storage_options=None, on_missing=None, num_workers=None, max_workers=None, batch_size=None, commit_granularity=None, priority=None))]
pub fn load_columns(
self_: PyRef<'_, Self>,
source_uris: Vec<String>,
source_format: String,
target_key: String,
columns: Vec<(String, Option<String>)>,
source_key: Option<String>,
source_storage_options: Option<std::collections::HashMap<String, String>>,
on_missing: Option<String>,
num_workers: Option<u32>,
max_workers: Option<u32>,
batch_size: Option<u32>,
commit_granularity: Option<u32>,
priority: Option<String>,
) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner_ref()?.clone();
let request = LoadColumnsRequest {
source_uris,
source_format,
source_storage_options,
target_key,
source_key,
columns,
on_missing,
num_workers,
max_workers,
batch_size,
commit_granularity,
priority,
};
future_into_py(self_.py(), async move {
inner.load_columns(request).await.infer_error()
})
}
pub fn add_columns(
self_: PyRef<'_, Self>,
definitions: Vec<(String, String)>,

View File

@@ -1,33 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import importlib.util
import sys
from pathlib import Path
def _load_oauth_module():
oauth_path = (
Path(__file__).parents[1] / "python" / "lancedb" / "remote" / "oauth.py"
)
spec = importlib.util.spec_from_file_location("lancedb_remote_oauth", oauth_path)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
sys.modules[spec.name] = module
spec.loader.exec_module(module)
return module
def test_oauth_config_repr_redacts_client_secret():
oauth = _load_oauth_module()
config = oauth.OAuthConfig(
issuer_url="https://issuer.example.com",
client_id="client-id",
scopes=["scope"],
client_secret="super-secret",
)
rendered = repr(config)
assert "super-secret" not in rendered
assert "client_secret" not in rendered

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb"
version = "0.31.0-beta.4"
version = "0.30.1-beta.2"
edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true
@@ -50,7 +50,7 @@ lance-namespace = { workspace = true }
lance-namespace-impls = { workspace = true }
moka = { workspace = true }
pin-project = { workspace = true }
tokio = { version = "1.23", features = ["rt-multi-thread", "sync"] }
tokio = { version = "1.23", features = ["rt-multi-thread"] }
log.workspace = true
async-trait = "0"
bytes = "1"
@@ -75,7 +75,6 @@ reqwest = { version = "0.12.0", default-features = false, features = [
"stream",
], optional = true }
http = { version = "1", optional = true } # Matching what is in reqwest
urlencoding = { version = "2", optional = true }
uuid = { version = "1.7.0", features = ["v4", "v5"] }
polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
polars = { version = ">=0.37,<0.40.0", optional = true }
@@ -94,7 +93,6 @@ semver = { workspace = true }
anyhow = "1"
tempfile = "3.5.0"
random_word = { version = "0.4.3", features = ["en"] }
tokio = { version = "1.23", features = ["io-util", "macros", "net", "rt-multi-thread", "sync"] }
uuid = { version = "1.7.0", features = ["v4"] }
walkdir = "2"
aws-sdk-dynamodb = { version = "1.55.0" }
@@ -131,13 +129,7 @@ huggingface = [
"lance-namespace-impls/dir-huggingface",
]
dynamodb = ["lance/dynamodb", "aws"]
remote = [
"dep:reqwest",
"dep:http",
"dep:urlencoding",
"lance-namespace-impls/rest",
"lance-namespace-impls/rest-adapter",
]
remote = ["dep:reqwest", "dep:http", "lance-namespace-impls/rest", "lance-namespace-impls/rest-adapter"]
fp16kernels = ["lance-linalg/fp16kernels"]
s3-test = []
bedrock = ["dep:aws-sdk-bedrockruntime"]

View File

@@ -1,435 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! Lance blob v2 columns store large binary payloads out of line.
//!
//! Declare a column with [`blob`]. On write, [`crate::table::Table::add`] coerces
//! raw `Binary` / `LargeBinary` into the blob struct layout. Queries return
//! small descriptors, not bytes.
//!
//! Blob tables require Lance file format >= 2.2 and stable row ids at create.
use std::sync::Arc;
use arrow_array::builder::LargeBinaryBuilder;
use arrow_array::{Array, LargeBinaryArray, RecordBatch, StructArray, UInt8Array, UInt64Array};
use arrow_schema::{DataType, Field, Schema};
use lance::dataset::{Dataset, WriteParams};
use lance_arrow::FieldExt;
use lance_core::datatypes::parse_field_path;
use lance_encoding::version::LanceFileVersion;
use crate::error::{Error, Result};
pub use lance::dataset::BlobFile;
/// Creates an Arrow field for a Lance blob v2 column.
///
/// `Struct<data, uri>` with the `lance.blob.v2` marker. Same layout Lance
/// expects on write.
///
/// A blob column may be top-level or nested inside a struct or list. Nested
/// blobs are addressed by a dotted path (e.g. `info.blob`) in the read APIs.
///
/// ```
/// use arrow_schema::{DataType, Field, Schema};
///
/// let schema = Schema::new(vec![
/// Field::new("id", DataType::Int64, false),
/// lancedb::blob("image", true),
/// ]);
/// ```
pub fn blob(name: impl AsRef<str>, nullable: bool) -> Field {
lance::blob::blob_field(name.as_ref(), nullable)
}
/// Returns true if `field` is a blob v2 column.
///
/// ```
/// let field = lancedb::blob("image", true);
/// assert!(lancedb::blob::is_blob(&field));
/// ```
pub fn is_blob(field: &Field) -> bool {
field.is_blob_v2()
}
/// Returns true if `field`, or any field nested under it, is a blob v2 column.
fn field_tree_has_blob_v2(field: &Field) -> bool {
if field.is_blob_v2() {
return true;
}
match field.data_type() {
DataType::Struct(children) => children.iter().any(|c| field_tree_has_blob_v2(c)),
DataType::List(child) | DataType::LargeList(child) | DataType::FixedSizeList(child, _) => {
field_tree_has_blob_v2(child)
}
_ => false,
}
}
/// Collects the dotted paths of blob v2 columns under `field`, into `paths`.
fn collect_blob_paths(field: &Field, prefix: &str, paths: &mut Vec<String>) {
let path = if prefix.is_empty() {
field.name().clone()
} else {
format!("{prefix}.{}", field.name())
};
if field.is_blob_v2() {
paths.push(path);
return;
}
match field.data_type() {
DataType::Struct(children) => {
for child in children {
collect_blob_paths(child, &path, paths);
}
}
DataType::List(child) | DataType::LargeList(child) | DataType::FixedSizeList(child, _) => {
collect_blob_paths(child, &path, paths)
}
_ => {}
}
}
/// Returns true if `schema` declares any blob v2 column, including nested ones.
pub(crate) fn has_blob_columns(schema: &Schema) -> bool {
schema.fields().iter().any(|f| field_tree_has_blob_v2(f))
}
/// Blob v2 column paths in `schema`, declaration order preserved. Nested blobs
/// are dotted paths (e.g. `info.blob`).
pub(crate) fn blob_column_names(schema: &Schema) -> Vec<String> {
let mut paths = Vec::new();
for field in schema.fields() {
collect_blob_paths(field, "", &mut paths);
}
paths
}
/// Bumps storage format to at least [`LanceFileVersion::V2_2`] for blob schemas.
pub(crate) fn ensure_blob_storage_version(schema: &Schema, params: &mut WriteParams) {
if !has_blob_columns(schema) {
return;
}
let resolved = params
.data_storage_version
.unwrap_or(LanceFileVersion::Stable)
.resolve();
if resolved < LanceFileVersion::V2_2 {
params.data_storage_version = Some(LanceFileVersion::V2_2);
}
}
/// Validate that `column` exists and is a blob v2 column.
///
/// Legacy v1 columns (`lance-encoding:blob`) error with a migration hint.
pub(crate) fn ensure_blob_v2_column(
schema: &lance_core::datatypes::Schema,
column: &str,
) -> Result<()> {
match schema.field(column) {
Some(field) if field.is_blob_v2() => Ok(()),
Some(field) if field.is_blob() => Err(Error::InvalidInput {
message: format!(
"column '{column}' is a legacy blob column; blob APIs require blob v2 columns \
(ARROW:extension:name = \"lance.blob.v2\")"
),
}),
Some(_) => Err(Error::InvalidInput {
message: format!("column '{column}' is not a blob column"),
}),
None => Err(Error::InvalidInput {
message: format!("no column named '{column}' in this table"),
}),
}
}
/// Returns the leaf descriptor `StructArray` for `column` in a descriptor batch.
fn leaf_descriptor_struct<'a>(batch: &'a RecordBatch, column: &str) -> Result<&'a StructArray> {
let path = parse_field_path(column).map_err(|e| Error::InvalidInput {
message: format!("invalid blob column path '{column}': {e}"),
})?;
let not_struct = || Error::Runtime {
message: format!("blob column '{column}' did not read back as a descriptor struct"),
};
let mut current = batch
.column_by_name(&path[0])
.and_then(|c| c.as_any().downcast_ref::<StructArray>())
.ok_or_else(not_struct)?;
for segment in &path[1..] {
current = current
.column_by_name(segment)
.and_then(|c| c.as_any().downcast_ref::<StructArray>())
.ok_or_else(not_struct)?;
}
Ok(current)
}
/// Null rows in `row_ids`, from a descriptor take.
///
/// Lance `read_blobs` / `take_blobs` skip null rows (`kind == 0 && position == 0 && size == 0`).
/// TODO(lance): aligned read API would drop this pass.
async fn blob_null_mask(
dataset: &Arc<Dataset>,
column: &str,
row_ids: &[u64],
) -> Result<Vec<bool>> {
let projection = dataset.schema().project(&[column])?;
let descriptors = dataset.take_builder(row_ids, projection)?.execute().await?;
if descriptors.num_rows() != row_ids.len() {
return Err(Error::InvalidInput {
message: format!(
"blob take for column '{column}' requested {} row ids but only {} exist in the \
table; pass row ids collected from this table",
row_ids.len(),
descriptors.num_rows()
),
});
}
let descriptor_struct = leaf_descriptor_struct(&descriptors, column)?;
let child = |name: &str| {
descriptor_struct
.column_by_name(name)
.ok_or_else(|| Error::Runtime {
message: format!("blob descriptor for '{column}' is missing the '{name}' field"),
})
};
let kinds = child("kind")?
.as_any()
.downcast_ref::<UInt8Array>()
.ok_or_else(|| Error::Runtime {
message: format!("blob descriptor 'kind' for '{column}' is not a UInt8 array"),
})?;
let positions = child("position")?
.as_any()
.downcast_ref::<UInt64Array>()
.ok_or_else(|| Error::Runtime {
message: format!("blob descriptor 'position' for '{column}' is not a UInt64 array"),
})?;
let sizes = child("size")?
.as_any()
.downcast_ref::<UInt64Array>()
.ok_or_else(|| Error::Runtime {
message: format!("blob descriptor 'size' for '{column}' is not a UInt64 array"),
})?;
// Match Lance `collect_blob_entries_v2` skip condition (`BlobKind::Inline` == 0).
Ok((0..descriptor_struct.len())
.map(|i| {
descriptor_struct.is_null(i)
|| kinds.is_null(i)
|| (kinds.value(i) == 0 && positions.value(i) == 0 && sizes.value(i) == 0)
})
.collect())
}
fn non_null_row_ids(row_ids: &[u64], null_mask: &[bool]) -> Vec<u64> {
row_ids
.iter()
.zip(null_mask)
.filter_map(|(row_id, is_null)| (!is_null).then_some(*row_id))
.collect()
}
/// Materialize blob bytes for `row_ids` (same length and order, nulls preserved).
pub(crate) async fn take_blobs_aligned(
dataset: &Arc<Dataset>,
column: &str,
row_ids: &[u64],
) -> Result<LargeBinaryArray> {
ensure_blob_v2_column(dataset.schema(), column)?;
if row_ids.is_empty() {
return Ok(LargeBinaryBuilder::new().finish());
}
let null_mask = blob_null_mask(dataset, column, row_ids).await?;
let non_null_row_ids = non_null_row_ids(row_ids, &null_mask);
let non_null_count = non_null_row_ids.len();
let payloads = if non_null_count == 0 {
Vec::new()
} else {
dataset
.read_blobs(column)?
.with_row_ids(non_null_row_ids)
.preserve_order(true)
.execute()
.await?
};
if payloads.len() != non_null_count {
return Err(Error::Runtime {
message: format!(
"blob read for column '{column}' returned {} payloads for {} non-null rows",
payloads.len(),
non_null_count
),
});
}
let mut builder = LargeBinaryBuilder::new();
let mut payload_idx = 0;
for is_null in &null_mask {
if *is_null {
builder.append_null();
} else {
builder.append_value(payloads[payload_idx].data.as_ref());
payload_idx += 1;
}
}
Ok(builder.finish())
}
/// Open lazy [`BlobFile`] handles for `row_ids` (same length and order, nulls as `None`).
pub(crate) async fn take_blob_files_aligned(
dataset: &Arc<Dataset>,
column: &str,
row_ids: &[u64],
) -> Result<Vec<Option<BlobFile>>> {
ensure_blob_v2_column(dataset.schema(), column)?;
if row_ids.is_empty() {
return Ok(Vec::new());
}
let null_mask = blob_null_mask(dataset, column, row_ids).await?;
let non_null_row_ids = non_null_row_ids(row_ids, &null_mask);
let handles = if non_null_row_ids.is_empty() {
Vec::new()
} else {
dataset.take_blobs(&non_null_row_ids, column).await?
};
if handles.len() != non_null_row_ids.len() {
return Err(Error::Runtime {
message: format!(
"blob take for column '{column}' returned {} handles for {} non-null rows",
handles.len(),
non_null_row_ids.len()
),
});
}
let mut handles = handles.into_iter();
Ok(null_mask
.iter()
.map(|is_null| {
if *is_null {
None
} else {
Some(handles.next().unwrap())
}
})
.collect())
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_schema::DataType;
use lance_arrow::ARROW_EXT_NAME_KEY;
fn blob_schema() -> Schema {
Schema::new(vec![
Field::new("id", DataType::Int64, false),
blob("image", true),
])
}
#[test]
fn blob_field_carries_v2_extension_marker() {
let field = blob("image", true);
assert_eq!(
field.metadata().get(ARROW_EXT_NAME_KEY).map(String::as_str),
Some("lance.blob.v2")
);
assert!(matches!(field.data_type(), DataType::Struct(_)));
}
#[test]
fn has_blob_columns_detects_blob_fields() {
assert!(has_blob_columns(&blob_schema()));
let plain = Schema::new(vec![Field::new("id", DataType::Int64, false)]);
assert!(!has_blob_columns(&plain));
}
#[test]
fn storage_version_bumps_to_v2_2() {
let mut params = WriteParams::default();
ensure_blob_storage_version(&blob_schema(), &mut params);
assert_eq!(
params.data_storage_version.unwrap().resolve(),
LanceFileVersion::V2_2
);
}
#[test]
fn storage_version_overrides_lower_explicit_version() {
let mut params = WriteParams {
data_storage_version: Some(LanceFileVersion::V2_0),
..Default::default()
};
ensure_blob_storage_version(&blob_schema(), &mut params);
assert_eq!(
params.data_storage_version.unwrap().resolve(),
LanceFileVersion::V2_2
);
}
#[test]
fn storage_version_keeps_higher_explicit_version() {
let mut params = WriteParams {
data_storage_version: Some(LanceFileVersion::V2_3),
..Default::default()
};
ensure_blob_storage_version(&blob_schema(), &mut params);
assert_eq!(params.data_storage_version.unwrap(), LanceFileVersion::V2_3);
}
#[test]
fn legacy_v1_blob_column_is_rejected_with_migration_hint() {
let legacy = Field::new("image", DataType::LargeBinary, true).with_metadata(
std::collections::HashMap::from([(
"lance-encoding:blob".to_string(),
"true".to_string(),
)]),
);
let arrow_schema = Schema::new(vec![legacy]);
let lance_schema = lance_core::datatypes::Schema::try_from(&arrow_schema).unwrap();
let err = ensure_blob_v2_column(&lance_schema, "image").unwrap_err();
assert!(matches!(err, Error::InvalidInput { .. }));
assert!(err.to_string().contains("legacy blob column"));
assert!(err.to_string().contains("lance.blob.v2"));
}
#[test]
fn non_blob_and_unknown_columns_are_rejected_by_name() {
let arrow_schema = Schema::new(vec![Field::new("id", DataType::Int64, false)]);
let lance_schema = lance_core::datatypes::Schema::try_from(&arrow_schema).unwrap();
let err = ensure_blob_v2_column(&lance_schema, "id").unwrap_err();
assert!(err.to_string().contains("'id' is not a blob column"));
let err = ensure_blob_v2_column(&lance_schema, "missing").unwrap_err();
assert!(err.to_string().contains("no column named 'missing'"));
}
#[test]
fn blob_column_names_includes_nested_path() {
let blob_field = blob("blob", true);
let info = Field::new(
"info",
DataType::Struct(vec![Field::new("name", DataType::Utf8, false), blob_field].into()),
true,
);
let schema = Schema::new(vec![Field::new("id", DataType::Int64, false), info]);
assert_eq!(blob_column_names(&schema), vec!["info.blob"]);
}
#[test]
fn storage_version_noop_without_blob_columns() {
let schema = Schema::new(vec![Field::new("id", DataType::Int64, false)]);
let mut params = WriteParams::default();
ensure_blob_storage_version(&schema, &mut params);
assert!(params.data_storage_version.is_none());
}
}

View File

@@ -23,10 +23,8 @@ use crate::connection::create_table::CreateTableBuilder;
use crate::data::scannable::Scannable;
use crate::database::listing::ListingDatabase;
use crate::database::{
CloneTableRequest, CreateFunctionRequest, CreateMaterializedViewRequest, Database,
DatabaseOptions, FunctionInfo, JobErrorInfo, JobHistoryInfo, JobInfo, MaterializedViewInfo,
MvRefreshPlan, OpenTableRequest, ReadConsistency, RefreshMaterializedViewRequest,
TableLineageRequest, TableNamesRequest,
CloneTableRequest, Database, DatabaseOptions, OpenTableRequest, ReadConsistency,
TableNamesRequest,
};
use crate::embeddings::{EmbeddingRegistry, MemoryRegistry};
use crate::error::{Error, Result};
@@ -490,113 +488,6 @@ impl Connection {
)
}
// -- Derived compute: functions, materialized views, jobs -------------
// Server-backed features (LanceDB Enterprise / Cloud); local
// databases return NotSupported for now.
/// Register a UDF (CREATE FUNCTION).
pub async fn create_function(&self, request: CreateFunctionRequest) -> Result<()> {
self.internal.create_function(request).await
}
/// List registered functions (SHOW FUNCTIONS).
pub async fn list_functions(&self) -> Result<Vec<FunctionInfo>> {
self.internal.list_functions().await
}
/// Drop a registered function (DROP FUNCTION).
pub async fn drop_function(&self, name: &str) -> Result<()> {
self.internal.drop_function(name).await
}
/// Create a materialized view (CREATE MATERIALIZED VIEW). Returns
/// the initial-population job id, absent when `with_no_data`.
pub async fn create_materialized_view(
&self,
request: CreateMaterializedViewRequest,
) -> Result<Option<String>> {
self.internal.create_materialized_view(request).await
}
/// Refresh a materialized view; returns the refresh job id.
pub async fn refresh_materialized_view(
&self,
request: RefreshMaterializedViewRequest,
) -> Result<String> {
self.internal.refresh_materialized_view(request).await
}
/// Derived-compute lineage of a table/view (or column), as server-defined
/// JSON. Read-only.
pub async fn table_lineage(&self, request: TableLineageRequest) -> Result<String> {
self.internal.table_lineage(request).await
}
/// Plan a materialized-view refresh without submitting work
/// (EXPLAIN REFRESH).
pub async fn explain_refresh_materialized_view(
&self,
name: &str,
full: bool,
src_version: Option<u64>,
) -> Result<MvRefreshPlan> {
self.internal
.explain_refresh_materialized_view(name, full, src_version)
.await
}
/// Update a materialized view's options (ALTER MATERIALIZED VIEW).
pub async fn alter_materialized_view(&self, name: &str, auto_refresh: bool) -> Result<()> {
self.internal
.alter_materialized_view(name, auto_refresh)
.await
}
/// Drop a materialized view definition (DROP MATERIALIZED VIEW).
pub async fn drop_materialized_view(&self, name: &str) -> Result<()> {
self.internal.drop_materialized_view(name).await
}
/// List registered materialized view definitions.
pub async fn list_materialized_views(&self) -> Result<Vec<MaterializedViewInfo>> {
self.internal.list_materialized_views().await
}
/// List inflight server-side jobs across the database's tables.
pub async fn list_jobs(&self) -> Result<Vec<JobInfo>> {
self.internal.list_jobs().await
}
/// Cancel an inflight server-side job by id. Returns true if a
/// matching inflight job was flagged for cancellation.
pub async fn cancel_job(&self, job_id: &str) -> Result<bool> {
self.internal.cancel_job(job_id).await
}
/// Look up a single server-side job by id -- the `wait()`/status poll path.
/// `table_hint` (the job's table) enables an O(1) server-side lookup; `None`
/// scans the database's active jobs. A `None` result means unknown / not
/// active.
pub async fn get_job(&self, job_id: &str, table_hint: Option<&str>) -> Result<Option<JobInfo>> {
self.internal.get_job(job_id, table_hint).await
}
/// Durable job history (SHOW JOB HISTORY) across the database's tables.
/// Pass `job_id` to narrow to a single job.
pub async fn job_history(&self, job_id: Option<&str>) -> Result<Vec<JobHistoryInfo>> {
self.internal.job_history(job_id).await
}
/// Per-row UDF errors (SHOW ERRORS) across the database's tables, optionally
/// filtered by `job_id` and/or `table`.
pub async fn errors(
&self,
job_id: Option<&str>,
table: Option<&str>,
) -> Result<Vec<JobErrorInfo>> {
self.internal.errors(job_id, table).await
}
/// Rename a table in the database.
///
/// This is only supported in LanceDB Cloud.
@@ -685,9 +576,6 @@ impl Connection {
/// For LanceNamespaceDatabase, it is the underlying LanceNamespace.
/// For ListingDatabase, it is the equivalent DirectoryNamespace.
/// For RemoteDatabase, it is the equivalent RestNamespace.
///
/// Remote connections using dynamic headers forward them through the
/// namespace client's per-request context provider.
pub async fn namespace_client(&self) -> Result<Arc<dyn lance_namespace::LanceNamespace>> {
self.internal.namespace_client().await
}
@@ -696,9 +584,6 @@ impl Connection {
/// Returns (impl_type, properties) where:
/// - impl_type: "dir" for DirectoryNamespace, "rest" for RestNamespace
/// - properties: configuration properties for the namespace
///
/// Remote connections using dynamic headers cannot be exported because the
/// namespace client config only carries static headers.
pub async fn namespace_client_config(
&self,
) -> Result<(String, std::collections::HashMap<String, String>)> {
@@ -776,8 +661,6 @@ pub struct ConnectRequest {
pub struct ConnectBuilder {
request: ConnectRequest,
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
#[cfg(feature = "remote")]
oauth_config: Option<crate::remote::OAuthConfig>,
}
#[cfg(feature = "remote")]
@@ -799,8 +682,6 @@ impl ConnectBuilder {
session: None,
},
embedding_registry: None,
#[cfg(feature = "remote")]
oauth_config: None,
}
}
@@ -889,19 +770,6 @@ impl ConnectBuilder {
self
}
/// Configure OAuth authentication for LanceDB Cloud/Enterprise.
///
/// This creates an [`OAuthHeaderProvider`](crate::remote::OAuthHeaderProvider)
/// from the given config and sets it as the header provider. OAuth cannot
/// be combined with an API key or another header provider.
///
/// Token acquisition and refresh are handled in Rust.
#[cfg(feature = "remote")]
pub fn oauth_config(mut self, config: crate::remote::OAuthConfig) -> Self {
self.oauth_config = Some(config);
self
}
/// Provide a custom [`EmbeddingRegistry`] to use for this connection.
pub fn embedding_registry(mut self, registry: Arc<dyn EmbeddingRegistry>) -> Self {
self.embedding_registry = Some(registry);
@@ -1047,40 +915,9 @@ impl ConnectBuilder {
let region = options.region.ok_or_else(|| Error::InvalidInput {
message: "A region is required when connecting to LanceDb Cloud".to_string(),
})?;
let api_key = match (&self.oauth_config, &options.api_key) {
(Some(_), None) => String::new(),
(Some(_), Some(_)) => {
return Err(Error::InvalidInput {
message:
"api_key and oauth_config cannot both be set when connecting to LanceDb Cloud"
.to_string(),
});
}
(None, Some(key)) => key.clone(),
(None, None) => {
return Err(Error::InvalidInput {
message:
"An api_key or oauth_config is required when connecting to LanceDb Cloud"
.to_string(),
});
}
};
if self.oauth_config.is_some() && self.request.client_config.header_provider.is_some() {
return Err(Error::InvalidInput {
message:
"oauth_config and client_config.header_provider cannot both be set when connecting to LanceDb Cloud"
.to_string(),
});
}
let mut client_config = self.request.client_config;
if let Some(oauth_config) = self.oauth_config {
let provider = crate::remote::OAuthHeaderProvider::new(oauth_config)?;
client_config.header_provider =
Some(Arc::new(provider) as Arc<dyn crate::remote::HeaderProvider>);
}
let api_key = options.api_key.ok_or_else(|| Error::InvalidInput {
message: "An api_key is required when connecting to LanceDb Cloud".to_string(),
})?;
let storage_options = StorageOptions(options.storage_options.clone());
let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new(
@@ -1088,7 +925,7 @@ impl ConnectBuilder {
&api_key,
&region,
options.host_override,
client_config,
self.request.client_config,
storage_options.into(),
self.request.read_consistency_interval,
)?);
@@ -1397,83 +1234,6 @@ mod tests {
assert_eq!(Some(&"EXPLICIT-VALUE".to_string()), options.get(opts_key));
}
#[cfg(feature = "remote")]
#[tokio::test]
async fn test_connect_rejects_api_key_with_oauth_config() {
let oauth_config = crate::remote::OAuthConfig {
issuer_url: "https://issuer.example.com".to_string(),
client_id: "client-id".to_string(),
client_secret: Some("secret".to_string()),
scopes: vec!["scope".to_string()],
flow: crate::remote::OAuthFlow::ClientCredentials,
refresh_buffer_secs: None,
};
let result = ConnectBuilder::new("db://my-container/my-prefix")
.region("us-east-1")
.api_key("my-api-key")
.oauth_config(oauth_config)
.execute()
.await;
match result {
Err(Error::InvalidInput { message })
if message
== "api_key and oauth_config cannot both be set when connecting to LanceDb Cloud" =>
{}
Err(err) => panic!("expected InvalidInput, got {err:?}"),
Ok(_) => panic!("expected api_key and oauth_config to be rejected"),
}
}
#[cfg(feature = "remote")]
#[tokio::test]
async fn test_connect_rejects_header_provider_with_oauth_config() {
#[derive(Debug)]
struct TestHeaderProvider;
#[async_trait::async_trait]
impl crate::remote::HeaderProvider for TestHeaderProvider {
async fn get_headers(&self) -> Result<HashMap<String, String>> {
Ok(HashMap::from([(
"authorization".to_string(),
"Bearer token".to_string(),
)]))
}
}
let oauth_config = crate::remote::OAuthConfig {
issuer_url: "https://issuer.example.com".to_string(),
client_id: "client-id".to_string(),
client_secret: Some("secret".to_string()),
scopes: vec!["scope".to_string()],
flow: crate::remote::OAuthFlow::ClientCredentials,
refresh_buffer_secs: None,
};
let client_config = crate::remote::ClientConfig {
header_provider: Some(
Arc::new(TestHeaderProvider) as Arc<dyn crate::remote::HeaderProvider>
),
..Default::default()
};
let result = ConnectBuilder::new("db://my-container/my-prefix")
.region("us-east-1")
.client_config(client_config)
.oauth_config(oauth_config)
.execute()
.await;
match result {
Err(Error::InvalidInput { message })
if message
== "oauth_config and client_config.header_provider cannot both be set when connecting to LanceDb Cloud" =>
{}
Err(err) => panic!("expected InvalidInput, got {err:?}"),
Ok(_) => panic!("expected header_provider and oauth_config to be rejected"),
}
}
#[cfg(not(windows))]
#[tokio::test]
async fn test_connect_relative() {

View File

@@ -27,12 +27,11 @@ use lance_namespace::models::{
};
use crate::data::scannable::Scannable;
use crate::error::{Error, Result};
use crate::error::Result;
use crate::table::{BaseTable, WriteOptions};
pub mod listing;
pub mod namespace;
pub(crate) mod read_freshness;
pub trait DatabaseOptions {
fn serialize_into_map(&self, map: &mut HashMap<String, String>);
@@ -200,205 +199,6 @@ pub enum ReadConsistency {
Strong,
}
/// A request to register a UDF (CREATE FUNCTION).
///
/// Functions are first-class database objects, decoupled from any
/// column; computed columns and materialized views reference them by
/// name. Server-backed feature (LanceDB Enterprise / Cloud).
#[derive(Debug, Clone)]
pub struct CreateFunctionRequest {
/// Function name.
pub name: String,
/// Implementation language (currently "python").
pub language: String,
/// SQL return type, e.g. `FLOAT`, `FLOAT[1536]`,
/// `STRUCT(a FLOAT, b VARCHAR)`, `TABLE(chunk VARCHAR, idx INT)`.
pub return_type: String,
/// Function body: source text, or base64 cloudpickle bytes when
/// `options["body_format"] = "cloudpickle"`.
pub body: String,
/// Options: input_columns, pip, num_gpus, batch_size, timeout,
/// error_policy, docker_image, body_format, ...
pub options: HashMap<String, String>,
}
/// A registered function, as returned by `list_functions`.
#[derive(Debug, Clone)]
pub struct FunctionInfo {
pub name: String,
pub language: String,
pub return_type: String,
pub description: String,
}
/// A request to create a materialized view (CREATE MATERIALIZED VIEW).
#[derive(Debug, Clone)]
pub struct CreateMaterializedViewRequest {
/// View name.
pub name: String,
/// The view's SELECT statement, e.g.
/// `SELECT id, embed(body) AS vec FROM articles WHERE id > 1`.
/// Bare columns project through; function-call columns compute via
/// registered UDFs (a RETURNS TABLE function makes a row-expanding
/// chunker view).
pub query: String,
/// Refresh automatically when the source table changes.
pub auto_refresh: bool,
/// Register the definition only; skip the initial population.
pub with_no_data: bool,
/// Optional source column to partition the view's table function on. If the
/// column has an IVF vector index the server partitions by its clusters
/// (image-dedup style); otherwise it groups by distinct value.
pub partition_by: Option<String>,
}
impl CreateMaterializedViewRequest {
pub fn new(name: impl Into<String>, query: impl Into<String>) -> Self {
Self {
name: name.into(),
query: query.into(),
auto_refresh: false,
with_no_data: false,
partition_by: None,
}
}
}
/// A request to refresh a materialized view.
#[derive(Debug, Clone)]
pub struct RefreshMaterializedViewRequest {
/// View name.
pub name: String,
/// Force a full rebuild (recompute and replace every row) instead of the
/// default incremental refresh.
pub full: bool,
/// Pin the refresh to a source-table version; latest when absent.
pub src_version: Option<u64>,
/// Initial worker count.
pub num_workers: Option<u32>,
/// Elastic worker ceiling.
pub max_workers: Option<u32>,
}
/// A request for the derived-compute lineage of a table/view (or one of its
/// columns). The response is server-defined lineage JSON, returned opaque so
/// this client need not model the server's lineage schema.
#[derive(Debug, Clone, Default)]
pub struct TableLineageRequest {
/// Table or view name.
pub name: String,
/// Column for column-level lineage; whole table/view when absent.
pub column: Option<String>,
/// "upstream" | "downstream" | "both" (server default when absent).
pub direction: Option<String>,
/// Column-hops to walk; transitive when absent.
pub depth: Option<u32>,
}
impl RefreshMaterializedViewRequest {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
full: false,
src_version: None,
num_workers: None,
max_workers: None,
}
}
}
/// A registered materialized view definition, as returned by
/// `list_materialized_views`.
#[derive(Debug, Clone)]
pub struct MaterializedViewInfo {
pub name: String,
pub source_table: String,
/// Source columns projected through.
pub projection: Vec<String>,
/// `alias=expression` per UDF-computed column.
pub udf_columns: Vec<String>,
pub filter: Option<String>,
pub auto_refresh: bool,
}
/// A row from `list_jobs`: one inflight server-side job (index build,
/// compaction, column refresh, view refresh, ...).
#[derive(Debug, Clone)]
pub struct JobInfo {
pub table: String,
pub job_id: String,
pub job_type: String,
/// Lifecycle state: "running", "cancelling", or "stale".
pub state: String,
pub column: Option<String>,
pub age_seconds: Option<i64>,
pub command: Option<String>,
pub units_done: Option<i64>,
pub units_total: Option<i64>,
/// Whether the job's final commit has completed (output visible).
pub committed: bool,
pub rows_skipped: u64,
pub error: Option<String>,
}
/// A row from `job_history`: one durable, completed/terminal server-side job
/// record (SHOW JOB HISTORY), read from a table's `_job_history` store. Unlike
/// `JobInfo` (live, inflight jobs) this carries created/updated/completed
/// timestamps and the lifecycle event log.
#[derive(Debug, Clone)]
pub struct JobHistoryInfo {
pub table: String,
pub job_id: String,
pub job_type: String,
pub state: String,
pub column: Option<String>,
pub created_ms: i64,
pub updated_ms: i64,
pub completed_ms: Option<i64>,
pub rows_processed: Option<i64>,
pub rows_skipped: Option<i64>,
pub error: Option<String>,
/// Newline-joined lifecycle event log, oldest first.
pub events: Option<String>,
}
/// A row from `errors`: one per-row UDF failure recorded by `error_policy=skip`
/// (SHOW ERRORS).
#[derive(Debug, Clone)]
pub struct JobErrorInfo {
pub job_id: String,
pub table: String,
pub column: String,
pub error_type: String,
pub error_message: String,
pub fragment_id: Option<i64>,
pub source_row_id: Option<i64>,
pub table_version: Option<i64>,
pub age_seconds: Option<i64>,
}
/// The plan a `REFRESH MATERIALIZED VIEW` would execute, as returned by
/// `explain_refresh_materialized_view` (EXPLAIN REFRESH). No work is run.
#[derive(Debug, Clone)]
pub struct MvRefreshPlan {
pub table_name: String,
/// Whether a refresh would do anything (rebuild or non-empty units).
pub has_work: bool,
pub source_version: u64,
pub last_refreshed_version: Option<u64>,
pub full_refresh: bool,
/// Source changed non-append-only since the last refresh -> rebuild.
pub rebuild: bool,
/// Number of row-range work units the refresh would process.
pub units_total: u64,
}
fn not_supported<T>(what: &str) -> Result<T> {
Err(Error::NotSupported {
message: format!("{} is not supported by this database", what),
})
}
/// The `Database` trait defines the interface for database implementations.
///
/// A database is responsible for managing tables and their metadata.
@@ -444,99 +244,6 @@ pub trait Database:
///
/// See [`CloneTableRequest`] for detailed documentation and examples.
async fn clone_table(&self, request: CloneTableRequest) -> Result<Arc<dyn BaseTable>>;
// -- Derived compute: functions, materialized views, jobs -------------
//
// Server-backed features (LanceDB Enterprise / Cloud). The defaults
// return NotSupported; the remote database overrides them. Local
// single-node implementations are planned.
/// Register a UDF (CREATE FUNCTION).
async fn create_function(&self, _request: CreateFunctionRequest) -> Result<()> {
not_supported("create_function")
}
/// List registered functions (SHOW FUNCTIONS).
async fn list_functions(&self) -> Result<Vec<FunctionInfo>> {
not_supported("list_functions")
}
/// Drop a registered function (DROP FUNCTION).
async fn drop_function(&self, _name: &str) -> Result<()> {
not_supported("drop_function")
}
/// Create a materialized view (CREATE MATERIALIZED VIEW). Returns
/// the initial-population job id, absent when `with_no_data`.
async fn create_materialized_view(
&self,
_request: CreateMaterializedViewRequest,
) -> Result<Option<String>> {
not_supported("create_materialized_view")
}
/// Refresh a materialized view; returns the refresh job id.
async fn refresh_materialized_view(
&self,
_request: RefreshMaterializedViewRequest,
) -> Result<String> {
not_supported("refresh_materialized_view")
}
/// Derived-compute lineage of a table/view (or column), as server-defined
/// JSON. Read-only.
async fn table_lineage(&self, _request: TableLineageRequest) -> Result<String> {
not_supported("table_lineage")
}
/// Plan a materialized-view refresh without submitting work
/// (EXPLAIN REFRESH). `full` plans a full rebuild (incremental
/// planning requires stable row IDs on the source).
async fn explain_refresh_materialized_view(
&self,
_name: &str,
_full: bool,
_src_version: Option<u64>,
) -> Result<MvRefreshPlan> {
not_supported("explain_refresh_materialized_view")
}
/// Update a materialized view's options (ALTER MATERIALIZED VIEW).
async fn alter_materialized_view(&self, _name: &str, _auto_refresh: bool) -> Result<()> {
not_supported("alter_materialized_view")
}
/// Drop a materialized view definition (DROP MATERIALIZED VIEW).
async fn drop_materialized_view(&self, _name: &str) -> Result<()> {
not_supported("drop_materialized_view")
}
/// List registered materialized view definitions.
async fn list_materialized_views(&self) -> Result<Vec<MaterializedViewInfo>> {
not_supported("list_materialized_views")
}
/// List inflight server-side jobs across the database's tables.
async fn list_jobs(&self) -> Result<Vec<JobInfo>> {
not_supported("list_jobs")
}
/// Cancel an inflight server-side job by id. Returns true if a
/// matching inflight job was found and flagged for cancellation,
/// false if none was inflight (best-effort, like SQL `CANCEL JOB`).
async fn cancel_job(&self, _job_id: &str) -> Result<bool> {
not_supported("cancel_job")
}
/// Point-access for a single job by id -- the `wait()`/status poll path.
/// `table_hint` (the job's table, which `wait()` callers know) enables an
/// O(1) server-side lookup. `None` if the job is unknown or not active.
async fn get_job(&self, _job_id: &str, _table_hint: Option<&str>) -> Result<Option<JobInfo>> {
not_supported("get_job")
}
/// Durable job history (SHOW JOB HISTORY) across the database's tables,
/// optionally narrowed to a single `job_id`.
async fn job_history(&self, _job_id: Option<&str>) -> Result<Vec<JobHistoryInfo>> {
not_supported("job_history")
}
/// Per-row UDF errors (SHOW ERRORS) recorded by `error_policy=skip` across
/// the database's tables, optionally filtered by `job_id` and/or `table`.
async fn errors(
&self,
_job_id: Option<&str>,
_table: Option<&str>,
) -> Result<Vec<JobErrorInfo>> {
not_supported("errors")
}
/// Open a table in the database
async fn open_table(&self, request: OpenTableRequest) -> Result<Arc<dyn BaseTable>>;
/// Rename a table in the database

View File

@@ -18,7 +18,6 @@ use lance_table::io::commit::commit_handler_from_url;
use object_store::local::LocalFileSystem;
use snafu::ResultExt;
use crate::blob::{ensure_blob_storage_version, has_blob_columns};
use crate::connection::ConnectRequest;
use crate::database::ReadConsistency;
use crate::database::namespace::LanceNamespaceDatabase;
@@ -839,16 +838,13 @@ impl ListingDatabase {
write_params.enable_v2_manifest_paths = enable_v2_manifest_paths;
}
let data_schema = request.data.arrow_schema();
if let Some(enable_stable_row_ids) = stable_row_ids_override
.or(self.new_table_config.enable_stable_row_ids)
.or(has_blob_columns(&data_schema).then_some(true))
// Apply enable_stable_row_ids: table-level override takes precedence over connection config
if let Some(enable_stable_row_ids) =
stable_row_ids_override.or(self.new_table_config.enable_stable_row_ids)
{
write_params.enable_stable_row_ids = enable_stable_row_ids;
}
ensure_blob_storage_version(&data_schema, &mut write_params);
if matches!(&request.mode, CreateTableMode::Overwrite) {
write_params.mode = WriteMode::Overwrite;
}

View File

@@ -4,7 +4,7 @@
//! Namespace-based database implementation that delegates table management to lance-namespace
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, Mutex};
use std::sync::Arc;
use async_trait::async_trait;
use lance::io::commit::namespace_manifest::LanceNamespaceExternalManifestStore;
@@ -23,16 +23,12 @@ use lance_namespace_impls::ConnectBuilder;
use lance_table::io::commit::CommitHandler;
use lance_table::io::commit::external_manifest::ExternalManifestCommitHandler;
use crate::blob::{ensure_blob_storage_version, has_blob_columns};
use crate::connection::NamespaceClientPushdownOperation;
use crate::database::ReadConsistency;
use crate::database::listing::{
NewTableConfig, OPT_NEW_TABLE_ENABLE_STABLE_ROW_IDS, OPT_NEW_TABLE_STORAGE_VERSION,
OPT_NEW_TABLE_V2_MANIFEST_PATHS,
};
use crate::database::read_freshness::{
FreshnessBaselines, ReadFreshnessContextProvider, TableFreshness,
};
use crate::error::{Error, Result};
use crate::table::{NativeTable, map_namespace_lance_error};
use lance::dataset::WriteMode;
@@ -55,10 +51,6 @@ fn is_table_already_exists_namespace_error(err: &lance::Error) -> bool {
false
}
/// Object-id delimiter default (matches `RestNamespaceBuilder`'s); overridable
/// via the `delimiter` property.
const DEFAULT_NAMESPACE_DELIMITER: &str = "$";
/// A database implementation that uses lance-namespace for table management
pub struct LanceNamespaceDatabase {
namespace: Arc<dyn LanceNamespace>,
@@ -78,17 +70,6 @@ pub struct LanceNamespaceDatabase {
ns_properties: HashMap<String, String>,
// Options for tables created by this connection
new_table_config: NewTableConfig,
// Per-table read-freshness baselines, shared with the context provider.
freshness_baselines: FreshnessBaselines,
// Delimiter for building freshness keys; see `table_freshness`.
delimiter: String,
}
fn resolve_delimiter(ns_properties: &HashMap<String, String>) -> String {
ns_properties
.get("delimiter")
.cloned()
.unwrap_or_else(|| DEFAULT_NAMESPACE_DELIMITER.to_string())
}
impl LanceNamespaceDatabase {
@@ -101,9 +82,6 @@ impl LanceNamespaceDatabase {
session: Option<Arc<lance::session::Session>>,
namespace_client_pushdown_operations: HashSet<NamespaceClientPushdownOperation>,
) -> Self {
// Client is pre-built, so we can't install the freshness provider here;
// baselines are still tracked for a uniform bump path.
let delimiter = resolve_delimiter(&namespace_client_properties);
Self {
namespace: namespace_client,
storage_options,
@@ -114,8 +92,6 @@ impl LanceNamespaceDatabase {
ns_impl: namespace_client_impl,
ns_properties: namespace_client_properties,
new_table_config: NewTableConfig::default(),
freshness_baselines: Arc::new(Mutex::new(HashMap::new())),
delimiter,
}
}
@@ -160,19 +136,10 @@ impl LanceNamespaceDatabase {
if let Some(ref sess) = session {
builder = builder.session(sess.clone());
}
// Install the read-freshness provider before building the client.
let freshness_baselines: FreshnessBaselines = Arc::new(Mutex::new(HashMap::new()));
builder = builder.context_provider(Arc::new(ReadFreshnessContextProvider::new(
freshness_baselines.clone(),
read_consistency_interval,
)));
let namespace = builder.connect().await.map_err(|e| Error::InvalidInput {
message: format!("Failed to connect to namespace: {:?}", e),
})?;
let delimiter = resolve_delimiter(&ns_properties);
Ok(Self {
namespace,
storage_options,
@@ -183,20 +150,9 @@ impl LanceNamespaceDatabase {
ns_impl: ns_impl.to_string(),
ns_properties,
new_table_config,
freshness_baselines,
delimiter,
})
}
/// Build a table's freshness handle, keyed to match the `object_id` the
/// namespace client sends on reads (table-id parts joined by the delimiter).
fn table_freshness(&self, namespace_path: &[String], name: &str) -> TableFreshness {
let mut parts = namespace_path.to_vec();
parts.push(name.to_string());
let key = parts.join(&self.delimiter);
TableFreshness::new(self.freshness_baselines.clone(), key)
}
fn extract_storage_overrides(
&self,
request: &DbCreateTableRequest,
@@ -258,16 +214,12 @@ impl LanceNamespaceDatabase {
params.enable_v2_manifest_paths = enable_v2_manifest_paths;
}
let data_schema = request.data.schema();
if let Some(enable_stable_row_ids) = stable_row_ids_override
.or(self.new_table_config.enable_stable_row_ids)
.or(has_blob_columns(data_schema.as_ref()).then_some(true))
if let Some(enable_stable_row_ids) =
stable_row_ids_override.or(self.new_table_config.enable_stable_row_ids)
{
params.enable_stable_row_ids = enable_stable_row_ids;
}
ensure_blob_storage_version(data_schema.as_ref(), params);
Ok(())
}
}
@@ -379,8 +331,7 @@ impl Database for LanceNamespaceDatabase {
self.pushdown_operations.clone(),
self.session.clone(),
)
.await?
.with_freshness(self.table_freshness(&request.namespace_path, &request.name));
.await?;
return Ok(Arc::new(native_table));
}
@@ -511,8 +462,7 @@ impl Database for LanceNamespaceDatabase {
self.pushdown_operations.clone(),
self.session.clone(),
)
.await?
.with_freshness(self.table_freshness(&request.namespace_path, &request.name));
.await?;
Ok(Arc::new(native_table))
}
@@ -528,8 +478,7 @@ impl Database for LanceNamespaceDatabase {
self.pushdown_operations.clone(),
self.session.clone(),
)
.await?
.with_freshness(self.table_freshness(&request.namespace_path, &request.name));
.await?;
Ok(Arc::new(native_table))
}

View File

@@ -1,312 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! Read-freshness signaling for the lance-namespace path.
//!
//! Against a server that serves cached table metadata up to some staleness
//! window, a handle that just wrote (or asked for the latest version via
//! `checkout_latest`) can still read a stale snapshot. To prevent that, reads
//! routed through the namespace client carry an `x-lancedb-min-timestamp`
//! header naming the oldest snapshot the caller will accept.
//!
//! This mirrors `remote::table`: a per-table baseline is bumped to "now" on
//! every write and on `checkout_latest()`, and reads send
//! `max(baseline, now - read_consistency_interval)`. Since the namespace client
//! takes no headers directly, a [`DynamicContextProvider`] injects it per request.
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime};
use lance_namespace_impls::{DynamicContextProvider, OperationInfo};
/// Provider context keys prefixed with `headers.` become HTTP headers (prefix
/// stripped), so this emits the `x-lancedb-min-timestamp` header.
const MIN_TIMESTAMP_CONTEXT_KEY: &str = "headers.x-lancedb-min-timestamp";
/// Per-table freshness baselines (keyed by namespace object id), shared between
/// the provider that reads them and the table handles that bump them.
pub type FreshnessBaselines = Arc<Mutex<HashMap<String, SystemTime>>>;
/// `max(baseline, now - interval)`, or `None` when neither constraint applies.
fn compute_min_timestamp(
baseline: Option<SystemTime>,
interval: Option<Duration>,
now: SystemTime,
) -> Option<SystemTime> {
let interval_based = match interval {
None => None,
Some(d) if d.is_zero() => Some(now),
Some(d) => Some(now.checked_sub(d).unwrap_or(now)),
};
match (interval_based, baseline) {
(None, None) => None,
(Some(t), None) | (None, Some(t)) => Some(t),
(Some(a), Some(b)) => Some(a.max(b)),
}
}
/// Advance the baseline to `now`, never backwards, so a concurrent handle's
/// write can't lower a floor another handle already set.
fn next_freshness_baseline(prev: Option<SystemTime>, now: SystemTime) -> SystemTime {
match prev {
Some(p) => p.max(now),
None => now,
}
}
/// A handle's view of the shared baseline map for a single table.
#[derive(Clone, Debug)]
pub struct TableFreshness {
baselines: FreshnessBaselines,
/// Namespace object id for this table (matches the read's `object_id`).
key: String,
}
impl TableFreshness {
pub fn new(baselines: FreshnessBaselines, key: String) -> Self {
Self { baselines, key }
}
pub fn bump(&self) {
let now = SystemTime::now();
let mut baselines = self.baselines.lock().unwrap();
let prev = baselines.get(&self.key).copied();
baselines.insert(self.key.clone(), next_freshness_baseline(prev, now));
}
}
/// Read ops that can be served stale and so carry the freshness floor.
/// `list_table_versions` resolves "latest" for managed-versioning tables, so it
/// is what makes `checkout_latest()` observe a prior write.
fn is_read_operation(operation: &str) -> bool {
matches!(
operation,
"describe_table" | "list_table_versions" | "query_table" | "list_tables"
)
}
/// Injects `x-lancedb-min-timestamp` on namespace reads, per addressed table.
#[derive(Debug)]
pub struct ReadFreshnessContextProvider {
baselines: FreshnessBaselines,
read_consistency_interval: Option<Duration>,
}
impl ReadFreshnessContextProvider {
pub fn new(baselines: FreshnessBaselines, read_consistency_interval: Option<Duration>) -> Self {
Self {
baselines,
read_consistency_interval,
}
}
}
impl DynamicContextProvider for ReadFreshnessContextProvider {
fn provide_context(&self, info: &OperationInfo) -> HashMap<String, String> {
if !is_read_operation(&info.operation) {
return HashMap::new();
}
let baseline = self.baselines.lock().unwrap().get(&info.object_id).copied();
match compute_min_timestamp(baseline, self.read_consistency_interval, SystemTime::now()) {
Some(ts) => {
let dt: chrono::DateTime<chrono::Utc> = ts.into();
HashMap::from([(MIN_TIMESTAMP_CONTEXT_KEY.to_string(), dt.to_rfc3339())])
}
None => HashMap::new(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
/// Allowed slop when comparing a header timestamp against a locally
/// captured wall-clock bound. Tests run fast enough that 1s is plenty.
const TOLERANCE: Duration = Duration::from_secs(1);
fn parse_header_ts(headers: &HashMap<String, String>) -> SystemTime {
let value = headers
.get(MIN_TIMESTAMP_CONTEXT_KEY)
.expect("expected min-timestamp context key");
chrono::DateTime::parse_from_rfc3339(value)
.unwrap()
.with_timezone(&chrono::Utc)
.into()
}
#[test]
fn test_compute_min_timestamp_combines_baseline_and_interval() {
let now = SystemTime::now();
let baseline = now - Duration::from_secs(60);
// No interval, no baseline -> no header.
assert_eq!(compute_min_timestamp(None, None, now), None);
// Baseline only -> baseline.
assert_eq!(
compute_min_timestamp(Some(baseline), None, now),
Some(baseline)
);
// ZERO interval, no baseline -> now (strong consistency).
assert_eq!(
compute_min_timestamp(None, Some(Duration::ZERO), now),
Some(now)
);
// Positive interval, no baseline -> now - interval.
assert_eq!(
compute_min_timestamp(None, Some(Duration::from_secs(10)), now),
Some(now - Duration::from_secs(10))
);
// Both: pick the more-recent (tighter) constraint.
// baseline = now-60, now-interval = now-10. now-10 is newer.
assert_eq!(
compute_min_timestamp(Some(baseline), Some(Duration::from_secs(10)), now),
Some(now - Duration::from_secs(10))
);
// Both, baseline newer: pick baseline.
let recent_baseline = now - Duration::from_secs(5);
assert_eq!(
compute_min_timestamp(Some(recent_baseline), Some(Duration::from_secs(60)), now),
Some(recent_baseline)
);
}
#[test]
fn test_next_freshness_baseline_is_monotonic() {
let now = SystemTime::now();
let earlier = now - Duration::from_secs(30);
let later = now + Duration::from_secs(30);
// No prior baseline -> now.
assert_eq!(next_freshness_baseline(None, now), now);
// Prior baseline older than now -> now.
assert_eq!(next_freshness_baseline(Some(earlier), now), now);
// Prior baseline newer than now -> keep the newer baseline.
assert_eq!(next_freshness_baseline(Some(later), now), later);
}
fn provider_with(
entries: &[(&str, SystemTime)],
interval: Option<Duration>,
) -> ReadFreshnessContextProvider {
let map: HashMap<String, SystemTime> =
entries.iter().map(|(k, v)| (k.to_string(), *v)).collect();
ReadFreshnessContextProvider::new(Arc::new(Mutex::new(map)), interval)
}
#[test]
fn test_provider_emits_header_at_or_after_bumped_baseline() {
// A baseline set "now" with no interval: every read op must carry a
// floor at or after that baseline. `list_table_versions` is the hook
// that makes managed-versioning `checkout_latest()` observe a write.
let baseline = SystemTime::now();
let provider = provider_with(&[("ns$tbl", baseline)], None);
// These ops are keyed by the table id, so they pick up the per-table
// baseline. (`list_tables` is keyed by the namespace, so it is covered
// separately by the interval-floor test.)
for op in ["describe_table", "list_table_versions", "query_table"] {
let ctx = provider.provide_context(&OperationInfo::new(op, "ns$tbl"));
let sent = parse_header_ts(&ctx);
assert!(
sent >= baseline - TOLERANCE && sent <= baseline + TOLERANCE,
"operation {op} should carry a floor at the bumped baseline"
);
}
}
#[test]
fn test_provider_list_tables_uses_interval_floor_not_table_baseline() {
// `list_tables` is addressed by the namespace id, which never matches a
// per-table baseline key, so a bumped table baseline must not leak onto
// it. With no interval it sends nothing; with one it sends now-interval.
let provider = provider_with(&[("ns$tbl", SystemTime::now())], None);
let ctx = provider.provide_context(&OperationInfo::new("list_tables", "ns"));
assert!(
ctx.is_empty(),
"list_tables must not inherit a per-table baseline"
);
let interval = Duration::from_secs(30);
let provider = provider_with(&[("ns$tbl", SystemTime::now())], Some(interval));
let before = SystemTime::now();
let ctx = provider.provide_context(&OperationInfo::new("list_tables", "ns"));
let after = SystemTime::now();
let sent = parse_header_ts(&ctx);
assert!(
sent >= before - interval - TOLERANCE && sent <= after - interval + TOLERANCE,
"list_tables should carry the interval floor"
);
}
#[test]
fn test_provider_no_header_for_empty_baseline_and_no_interval() {
// Manual consistency (no interval) on a table that was never bumped:
// no floor, so the server may serve from cache.
let provider = provider_with(&[], None);
let ctx = provider.provide_context(&OperationInfo::new("describe_table", "ns$tbl"));
assert!(ctx.is_empty());
}
#[test]
fn test_provider_interval_floor_applies_without_baseline() {
// With a consistency interval and no baseline, the floor is now-interval.
let interval = Duration::from_secs(30);
let provider = provider_with(&[], Some(interval));
let before = SystemTime::now();
let ctx = provider.provide_context(&OperationInfo::new("query_table", "ns$tbl"));
let after = SystemTime::now();
let sent = parse_header_ts(&ctx);
assert!(
sent >= before - interval - TOLERANCE && sent <= after - interval + TOLERANCE,
"expected floor at roughly now - interval"
);
}
#[test]
fn test_provider_non_read_ops_emit_nothing() {
// Even with a fresh baseline and a zero interval, a non-read operation
// (which establishes rather than consumes a baseline) sends no header.
let provider = provider_with(&[("ns$tbl", SystemTime::now())], Some(Duration::ZERO));
for op in [
"create_table",
"register_table",
"drop_table",
"rename_table",
// Pinned to an immutable version, so it cannot be served stale.
"describe_table_version",
] {
let ctx = provider.provide_context(&OperationInfo::new(op, "ns$tbl"));
assert!(
ctx.is_empty(),
"operation {op} must not send a freshness header"
);
}
}
#[test]
fn test_provider_uses_per_table_baseline() {
// The floor is looked up by object id, so an unrelated table's baseline
// does not leak onto another table's read.
let baseline = SystemTime::now();
let provider = provider_with(&[("ns$has_baseline", baseline)], None);
// The bumped table gets a header.
let hit =
provider.provide_context(&OperationInfo::new("describe_table", "ns$has_baseline"));
assert!(!hit.is_empty());
// A different table with no baseline (and no interval) gets nothing.
let miss = provider.provide_context(&OperationInfo::new("describe_table", "ns$other"));
assert!(miss.is_empty());
}
}

View File

@@ -13,7 +13,7 @@ use serde_json::{Value, json};
use super::EmbeddingFunction;
use crate::{Error, Result};
use tokio::runtime::{Handle, RuntimeFlavor};
use tokio::runtime::Handle;
use tokio::task::block_in_place;
#[derive(Debug)]
@@ -148,12 +148,6 @@ impl BedrockEmbeddingFunction {
_ => unreachable!(),
};
// Bedrock's SDK is async but this trait method is synchronous, so we
// bridge with `block_in_place` + `block_on`. That requires a
// multi-threaded Tokio runtime; return a typed error instead of
// panicking when no compatible runtime is available.
let handle = current_multi_thread_handle()?;
for text in texts {
let request_body = match self.model {
BedrockEmbeddingModel::TitanEmbedding => {
@@ -169,28 +163,24 @@ impl BedrockEmbeddingFunction {
}
};
// Serialize before entering the blocking section so a serialization
// failure surfaces as a typed error rather than an `unwrap` panic.
let body = serde_json::to_vec(&request_body).map_err(|e| Error::Runtime {
message: format!("Failed to serialize Bedrock request: {e}"),
})?;
let client = self.client.clone();
let model_id = self.model.model_id().to_string();
let request_body = request_body.clone();
let response = block_in_place(|| {
handle.block_on(async move {
let response = block_in_place(move || {
Handle::current().block_on(async move {
client
.invoke_model()
.model_id(model_id)
.body(aws_sdk_bedrockruntime::primitives::Blob::new(body))
.body(aws_sdk_bedrockruntime::primitives::Blob::new(
serde_json::to_vec(&request_body).unwrap(),
))
.send()
.await
.map_err(|e| Error::Runtime {
message: format!("Bedrock invoke_model request failed: {e}"),
})
.map_err(Box::new)
})
})?;
})
.unwrap();
let response_json: Value =
serde_json::from_slice(response.body.as_ref()).map_err(|e| Error::Runtime {
@@ -198,12 +188,22 @@ impl BedrockEmbeddingFunction {
})?;
let embedding = match self.model {
BedrockEmbeddingModel::TitanEmbedding => {
json_array_to_f32(&response_json["embedding"], "embedding")?
}
BedrockEmbeddingModel::CohereLarge => {
json_array_to_f32(&response_json["embeddings"][0], "embeddings")?
}
BedrockEmbeddingModel::TitanEmbedding => response_json["embedding"]
.as_array()
.ok_or_else(|| Error::Runtime {
message: "Missing embedding in response".to_string(),
})?
.iter()
.map(|v| v.as_f64().unwrap() as f32)
.collect::<Vec<f32>>(),
BedrockEmbeddingModel::CohereLarge => response_json["embeddings"][0]
.as_array()
.ok_or_else(|| Error::Runtime {
message: "Missing embeddings in response".to_string(),
})?
.iter()
.map(|v| v.as_f64().unwrap() as f32)
.collect::<Vec<f32>>(),
};
builder.append_slice(&embedding);
@@ -212,86 +212,3 @@ impl BedrockEmbeddingFunction {
Ok(builder.finish())
}
}
/// Returns a handle to the current multi-threaded Tokio runtime, or a typed
/// [`Error::Runtime`] when called outside a runtime or on the current-thread
/// runtime. This keeps the synchronous-over-async bridge in
/// [`BedrockEmbeddingFunction::compute_inner`] from panicking on runtime
/// configurations that cannot support `block_in_place`.
fn current_multi_thread_handle() -> Result<Handle> {
let handle = Handle::try_current().map_err(|e| Error::Runtime {
message: format!("Bedrock embedding must be called from within a Tokio runtime: {e}"),
})?;
if handle.runtime_flavor() == RuntimeFlavor::CurrentThread {
return Err(Error::Runtime {
message: "Bedrock embedding requires a multi-threaded Tokio runtime; the \
current-thread runtime cannot use `block_in_place`"
.to_string(),
});
}
Ok(handle)
}
/// Converts a JSON value expected to be an array of numbers into `Vec<f32>`.
///
/// Returns a typed [`Error::Runtime`] (rather than panicking) when the value is
/// not an array or contains a non-numeric element, so malformed provider
/// responses degrade gracefully.
fn json_array_to_f32(value: &Value, field: &str) -> Result<Vec<f32>> {
let arr = value.as_array().ok_or_else(|| Error::Runtime {
message: format!("Missing or non-array '{field}' field in Bedrock response"),
})?;
arr.iter()
.map(|v| {
v.as_f64().map(|f| f as f32).ok_or_else(|| Error::Runtime {
message: format!("Non-numeric value in Bedrock '{field}' embedding: {v}"),
})
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn json_array_to_f32_parses_numbers() {
let v = json!([1.0, 2, -3.5]);
let out = json_array_to_f32(&v, "embedding").unwrap();
assert_eq!(out, vec![1.0_f32, 2.0, -3.5]);
}
#[test]
fn json_array_to_f32_rejects_non_array() {
// Missing field indexes to `Value::Null`; a malformed payload should be
// a typed error, not a panic.
let v = json!({"unexpected": "shape"});
let err = json_array_to_f32(&v["embedding"], "embedding").unwrap_err();
assert!(matches!(err, Error::Runtime { .. }), "got {err:?}");
}
#[test]
fn json_array_to_f32_rejects_non_numeric_element() {
let v = json!([1.0, "not-a-number", 3.0]);
let err = json_array_to_f32(&v, "embedding").unwrap_err();
assert!(matches!(err, Error::Runtime { .. }), "got {err:?}");
}
#[test]
fn handle_errors_without_runtime() {
// No Tokio runtime in scope -> typed error instead of a panic.
let err = current_multi_thread_handle().unwrap_err();
assert!(matches!(err, Error::Runtime { .. }), "got {err:?}");
}
#[tokio::test(flavor = "current_thread")]
async fn handle_errors_on_current_thread_runtime() {
let err = current_multi_thread_handle().unwrap_err();
assert!(matches!(err, Error::Runtime { .. }), "got {err:?}");
}
#[tokio::test(flavor = "multi_thread")]
async fn handle_ok_on_multi_thread_runtime() {
current_multi_thread_handle().expect("multi-threaded runtime should be accepted");
}
}

View File

@@ -163,7 +163,6 @@
//! ```
pub mod arrow;
pub mod blob;
pub mod connection;
pub mod data;
pub mod database;
@@ -189,7 +188,6 @@ use std::{fmt::Display, str::FromStr};
use serde::{Deserialize, Serialize};
pub use blob::{blob, is_blob};
pub use connection::{ConnectNamespaceBuilder, Connection};
pub use error::{Error, Result};
use lance_index::vector::ApproxMode as LanceApproxMode;

View File

@@ -8,7 +8,6 @@
pub(crate) mod client;
pub(crate) mod db;
pub mod oauth;
mod retry;
pub(crate) mod table;
pub(crate) mod util;
@@ -21,4 +20,3 @@ const JSON_CONTENT_TYPE: &str = "application/json";
pub use client::{ClientConfig, HeaderProvider, RetryConfig, TimeoutConfig, TlsConfig};
pub use db::{RemoteDatabaseOptions, RemoteDatabaseOptionsBuilder};
pub use oauth::{OAuthConfig, OAuthFlow, OAuthHeaderProvider};

View File

@@ -459,14 +459,12 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
config: &ClientConfig,
) -> Result<HeaderMap> {
let mut headers = HeaderMap::new();
if !api_key.is_empty() {
headers.insert(
HeaderName::from_static("x-api-key"),
HeaderValue::from_str(api_key).map_err(|_| Error::InvalidInput {
message: "non-ascii api key provided".to_string(),
})?,
);
}
headers.insert(
HeaderName::from_static("x-api-key"),
HeaderValue::from_str(api_key).map_err(|_| Error::InvalidInput {
message: "non-ascii api key provided".to_string(),
})?,
);
if region == "local" {
let host = format!("{}.local.api.lancedb.com", db_name);
headers.insert(
@@ -1007,33 +1005,6 @@ mod tests {
assert!(!config_tls.assert_hostname);
}
#[test]
fn test_default_headers_skip_empty_api_key() {
let headers = RestfulLanceDbClient::<Sender>::default_headers(
"",
"us-east-1",
"db-name",
false,
&RemoteOptions::default(),
None,
&ClientConfig::default(),
)
.unwrap();
assert!(!headers.contains_key("x-api-key"));
let headers = RestfulLanceDbClient::<Sender>::default_headers(
"api-key",
"us-east-1",
"db-name",
false,
&RemoteOptions::default(),
None,
&ClientConfig::default(),
)
.unwrap();
assert_eq!(headers.get("x-api-key").unwrap(), "api-key");
}
// Test implementation of HeaderProvider
#[derive(Debug, Clone)]
struct TestHeaderProvider {

View File

@@ -7,7 +7,6 @@ use std::sync::Arc;
use async_trait::async_trait;
use http::StatusCode;
use lance_io::object_store::StorageOptions;
use lance_namespace_impls::{DynamicContextProvider, OperationInfo};
use moka::future::Cache;
use reqwest::header::CONTENT_TYPE;
@@ -19,264 +18,18 @@ use lance_namespace::models::{
use crate::Error;
use crate::database::{
CloneTableRequest, CreateFunctionRequest, CreateMaterializedViewRequest, CreateTableMode,
CreateTableRequest, Database, DatabaseOptions, FunctionInfo, JobErrorInfo, JobHistoryInfo,
JobInfo, MaterializedViewInfo, MvRefreshPlan, OpenTableRequest, ReadConsistency,
RefreshMaterializedViewRequest, TableLineageRequest, TableNamesRequest,
CloneTableRequest, CreateTableMode, CreateTableRequest, Database, DatabaseOptions,
OpenTableRequest, ReadConsistency, TableNamesRequest,
};
use crate::error::Result;
use crate::remote::util::stream_as_body;
use crate::table::BaseTable;
use super::ARROW_STREAM_CONTENT_TYPE;
use super::client::{
ClientConfig, HeaderProvider, HttpSend, RequestResultExt, RestfulLanceDbClient, Sender,
};
use super::client::{ClientConfig, HttpSend, RequestResultExt, RestfulLanceDbClient, Sender};
use super::table::RemoteTable;
use super::util::parse_server_version;
// Wire types for the derived-compute routes (functions, materialized
// views, jobs). Field shapes mirror the server's REST contract.
#[derive(serde::Serialize)]
struct RemoteCreateFunctionRequest {
language: String,
return_type: String,
body: String,
options: std::collections::HashMap<String, String>,
}
#[derive(serde::Deserialize)]
struct RemoteFunctionEntry {
name: String,
language: String,
return_type: String,
#[serde(default)]
description: String,
}
#[derive(serde::Deserialize)]
struct RemoteListFunctionsResponse {
functions: Vec<RemoteFunctionEntry>,
}
#[derive(serde::Serialize)]
struct RemoteCreateMaterializedViewRequest {
query: String,
auto_refresh: bool,
with_no_data: bool,
#[serde(skip_serializing_if = "Option::is_none")]
partition_by: Option<String>,
}
#[derive(serde::Deserialize)]
struct RemoteCreateMaterializedViewResponse {
#[serde(default)]
job_id: Option<String>,
}
#[derive(serde::Serialize)]
struct RemoteRefreshMaterializedViewRequest {
#[serde(skip_serializing_if = "std::ops::Not::not")]
full: bool,
#[serde(skip_serializing_if = "Option::is_none")]
src_version: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
num_workers: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
max_workers: Option<u32>,
}
#[derive(serde::Deserialize)]
struct RemoteRefreshMaterializedViewResponse {
job_id: String,
}
#[derive(serde::Serialize)]
struct RemoteExplainRefreshRequest {
#[serde(skip_serializing_if = "Option::is_none")]
full: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
src_version: Option<u64>,
}
#[derive(serde::Deserialize)]
struct RemoteExplainRefreshResponse {
table_name: String,
has_work: bool,
source_version: u64,
last_refreshed_version: Option<u64>,
full_refresh: bool,
rebuild: bool,
units_total: u64,
}
#[derive(serde::Serialize)]
struct RemoteAlterMaterializedViewRequest {
auto_refresh: bool,
}
#[derive(serde::Deserialize)]
struct RemoteMaterializedViewEntry {
name: String,
source_table: String,
#[serde(default)]
projection: Vec<String>,
#[serde(default)]
udf_columns: Vec<String>,
#[serde(default)]
filter: Option<String>,
#[serde(default)]
auto_refresh: bool,
}
#[derive(serde::Deserialize)]
struct RemoteListMaterializedViewsResponse {
views: Vec<RemoteMaterializedViewEntry>,
}
#[derive(serde::Deserialize)]
struct RemoteJobEntry {
table: String,
job_id: String,
job_type: String,
state: String,
#[serde(default)]
column: Option<String>,
#[serde(default)]
age_seconds: Option<i64>,
#[serde(default)]
command: Option<String>,
#[serde(default)]
units_done: Option<i64>,
#[serde(default)]
units_total: Option<i64>,
#[serde(default)]
committed: bool,
#[serde(default)]
rows_skipped: u64,
#[serde(default)]
error: Option<String>,
}
#[derive(serde::Deserialize)]
struct RemoteListJobsResponse {
jobs: Vec<RemoteJobEntry>,
}
#[derive(serde::Deserialize)]
struct RemoteGetJobResponse {
#[serde(default)]
job: Option<RemoteJobEntry>,
}
#[derive(serde::Deserialize)]
struct RemoteCancelJobResponse {
cancelled: bool,
}
impl From<RemoteJobEntry> for JobInfo {
fn from(j: RemoteJobEntry) -> Self {
JobInfo {
table: j.table,
job_id: j.job_id,
job_type: j.job_type,
state: j.state,
column: j.column,
age_seconds: j.age_seconds,
command: j.command,
units_done: j.units_done,
units_total: j.units_total,
committed: j.committed,
rows_skipped: j.rows_skipped,
error: j.error,
}
}
}
#[derive(serde::Deserialize)]
struct RemoteJobHistoryEntry {
table: String,
job_id: String,
job_type: String,
state: String,
#[serde(default)]
column: Option<String>,
created_ms: i64,
updated_ms: i64,
#[serde(default)]
completed_ms: Option<i64>,
#[serde(default)]
rows_processed: Option<i64>,
#[serde(default)]
rows_skipped: Option<i64>,
#[serde(default)]
error: Option<String>,
#[serde(default)]
events: Option<String>,
}
#[derive(serde::Deserialize)]
struct RemoteJobHistoryResponse {
jobs: Vec<RemoteJobHistoryEntry>,
}
impl From<RemoteJobHistoryEntry> for JobHistoryInfo {
fn from(j: RemoteJobHistoryEntry) -> Self {
JobHistoryInfo {
table: j.table,
job_id: j.job_id,
job_type: j.job_type,
state: j.state,
column: j.column,
created_ms: j.created_ms,
updated_ms: j.updated_ms,
completed_ms: j.completed_ms,
rows_processed: j.rows_processed,
rows_skipped: j.rows_skipped,
error: j.error,
events: j.events,
}
}
}
#[derive(serde::Deserialize)]
struct RemoteErrorEntry {
job_id: String,
table: String,
column: String,
error_type: String,
error_message: String,
#[serde(default)]
fragment_id: Option<i64>,
#[serde(default)]
source_row_id: Option<i64>,
#[serde(default)]
table_version: Option<i64>,
#[serde(default)]
age_seconds: Option<i64>,
}
#[derive(serde::Deserialize)]
struct RemoteErrorsResponse {
errors: Vec<RemoteErrorEntry>,
}
impl From<RemoteErrorEntry> for JobErrorInfo {
fn from(e: RemoteErrorEntry) -> Self {
JobErrorInfo {
job_id: e.job_id,
table: e.table,
column: e.column,
error_type: e.error_type,
error_message: e.error_message,
fragment_id: e.fragment_id,
source_row_id: e.source_row_id,
table_version: e.table_version,
age_seconds: e.age_seconds,
}
}
}
// Request structure for the remote clone table API
#[derive(serde::Serialize)]
struct RemoteCloneTableRequest {
@@ -441,66 +194,10 @@ pub struct RemoteDatabase<S: HttpSend = Sender> {
uri: String,
/// Headers to pass to the namespace client for authentication
namespace_headers: HashMap<String, String>,
namespace_context_provider: Option<Arc<dyn DynamicContextProvider>>,
/// TLS configuration for mTLS support
tls_config: Option<super::client::TlsConfig>,
}
#[derive(Clone)]
struct NamespaceHeaderProviderContext {
header_provider: Arc<dyn HeaderProvider>,
}
impl std::fmt::Debug for NamespaceHeaderProviderContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NamespaceHeaderProviderContext")
.field("header_provider", &"Some(...)")
.finish()
}
}
impl DynamicContextProvider for NamespaceHeaderProviderContext {
fn provide_context(&self, _info: &OperationInfo) -> HashMap<String, String> {
let header_provider = Arc::clone(&self.header_provider);
let handle = match std::thread::Builder::new()
.name("lancedb-namespace-headers".to_string())
.spawn(move || {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| Error::Runtime {
message: format!(
"Failed to create runtime for namespace header provider: {e}"
),
})?
.block_on(header_provider.get_headers())
}) {
Ok(handle) => handle,
Err(err) => {
log::warn!("Failed to spawn dynamic namespace header provider thread: {err}");
return HashMap::new();
}
};
let headers = handle.join();
match headers {
Ok(Ok(headers)) => headers
.into_iter()
.map(|(key, value)| (format!("headers.{key}"), value))
.collect(),
Ok(Err(err)) => {
log::warn!("Failed to get dynamic namespace headers: {err}");
HashMap::new()
}
Err(_) => {
log::warn!("Dynamic namespace header provider panicked");
HashMap::new()
}
}
}
}
impl RemoteDatabase {
pub fn try_new(
uri: &str,
@@ -531,16 +228,6 @@ impl RemoteDatabase {
})
.collect();
let namespace_context_provider =
client_config
.header_provider
.as_ref()
.map(|header_provider| {
Arc::new(NamespaceHeaderProviderContext {
header_provider: Arc::clone(header_provider),
}) as Arc<dyn DynamicContextProvider>
});
let client = RestfulLanceDbClient::try_new(
&parsed,
region,
@@ -560,7 +247,6 @@ impl RemoteDatabase {
table_cache,
uri: uri.to_owned(),
namespace_headers,
namespace_context_provider,
tls_config: client_config.tls_config,
})
}
@@ -585,7 +271,6 @@ mod test_utils {
table_cache: Cache::new(0),
uri: "http://localhost".to_string(),
namespace_headers: HashMap::new(),
namespace_context_provider: None,
tls_config: None,
}
}
@@ -596,18 +281,11 @@ mod test_utils {
T: Into<reqwest::Body>,
{
let client = client_with_handler_and_config(handler, config.clone());
let namespace_context_provider =
config.header_provider.as_ref().map(|header_provider| {
Arc::new(NamespaceHeaderProviderContext {
header_provider: Arc::clone(header_provider),
}) as Arc<dyn DynamicContextProvider>
});
Self {
client,
table_cache: Cache::new(0),
uri: "http://localhost".to_string(),
namespace_headers: config.extra_headers.clone(),
namespace_context_provider,
tls_config: config.tls_config.clone(),
}
}
@@ -885,228 +563,6 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
Ok(table)
}
async fn create_function(&self, request: CreateFunctionRequest) -> Result<()> {
let body = RemoteCreateFunctionRequest {
language: request.language,
return_type: request.return_type,
body: request.body,
options: request.options,
};
let req = self
.client
.post(&format!("/v1/function/{}/create", request.name))
.json(&body);
let (request_id, rsp) = self.client.send(req).await?;
self.client.check_response(&request_id, rsp).await?;
Ok(())
}
async fn list_functions(&self) -> Result<Vec<FunctionInfo>> {
let req = self.client.get("/v1/function/list");
let (request_id, rsp) = self.client.send(req).await?;
let rsp = self.client.check_response(&request_id, rsp).await?;
let body: RemoteListFunctionsResponse = rsp.json().await.err_to_http(request_id)?;
Ok(body
.functions
.into_iter()
.map(|f| FunctionInfo {
name: f.name,
language: f.language,
return_type: f.return_type,
description: f.description,
})
.collect())
}
async fn drop_function(&self, name: &str) -> Result<()> {
let req = self.client.post(&format!("/v1/function/{}/drop", name));
let (request_id, rsp) = self.client.send(req).await?;
self.client.check_response(&request_id, rsp).await?;
Ok(())
}
async fn create_materialized_view(
&self,
request: CreateMaterializedViewRequest,
) -> Result<Option<String>> {
let body = RemoteCreateMaterializedViewRequest {
query: request.query,
auto_refresh: request.auto_refresh,
with_no_data: request.with_no_data,
partition_by: request.partition_by,
};
let req = self
.client
.post(&format!("/v1/materialized_view/{}/create", request.name))
.json(&body);
let (request_id, rsp) = self.client.send(req).await?;
let rsp = self.client.check_response(&request_id, rsp).await?;
let body: RemoteCreateMaterializedViewResponse =
rsp.json().await.err_to_http(request_id)?;
Ok(body.job_id)
}
async fn refresh_materialized_view(
&self,
request: RefreshMaterializedViewRequest,
) -> Result<String> {
let body = RemoteRefreshMaterializedViewRequest {
full: request.full,
src_version: request.src_version,
num_workers: request.num_workers,
max_workers: request.max_workers,
};
let req = self
.client
.post(&format!("/v1/materialized_view/{}/refresh", request.name))
.json(&body);
let (request_id, rsp) = self.client.send(req).await?;
let rsp = self.client.check_response(&request_id, rsp).await?;
let body: RemoteRefreshMaterializedViewResponse =
rsp.json().await.err_to_http(request_id)?;
Ok(body.job_id)
}
async fn table_lineage(&self, request: TableLineageRequest) -> Result<String> {
let mut req = self
.client
.get(&format!("/v1/table/{}/lineage", request.name));
if let Some(column) = &request.column {
req = req.query(&[("column", column)]);
}
if let Some(direction) = &request.direction {
req = req.query(&[("direction", direction)]);
}
if let Some(depth) = request.depth {
req = req.query(&[("depth", depth.to_string())]);
}
let (request_id, rsp) = self.client.send(req).await?;
let rsp = self.client.check_response(&request_id, rsp).await?;
// Server-defined lineage JSON, returned opaque (the client does not
// model the lineage schema; the Python layer deserializes it).
rsp.text().await.err_to_http(request_id)
}
async fn explain_refresh_materialized_view(
&self,
name: &str,
full: bool,
src_version: Option<u64>,
) -> Result<MvRefreshPlan> {
let body = RemoteExplainRefreshRequest {
full: Some(full),
src_version,
};
let req = self
.client
.post(&format!("/v1/materialized_view/{}/explain_refresh", name))
.json(&body);
let (request_id, rsp) = self.client.send(req).await?;
let rsp = self.client.check_response(&request_id, rsp).await?;
let body: RemoteExplainRefreshResponse = rsp.json().await.err_to_http(request_id)?;
Ok(MvRefreshPlan {
table_name: body.table_name,
has_work: body.has_work,
source_version: body.source_version,
last_refreshed_version: body.last_refreshed_version,
full_refresh: body.full_refresh,
rebuild: body.rebuild,
units_total: body.units_total,
})
}
async fn alter_materialized_view(&self, name: &str, auto_refresh: bool) -> Result<()> {
let req = self
.client
.post(&format!("/v1/materialized_view/{}/alter", name))
.json(&RemoteAlterMaterializedViewRequest { auto_refresh });
let (request_id, rsp) = self.client.send(req).await?;
self.client.check_response(&request_id, rsp).await?;
Ok(())
}
async fn drop_materialized_view(&self, name: &str) -> Result<()> {
let req = self
.client
.post(&format!("/v1/materialized_view/{}/drop", name));
let (request_id, rsp) = self.client.send(req).await?;
self.client.check_response(&request_id, rsp).await?;
Ok(())
}
async fn list_materialized_views(&self) -> Result<Vec<MaterializedViewInfo>> {
let req = self.client.get("/v1/materialized_view/list");
let (request_id, rsp) = self.client.send(req).await?;
let rsp = self.client.check_response(&request_id, rsp).await?;
let body: RemoteListMaterializedViewsResponse = rsp.json().await.err_to_http(request_id)?;
Ok(body
.views
.into_iter()
.map(|v| MaterializedViewInfo {
name: v.name,
source_table: v.source_table,
projection: v.projection,
udf_columns: v.udf_columns,
filter: v.filter,
auto_refresh: v.auto_refresh,
})
.collect())
}
async fn list_jobs(&self) -> Result<Vec<JobInfo>> {
let req = self.client.get("/v1/job/list");
let (request_id, rsp) = self.client.send(req).await?;
let rsp = self.client.check_response(&request_id, rsp).await?;
let body: RemoteListJobsResponse = rsp.json().await.err_to_http(request_id)?;
Ok(body.jobs.into_iter().map(JobInfo::from).collect())
}
async fn get_job(&self, job_id: &str, table: Option<&str>) -> Result<Option<JobInfo>> {
// Point-access poll path: GET /v1/job/{id}, with the table as the O(1)
// hint when known. `query` handles URL-encoding the table name.
let mut req = self.client.get(&format!("/v1/job/{job_id}"));
if let Some(t) = table {
req = req.query(&[("table", t)]);
}
let (request_id, rsp) = self.client.send(req).await?;
let rsp = self.client.check_response(&request_id, rsp).await?;
let body: RemoteGetJobResponse = rsp.json().await.err_to_http(request_id)?;
Ok(body.job.map(JobInfo::from))
}
async fn cancel_job(&self, job_id: &str) -> Result<bool> {
let req = self.client.post(&format!("/v1/job/{}/cancel", job_id));
let (request_id, rsp) = self.client.send(req).await?;
let rsp = self.client.check_response(&request_id, rsp).await?;
let body: RemoteCancelJobResponse = rsp.json().await.err_to_http(request_id)?;
Ok(body.cancelled)
}
async fn job_history(&self, job_id: Option<&str>) -> Result<Vec<JobHistoryInfo>> {
let mut req = self.client.get("/v1/job/history");
if let Some(j) = job_id {
req = req.query(&[("job", j)]);
}
let (request_id, rsp) = self.client.send(req).await?;
let rsp = self.client.check_response(&request_id, rsp).await?;
let body: RemoteJobHistoryResponse = rsp.json().await.err_to_http(request_id)?;
Ok(body.jobs.into_iter().map(JobHistoryInfo::from).collect())
}
async fn errors(&self, job_id: Option<&str>, table: Option<&str>) -> Result<Vec<JobErrorInfo>> {
let mut req = self.client.get("/v1/job/errors");
if let Some(j) = job_id {
req = req.query(&[("job", j)]);
}
if let Some(t) = table {
req = req.query(&[("table", t)]);
}
let (request_id, rsp) = self.client.send(req).await?;
let rsp = self.client.check_response(&request_id, rsp).await?;
let body: RemoteErrorsResponse = rsp.json().await.err_to_http(request_id)?;
Ok(body.errors.into_iter().map(JobErrorInfo::from).collect())
}
async fn open_table(&self, request: OpenTableRequest) -> Result<Arc<dyn BaseTable>> {
let identifier = build_table_identifier(
&request.name,
@@ -1303,12 +759,9 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
// Create a RestNamespace pointing to the same remote host with the same authentication headers
let mut builder = lance_namespace_impls::RestNamespaceBuilder::new(self.client.host())
.delimiter(&self.client.id_delimiter)
// TODO: support header provider
.headers(self.namespace_headers.clone());
if let Some(context_provider) = &self.namespace_context_provider {
builder = builder.context_provider(Arc::clone(context_provider));
}
// Apply mTLS configuration if present
if let Some(tls_config) = &self.tls_config {
if let Some(cert_file) = &tls_config.cert_file {
@@ -1328,14 +781,6 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
}
async fn namespace_client_config(&self) -> Result<(String, HashMap<String, String>)> {
if self.namespace_context_provider.is_some() {
return Err(Error::NotSupported {
message:
"Cannot export a namespace client config when dynamic headers are configured; use LanceDB connection namespace methods instead"
.to_string(),
});
}
let mut properties = HashMap::new();
properties.insert("uri".to_string(), self.client.host().to_string());
properties.insert("delimiter".to_string(), self.client.id_delimiter.clone());
@@ -1387,13 +832,12 @@ impl From<StorageOptions> for RemoteOptions {
#[cfg(test)]
mod tests {
use super::{NamespaceHeaderProviderContext, build_cache_key};
use super::build_cache_key;
use std::collections::HashMap;
use std::sync::{Arc, OnceLock};
use arrow_array::{Int32Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use lance_namespace_impls::{DynamicContextProvider, OperationInfo};
use crate::connection::ConnectBuilder;
use crate::{
@@ -2046,223 +1490,6 @@ mod tests {
}
}
#[tokio::test]
async fn test_derived_compute_routes() {
// create_function
let conn = Connection::new_with_handler(|request| {
assert_eq!(request.method(), &reqwest::Method::POST);
assert_eq!(request.url().path(), "/v1/function/embed/create");
let body: serde_json::Value =
serde_json::from_slice(request.body().unwrap().as_bytes().unwrap()).unwrap();
assert_eq!(body["language"], "python");
assert_eq!(body["return_type"], "FLOAT[4]");
assert_eq!(body["body"], "def embed(x): ...");
assert_eq!(body["options"]["pip"], "torch");
http::Response::builder()
.status(200)
.body(r#"{"name":"embed","status":"OK"}"#)
.unwrap()
});
conn.create_function(crate::database::CreateFunctionRequest {
name: "embed".into(),
language: "python".into(),
return_type: "FLOAT[4]".into(),
body: "def embed(x): ...".into(),
options: [("pip".to_string(), "torch".to_string())].into(),
})
.await
.unwrap();
// list_functions
let conn = Connection::new_with_handler(|request| {
assert_eq!(request.method(), &reqwest::Method::GET);
assert_eq!(request.url().path(), "/v1/function/list");
http::Response::builder()
.status(200)
.body(
r#"{"functions":[{"name":"embed","language":"python","return_type":"Float32","description":""}]}"#,
)
.unwrap()
});
let functions = conn.list_functions().await.unwrap();
assert_eq!(functions.len(), 1);
assert_eq!(functions[0].name, "embed");
// drop_function
let conn = Connection::new_with_handler(|request| {
assert_eq!(request.method(), &reqwest::Method::POST);
assert_eq!(request.url().path(), "/v1/function/embed/drop");
http::Response::builder()
.status(200)
.body(r#"{"name":"embed","status":"OK"}"#)
.unwrap()
});
conn.drop_function("embed").await.unwrap();
// create_materialized_view
let conn = Connection::new_with_handler(|request| {
assert_eq!(request.method(), &reqwest::Method::POST);
assert_eq!(request.url().path(), "/v1/materialized_view/mv1/create");
let body: serde_json::Value =
serde_json::from_slice(request.body().unwrap().as_bytes().unwrap()).unwrap();
assert_eq!(body["query"], "SELECT id, embed(body) AS vec FROM docs");
assert_eq!(body["auto_refresh"], true);
assert_eq!(body["with_no_data"], false);
http::Response::builder()
.status(200)
.body(r#"{"name":"mv1","job_id":"j-1"}"#)
.unwrap()
});
let mut request = crate::database::CreateMaterializedViewRequest::new(
"mv1",
"SELECT id, embed(body) AS vec FROM docs",
);
request.auto_refresh = true;
let job_id = conn.create_materialized_view(request).await.unwrap();
assert_eq!(job_id.as_deref(), Some("j-1"));
// refresh_materialized_view
let conn = Connection::new_with_handler(|request| {
assert_eq!(request.method(), &reqwest::Method::POST);
assert_eq!(request.url().path(), "/v1/materialized_view/mv1/refresh");
let body: serde_json::Value =
serde_json::from_slice(request.body().unwrap().as_bytes().unwrap()).unwrap();
assert_eq!(body["num_workers"], 2);
assert!(body.get("src_version").is_none());
http::Response::builder()
.status(202)
.body(r#"{"job_id":"j-2"}"#)
.unwrap()
});
let mut request = crate::database::RefreshMaterializedViewRequest::new("mv1");
request.num_workers = Some(2);
let job_id = conn.refresh_materialized_view(request).await.unwrap();
assert_eq!(job_id, "j-2");
// alter_materialized_view
let conn = Connection::new_with_handler(|request| {
assert_eq!(request.method(), &reqwest::Method::POST);
assert_eq!(request.url().path(), "/v1/materialized_view/mv1/alter");
let body: serde_json::Value =
serde_json::from_slice(request.body().unwrap().as_bytes().unwrap()).unwrap();
assert_eq!(body["auto_refresh"], false);
http::Response::builder()
.status(200)
.body(r#"{"name":"mv1","status":"OK"}"#)
.unwrap()
});
conn.alter_materialized_view("mv1", false).await.unwrap();
// drop_materialized_view
let conn = Connection::new_with_handler(|request| {
assert_eq!(request.url().path(), "/v1/materialized_view/mv1/drop");
http::Response::builder()
.status(200)
.body(r#"{"name":"mv1","status":"OK"}"#)
.unwrap()
});
conn.drop_materialized_view("mv1").await.unwrap();
// list_materialized_views
let conn = Connection::new_with_handler(|request| {
assert_eq!(request.method(), &reqwest::Method::GET);
assert_eq!(request.url().path(), "/v1/materialized_view/list");
http::Response::builder()
.status(200)
.body(
r#"{"views":[{"name":"mv1","source_table":"docs","projection":["id"],"udf_columns":["vec=embed(body)"],"filter":null,"auto_refresh":true}]}"#,
)
.unwrap()
});
let views = conn.list_materialized_views().await.unwrap();
assert_eq!(views.len(), 1);
assert_eq!(views[0].source_table, "docs");
assert!(views[0].auto_refresh);
// list_jobs
let conn = Connection::new_with_handler(|request| {
assert_eq!(request.method(), &reqwest::Method::GET);
assert_eq!(request.url().path(), "/v1/job/list");
http::Response::builder()
.status(200)
.body(
r#"{"jobs":[{"table":"docs","job_id":"j-3","job_type":"udf_virtual_column_backfill","state":"running","column":"vec","age_seconds":4,"command":null,"units_done":1,"units_total":2,"committed":false,"rows_skipped":0,"error":null}]}"#,
)
.unwrap()
});
let jobs = conn.list_jobs().await.unwrap();
assert_eq!(jobs.len(), 1);
assert_eq!(jobs[0].state, "running");
assert_eq!(jobs[0].units_total, Some(2));
// cancel_job
let conn = Connection::new_with_handler(|request| {
assert_eq!(request.method(), &reqwest::Method::POST);
assert_eq!(request.url().path(), "/v1/job/j-3/cancel");
http::Response::builder()
.status(200)
.body(r#"{"cancelled":true}"#)
.unwrap()
});
assert!(conn.cancel_job("j-3").await.unwrap());
// cancel_job: no such inflight job -> false, not an error
let conn = Connection::new_with_handler(|request| {
assert_eq!(request.url().path(), "/v1/job/gone/cancel");
http::Response::builder()
.status(200)
.body(r#"{"cancelled":false}"#)
.unwrap()
});
assert!(!conn.cancel_job("gone").await.unwrap());
// job_history: GET /v1/job/history, no filter
let conn = Connection::new_with_handler(|request| {
assert_eq!(request.method(), &reqwest::Method::GET);
assert_eq!(request.url().path(), "/v1/job/history");
assert!(request.url().query().is_none());
http::Response::builder()
.status(200)
.body(
r#"{"jobs":[{"table":"docs","job_id":"j-1","job_type":"udf_virtual_column_backfill","state":"done","column":"vec","created_ms":1000,"updated_ms":2000,"completed_ms":2000,"rows_processed":42,"rows_skipped":3,"error":null,"events":"created\ndone"}]}"#,
)
.unwrap()
});
let hist = conn.job_history(None).await.unwrap();
assert_eq!(hist.len(), 1);
assert_eq!(hist[0].state, "done");
assert_eq!(hist[0].rows_processed, Some(42));
assert_eq!(hist[0].events.as_deref(), Some("created\ndone"));
// job_history: ?job= narrows to one job
let conn = Connection::new_with_handler(|request| {
assert_eq!(request.url().path(), "/v1/job/history");
assert_eq!(request.url().query(), Some("job=j-1"));
http::Response::builder()
.status(200)
.body(r#"{"jobs":[]}"#)
.unwrap()
});
assert!(conn.job_history(Some("j-1")).await.unwrap().is_empty());
// errors: GET /v1/job/errors with job + table filters
let conn = Connection::new_with_handler(|request| {
assert_eq!(request.method(), &reqwest::Method::GET);
assert_eq!(request.url().path(), "/v1/job/errors");
assert_eq!(request.url().query(), Some("job=j-1&table=docs"));
http::Response::builder()
.status(200)
.body(
r#"{"errors":[{"job_id":"j-1","table":"docs","column":"vec","error_type":"ValueError","error_message":"boom","fragment_id":0,"source_row_id":42,"table_version":7,"age_seconds":5}]}"#,
)
.unwrap()
});
let errs = conn.errors(Some("j-1"), Some("docs")).await.unwrap();
assert_eq!(errs.len(), 1);
assert_eq!(errs[0].error_type, "ValueError");
assert_eq!(errs[0].source_row_id, Some(42));
}
#[tokio::test]
async fn test_clone_table() {
let conn = Connection::new_with_handler(|request| {
@@ -2475,75 +1702,6 @@ mod tests {
assert!(namespace_client.is_ok());
}
#[test]
fn test_namespace_header_provider_context_maps_headers() {
#[derive(Debug)]
struct TestHeaderProvider;
#[async_trait::async_trait]
impl HeaderProvider for TestHeaderProvider {
async fn get_headers(&self) -> crate::Result<HashMap<String, String>> {
Ok(HashMap::from([(
"authorization".to_string(),
"Bearer token".to_string(),
)]))
}
}
let context_provider = NamespaceHeaderProviderContext {
header_provider: Arc::new(TestHeaderProvider) as Arc<dyn HeaderProvider>,
};
let context =
context_provider.provide_context(&OperationInfo::new("list_tables", "namespace"));
assert_eq!(
context.get("headers.authorization"),
Some(&"Bearer token".to_string())
);
}
#[tokio::test]
async fn test_namespace_client_supports_dynamic_headers() {
#[derive(Debug)]
struct TestHeaderProvider;
#[async_trait::async_trait]
impl HeaderProvider for TestHeaderProvider {
async fn get_headers(&self) -> crate::Result<HashMap<String, String>> {
Ok(HashMap::from([(
"authorization".to_string(),
"Bearer token".to_string(),
)]))
}
}
let client_config = ClientConfig {
header_provider: Some(Arc::new(TestHeaderProvider) as Arc<dyn HeaderProvider>),
..Default::default()
};
let conn = Connection::new_with_handler_and_config(
|_| {
http::Response::builder()
.status(200)
.body(r#"{"tables": []}"#)
.unwrap()
},
client_config,
);
let namespace_client = conn.namespace_client().await;
assert!(namespace_client.is_ok());
match conn.namespace_client_config().await {
Err(Error::NotSupported { message })
if message.contains("dynamic headers are configured") => {}
Err(err) => panic!("expected NotSupported, got {err:?}"),
Ok(_) => panic!("expected namespace_client_config to reject dynamic headers"),
}
}
/// Integration tests using RestAdapter to run RemoteDatabase against a real namespace server
mod rest_adapter_integration {
use super::*;

View File

@@ -1,907 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use log::debug;
use reqwest::Client;
use serde::Deserialize;
use tokio::sync::RwLock;
use crate::error::{Error, Result};
use crate::remote::client::HeaderProvider;
const DEFAULT_REFRESH_BUFFER_SECS: u64 = 300;
const DEFAULT_TOKEN_TTL_SECS: u64 = 3600;
const AZURE_IMDS_ENDPOINT: &str = "http://169.254.169.254/metadata/identity/oauth2/token";
const AZURE_IMDS_API_VERSION: &str = "2018-02-01";
/// OAuth authentication flow configuration.
#[derive(Debug, Clone)]
pub enum OAuthFlow {
/// Client Credentials grant (service-to-service / M2M).
/// Requires `client_secret` in [`OAuthConfig`].
ClientCredentials,
/// Azure Managed Identity via IMDS.
/// Works on Azure VMs, AKS, App Service, and Azure Functions.
/// IMDS requests bypass proxy settings because the endpoint is link-local.
AzureManagedIdentity {
/// Client ID for user-assigned managed identity.
/// Omit for system-assigned managed identity.
client_id: Option<String>,
},
}
/// OAuth configuration for LanceDB authentication.
///
/// All token acquisition and refresh is handled in the Rust layer.
/// Python and TypeScript bindings expose this as a plain config object.
#[derive(Clone)]
pub struct OAuthConfig {
/// OIDC issuer URL or OAuth authority URL.
/// For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
pub issuer_url: String,
/// Application / Client ID.
pub client_id: String,
/// Client secret (required for `ClientCredentials`, optional for others).
pub client_secret: Option<String>,
/// OAuth scopes to request.
/// For Azure managed identity, exactly one scope or resource is required.
/// For example: `["api://{app_id}/.default"]`
pub scopes: Vec<String>,
/// Authentication flow to use.
pub flow: OAuthFlow,
/// Seconds before token expiry to trigger proactive refresh (default: 300).
/// Keep this well below the token TTL; if it is greater than or equal to
/// the TTL, each request refreshes the token.
pub refresh_buffer_secs: Option<u64>,
}
impl std::fmt::Debug for OAuthConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OAuthConfig")
.field("issuer_url", &self.issuer_url)
.field("client_id", &self.client_id)
.field(
"client_secret",
&self.client_secret.as_deref().map(|_| "<redacted>"),
)
.field("scopes", &self.scopes)
.field("flow", &self.flow)
.field("refresh_buffer_secs", &self.refresh_buffer_secs)
.finish()
}
}
// -- OIDC Discovery --
#[derive(Clone, Debug, Deserialize)]
struct OidcDiscovery {
token_endpoint: String,
}
// -- Token Response --
#[derive(Deserialize)]
struct TokenResponse {
access_token: String,
/// Token lifetime in seconds.
/// Some providers (Azure IMDS) return this as a string, so we accept both.
#[serde(default, deserialize_with = "deserialize_optional_u64_or_string")]
expires_in: Option<u64>,
#[serde(default)]
#[allow(dead_code)]
token_type: Option<String>,
}
impl std::fmt::Debug for TokenResponse {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TokenResponse")
.field("access_token", &"<redacted>")
.field("expires_in", &self.expires_in)
.field("token_type", &self.token_type)
.finish()
}
}
fn deserialize_optional_u64_or_string<'de, D>(
deserializer: D,
) -> std::result::Result<Option<u64>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de;
struct U64OrString;
impl<'de> de::Visitor<'de> for U64OrString {
type Value = Option<u64>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("an integer, an integer-valued float, a numeric string, or null")
}
fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
Ok(Some(v))
}
fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
if v < 0 {
return Err(E::custom(format!("invalid expires_in value: {v}")));
}
Ok(Some(v as u64))
}
fn visit_f64<E: de::Error>(self, v: f64) -> std::result::Result<Self::Value, E> {
if !v.is_finite() || v < 0.0 || v.fract() != 0.0 || v > u64::MAX as f64 {
return Err(E::custom(format!("invalid expires_in value: {v}")));
}
Ok(Some(v as u64))
}
fn visit_str<E: de::Error>(self, v: &str) -> std::result::Result<Self::Value, E> {
v.parse::<u64>().map(Some).map_err(de::Error::custom)
}
fn visit_none<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
Ok(None)
}
fn visit_unit<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
Ok(None)
}
}
deserializer.deserialize_any(U64OrString)
}
// -- Internal Token State --
struct TokenState {
access_token: Option<String>,
expires_at: Option<Instant>,
}
impl TokenState {
fn new() -> Self {
Self {
access_token: None,
expires_at: None,
}
}
fn is_expired(&self, buffer: Duration) -> bool {
match (self.access_token.as_ref(), self.expires_at) {
(Some(_), Some(expires_at)) => Instant::now() + buffer >= expires_at,
(None, _) => true,
(Some(_), None) => true,
}
}
fn update(&mut self, resp: &TokenResponse) {
self.access_token = Some(resp.access_token.clone());
let expires_in = resp.expires_in.unwrap_or(DEFAULT_TOKEN_TTL_SECS);
self.expires_at = Some(Instant::now() + Duration::from_secs(expires_in));
}
}
#[async_trait]
trait TokenSource: Send + Sync + std::fmt::Debug {
async fn fetch_token(&self) -> Result<TokenResponse>;
}
struct ClientCredentialsSource {
issuer_url: String,
client_id: String,
client_secret: String,
scopes: Vec<String>,
http_client: Client,
discovery: RwLock<Option<OidcDiscovery>>,
}
impl std::fmt::Debug for ClientCredentialsSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ClientCredentialsSource")
.field("issuer_url", &self.issuer_url)
.field("client_id", &self.client_id)
.field("client_secret", &"<redacted>")
.field("scopes", &self.scopes)
.finish()
}
}
impl ClientCredentialsSource {
fn new(
issuer_url: String,
client_id: String,
client_secret: Option<String>,
scopes: Vec<String>,
) -> Result<Self> {
let client_secret = client_secret.ok_or(Error::InvalidInput {
message: "client_secret is required for ClientCredentials flow".to_string(),
})?;
Self::validate_issuer_transport(&issuer_url)?;
let http_client = Client::builder()
.timeout(Duration::from_secs(30))
.build()
.map_err(|e| Error::Runtime {
message: format!("Failed to create HTTP client for OAuth: {e}"),
})?;
Ok(Self {
issuer_url,
client_id,
client_secret,
scopes,
http_client,
discovery: RwLock::new(None),
})
}
fn validate_issuer_transport(issuer_url: &str) -> Result<()> {
let issuer = url::Url::parse(issuer_url).map_err(|e| Error::InvalidInput {
message: format!("Invalid OAuth issuer_url: {e}"),
})?;
match issuer.scheme() {
"https" => Ok(()),
"http" if Self::is_loopback_issuer(&issuer) => Ok(()),
_ => Err(Error::InvalidInput {
message:
"ClientCredentials OAuth issuer_url must use https, except for loopback hosts"
.to_string(),
}),
}
}
fn is_loopback_issuer(issuer: &url::Url) -> bool {
let Some(host) = issuer.host_str() else {
return false;
};
host.eq_ignore_ascii_case("localhost")
|| host
.parse::<IpAddr>()
.map(|addr| addr.is_loopback())
.unwrap_or(false)
}
async fn get_discovery(&self) -> Result<OidcDiscovery> {
{
let cached = self.discovery.read().await;
if let Some(ref disc) = *cached {
return Ok(disc.clone());
}
}
let mut cache = self.discovery.write().await;
// Double-check
if let Some(ref disc) = *cache {
return Ok(disc.clone());
}
let discovery_url = format!(
"{}/.well-known/openid-configuration",
self.issuer_url.trim_end_matches('/')
);
debug!("Fetching OIDC discovery from {}", discovery_url);
let resp = self
.http_client
.get(&discovery_url)
.send()
.await
.map_err(|e| Error::Runtime {
message: format!("Failed to fetch OIDC discovery document: {e}"),
})?;
if !resp.status().is_success() {
return Err(Error::Runtime {
message: format!(
"OIDC discovery failed with status {}: {}",
resp.status(),
resp.text().await.unwrap_or_default()
),
});
}
let disc: OidcDiscovery = resp.json().await.map_err(|e| Error::Runtime {
message: format!("Failed to parse OIDC discovery document: {e}"),
})?;
let result = disc.clone();
*cache = Some(disc);
Ok(result)
}
async fn get_token_endpoint(&self) -> Result<String> {
self.get_discovery().await.map(|disc| disc.token_endpoint)
}
fn scopes_string(&self) -> String {
self.scopes.join(" ")
}
async fn post_token_request(
&self,
endpoint: &str,
params: &[(&str, &str)],
) -> Result<TokenResponse> {
let resp = self
.http_client
.post(endpoint)
.form(params)
.send()
.await
.map_err(|e| Error::Runtime {
message: format!("Token request to {endpoint} failed: {e}"),
})?;
if !resp.status().is_success() {
return Err(Error::Runtime {
message: format!(
"Token request failed with status {}: {}",
resp.status(),
resp.text().await.unwrap_or_default()
),
});
}
resp.json().await.map_err(|e| Error::Runtime {
message: format!("Failed to parse token response: {e}"),
})
}
}
#[async_trait]
impl TokenSource for ClientCredentialsSource {
async fn fetch_token(&self) -> Result<TokenResponse> {
let token_endpoint = self.get_token_endpoint().await?;
let scope = self.scopes_string();
let params = [
("grant_type", "client_credentials"),
("client_id", self.client_id.as_str()),
("client_secret", self.client_secret.as_str()),
("scope", scope.as_str()),
];
self.post_token_request(&token_endpoint, &params).await
}
}
struct AzureImdsSource {
client_id: Option<String>,
resource: String,
http_client: Client,
}
impl std::fmt::Debug for AzureImdsSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AzureImdsSource")
.field("client_id", &self.client_id)
.field("resource", &self.resource)
.finish()
}
}
impl AzureImdsSource {
fn new(scopes: Vec<String>, client_id: Option<String>) -> Result<Self> {
let resource = Self::resource_from_scopes(&scopes)?;
let http_client = Client::builder()
.timeout(Duration::from_secs(30))
.no_proxy()
.build()
.map_err(|e| Error::Runtime {
message: format!("Failed to create HTTP client for Azure IMDS OAuth: {e}"),
})?;
Ok(Self {
client_id,
resource,
http_client,
})
}
fn resource_from_scopes(scopes: &[String]) -> Result<String> {
let [scope] = scopes else {
return Err(Error::InvalidInput {
message: "AzureManagedIdentity flow requires exactly one OAuth scope or resource"
.to_string(),
});
};
Ok(scope.strip_suffix("/.default").unwrap_or(scope).to_string())
}
}
#[async_trait]
impl TokenSource for AzureImdsSource {
async fn fetch_token(&self) -> Result<TokenResponse> {
let mut url = format!(
"{AZURE_IMDS_ENDPOINT}?api-version={AZURE_IMDS_API_VERSION}&resource={}",
urlencoding::encode(&self.resource),
);
if let Some(cid) = self.client_id.as_deref() {
url.push_str(&format!("&client_id={}", urlencoding::encode(cid)));
}
let resp = self
.http_client
.get(&url)
.header("Metadata", "true")
.send()
.await
.map_err(|e| Error::Runtime {
message: format!("Azure IMDS request failed: {e}"),
})?;
if !resp.status().is_success() {
return Err(Error::Runtime {
message: format!(
"Azure IMDS returned status {}: {}",
resp.status(),
resp.text().await.unwrap_or_default()
),
});
}
resp.json().await.map_err(|e| Error::Runtime {
message: format!("Failed to parse IMDS token response: {e}"),
})
}
}
/// OAuth header provider that manages the full token lifecycle.
///
/// Implements [`HeaderProvider`] to inject `Authorization: Bearer <token>`
/// headers into every LanceDB request, with automatic token refresh.
pub struct OAuthHeaderProvider {
token_source: Box<dyn TokenSource>,
token_state: Arc<RwLock<TokenState>>,
refresh_buffer: Duration,
}
impl std::fmt::Debug for OAuthHeaderProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OAuthHeaderProvider")
.field("token_source", &self.token_source)
.finish()
}
}
impl OAuthHeaderProvider {
/// Create a new OAuth header provider from configuration.
pub fn new(config: OAuthConfig) -> Result<Self> {
let OAuthConfig {
issuer_url,
client_id,
client_secret,
scopes,
flow,
refresh_buffer_secs,
} = config;
if scopes.is_empty() {
return Err(Error::InvalidInput {
message: "At least one OAuth scope is required".to_string(),
});
}
let refresh_buffer =
Duration::from_secs(refresh_buffer_secs.unwrap_or(DEFAULT_REFRESH_BUFFER_SECS));
let token_source: Box<dyn TokenSource> = match flow {
OAuthFlow::ClientCredentials => Box::new(ClientCredentialsSource::new(
issuer_url,
client_id,
client_secret,
scopes,
)?),
OAuthFlow::AzureManagedIdentity { client_id } => {
Box::new(AzureImdsSource::new(scopes, client_id)?)
}
};
Ok(Self {
token_source,
token_state: Arc::new(RwLock::new(TokenState::new())),
refresh_buffer,
})
}
/// Get a valid access token, refreshing if necessary.
async fn get_valid_token(&self) -> Result<String> {
// Fast path: check if current token is still valid
{
let state = self.token_state.read().await;
if !state.is_expired(self.refresh_buffer)
&& let Some(ref token) = state.access_token
{
return Ok(token.clone());
}
}
// Slow path: acquire or refresh token
let mut state = self.token_state.write().await;
// Double-check after acquiring write lock
if !state.is_expired(self.refresh_buffer)
&& let Some(ref token) = state.access_token
{
return Ok(token.clone());
}
debug!("Acquiring new OAuth token via {:?}", self.token_source);
let resp = self.token_source.fetch_token().await?;
state.update(&resp);
Ok(resp.access_token)
}
}
#[async_trait]
impl HeaderProvider for OAuthHeaderProvider {
async fn get_headers(&self) -> Result<HashMap<String, String>> {
let token = self.get_valid_token().await?;
Ok(HashMap::from([(
"authorization".to_string(),
format!("Bearer {token}"),
)]))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::task::JoinHandle;
#[test]
fn test_token_state_expiry() {
let mut state = TokenState::new();
assert!(state.is_expired(Duration::from_secs(0)));
state.access_token = Some("tok".to_string());
state.expires_at = Some(Instant::now() + Duration::from_secs(600));
assert!(!state.is_expired(Duration::from_secs(300)));
assert!(state.is_expired(Duration::from_secs(601)));
state.expires_at = None;
assert!(state.is_expired(Duration::from_secs(0)));
}
#[test]
fn test_token_state_uses_default_expiry() {
let mut state = TokenState::new();
let response = TokenResponse {
access_token: "tok".to_string(),
expires_in: None,
token_type: None,
};
state.update(&response);
assert!(!state.is_expired(Duration::from_secs(DEFAULT_TOKEN_TTL_SECS - 1)));
assert!(state.is_expired(Duration::from_secs(DEFAULT_TOKEN_TTL_SECS + 1)));
}
#[test]
fn test_token_response_accepts_float_expires_in() {
let response: TokenResponse =
serde_json::from_str(r#"{"access_token":"tok","expires_in":3600.0}"#).unwrap();
assert_eq!(response.expires_in, Some(3600));
}
#[test]
fn test_token_response_rejects_negative_expires_in() {
let err =
serde_json::from_str::<TokenResponse>(r#"{"access_token":"tok","expires_in":-1}"#)
.unwrap_err();
assert!(err.to_string().contains("invalid expires_in value: -1"));
}
#[test]
fn test_token_response_debug_redacts_access_token() {
let response = TokenResponse {
access_token: "secret-token".to_string(),
expires_in: Some(3600),
token_type: Some("Bearer".to_string()),
};
let debug = format!("{response:?}");
assert!(!debug.contains("secret-token"));
assert!(debug.contains("access_token: \"<redacted>\""));
}
#[test]
fn test_scopes_string() {
let source = ClientCredentialsSource::new(
"https://login.microsoftonline.com/tenant/v2.0".to_string(),
"app-id".to_string(),
Some("secret".to_string()),
vec!["scope1".to_string(), "scope2".to_string()],
)
.unwrap();
assert_eq!(source.scopes_string(), "scope1 scope2");
}
#[test]
fn test_oauth_config_debug_redacts_client_secret() {
let config = OAuthConfig {
issuer_url: "https://issuer.example.com".to_string(),
client_id: "client-id".to_string(),
client_secret: Some("super-secret".to_string()),
scopes: vec!["scope".to_string()],
flow: OAuthFlow::ClientCredentials,
refresh_buffer_secs: None,
};
let debug = format!("{config:?}");
assert!(!debug.contains("super-secret"));
assert!(debug.contains("client_secret: Some(\"<redacted>\")"));
}
#[test]
fn test_oauth_header_provider_debug_redacts_client_secret() {
let config = OAuthConfig {
issuer_url: "https://issuer.example.com".to_string(),
client_id: "client-id".to_string(),
client_secret: Some("super-secret".to_string()),
scopes: vec!["scope".to_string()],
flow: OAuthFlow::ClientCredentials,
refresh_buffer_secs: None,
};
let provider = OAuthHeaderProvider::new(config).unwrap();
let debug = format!("{provider:?}");
assert!(!debug.contains("super-secret"));
assert!(debug.contains("client_secret: \"<redacted>\""));
}
#[test]
fn test_managed_identity_resource_from_default_scope() {
assert_eq!(
AzureImdsSource::resource_from_scopes(&["api://test/.default".to_string()]).unwrap(),
"api://test"
);
}
#[test]
fn test_managed_identity_resource_without_default_suffix() {
assert_eq!(
AzureImdsSource::resource_from_scopes(&["api://test".to_string()]).unwrap(),
"api://test"
);
}
#[test]
fn test_managed_identity_rejects_multiple_scopes() {
let config = OAuthConfig {
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
client_id: "app-id".to_string(),
client_secret: None,
scopes: vec![
"api://test-a/.default".to_string(),
"api://test-b/.default".to_string(),
],
flow: OAuthFlow::AzureManagedIdentity { client_id: None },
refresh_buffer_secs: None,
};
assert!(OAuthHeaderProvider::new(config).is_err());
}
#[tokio::test]
async fn test_token_endpoint_requires_discovery_success() {
let (issuer_url, server) = spawn_discovery_error_server().await;
let source = ClientCredentialsSource::new(
issuer_url,
"client-id".to_string(),
Some("secret".to_string()),
vec!["scope".to_string()],
)
.unwrap();
let err = source.get_token_endpoint().await.unwrap_err();
assert!(matches!(
err,
Error::Runtime { message }
if message.contains("OIDC discovery failed with status 503")
));
server.await.unwrap();
}
#[test]
fn test_client_credentials_requires_secret() {
let config = OAuthConfig {
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
client_id: "app-id".to_string(),
client_secret: None,
scopes: vec!["scope".to_string()],
flow: OAuthFlow::ClientCredentials,
refresh_buffer_secs: None,
};
assert!(OAuthHeaderProvider::new(config).is_err());
}
#[test]
fn test_client_credentials_rejects_insecure_non_loopback_issuer() {
let config = OAuthConfig {
issuer_url: "http://issuer.example.com".to_string(),
client_id: "app-id".to_string(),
client_secret: Some("secret".to_string()),
scopes: vec!["scope".to_string()],
flow: OAuthFlow::ClientCredentials,
refresh_buffer_secs: None,
};
let err = OAuthHeaderProvider::new(config).unwrap_err();
assert!(matches!(
err,
Error::InvalidInput { message }
if message == "ClientCredentials OAuth issuer_url must use https, except for loopback hosts"
));
}
#[test]
fn test_empty_scopes_rejected() {
let config = OAuthConfig {
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
client_id: "app-id".to_string(),
client_secret: None,
scopes: vec![],
flow: OAuthFlow::AzureManagedIdentity { client_id: None },
refresh_buffer_secs: None,
};
assert!(OAuthHeaderProvider::new(config).is_err());
}
#[tokio::test]
async fn test_client_credentials_token_lifecycle() {
let (issuer_url, token_requests, server) = spawn_oauth_server().await;
let config = OAuthConfig {
issuer_url,
client_id: "client-id".to_string(),
client_secret: Some("secret".to_string()),
scopes: vec!["scope".to_string()],
flow: OAuthFlow::ClientCredentials,
refresh_buffer_secs: Some(0),
};
let provider = OAuthHeaderProvider::new(config).unwrap();
let headers = provider.get_headers().await.unwrap();
assert_eq!(headers.get("authorization").unwrap(), "Bearer token-1");
assert_eq!(token_requests.load(Ordering::SeqCst), 1);
let headers = provider.get_headers().await.unwrap();
assert_eq!(headers.get("authorization").unwrap(), "Bearer token-1");
assert_eq!(token_requests.load(Ordering::SeqCst), 1);
provider.token_state.write().await.expires_at =
Some(Instant::now() - Duration::from_secs(1));
let headers = provider.get_headers().await.unwrap();
assert_eq!(headers.get("authorization").unwrap(), "Bearer token-2");
assert_eq!(token_requests.load(Ordering::SeqCst), 2);
server.await.unwrap();
}
async fn spawn_oauth_server() -> (String, Arc<AtomicUsize>, JoinHandle<()>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let issuer_url = format!("http://{addr}");
let token_requests = Arc::new(AtomicUsize::new(0));
let server_token_requests = Arc::clone(&token_requests);
let server = tokio::spawn(async move {
for _ in 0..3 {
let (mut stream, _) = listener.accept().await.unwrap();
let (request_line, body) = read_http_request(&mut stream).await;
if request_line.starts_with("GET /.well-known/openid-configuration ") {
let discovery = format!(r#"{{"token_endpoint":"http://{addr}/token"}}"#);
write_json_response(&mut stream, "200 OK", &discovery).await;
} else if request_line.starts_with("POST /token ") {
assert!(body.contains("grant_type=client_credentials"));
assert!(body.contains("client_id=client-id"));
assert!(body.contains("client_secret=secret"));
assert!(body.contains("scope=scope"));
let token_num = server_token_requests.fetch_add(1, Ordering::SeqCst) + 1;
let token = format!(
r#"{{"access_token":"token-{token_num}","expires_in":3600,"token_type":"Bearer"}}"#
);
write_json_response(&mut stream, "200 OK", &token).await;
} else {
write_json_response(&mut stream, "404 Not Found", "{}").await;
}
}
});
(issuer_url, token_requests, server)
}
async fn spawn_discovery_error_server() -> (String, JoinHandle<()>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let issuer_url = format!("http://{addr}");
let server = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let (request_line, _) = read_http_request(&mut stream).await;
assert!(request_line.starts_with("GET /.well-known/openid-configuration "));
write_json_response(&mut stream, "503 Service Unavailable", "{}").await;
});
(issuer_url, server)
}
async fn read_http_request(stream: &mut TcpStream) -> (String, String) {
let mut buffer = Vec::new();
let mut header_end = None;
while header_end.is_none() {
let mut chunk = [0; 1024];
let read = stream.read(&mut chunk).await.unwrap();
assert_ne!(read, 0, "connection closed before request headers");
buffer.extend_from_slice(&chunk[..read]);
header_end = find_subsequence(&buffer, b"\r\n\r\n").map(|pos| pos + 4);
}
let header_end = header_end.unwrap();
let headers = String::from_utf8_lossy(&buffer[..header_end]).to_string();
let request_line = headers.lines().next().unwrap_or_default().to_string();
let content_length = headers
.lines()
.find_map(|line| {
let (name, value) = line.split_once(':')?;
name.eq_ignore_ascii_case("content-length")
.then(|| value.trim().parse::<usize>().ok())
.flatten()
})
.unwrap_or(0);
while buffer.len() < header_end + content_length {
let mut chunk = [0; 1024];
let read = stream.read(&mut chunk).await.unwrap();
assert_ne!(read, 0, "connection closed before request body");
buffer.extend_from_slice(&chunk[..read]);
}
let body =
String::from_utf8_lossy(&buffer[header_end..header_end + content_length]).to_string();
(request_line, body)
}
fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option<usize> {
haystack
.windows(needle.len())
.position(|window| window == needle)
}
async fn write_json_response(stream: &mut TcpStream, status: &str, body: &str) {
let response = format!(
"HTTP/1.1 {status}\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
body.len()
);
stream.write_all(response.as_bytes()).await.unwrap();
}
}

View File

@@ -1352,35 +1352,6 @@ impl<S: HttpSend + 'static> RemoteTable<S> {
}
}
/// Deserialize an index's `created_at` field.
///
/// The server returns this as an RFC 3339 string (e.g. `"2026-06-18T21:37:36.637Z"`),
/// but older deployments sent a unix timestamp in milliseconds. Accept both so the
/// client works against any server version.
fn deserialize_created_at<'de, D>(
deserializer: D,
) -> std::result::Result<Option<DateTime<Utc>>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error as _;
#[derive(Deserialize)]
#[serde(untagged)]
enum CreatedAt {
Rfc3339(String),
Millis(i64),
}
match Option::<CreatedAt>::deserialize(deserializer)? {
None => Ok(None),
Some(CreatedAt::Rfc3339(s)) => DateTime::parse_from_rfc3339(&s)
.map(|dt| Some(dt.with_timezone(&Utc)))
.map_err(D::Error::custom),
Some(CreatedAt::Millis(ms)) => Ok(DateTime::from_timestamp_millis(ms)),
}
}
impl<S: HttpSend + 'static> RemoteTable<S> {
/// Parse the response from `/index/list/` into `IndexConfig` entries.
///
@@ -1409,7 +1380,7 @@ impl<S: HttpSend + 'static> RemoteTable<S> {
// Used as the sentinel to decide whether to skip the stats call.
index_type: Option<IndexType>,
index_uuid: Option<String>,
#[serde(default, deserialize_with = "deserialize_created_at")]
#[serde(default, with = "chrono::serde::ts_milliseconds_option")]
created_at: Option<DateTime<Utc>>,
num_indexed_rows: Option<u64>,
num_unindexed_rows: Option<u64>,
@@ -2309,126 +2280,6 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
message: "optimize is not supported on LanceDB cloud.".into(),
})
}
async fn add_computed_columns(
&self,
columns: &[(String, String)],
expression: &str,
) -> Result<()> {
let new_columns: Vec<serde_json::Value> = columns
.iter()
.map(|(name, data_type)| {
serde_json::json!({
"name": name,
"computed": { "data_type": data_type, "expression": expression },
})
})
.collect();
let request = self
.client
.post(&format!("/v1/table/{}/add_columns/", self.identifier))
.json(&serde_json::json!({ "new_columns": new_columns }));
let (request_id, response) = self.send(request, true).await?;
self.check_table_response(&request_id, response).await?;
Ok(())
}
async fn refresh_column(
&self,
columns: &[String],
where_clause: Option<String>,
num_workers: Option<u32>,
max_workers: Option<u32>,
batch_size: Option<u32>,
priority: Option<String>,
) -> Result<String> {
let mut body = serde_json::json!({ "columns": columns });
if let Some(w) = where_clause {
body["where_clause"] = serde_json::Value::String(w);
}
if let Some(n) = num_workers {
body["num_workers"] = n.into();
}
if let Some(n) = max_workers {
body["max_workers"] = n.into();
}
if let Some(n) = batch_size {
body["batch_size"] = n.into();
}
if let Some(p) = priority {
body["priority"] = serde_json::Value::String(p);
}
let request = self
.client
.post(&format!("/v1/table/{}/refresh_column", self.identifier))
.json(&body);
let (request_id, response) = self.send(request, true).await?;
let response = self.check_table_response(&request_id, response).await?;
#[derive(serde::Deserialize)]
struct RefreshColumnResponse {
job_id: String,
}
let body: RefreshColumnResponse = response.json().await.err_to_http(request_id)?;
Ok(body.job_id)
}
async fn load_columns(&self, request: crate::table::LoadColumnsRequest) -> Result<String> {
let columns: Vec<serde_json::Value> = request
.columns
.iter()
.map(|(target, source)| {
serde_json::json!({
"target": target,
"source": source.clone().unwrap_or_else(|| target.clone()),
})
})
.collect();
let mut source = serde_json::json!({
"uris": request.source_uris,
"format": request.source_format,
});
if let Some(opts) = request.source_storage_options {
source["storage_options"] = serde_json::to_value(opts).unwrap_or_default();
}
let mut body = serde_json::json!({
"columns": columns,
"source": source,
"target_key": request.target_key,
});
if let Some(k) = request.source_key {
body["source_key"] = serde_json::Value::String(k);
}
if let Some(m) = request.on_missing {
body["on_missing"] = serde_json::Value::String(m);
}
if let Some(n) = request.num_workers {
body["num_workers"] = n.into();
}
if let Some(n) = request.max_workers {
body["max_workers"] = n.into();
}
if let Some(n) = request.batch_size {
body["batch_size"] = n.into();
}
if let Some(n) = request.commit_granularity {
body["commit_granularity"] = n.into();
}
if let Some(p) = request.priority {
body["priority"] = serde_json::Value::String(p);
}
let http_request = self
.client
.post(&format!("/v1/table/{}/load_columns", self.identifier))
.json(&body);
let (request_id, response) = self.send(http_request, true).await?;
let response = self.check_table_response(&request_id, response).await?;
#[derive(serde::Deserialize)]
struct LoadColumnsResponse {
job_id: String,
}
let body: LoadColumnsResponse = response.json().await.err_to_http(request_id)?;
Ok(body.job_id)
}
async fn add_columns(
&self,
transforms: NewColumnTransform,
@@ -2921,75 +2772,6 @@ mod tests {
}
}
#[tokio::test]
async fn test_refresh_column() {
let table = Table::new_with_handler("my_table", |request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/v1/table/my_table/refresh_column");
let body: serde_json::Value =
serde_json::from_slice(request.body().unwrap().as_bytes().unwrap()).unwrap();
assert_eq!(body["columns"], serde_json::json!(["vec"]));
assert_eq!(body["num_workers"], 2);
assert!(body.get("where_clause").is_none());
http::Response::builder()
.status(202)
.body(r#"{"job_id":"j-9"}"#)
.unwrap()
});
let job_id = table
.refresh_column(&["vec".to_string()], None, Some(2), None, None, None)
.await
.unwrap();
assert_eq!(job_id, "j-9");
}
#[tokio::test]
async fn test_load_columns() {
let table = Table::new_with_handler("my_table", |request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/v1/table/my_table/load_columns");
let body: serde_json::Value =
serde_json::from_slice(request.body().unwrap().as_bytes().unwrap()).unwrap();
assert_eq!(
body["columns"],
serde_json::json!([{"target": "embedding", "source": "emb"}])
);
assert_eq!(body["source"]["format"], "parquet");
assert_eq!(
body["source"]["uris"],
serde_json::json!(["s3://b/x.parquet"])
);
assert_eq!(body["target_key"], "document_id");
assert_eq!(body["source_key"], "doc_id");
assert_eq!(body["on_missing"], "null");
assert_eq!(body["num_workers"], 4);
http::Response::builder()
.status(202)
.body(r#"{"job_id":"lc-7"}"#)
.unwrap()
});
let request = crate::table::LoadColumnsRequest {
source_uris: vec!["s3://b/x.parquet".to_string()],
source_format: "parquet".to_string(),
source_storage_options: None,
target_key: "document_id".to_string(),
source_key: Some("doc_id".to_string()),
columns: vec![("embedding".to_string(), Some("emb".to_string()))],
on_missing: Some("null".to_string()),
num_workers: Some(4),
max_workers: None,
batch_size: None,
commit_granularity: None,
priority: None,
};
let job_id = table.load_columns(request).await.unwrap();
assert_eq!(job_id, "lc-7");
}
#[tokio::test]
async fn test_version() {
let table = Table::new_with_handler("my_table", |request| {
@@ -4896,7 +4678,7 @@ mod tests {
"num_segments": 2,
"index_version": 1,
"index_details": "{\"num_partitions\":16}",
"created_at": "2026-06-18T21:37:36.637Z",
"created_at": 1700000000000i64,
"type_url": "type.googleapis.com/lance.index.vector.IvfPq",
},
{
@@ -4946,10 +4728,7 @@ mod tests {
vec_idx.type_url,
Some("type.googleapis.com/lance.index.vector.IvfPq".to_string())
);
assert_eq!(
vec_idx.created_at,
Some("2026-06-18T21:37:36.637Z".parse::<DateTime<Utc>>().unwrap())
);
assert!(vec_idx.created_at.is_some());
let text_idx = &indices[1];
assert_eq!(text_idx.name, "text_idx");
@@ -4970,36 +4749,6 @@ mod tests {
assert_eq!(text_idx.created_at, None);
}
#[test]
fn test_deserialize_created_at() {
#[derive(Deserialize)]
struct Wrapper {
#[serde(default, deserialize_with = "deserialize_created_at")]
created_at: Option<DateTime<Utc>>,
}
// RFC 3339 string (current server format).
let w: Wrapper =
serde_json::from_str(r#"{"created_at": "2026-06-18T21:37:36.637Z"}"#).unwrap();
assert_eq!(
w.created_at,
Some("2026-06-18T21:37:36.637Z".parse::<DateTime<Utc>>().unwrap())
);
// Unix milliseconds (legacy server format).
let w: Wrapper = serde_json::from_str(r#"{"created_at": 1700000000000}"#).unwrap();
assert_eq!(w.created_at, DateTime::from_timestamp_millis(1700000000000));
// Null and missing both yield None.
let w: Wrapper = serde_json::from_str(r#"{"created_at": null}"#).unwrap();
assert_eq!(w.created_at, None);
let w: Wrapper = serde_json::from_str(r#"{}"#).unwrap();
assert_eq!(w.created_at, None);
// A malformed string is rejected rather than silently dropped to None.
assert!(serde_json::from_str::<Wrapper>(r#"{"created_at": "not-a-date"}"#).is_err());
}
#[tokio::test]
async fn test_list_versions() {
let table = Table::new_with_handler("my_table", |request| {

View File

@@ -3,7 +3,7 @@
//! LanceDB Table APIs
use arrow_array::{LargeBinaryArray, RecordBatch, RecordBatchReader};
use arrow_array::{RecordBatch, RecordBatchReader};
use arrow_schema::{Schema, SchemaRef};
use async_trait::async_trait;
use datafusion_execution::TaskContext;
@@ -12,7 +12,6 @@ use datafusion_physical_plan::ExecutionPlan;
use datafusion_physical_plan::display::DisplayableExecutionPlan;
use futures::StreamExt;
use futures::stream::FuturesUnordered;
use lance::dataset::BlobFile;
pub use lance::dataset::ColumnAlteration;
pub use lance::dataset::NewColumnTransform;
pub use lance::dataset::ReadParams;
@@ -44,7 +43,6 @@ use crate::connection::NamespaceClientPushdownOperation;
use crate::data::scannable::{PeekedScannable, Scannable, estimate_write_partitions};
use crate::database::Database;
use crate::database::read_freshness::TableFreshness;
use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MemoryRegistry};
use crate::error::{Error, Result};
use crate::index::IndexStatistics;
@@ -471,33 +469,6 @@ impl LsmWriteSpec {
}
}
/// Request to fill existing table columns from an external source by
/// primary-key join (Geneva `Table.load_columns()` parity). Server-backed
/// feature (LanceDB Enterprise / Cloud).
#[derive(Debug, Clone)]
pub struct LoadColumnsRequest {
/// External source URIs.
pub source_uris: Vec<String>,
/// Source format: "parquet" | "lance" | "ipc".
pub source_format: String,
/// Source-only storage options (e.g. cloud credentials).
pub source_storage_options: Option<HashMap<String, String>>,
/// Destination primary-key column.
pub target_key: String,
/// Source primary-key column. Defaults to `target_key` when None.
pub source_key: Option<String>,
/// Value column mappings as `(target, source)`; a None source defaults to
/// the target name.
pub columns: Vec<(String, Option<String>)>,
/// Missing-row policy: "carry" (default) | "null" | "error".
pub on_missing: Option<String>,
pub num_workers: Option<u32>,
pub max_workers: Option<u32>,
pub batch_size: Option<u32>,
pub commit_granularity: Option<u32>,
pub priority: Option<String>,
}
/// A trait for anything "table-like". This is used for both native tables (which target
/// Lance datasets) and remote tables (which target LanceDB cloud)
///
@@ -615,28 +586,6 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
async fn close_lsm_writers(&self) -> Result<()> {
Ok(())
}
/// Names of the blob v2 columns in this table, in declaration order.
async fn blob_columns(&self) -> Result<Vec<String>> {
Err(Error::NotSupported {
message: "blob_columns is not supported on this table type".into(),
})
}
/// Materialize blob bytes for the given row ids. See [`Table::fetch_blobs`].
async fn fetch_blobs(&self, _column: &str, _row_ids: &[u64]) -> Result<LargeBinaryArray> {
Err(Error::NotSupported {
message: "fetch_blobs is not supported on this table type".into(),
})
}
/// Open lazy blob handles for the given row ids. See [`Table::fetch_blob_files`].
async fn fetch_blob_files(
&self,
_column: &str,
_row_ids: &[u64],
) -> Result<Vec<Option<BlobFile>>> {
Err(Error::NotSupported {
message: "fetch_blob_files is not supported on this table type".into(),
})
}
/// Gets the table tag manager.
async fn tags(&self) -> Result<Box<dyn Tags + '_>>;
/// Optimize the dataset.
@@ -647,47 +596,6 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
transforms: NewColumnTransform,
read_columns: Option<Vec<String>>,
) -> Result<AddColumnsResult>;
/// Declare computed columns bound to a registered function: each
/// `(name, sql_type)` is added all-null with the expression stored
/// as its binding; no compute happens here (the server's lazy
/// detector or refresh_column fills them). Several columns map a
/// struct-returning function's fields positionally. Server-backed
/// feature; the default returns NotSupported.
async fn add_computed_columns(
&self,
_columns: &[(String, String)],
_expression: &str,
) -> Result<()> {
Err(Error::NotSupported {
message: "computed columns are not supported by this table".into(),
})
}
/// Trigger recompute of computed columns. The expression is
/// resolved server-side from each column's stored binding; columns
/// bound to the same struct-returning function refresh together.
/// Returns the refresh job id. Server-backed feature (LanceDB
/// Enterprise / Cloud); the default returns NotSupported.
async fn refresh_column(
&self,
_columns: &[String],
_where_clause: Option<String>,
_num_workers: Option<u32>,
_max_workers: Option<u32>,
_batch_size: Option<u32>,
_priority: Option<String>,
) -> Result<String> {
Err(Error::NotSupported {
message: "refresh_column is not supported by this table".into(),
})
}
/// Fill existing columns from an external source by primary-key join
/// (Geneva `load_columns`). Returns the load job id. Server-backed feature;
/// the default returns NotSupported.
async fn load_columns(&self, _request: LoadColumnsRequest) -> Result<String> {
Err(Error::NotSupported {
message: "load_columns is not supported by this table".into(),
})
}
/// Alter columns in the table.
async fn alter_columns(&self, alterations: &[ColumnAlteration]) -> Result<AlterColumnsResult>;
/// Drop columns from the table.
@@ -1018,76 +926,6 @@ impl Table {
self.inner.count_rows(filter.map(Filter::Sql)).await
}
/// Names of the blob v2 columns in this table, in declaration order.
///
/// Nested blobs use dotted paths (e.g. `info.blob`). Returns
/// [`Error::NotSupported`] on table types without blob support.
pub async fn blob_columns(&self) -> Result<Vec<String>> {
self.inner.blob_columns().await
}
/// Materialize blob bytes for the given row ids.
///
/// Output matches `row_ids` in length and order. Null and zero-length rows
/// are null. Prefer [`Self::fetch_blob_files`] for large selections.
///
/// ```
/// use arrow_array::UInt64Array;
/// use futures::TryStreamExt;
/// use lancedb::query::{ExecutableQuery, QueryBase};
///
/// # use lancedb::Table;
/// # async fn materialize(table: &Table) -> Result<(), Box<dyn std::error::Error>> {
/// let mut stream = table.query().with_row_id().limit(10).execute().await?;
/// while let Some(batch) = stream.try_next().await? {
/// let row_ids = batch
/// .column_by_name("_rowid")
/// .unwrap()
/// .as_any()
/// .downcast_ref::<UInt64Array>()
/// .unwrap();
/// let images = table.fetch_blobs("image", row_ids.values()).await?;
/// let _ = images;
/// }
/// # Ok(())
/// # }
/// ```
///
/// Returns [`Error::InvalidInput`] when the column does not exist or is
/// not a blob v2 column, and [`Error::NotSupported`] on table types
/// without blob support.
pub async fn fetch_blobs(
&self,
column: impl AsRef<str>,
row_ids: &[u64],
) -> Result<LargeBinaryArray> {
self.inner.fetch_blobs(column.as_ref(), row_ids).await
}
/// Open lazy [`BlobFile`] handles for the given row ids.
///
/// Same length and order as `row_ids`. Null rows are `None`. Bytes are not
/// read from disk until a call to [`BlobFile::read`].
///
/// ```
/// # use lancedb::Table;
/// # async fn lazy_read(table: &Table, row_ids: &[u64]) -> Result<(), Box<dyn std::error::Error>> {
/// let handles = table.fetch_blob_files("image", row_ids).await?;
/// if let Some(Some(first)) = handles.first() {
/// let bytes = first.read().await?;
/// println!("first blob is {} bytes", bytes.len());
/// }
/// # Ok(())
/// # }
/// ```
pub async fn fetch_blob_files(
&self,
column: impl AsRef<str>,
row_ids: &[u64],
) -> Result<Vec<Option<BlobFile>>> {
self.inner.fetch_blob_files(column.as_ref(), row_ids).await
}
/// Insert new records into this Table
///
/// # Arguments
@@ -1529,48 +1367,6 @@ impl Table {
self.inner.add_columns(transforms, read_columns).await
}
/// Declare computed columns bound to a registered function
/// (`(name, sql_type)` pairs + a `f(args)` expression). No compute
/// happens here. Server-backed feature.
pub async fn add_computed_columns(
&self,
columns: &[(String, String)],
expression: &str,
) -> Result<()> {
self.inner.add_computed_columns(columns, expression).await
}
/// Trigger recompute of computed columns (REFRESH COLUMN). The
/// expression comes from each column's stored binding; columns
/// bound to the same struct-returning function refresh together.
/// Returns the refresh job id. Server-backed feature.
pub async fn refresh_column(
&self,
columns: &[String],
where_clause: Option<String>,
num_workers: Option<u32>,
max_workers: Option<u32>,
batch_size: Option<u32>,
priority: Option<String>,
) -> Result<String> {
self.inner
.refresh_column(
columns,
where_clause,
num_workers,
max_workers,
batch_size,
priority,
)
.await
}
/// Fill existing columns from an external Parquet/Lance/IPC source by
/// primary-key join (Geneva `Table.load_columns()`). Returns the job id.
pub async fn load_columns(&self, request: LoadColumnsRequest) -> Result<String> {
self.inner.load_columns(request).await
}
/// Change a column's name or nullability.
pub async fn alter_columns(
&self,
@@ -1967,8 +1763,6 @@ pub struct NativeTable {
// Operations to push down to the namespace server.
// pub(crate) so query.rs can access the field for server-side query execution.
pub(crate) pushdown_operations: HashSet<NamespaceClientPushdownOperation>,
// Read-freshness baseline; `Some` only for namespace-backed tables.
freshness: Option<TableFreshness>,
}
impl std::fmt::Debug for NativeTable {
@@ -2129,7 +1923,6 @@ impl NativeTable {
read_consistency_interval,
namespace_client,
pushdown_operations,
freshness: None,
})
}
@@ -2141,12 +1934,6 @@ impl NativeTable {
self
}
/// Attach the read-freshness baseline handle (namespace connections only).
pub(crate) fn with_freshness(mut self, freshness: TableFreshness) -> Self {
self.freshness = Some(freshness);
self
}
/// Build a sibling `NativeTable` with the same identity but a different
/// (independent) dataset wrapper — used to hand out branch-scoped handles.
fn with_dataset(&self, dataset: dataset::DatasetConsistencyWrapper) -> Self {
@@ -2159,14 +1946,6 @@ impl NativeTable {
read_consistency_interval: self.read_consistency_interval,
namespace_client: self.namespace_client.clone(),
pushdown_operations: self.pushdown_operations.clone(),
freshness: self.freshness.clone(),
}
}
/// Bump the read-freshness baseline; no-op for non-namespace tables.
fn bump_freshness(&self) {
if let Some(freshness) = &self.freshness {
freshness.bump();
}
}
@@ -2266,7 +2045,6 @@ impl NativeTable {
read_consistency_interval,
namespace_client: stored_namespace_client,
pushdown_operations,
freshness: None,
})
}
@@ -2356,7 +2134,6 @@ impl NativeTable {
read_consistency_interval,
namespace_client,
pushdown_operations,
freshness: None,
})
}
@@ -2488,7 +2265,6 @@ impl NativeTable {
read_consistency_interval,
namespace_client: stored_namespace_client,
pushdown_operations,
freshness: None,
})
}
@@ -2648,8 +2424,6 @@ impl BaseTable for NativeTable {
}
async fn checkout_latest(&self) -> Result<()> {
// Bump before resolving "latest" so that request carries the floor.
self.bump_freshness();
self.dataset.as_latest().await?;
self.dataset.reload().await
}
@@ -2737,8 +2511,6 @@ impl BaseTable for NativeTable {
debug_assert_eq!(dataset.version().version, version);
dataset.restore().await?;
}
// Restore moves "latest", so bump before resolving it (as RemoteTable does).
self.bump_freshness();
self.dataset.as_latest().await?;
Ok(())
}
@@ -2819,13 +2591,7 @@ impl BaseTable for NativeTable {
output.plan
};
let insert_exec = Arc::new(InsertExec::new_with_tracker(
ds_wrapper.clone(),
ds,
plan,
lance_params,
output.tracker.clone(),
));
let insert_exec = Arc::new(InsertExec::new(ds_wrapper.clone(), ds, plan, lance_params));
let tracker_for_tasks = output.tracker.clone();
if let Some(ref t) = tracker_for_tasks {
@@ -2858,7 +2624,6 @@ impl BaseTable for NativeTable {
}
let version = ds_wrapper.get().await?.manifest().version;
self.bump_freshness();
Ok(AddResult { version })
}
@@ -2909,9 +2674,7 @@ impl BaseTable for NativeTable {
async fn update(&self, update: UpdateBuilder) -> Result<UpdateResult> {
// Delegate to the submodule implementation
let result = update::execute_update(self, update).await?;
self.bump_freshness();
Ok(result)
update::execute_update(self, update).await
}
async fn create_plan(
@@ -2943,9 +2706,7 @@ impl BaseTable for NativeTable {
params: MergeInsertBuilder,
new_data: Box<dyn RecordBatchReader + Send>,
) -> Result<MergeResult> {
let result = merge::execute_merge_insert(self, params, new_data).await?;
self.bump_freshness();
Ok(result)
merge::execute_merge_insert(self, params, new_data).await
}
async fn set_unenforced_primary_key(&self, columns: &[&str]) -> Result<()> {
@@ -2964,30 +2725,9 @@ impl BaseTable for NativeTable {
merge::lsm::close_lsm_writers(self).await
}
async fn blob_columns(&self) -> Result<Vec<String>> {
let schema = self.schema().await?;
Ok(crate::blob::blob_column_names(schema.as_ref()))
}
async fn fetch_blobs(&self, column: &str, row_ids: &[u64]) -> Result<LargeBinaryArray> {
let dataset = self.dataset.get().await?;
crate::blob::take_blobs_aligned(&dataset, column, row_ids).await
}
async fn fetch_blob_files(
&self,
column: &str,
row_ids: &[u64],
) -> Result<Vec<Option<BlobFile>>> {
let dataset = self.dataset.get().await?;
crate::blob::take_blob_files_aligned(&dataset, column, row_ids).await
}
/// Delete rows from the table
async fn delete(&self, predicate: Predicate<'_>) -> Result<DeleteResult> {
let result = delete::execute_delete(self, predicate).await?;
self.bump_freshness();
Ok(result)
delete::execute_delete(self, predicate).await
}
async fn tags(&self) -> Result<Box<dyn Tags + '_>> {
@@ -3006,30 +2746,22 @@ impl BaseTable for NativeTable {
transforms: NewColumnTransform,
read_columns: Option<Vec<String>>,
) -> Result<AddColumnsResult> {
let result = schema_evolution::execute_add_columns(self, transforms, read_columns).await?;
self.bump_freshness();
Ok(result)
schema_evolution::execute_add_columns(self, transforms, read_columns).await
}
async fn alter_columns(&self, alterations: &[ColumnAlteration]) -> Result<AlterColumnsResult> {
let result = schema_evolution::execute_alter_columns(self, alterations).await?;
self.bump_freshness();
Ok(result)
schema_evolution::execute_alter_columns(self, alterations).await
}
async fn update_field_metadata(
&self,
updates: &[FieldMetadataUpdate],
) -> Result<UpdateFieldMetadataResult> {
let result = schema_evolution::execute_update_field_metadata(self, updates).await?;
self.bump_freshness();
Ok(result)
schema_evolution::execute_update_field_metadata(self, updates).await
}
async fn drop_columns(&self, columns: &[&str]) -> Result<DropColumnsResult> {
let result = schema_evolution::execute_drop_columns(self, columns).await?;
self.bump_freshness();
Ok(result)
schema_evolution::execute_drop_columns(self, columns).await
}
async fn list_indices(&self) -> Result<Vec<IndexConfig>> {

View File

@@ -26,9 +26,6 @@ pub enum AddDataMode {
#[default]
Append,
/// The existing table will be overwritten with the new data
///
/// On overwrite, raw binary is not coerced into a blob struct. The input
/// must declare blob v2 for the column to stay a blob column.
Overwrite,
}

View File

@@ -3,7 +3,6 @@
//! This module contains adapters to allow LanceDB tables to be used as DataFusion table providers.
mod blob_coerce;
pub mod cast;
pub mod insert;
pub mod reject_nan;

View File

@@ -1,495 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! Coerces write-path input into blob v2 struct columns.
//!
//! [`super::cast::cast_to_table_schema`] calls [`coerce_blob_expr`].
use std::sync::Arc;
use arrow_schema::{DataType, Field, FieldRef};
use datafusion::functions::core::{get_field, named_struct};
use datafusion_common::ScalarValue;
use datafusion_common::config::ConfigOptions;
use datafusion_physical_expr::ScalarFunctionExpr;
use datafusion_physical_expr::expressions::{CastExpr, Literal};
use datafusion_physical_plan::PhysicalExpr;
use crate::error::{Error, Result};
/// Build a projection expression coercing `input_expr` into the blob struct
/// declared by `table_field`, composing `named_struct` / `get_field` / `cast`.
pub(super) fn coerce_blob_expr(
input_expr: Arc<dyn PhysicalExpr>,
input_field: &Field,
table_field: &FieldRef,
config: &Arc<ConfigOptions>,
) -> Result<(Arc<dyn PhysicalExpr>, FieldRef)> {
let DataType::Struct(declared_fields) = table_field.data_type() else {
return Err(Error::InvalidInput {
message: format!(
"blob v2 column '{}' must be a struct, table declares {}",
table_field.name(),
table_field.data_type()
),
});
};
let input_struct_children = match input_field.data_type() {
DataType::Binary | DataType::LargeBinary | DataType::BinaryView => None,
DataType::Struct(children) => {
if !children
.iter()
.any(|c| c.name() == "data" || c.name() == "uri")
{
return Err(Error::InvalidInput {
message: format!(
"blob struct input for column '{}' must contain a 'data' or 'uri' child",
table_field.name()
),
});
}
Some(children)
}
other => {
return Err(Error::InvalidInput {
message: format!(
"cannot coerce column '{}' with type {} into a blob v2 struct. \
expected Binary, LargeBinary, BinaryView, or a Struct with a 'data' or 'uri' child",
table_field.name(),
other,
),
});
}
};
let mut ns_args: Vec<Arc<dyn PhysicalExpr>> = Vec::with_capacity(declared_fields.len() * 2);
for declared in declared_fields.iter() {
ns_args.push(Arc::new(Literal::new(ScalarValue::from(
declared.name().as_str(),
))));
let value: Arc<dyn PhysicalExpr> = match input_struct_children {
// Raw binary lands in `data` and everything else is a typed null.
None => {
if declared.name() == "data" {
Arc::new(CastExpr::new(
input_expr.clone(),
declared.data_type().clone(),
None,
))
} else {
typed_null(declared.data_type())?
}
}
Some(children) => match children.iter().find(|c| c.name() == declared.name()) {
Some(child) => {
let field_expr: Arc<dyn PhysicalExpr> = Arc::new(ScalarFunctionExpr::new(
&format!("get_field({})", declared.name()),
get_field(),
vec![
input_expr.clone(),
Arc::new(Literal::new(ScalarValue::from(declared.name().as_str()))),
],
Arc::new(child.as_ref().clone()),
config.clone(),
));
if child.data_type() == declared.data_type() {
field_expr
} else {
Arc::new(CastExpr::new(
field_expr,
declared.data_type().clone(),
None,
))
}
}
None => typed_null(declared.data_type())?,
},
};
ns_args.push(value);
}
let expr: Arc<dyn PhysicalExpr> = Arc::new(ScalarFunctionExpr::new(
&format!("named_struct({})", table_field.name()),
named_struct(),
ns_args,
table_field.clone(),
config.clone(),
));
Ok((expr, table_field.clone()))
}
fn typed_null(data_type: &DataType) -> Result<Arc<dyn PhysicalExpr>> {
let scalar = ScalarValue::try_from(data_type).map_err(|e| Error::InvalidInput {
message: format!("cannot build null literal for blob child type {data_type}: {e}"),
})?;
Ok(Arc::new(Literal::new(scalar)))
}
#[cfg(test)]
mod tests {
use super::super::cast::cast_to_table_schema;
use super::*;
use crate::blob::blob;
use arrow_array::{
Array, ArrayRef, BinaryArray, BinaryViewArray, Int32Array, Int64Array, LargeBinaryArray,
RecordBatch, StringArray, StructArray, UInt8Array, UInt64Array,
};
use arrow_schema::Schema;
use datafusion::prelude::SessionContext;
use datafusion_catalog::MemTable;
use datafusion_physical_plan::ExecutionPlan;
use futures::TryStreamExt;
use lance_arrow::FieldExt;
use std::collections::HashMap;
fn wide_blob_field(name: &str) -> Field {
Field::new(
name,
DataType::Struct(
vec![
Field::new("data", DataType::LargeBinary, true),
Field::new("uri", DataType::Utf8, true),
Field::new("position", DataType::UInt64, true),
Field::new("size", DataType::UInt64, true),
]
.into(),
),
true,
)
.with_metadata(HashMap::from([(
"ARROW:extension:name".to_string(),
"lance.blob.v2".to_string(),
)]))
}
fn blob_table_schema() -> Schema {
Schema::new(vec![
Field::new("id", DataType::Int64, false),
blob("image", true),
])
}
fn batch_with_image(image_field: Field, image: ArrayRef) -> RecordBatch {
let len = image.len();
RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
image_field,
])),
vec![Arc::new(Int64Array::from_iter_values(0..len as i64)), image],
)
.unwrap()
}
fn image_struct(batch: &RecordBatch) -> &StructArray {
batch
.column_by_name("image")
.unwrap()
.as_any()
.downcast_ref::<StructArray>()
.unwrap()
}
async fn plan_from_batch(batch: RecordBatch) -> Arc<dyn ExecutionPlan> {
let schema = batch.schema();
let table = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
let ctx = SessionContext::new();
ctx.register_table("t", Arc::new(table)).unwrap();
let df = ctx.table("t").await.unwrap();
df.create_physical_plan().await.unwrap()
}
async fn coerce(batch: RecordBatch, table_schema: &Schema) -> RecordBatch {
let plan = plan_from_batch(batch).await;
let plan = cast_to_table_schema(plan, table_schema).unwrap();
let ctx = SessionContext::new();
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
arrow_select::concat::concat_batches(&plan.schema(), &batches).unwrap()
}
async fn coerce_err(batch: RecordBatch, table_schema: &Schema) -> Error {
let plan = plan_from_batch(batch).await;
cast_to_table_schema(plan, table_schema).unwrap_err()
}
#[tokio::test]
async fn large_binary_coerces_to_declared_blob_struct() {
let batch = batch_with_image(
Field::new("image", DataType::LargeBinary, true),
Arc::new(LargeBinaryArray::from_iter_values([b"hello".as_slice()])),
);
let coerced = coerce(batch, &blob_table_schema()).await;
let image_field = coerced.schema().field_with_name("image").unwrap().clone();
assert!(image_field.is_blob_v2());
assert!(matches!(image_field.data_type(), DataType::Struct(_)));
let data = image_struct(&coerced).column_by_name("data").unwrap();
let data: &LargeBinaryArray = data.as_any().downcast_ref().unwrap();
assert_eq!(data.value(0), b"hello");
}
#[tokio::test]
async fn binary_coerces_to_declared_blob_struct() {
let batch = batch_with_image(
Field::new("image", DataType::Binary, true),
Arc::new(BinaryArray::from_iter_values([b"hi".as_slice()])),
);
let coerced = coerce(batch, &blob_table_schema()).await;
assert!(
coerced
.schema()
.field_with_name("image")
.unwrap()
.is_blob_v2()
);
}
#[tokio::test]
async fn binary_view_coerces_to_declared_blob_struct() {
let batch = batch_with_image(
Field::new("image", DataType::BinaryView, true),
Arc::new(BinaryViewArray::from_iter_values([b"view".as_slice()])),
);
let coerced = coerce(batch, &blob_table_schema()).await;
let data = image_struct(&coerced).column_by_name("data").unwrap();
let data: &LargeBinaryArray = data.as_any().downcast_ref().unwrap();
assert_eq!(data.value(0), b"view");
}
#[tokio::test]
async fn binary_nulls_stay_null_after_coercion() {
let batch = batch_with_image(
Field::new("image", DataType::Binary, true),
Arc::new(BinaryArray::from_iter(vec![
Some(b"present".as_slice()),
None,
])),
);
let coerced = coerce(batch, &blob_table_schema()).await;
let image = image_struct(&coerced);
let data = image.column_by_name("data").unwrap();
assert!(!data.is_null(0));
assert!(data.is_null(1));
}
#[tokio::test]
async fn binary_coerces_into_four_child_blob_layout() {
let table_schema = Schema::new(vec![
Field::new("id", DataType::Int64, false),
wide_blob_field("image"),
]);
let batch = batch_with_image(
Field::new("image", DataType::LargeBinary, true),
Arc::new(LargeBinaryArray::from_iter(vec![
Some(b"alpha".as_slice()),
None,
])),
);
let coerced = coerce(batch, &table_schema).await;
let image = image_struct(&coerced);
assert_eq!(
image.num_columns(),
4,
"coerced struct keeps the declared layout"
);
assert!(image.column_by_name("position").unwrap().is_null(0));
assert!(image.column_by_name("size").unwrap().is_null(0));
assert!(!image.column_by_name("data").unwrap().is_null(0));
assert!(image.column_by_name("data").unwrap().is_null(1));
}
#[tokio::test]
async fn prebuilt_struct_gains_blob_field_metadata() {
let DataType::Struct(children) = blob("image", true).data_type().clone() else {
unreachable!("blob field is a struct")
};
let prebuilt = StructArray::new(
children,
vec![
Arc::new(LargeBinaryArray::from_iter_values([b"prebuilt".as_slice()])),
Arc::new(StringArray::from(vec![None::<&str>])),
],
None,
);
let batch = batch_with_image(
Field::new("image", prebuilt.data_type().clone(), true),
Arc::new(prebuilt),
);
let coerced = coerce(batch, &blob_table_schema()).await;
assert!(
coerced
.schema()
.field_with_name("image")
.unwrap()
.is_blob_v2()
);
}
#[tokio::test]
async fn prebuilt_narrow_struct_widens_to_declared_layout() {
let DataType::Struct(narrow_children) = blob("image", true).data_type().clone() else {
unreachable!("blob field is a struct")
};
let prebuilt = StructArray::new(
narrow_children,
vec![
Arc::new(LargeBinaryArray::from_iter_values([b"prebuilt".as_slice()])),
Arc::new(StringArray::from(vec![None::<&str>])),
],
None,
);
let table_schema = Schema::new(vec![
Field::new("id", DataType::Int64, false),
wide_blob_field("image"),
]);
let batch = batch_with_image(
Field::new("image", prebuilt.data_type().clone(), true),
Arc::new(prebuilt),
);
let coerced = coerce(batch, &table_schema).await;
let image = image_struct(&coerced);
assert_eq!(image.num_columns(), 4);
assert!(image.column_by_name("position").unwrap().is_null(0));
assert!(image.column_by_name("size").unwrap().is_null(0));
}
#[tokio::test]
async fn external_reference_struct_preserves_uri_position_and_size() {
let prebuilt = StructArray::new(
vec![
Field::new("data", DataType::LargeBinary, true),
Field::new("uri", DataType::Utf8, true),
Field::new("position", DataType::UInt64, true),
Field::new("size", DataType::UInt64, true),
]
.into(),
vec![
Arc::new(LargeBinaryArray::from(vec![None::<&[u8]>])) as ArrayRef,
Arc::new(StringArray::from(vec![Some("s3://bucket/blob.bin")])) as ArrayRef,
Arc::new(UInt64Array::from(vec![Some(7)])) as ArrayRef,
Arc::new(UInt64Array::from(vec![Some(6)])) as ArrayRef,
],
None,
);
let table_schema = Schema::new(vec![
Field::new("id", DataType::Int64, false),
wide_blob_field("image"),
]);
let batch = batch_with_image(
Field::new("image", prebuilt.data_type().clone(), true),
Arc::new(prebuilt),
);
let coerced = coerce(batch, &table_schema).await;
let image = image_struct(&coerced);
let uri: &StringArray = image
.column_by_name("uri")
.unwrap()
.as_any()
.downcast_ref()
.unwrap();
assert_eq!(uri.value(0), "s3://bucket/blob.bin");
let position: &UInt64Array = image
.column_by_name("position")
.unwrap()
.as_any()
.downcast_ref()
.unwrap();
assert_eq!(position.value(0), 7);
let size: &UInt64Array = image
.column_by_name("size")
.unwrap()
.as_any()
.downcast_ref()
.unwrap();
assert_eq!(size.value(0), 6);
assert!(image.column_by_name("data").unwrap().is_null(0));
}
#[tokio::test]
async fn descriptor_struct_without_value_child_is_rejected() {
let descriptor = StructArray::new(
vec![
Field::new("kind", DataType::UInt8, false),
Field::new("position", DataType::UInt64, false),
Field::new("size", DataType::UInt64, false),
]
.into(),
vec![
Arc::new(UInt8Array::from(vec![0])),
Arc::new(UInt64Array::from(vec![0])),
Arc::new(UInt64Array::from(vec![0])),
],
None,
);
let batch = batch_with_image(
Field::new("image", descriptor.data_type().clone(), true),
Arc::new(descriptor),
);
let err = coerce_err(batch, &blob_table_schema()).await;
assert!(err.to_string().contains("'data' or 'uri'"));
assert!(err.to_string().contains("image"));
}
#[tokio::test]
async fn unsupported_input_type_is_rejected_with_column_name() {
let batch = batch_with_image(
Field::new("image", DataType::Utf8, true),
Arc::new(StringArray::from(vec!["not bytes"])),
);
let err = coerce_err(batch, &blob_table_schema()).await;
assert!(matches!(err, Error::InvalidInput { .. }), "got {err:?}");
assert!(err.to_string().contains("image"));
}
#[tokio::test]
async fn blob_metadata_survives_cast_of_sibling_column() {
let batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("image", DataType::LargeBinary, true),
])),
vec![
Arc::new(Int32Array::from(vec![1])),
Arc::new(LargeBinaryArray::from_iter_values([b"x".as_slice()])),
],
)
.unwrap();
let coerced = coerce(batch, &blob_table_schema()).await;
let image_field = coerced.schema().field_with_name("image").unwrap().clone();
assert!(
image_field.is_blob_v2(),
"expected blob marker on image field, got {:?}",
image_field.metadata()
);
assert_eq!(
coerced.schema().field_with_name("id").unwrap().data_type(),
&DataType::Int64
);
}
#[tokio::test]
async fn exact_blob_input_passes_through_unchanged() {
let DataType::Struct(children) = blob("image", true).data_type().clone() else {
unreachable!("blob field is a struct")
};
let image = StructArray::new(
children,
vec![
Arc::new(LargeBinaryArray::from_iter_values([b"exact".as_slice()])),
Arc::new(StringArray::from(vec![None::<&str>])),
],
None,
);
let batch = batch_with_image(blob("image", true), Arc::new(image));
let table_schema = blob_table_schema();
let input = plan_from_batch(batch).await;
let input_ptr = Arc::as_ptr(&input);
let plan = cast_to_table_schema(input, &table_schema).unwrap();
assert_eq!(Arc::as_ptr(&plan), input_ptr, "no projection inserted");
}
}

View File

@@ -13,10 +13,8 @@ use datafusion_physical_expr::expressions::{CastExpr, Literal};
use datafusion_physical_plan::expressions::Column;
use datafusion_physical_plan::projection::ProjectionExec;
use datafusion_physical_plan::{ExecutionPlan, PhysicalExpr};
use lance_arrow::FieldExt;
use lance_arrow::json::{is_arrow_json_field, is_json_field};
use super::blob_coerce::coerce_blob_expr;
use crate::{Error, Result};
pub fn cast_to_table_schema(
@@ -79,17 +77,6 @@ fn build_field_exprs(
continue;
}
// Blob columns accept raw binary on write; exact matches pass through below.
if table_field.is_blob_v2() && input_field.as_ref() != table_field.as_ref() {
result.push(coerce_blob_expr(
input_expr,
input_field,
table_field,
&config,
)?);
continue;
}
let expr = match (input_field.data_type(), table_field.data_type()) {
// Both are structs: recurse into sub-fields to handle subschemas and casts.
(DataType::Struct(in_children), DataType::Struct(tbl_children))

View File

@@ -4,7 +4,6 @@
//! DataFusion ExecutionPlan for inserting data into LanceDB tables.
use std::any::Any;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, LazyLock, Mutex};
use arrow_array::{RecordBatch, UInt64Array};
@@ -21,12 +20,11 @@ use datafusion_physical_plan::{
use futures::TryStreamExt;
use lance::Dataset;
use lance::dataset::transaction::{Operation, Transaction};
use lance::dataset::{CommitBuilder, InsertBuilder, WriteParams, WriteProgressFn};
use lance::dataset::{CommitBuilder, InsertBuilder, WriteParams};
use lance::io::exec::utils::InstrumentedRecordBatchStreamAdapter;
use lance_table::format::Fragment;
use crate::table::dataset::DatasetConsistencyWrapper;
use crate::table::write_progress::WriteProgressTracker;
pub(crate) static COUNT_SCHEMA: LazyLock<SchemaRef> = LazyLock::new(|| {
Arc::new(ArrowSchema::new(vec![Field::new(
@@ -83,7 +81,6 @@ pub struct InsertExec {
dataset: Arc<Dataset>,
input: Arc<dyn ExecutionPlan>,
write_params: WriteParams,
tracker: Option<Arc<WriteProgressTracker>>,
properties: Arc<PlanProperties>,
partial_transactions: Arc<Mutex<Vec<Transaction>>>,
metrics: ExecutionPlanMetricsSet,
@@ -95,16 +92,6 @@ impl InsertExec {
dataset: Arc<Dataset>,
input: Arc<dyn ExecutionPlan>,
write_params: WriteParams,
) -> Self {
Self::new_with_tracker(ds_wrapper, dataset, input, write_params, None)
}
pub(crate) fn new_with_tracker(
ds_wrapper: DatasetConsistencyWrapper,
dataset: Arc<Dataset>,
input: Arc<dyn ExecutionPlan>,
write_params: WriteParams,
tracker: Option<Arc<WriteProgressTracker>>,
) -> Self {
let schema = COUNT_SCHEMA.clone();
let num_partitions = input.output_partitioning().partition_count();
@@ -120,7 +107,6 @@ impl InsertExec {
dataset,
input,
write_params,
tracker,
properties: Arc::new(properties),
partial_transactions: Arc::new(Mutex::new(Vec::with_capacity(num_partitions))),
metrics: ExecutionPlanMetricsSet::new(),
@@ -175,12 +161,11 @@ impl ExecutionPlan for InsertExec {
"InsertExec requires exactly one child".to_string(),
));
}
Ok(Arc::new(Self::new_with_tracker(
Ok(Arc::new(Self::new(
self.ds_wrapper.clone(),
self.dataset.clone(),
children[0].clone(),
self.write_params.clone(),
self.tracker.clone(),
)))
}
@@ -191,11 +176,10 @@ impl ExecutionPlan for InsertExec {
) -> DataFusionResult<SendableRecordBatchStream> {
let input_stream = self.input.execute(partition, context)?;
let dataset = self.dataset.clone();
let mut write_params = self.write_params.clone();
let write_params = self.write_params.clone();
let partial_transactions = self.partial_transactions.clone();
let total_partitions = self.input.output_partitioning().partition_count();
let ds_wrapper = self.ds_wrapper.clone();
let tracker = self.tracker.clone();
let output_bytes = MetricBuilder::new(&self.metrics).output_bytes(partition);
let input_schema = input_stream.schema();
@@ -211,20 +195,6 @@ impl ExecutionPlan for InsertExec {
));
let stream = futures::stream::once(async move {
if let Some(tracker) = tracker
&& write_params.write_progress.is_none()
{
let last_bytes = Arc::new(AtomicU64::new(0));
write_params.write_progress = Some(WriteProgressFn::new(move |stats| {
let previous = last_bytes.swap(stats.bytes_written, Ordering::Relaxed);
if stats.bytes_written > previous {
let delta =
usize::try_from(stats.bytes_written - previous).unwrap_or(usize::MAX);
tracker.record_bytes(delta);
}
}));
}
let transaction = InsertBuilder::new(dataset.clone())
.with_params(&write_params)
.execute_uncommitted_stream(input_stream)

View File

@@ -518,10 +518,6 @@ mod tests {
let wrapper = DatasetConsistencyWrapper::new_latest(ds, Some(Duration::from_millis(200)));
// Freeze `cached_at` on the mock clock so a slow external write below can't
// expire the TTL before the explicit advance_by() does (flake on loaded CI).
clock::pin();
// Populate the cache
let v1 = wrapper.get().await.unwrap().version().version;
assert_eq!(v1, 1);

View File

@@ -579,45 +579,24 @@ fn array_to_f32_vec(arr: &Arc<dyn arrow_array::Array>) -> Result<Vec<f32>> {
})
}
/// Magic bytes that prefix (and suffix) the Arrow IPC *file* format.
const ARROW_IPC_FILE_MAGIC: &[u8] = b"ARROW1";
/// Parse Arrow IPC response from the namespace server.
///
/// The server may return either the Arrow IPC *file* format or the *stream*
/// format. REST/phalanx returns the file format (it begins with the `ARROW1`
/// magic); reading that with a `StreamReader` fails with "failed to fill whole
/// buffer". Detect the magic and pick the matching reader so both are handled.
async fn parse_arrow_ipc_response(bytes: bytes::Bytes) -> Result<DatasetRecordBatchStream> {
use arrow_ipc::reader::{FileReader, StreamReader};
use arrow_ipc::reader::StreamReader;
use std::io::Cursor;
let (schema, batches) = if bytes.starts_with(ARROW_IPC_FILE_MAGIC) {
let reader = FileReader::try_new(Cursor::new(bytes), None).map_err(|e| Error::Runtime {
message: format!("Failed to parse Arrow IPC file response: {}", e),
let cursor = Cursor::new(bytes);
let reader = StreamReader::try_new(cursor, None).map_err(|e| Error::Runtime {
message: format!("Failed to parse Arrow IPC response: {}", e),
})?;
// Collect all record batches
let schema = reader.schema();
let batches: Vec<_> = reader
.into_iter()
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| Error::Runtime {
message: format!("Failed to read Arrow IPC batches: {}", e),
})?;
let schema = reader.schema();
let batches = reader
.into_iter()
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| Error::Runtime {
message: format!("Failed to read Arrow IPC file batches: {}", e),
})?;
(schema, batches)
} else {
let reader =
StreamReader::try_new(Cursor::new(bytes), None).map_err(|e| Error::Runtime {
message: format!("Failed to parse Arrow IPC response: {}", e),
})?;
let schema = reader.schema();
let batches = reader
.into_iter()
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| Error::Runtime {
message: format!("Failed to read Arrow IPC batches: {}", e),
})?;
(schema, batches)
};
// Create a stream from the batches
let stream = futures::stream::iter(batches.into_iter().map(Ok));
@@ -645,59 +624,6 @@ mod tests {
FixedSizeListArray::try_new_from_values(Float32Array::from(values), dimension).unwrap()
}
#[tokio::test]
async fn test_parse_arrow_ipc_response_handles_file_and_stream() {
use arrow_array::{Int32Array, RecordBatch};
use arrow_ipc::writer::{FileWriter, StreamWriter};
use arrow_schema::{DataType, Field, Schema};
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef],
)
.unwrap();
// Arrow IPC *file* format -- what REST/phalanx returns. Previously this
// failed with "failed to fill whole buffer" because we used a StreamReader.
let mut file_buf = Vec::new();
{
let mut writer = FileWriter::try_new(&mut file_buf, &schema).unwrap();
writer.write(&batch).unwrap();
writer.finish().unwrap();
}
assert!(file_buf.starts_with(ARROW_IPC_FILE_MAGIC));
let rows: usize = parse_arrow_ipc_response(bytes::Bytes::from(file_buf))
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap()
.iter()
.map(|b| b.num_rows())
.sum();
assert_eq!(rows, 3);
// Arrow IPC *stream* format must still parse.
let mut stream_buf = Vec::new();
{
let mut writer = StreamWriter::try_new(&mut stream_buf, &schema).unwrap();
writer.write(&batch).unwrap();
writer.finish().unwrap();
}
assert!(!stream_buf.starts_with(ARROW_IPC_FILE_MAGIC));
let rows: usize = parse_arrow_ipc_response(bytes::Bytes::from(stream_buf))
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap()
.iter()
.map(|b| b.num_rows())
.sum();
assert_eq!(rows, 3);
}
#[test]
fn test_convert_to_namespace_query_vector() {
let query_vector = Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0, 4.0]));

View File

@@ -142,21 +142,11 @@ impl WriteProgressTracker {
cb(&progress);
}
/// Record wire bytes from the insert layer.
///
/// These bytes may be IPC-encoded bytes for remote writes or bytes handed
/// to Lance's local writer. When wire bytes are recorded, they take
/// precedence over the in-memory Arrow bytes tracked by [`record_batch`].
/// Record wire bytes from the insert layer (e.g. IPC-encoded bytes for
/// remote writes). When wire bytes are recorded, they take precedence over
/// the in-memory Arrow bytes tracked by [`record_batch`].
pub fn record_bytes(&self, bytes: usize) {
self.wire_bytes.fetch_add(bytes, Ordering::Relaxed);
let mut cb = self.callback.lock().unwrap_or_else(|e| e.into_inner());
let guard = self
.rows_and_bytes
.lock()
.unwrap_or_else(|e| e.into_inner());
let progress = self.snapshot(guard.0, guard.1, false);
drop(guard);
cb(&progress);
}
/// Emit the final progress callback indicating the write is complete.
@@ -179,6 +169,8 @@ impl WriteProgressTracker {
let wire = self.wire_bytes.load(Ordering::Relaxed);
// Prefer wire bytes (actual I/O size) when the insert layer is
// tracking them; fall back to in-memory Arrow size otherwise.
// TODO: for local writes, track actual bytes written by Lance
// instead of using in-memory Arrow size as a proxy.
let output_bytes = if wire > 0 { wire } else { in_memory_bytes };
WriteProgress {
elapsed: self.start.elapsed(),
@@ -391,54 +383,6 @@ mod tests {
}
}
#[tokio::test]
async fn test_progress_uses_lance_write_bytes_for_local_tables() {
let dir = tempfile::tempdir().unwrap();
let db = connect(dir.path().to_str().unwrap())
.execute()
.await
.unwrap();
let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap();
let table = db
.create_table("local_write_bytes", batch)
.execute()
.await
.unwrap();
let new_data = record_batch!(("id", Int32, [4, 5, 6])).unwrap();
let in_memory_bytes = new_data.get_array_memory_size();
let final_bytes = Arc::new(AtomicUsize::new(0));
let seen_non_memory_bytes = Arc::new(std::sync::atomic::AtomicBool::new(false));
let final_bytes_cb = final_bytes.clone();
let seen_non_memory_bytes_cb = seen_non_memory_bytes.clone();
table
.add(new_data)
.write_parallelism(1)
.progress(move |p| {
if p.output_bytes() > 0 && p.output_bytes() != in_memory_bytes {
seen_non_memory_bytes_cb.store(true, Ordering::SeqCst);
}
if p.done() {
final_bytes_cb.store(p.output_bytes(), Ordering::SeqCst);
}
})
.execute()
.await
.unwrap();
assert!(
seen_non_memory_bytes.load(Ordering::SeqCst),
"progress should report Lance writer bytes, not only Arrow memory bytes"
);
assert_ne!(
final_bytes.load(Ordering::SeqCst),
in_memory_bytes,
"final progress bytes should come from Lance write stats"
);
}
#[test]
fn test_record_batch_recovers_from_poisoned_callback_lock() {
use super::{ProgressCallback, WriteProgressTracker};

View File

@@ -329,15 +329,6 @@ pub mod clock {
});
}
/// Start mock time at the current instant if not already pinned.
pub fn pin() {
MOCK_NOW.with(|mock| {
if mock.get().is_none() {
mock.set(Some(Instant::now()));
}
});
}
#[allow(dead_code)]
pub fn clear_mock() {
MOCK_NOW.with(|mock| mock.set(None));

View File

@@ -1,949 +0,0 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::Arc;
use arrow_array::{
Array, ArrayRef, BinaryArray, Int64Array, LargeBinaryArray, RecordBatch, StringArray,
StructArray, UInt64Array,
};
use arrow_schema::{DataType, Field, Fields, Schema};
use futures::TryStreamExt;
use lance_encoding::version::LanceFileVersion;
use lancedb::{
Connection, Error, Result, Table,
blob::blob,
connect, connect_namespace,
database::listing::OPT_NEW_TABLE_ENABLE_STABLE_ROW_IDS,
query::{ExecutableQuery, QueryBase},
table::{AddDataMode, CompactionOptions, OptimizeAction},
};
use tempfile::tempdir;
fn blob_table_schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
blob("image", true),
]))
}
fn binary_input_batch(ids: &[i64], payloads: &[Option<&[u8]>]) -> RecordBatch {
RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("image", DataType::LargeBinary, true),
])),
vec![
Arc::new(Int64Array::from(ids.to_vec())),
Arc::new(LargeBinaryArray::from_iter(payloads.iter().copied())),
],
)
.unwrap()
}
async fn create_inline_blob_table(
db: &Connection,
name: &str,
ids: &[i64],
payloads: &[Option<&[u8]>],
) -> Result<Table> {
let table = db
.create_empty_table(name, blob_table_schema())
.execute()
.await?;
table
.add(binary_input_batch(ids, payloads))
.execute()
.await?;
Ok(table)
}
async fn storage_format_version(table: &Table) -> LanceFileVersion {
table
.as_native()
.unwrap()
.manifest()
.await
.unwrap()
.data_storage_format
.lance_file_version()
.unwrap()
.resolve()
}
async fn uses_stable_row_ids(table: &Table) -> bool {
table
.as_native()
.unwrap()
.manifest()
.await
.unwrap()
.uses_stable_row_ids()
}
async fn query_image_struct(table: &Table) -> StructArray {
let batches = table
.query()
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch = arrow_select::concat::concat_batches(&batches[0].schema(), &batches).unwrap();
batch
.column_by_name("image")
.expect("image column present")
.as_any()
.downcast_ref::<StructArray>()
.expect("image column is a descriptor struct")
.clone()
}
#[tokio::test]
async fn declaring_blob_column_bumps_format_and_enables_stable_row_ids() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let table = db
.create_empty_table("t", blob_table_schema())
.execute()
.await?;
assert!(storage_format_version(&table).await >= LanceFileVersion::V2_2);
assert!(uses_stable_row_ids(&table).await);
Ok(())
}
#[tokio::test]
async fn explicit_stable_row_id_setting_wins_over_blob_default() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let table = db
.create_empty_table("t", blob_table_schema())
.storage_option(OPT_NEW_TABLE_ENABLE_STABLE_ROW_IDS, "false")
.execute()
.await?;
assert!(storage_format_version(&table).await >= LanceFileVersion::V2_2);
assert!(!uses_stable_row_ids(&table).await);
Ok(())
}
#[tokio::test]
async fn non_blob_table_keeps_default_format_and_row_id_setting() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
let table = db.create_empty_table("t", schema).execute().await?;
assert!(storage_format_version(&table).await < LanceFileVersion::V2_2);
assert!(!uses_stable_row_ids(&table).await);
Ok(())
}
#[tokio::test]
async fn creating_with_blob_data_bumps_format() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let blob_field = blob("image", true);
let DataType::Struct(children) = blob_field.data_type().clone() else {
unreachable!("blob field is a struct")
};
let image = StructArray::new(
children,
vec![
Arc::new(LargeBinaryArray::from_iter_values([b"payload".as_slice()])),
Arc::new(StringArray::from(vec![None::<&str>])),
],
None,
);
let batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
blob_field,
])),
vec![Arc::new(Int64Array::from(vec![1])), Arc::new(image)],
)
.unwrap();
let table = db.create_table("t", batch).execute().await?;
assert!(storage_format_version(&table).await >= LanceFileVersion::V2_2);
assert!(uses_stable_row_ids(&table).await);
assert_eq!(table.count_rows(None).await?, 1);
Ok(())
}
#[tokio::test]
async fn add_coerces_large_binary_into_blob_column() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let table =
create_inline_blob_table(&db, "t", &[1, 2], &[Some(b"cat".as_slice()), Some(b"dog")])
.await?;
assert_eq!(table.count_rows(None).await?, 2);
let image = query_image_struct(&table).await;
assert_eq!(image.len(), 2);
let schema = table.schema().await?;
let field = schema.field_with_name("image").unwrap();
assert_eq!(
field
.metadata()
.get("ARROW:extension:name")
.map(String::as_str),
Some("lance.blob.v2")
);
Ok(())
}
#[tokio::test]
async fn add_coerces_binary_into_blob_column() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let table = db
.create_empty_table("t", blob_table_schema())
.execute()
.await?;
let batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("image", DataType::Binary, true),
])),
vec![
Arc::new(Int64Array::from(vec![1])),
Arc::new(BinaryArray::from_iter_values([b"small".as_slice()])),
],
)
.unwrap();
table.add(batch).execute().await?;
assert_eq!(table.count_rows(None).await?, 1);
Ok(())
}
#[tokio::test]
async fn add_accepts_null_blob_rows() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let table = create_inline_blob_table(
&db,
"t",
&[1, 2, 3],
&[Some(b"first".as_slice()), None, Some(b"third")],
)
.await?;
assert_eq!(table.count_rows(None).await?, 3);
let image = query_image_struct(&table).await;
assert_eq!(image.len(), 3);
Ok(())
}
#[tokio::test]
async fn add_rejects_uncoercible_blob_input() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let table = db
.create_empty_table("t", blob_table_schema())
.execute()
.await?;
let batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("image", DataType::Utf8, true),
])),
vec![
Arc::new(Int64Array::from(vec![1])),
Arc::new(StringArray::from(vec!["not bytes"])),
],
)
.unwrap();
let err = table.add(batch).execute().await.unwrap_err();
assert!(err.to_string().contains("image"));
Ok(())
}
#[tokio::test]
async fn connection_level_stable_row_id_setting_wins_over_blob_default() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap())
.storage_option(OPT_NEW_TABLE_ENABLE_STABLE_ROW_IDS, "false")
.execute()
.await?;
let table = db
.create_empty_table("t", blob_table_schema())
.execute()
.await?;
assert!(storage_format_version(&table).await >= LanceFileVersion::V2_2);
assert!(!uses_stable_row_ids(&table).await);
Ok(())
}
#[tokio::test]
async fn namespace_create_applies_blob_defaults() -> Result<()> {
let tmp = tempdir().unwrap();
let mut properties = std::collections::HashMap::new();
properties.insert("root".to_string(), tmp.path().to_str().unwrap().to_string());
let db = connect_namespace("dir", properties).execute().await?;
let table = db
.create_empty_table("t", blob_table_schema())
.execute()
.await?;
assert!(storage_format_version(&table).await >= LanceFileVersion::V2_2);
assert!(uses_stable_row_ids(&table).await);
Ok(())
}
// Overwrite takes the input schema as-is. A raw-binary overwrite drops the blob
// marker; re-declaring blob v2 in the input restores it.
#[tokio::test]
async fn overwrite_replaces_blob_schema_with_input_schema() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let table = create_inline_blob_table(&db, "t", &[1], &[Some(b"blob".as_slice())]).await?;
let raw_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("image", DataType::LargeBinary, true),
]));
let raw_batch = RecordBatch::try_new(
raw_schema.clone(),
vec![
Arc::new(Int64Array::from(vec![2])),
Arc::new(LargeBinaryArray::from_iter_values([b"plain".as_slice()])),
],
)
.unwrap();
table
.add(raw_batch)
.mode(AddDataMode::Overwrite)
.execute()
.await?;
let schema = table.schema().await?;
assert_eq!(schema, raw_schema);
assert!(
!schema
.field_with_name("image")
.unwrap()
.metadata()
.contains_key("ARROW:extension:name")
);
let blob_field = blob("image", true);
let DataType::Struct(children) = blob_field.data_type().clone() else {
unreachable!("blob field is a struct")
};
let image = StructArray::new(
children,
vec![
Arc::new(LargeBinaryArray::from_iter_values([b"declared".as_slice()])),
Arc::new(StringArray::from(vec![None::<&str>])),
],
None,
);
let declared_batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
blob_field,
])),
vec![Arc::new(Int64Array::from(vec![3])), Arc::new(image)],
)
.unwrap();
table
.add(declared_batch)
.mode(AddDataMode::Overwrite)
.execute()
.await?;
let schema = table.schema().await?;
assert_eq!(
schema
.field_with_name("image")
.unwrap()
.metadata()
.get("ARROW:extension:name")
.map(String::as_str),
Some("lance.blob.v2")
);
Ok(())
}
async fn collect_row_ids(table: &Table) -> Result<Vec<u64>> {
let batches = table
.query()
.with_row_id()
.execute()
.await?
.try_collect::<Vec<_>>()
.await?;
let batch = arrow_select::concat::concat_batches(&batches[0].schema(), &batches).unwrap();
Ok(batch
.column_by_name("_rowid")
.unwrap()
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap()
.values()
.to_vec())
}
async fn collect_id_rowid(table: &Table) -> Result<Vec<(i64, u64)>> {
let batches = table
.query()
.with_row_id()
.execute()
.await?
.try_collect::<Vec<_>>()
.await?;
let batch = arrow_select::concat::concat_batches(&batches[0].schema(), &batches).unwrap();
let ids = batch
.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
let row_ids = batch
.column_by_name("_rowid")
.unwrap()
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap();
Ok(ids
.values()
.iter()
.copied()
.zip(row_ids.values().iter().copied())
.collect())
}
#[tokio::test]
async fn fetch_blobs_round_trips_bytes() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let payload: &[u8] = b"blob-round-trip-payload";
let table = create_inline_blob_table(&db, "t", &[1], &[Some(payload)]).await?;
let ids = collect_row_ids(&table).await?;
let bytes = table.fetch_blobs("image", &ids).await?;
assert_eq!(bytes.len(), 1);
assert_eq!(bytes.value(0), payload);
Ok(())
}
#[tokio::test]
async fn fetch_blobs_round_trips_nested_blob_column() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let blob_field = blob("blob", true);
let DataType::Struct(blob_children) = blob_field.data_type().clone() else {
unreachable!("blob field is a struct")
};
let blob_array = StructArray::new(
blob_children,
vec![
Arc::new(LargeBinaryArray::from_iter_values([
b"hello".as_slice(),
b"world".as_slice(),
])) as ArrayRef,
Arc::new(StringArray::from(vec![None::<&str>, None::<&str>])) as ArrayRef,
],
None,
);
let info_fields: Fields = vec![Field::new("name", DataType::Utf8, false), blob_field].into();
let info_array = StructArray::new(
info_fields.clone(),
vec![
Arc::new(StringArray::from(vec!["a", "b"])) as ArrayRef,
Arc::new(blob_array) as ArrayRef,
],
None,
);
let schema = Arc::new(Schema::new(vec![Field::new(
"info",
DataType::Struct(info_fields),
true,
)]));
let batch = RecordBatch::try_new(schema, vec![Arc::new(info_array) as ArrayRef]).unwrap();
let table = db.create_table("t", batch).execute().await?;
assert!(storage_format_version(&table).await >= LanceFileVersion::V2_2);
assert!(uses_stable_row_ids(&table).await);
let ids = collect_row_ids(&table).await?;
let bytes = table.fetch_blobs("info.blob", &ids).await?;
assert_eq!(bytes.len(), 2);
let values: std::collections::HashSet<&[u8]> =
(0..bytes.len()).map(|i| bytes.value(i)).collect();
assert!(values.contains(b"hello".as_slice()));
assert!(values.contains(b"world".as_slice()));
Ok(())
}
#[tokio::test]
async fn blob_columns_lists_nested_dotted_paths() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let blob_field = blob("blob", true);
let info = Field::new(
"info",
DataType::Struct(vec![Field::new("name", DataType::Utf8, false), blob_field].into()),
true,
);
let schema = Arc::new(Schema::new(vec![
blob("thumbnail", true),
Field::new("id", DataType::Int64, false),
info,
]));
let table = db.create_empty_table("t", schema).execute().await?;
assert_eq!(table.blob_columns().await?, vec!["thumbnail", "info.blob"]);
Ok(())
}
#[tokio::test]
async fn blob_columns_lists_blob_fields_in_order() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let schema = Arc::new(Schema::new(vec![
blob("thumbnail", true),
Field::new("id", DataType::Int64, false),
blob("image", true),
]));
let table = db.create_empty_table("t", schema).execute().await?;
assert_eq!(table.blob_columns().await?, vec!["thumbnail", "image"]);
let plain = db
.create_empty_table(
"plain",
Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)])),
)
.execute()
.await?;
assert!(plain.blob_columns().await?.is_empty());
Ok(())
}
#[tokio::test]
async fn fetch_blobs_preserves_null_alignment() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let table = create_inline_blob_table(
&db,
"t",
&[1, 2, 3, 4],
&[Some(b"a".as_slice()), None, Some(b"c"), None],
)
.await?;
let pairs = collect_id_rowid(&table).await?;
let ids: Vec<u64> = pairs.iter().map(|(_, rowid)| *rowid).collect();
let bytes = table.fetch_blobs("image", &ids).await?;
assert_eq!(bytes.len(), ids.len());
for (i, (id, _)) in pairs.iter().enumerate() {
match id {
1 => assert_eq!(bytes.value(i), b"a"),
2 | 4 => assert!(bytes.is_null(i)),
3 => assert_eq!(bytes.value(i), b"c"),
_ => unreachable!(),
}
}
Ok(())
}
#[tokio::test]
async fn fetch_blobs_all_null_column_returns_all_nulls() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let table = create_inline_blob_table(&db, "t", &[1, 2], &[None, None]).await?;
let ids = collect_row_ids(&table).await?;
let bytes = table.fetch_blobs("image", &ids).await?;
assert_eq!(bytes.len(), 2);
assert_eq!(bytes.null_count(), 2);
let files = table.fetch_blob_files("image", &ids).await?;
assert_eq!(files.len(), 2);
assert!(files.iter().all(Option::is_none));
Ok(())
}
#[tokio::test]
async fn fetch_blobs_aligns_with_reordered_and_duplicate_ids() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let table = create_inline_blob_table(
&db,
"t",
&[1, 2, 3],
&[Some(b"one".as_slice()), Some(b"two"), Some(b"three")],
)
.await?;
let pairs = collect_id_rowid(&table).await?;
let by_id = |want: i64| pairs.iter().find(|(id, _)| *id == want).unwrap().1;
let request = vec![by_id(3), by_id(1), by_id(3), by_id(2)];
let bytes = table.fetch_blobs("image", &request).await?;
assert_eq!(bytes.len(), 4);
assert_eq!(bytes.value(0), b"three");
assert_eq!(bytes.value(1), b"one");
assert_eq!(bytes.value(2), b"three");
assert_eq!(bytes.value(3), b"two");
Ok(())
}
#[tokio::test]
async fn fetch_blobs_empty_ids_returns_empty() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let table = create_inline_blob_table(&db, "t", &[1], &[Some(b"x".as_slice())]).await?;
assert_eq!(table.fetch_blobs("image", &[]).await?.len(), 0);
assert!(table.fetch_blob_files("image", &[]).await?.is_empty());
Ok(())
}
#[tokio::test]
async fn fetch_blobs_out_of_range_id_errors_without_panic() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let table = create_inline_blob_table(&db, "t", &[1], &[Some(b"x".as_slice())]).await?;
let err = table.fetch_blobs("image", &[u64::MAX]).await.unwrap_err();
assert!(err.to_string().contains("row ids"));
Ok(())
}
#[tokio::test]
async fn fetch_blobs_rejects_non_blob_column() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let table = create_inline_blob_table(&db, "t", &[1], &[Some(b"x".as_slice())]).await?;
let err = table.fetch_blobs("id", &[0]).await.unwrap_err();
assert!(matches!(err, Error::InvalidInput { .. }));
assert!(err.to_string().contains("'id' is not a blob column"));
let err = table.fetch_blob_files("id", &[0]).await.unwrap_err();
assert!(err.to_string().contains("'id' is not a blob column"));
Ok(())
}
#[tokio::test]
async fn fetch_blobs_rejects_unknown_column() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let table = create_inline_blob_table(&db, "t", &[1], &[Some(b"x".as_slice())]).await?;
let err = table.fetch_blobs("missing", &[0]).await.unwrap_err();
assert!(err.to_string().contains("no column named 'missing'"));
Ok(())
}
#[tokio::test]
async fn fetch_blobs_rejects_legacy_v1_blob_column() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let legacy = Field::new("image", DataType::LargeBinary, true).with_metadata(
std::collections::HashMap::from([("lance-encoding:blob".to_string(), "true".to_string())]),
);
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
legacy,
]));
let table = db.create_empty_table("t", schema).execute().await?;
let err = table.fetch_blobs("image", &[0]).await.unwrap_err();
assert!(err.to_string().contains("legacy blob column"));
Ok(())
}
#[tokio::test]
async fn fetch_blob_files_reads_lazily_and_aligns_nulls() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let table =
create_inline_blob_table(&db, "t", &[1, 2], &[Some(b"lazy-bytes".as_slice()), None])
.await?;
let pairs = collect_id_rowid(&table).await?;
let ids: Vec<u64> = pairs.iter().map(|(_, rowid)| *rowid).collect();
let files = table.fetch_blob_files("image", &ids).await?;
assert_eq!(files.len(), 2);
for ((id, _), file) in pairs.iter().zip(&files) {
match id {
1 => {
let handle = file.as_ref().unwrap();
assert_eq!(handle.read().await.unwrap().as_ref(), b"lazy-bytes");
}
2 => assert!(file.is_none()),
_ => unreachable!(),
}
}
Ok(())
}
#[tokio::test]
async fn fetch_blobs_reads_multiple_blob_columns_independently() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
blob("image", true),
blob("thumbnail", true),
]));
let table = db.create_empty_table("t", schema).execute().await?;
let batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("image", DataType::LargeBinary, true),
Field::new("thumbnail", DataType::LargeBinary, true),
])),
vec![
Arc::new(Int64Array::from(vec![1, 2])),
Arc::new(LargeBinaryArray::from_iter(vec![
Some(b"image-1".as_slice()),
None,
])),
Arc::new(LargeBinaryArray::from_iter(vec![
None,
Some(b"thumb-2".as_slice()),
])),
],
)
.unwrap();
table.add(batch).execute().await?;
let pairs = collect_id_rowid(&table).await?;
let ids: Vec<u64> = pairs.iter().map(|(_, rowid)| *rowid).collect();
let images = table.fetch_blobs("image", &ids).await?;
let thumbs = table.fetch_blobs("thumbnail", &ids).await?;
for (i, (id, _)) in pairs.iter().enumerate() {
match id {
1 => {
assert_eq!(images.value(i), b"image-1");
assert!(thumbs.is_null(i));
}
2 => {
assert!(images.is_null(i));
assert_eq!(thumbs.value(i), b"thumb-2");
}
_ => unreachable!(),
}
}
Ok(())
}
#[tokio::test]
async fn fetch_blobs_spans_fragments() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let table = create_inline_blob_table(&db, "t", &[1], &[Some(b"frag-one".as_slice())]).await?;
table
.add(binary_input_batch(&[2], &[Some(b"frag-two".as_slice())]))
.execute()
.await?;
let pairs = collect_id_rowid(&table).await?;
let ids: Vec<u64> = pairs.iter().map(|(_, rowid)| *rowid).collect();
let bytes = table.fetch_blobs("image", &ids).await?;
for (i, (id, _)) in pairs.iter().enumerate() {
match id {
1 => assert_eq!(bytes.value(i), b"frag-one"),
2 => assert_eq!(bytes.value(i), b"frag-two"),
_ => unreachable!(),
}
}
Ok(())
}
#[tokio::test]
async fn fetch_blobs_packed_payload_round_trip() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let big = vec![0xAB_u8; 100 * 1024];
let small = b"small".to_vec();
let table = create_inline_blob_table(
&db,
"t",
&[1, 2],
&[Some(big.as_slice()), Some(small.as_slice())],
)
.await?;
let pairs = collect_id_rowid(&table).await?;
let ids: Vec<u64> = pairs.iter().map(|(_, rowid)| *rowid).collect();
let bytes = table.fetch_blobs("image", &ids).await?;
for (i, (id, _)) in pairs.iter().enumerate() {
match id {
1 => assert_eq!(bytes.value(i), big.as_slice()),
2 => assert_eq!(bytes.value(i), small.as_slice()),
_ => unreachable!(),
}
}
Ok(())
}
#[tokio::test]
async fn fetch_blobs_after_delete() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let table = create_inline_blob_table(
&db,
"t",
&[1, 2, 3],
&[Some(b"one".as_slice()), Some(b"two"), Some(b"three")],
)
.await?;
table.delete("id = 2").await?;
let pairs = collect_id_rowid(&table).await?;
assert_eq!(pairs.len(), 2);
let ids: Vec<u64> = pairs.iter().map(|(_, rowid)| *rowid).collect();
let bytes = table.fetch_blobs("image", &ids).await?;
for (i, (id, _)) in pairs.iter().enumerate() {
match id {
1 => assert_eq!(bytes.value(i), b"one"),
3 => assert_eq!(bytes.value(i), b"three"),
_ => unreachable!(),
}
}
Ok(())
}
#[tokio::test]
async fn fetch_blobs_with_precompaction_row_ids_survives_compaction() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let table = create_inline_blob_table(&db, "t", &[1], &[Some(b"frag-one".as_slice())]).await?;
table
.add(binary_input_batch(&[2], &[Some(b"frag-two".as_slice())]))
.execute()
.await?;
let pairs_before = collect_id_rowid(&table).await?;
let ids_before: Vec<u64> = pairs_before.iter().map(|(_, rowid)| *rowid).collect();
table
.optimize(OptimizeAction::Compact {
options: CompactionOptions::default(),
remap_options: None,
})
.await?;
let bytes_after = table.fetch_blobs("image", &ids_before).await?;
assert_eq!(bytes_after.len(), 2);
for (i, (id, _)) in pairs_before.iter().enumerate() {
match id {
1 => assert_eq!(bytes_after.value(i), b"frag-one"),
2 => assert_eq!(bytes_after.value(i), b"frag-two"),
_ => unreachable!(),
}
}
Ok(())
}
#[tokio::test]
async fn zero_length_blob_reads_back_as_null() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let table = create_inline_blob_table(&db, "t", &[1], &[Some(b"".as_slice())]).await?;
let ids = collect_row_ids(&table).await?;
let bytes = table.fetch_blobs("image", &ids).await?;
assert_eq!(bytes.len(), 1);
assert!(bytes.is_null(0));
Ok(())
}
const DEDICATED_BLOB_LEN: usize = 64 * 1024;
const SCRAMBLED_LOGICAL_IDS: [i64; 7] = [6, 3, 1, 4, 6, 2, 5];
fn dedicated_blob_bytes(tag: u8) -> Vec<u8> {
vec![tag; DEDICATED_BLOB_LEN]
}
async fn multi_fragment_dedicated_blob_table(db: &Connection) -> Result<Table> {
let rows: [(i64, Option<u8>); 6] = [
(1, Some(1)),
(2, Some(2)),
(3, None),
(4, Some(4)),
(5, None),
(6, Some(6)),
];
let mut table: Option<Table> = None;
for (logical_id, blob_tag) in rows {
let bytes = blob_tag.map(dedicated_blob_bytes);
let image = [bytes.as_deref()];
table = Some(match table {
None => create_inline_blob_table(db, "t", &[logical_id], &image).await?,
Some(t) => {
t.add(binary_input_batch(&[logical_id], &image))
.execute()
.await?;
t
}
});
}
Ok(table.unwrap())
}
async fn row_ids_for_logical(table: &Table, logical_ids: &[i64]) -> Result<Vec<u64>> {
let id_rowid = collect_id_rowid(table).await?;
Ok(logical_ids
.iter()
.map(|logical_id| {
id_rowid
.iter()
.find(|(id, _)| id == logical_id)
.map(|(_, row_id)| *row_id)
.unwrap()
})
.collect())
}
#[tokio::test]
async fn fetch_blobs_aligns_across_fragments_with_nulls_and_dups() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let table = multi_fragment_dedicated_blob_table(&db).await?;
let row_ids = row_ids_for_logical(&table, &SCRAMBLED_LOGICAL_IDS).await?;
let bytes = table.fetch_blobs("image", &row_ids).await?;
assert_eq!(bytes.len(), SCRAMBLED_LOGICAL_IDS.len());
for (slot, logical_id) in SCRAMBLED_LOGICAL_IDS.iter().enumerate() {
match logical_id {
3 | 5 => assert!(bytes.is_null(slot)),
id => assert_eq!(
bytes.value(slot),
dedicated_blob_bytes(*id as u8).as_slice()
),
}
}
Ok(())
}
#[tokio::test]
async fn fetch_blob_files_aligns_across_fragments_with_nulls_and_dups() -> Result<()> {
let tmp = tempdir().unwrap();
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
let table = multi_fragment_dedicated_blob_table(&db).await?;
let row_ids = row_ids_for_logical(&table, &SCRAMBLED_LOGICAL_IDS).await?;
let files = table.fetch_blob_files("image", &row_ids).await?;
assert_eq!(files.len(), SCRAMBLED_LOGICAL_IDS.len());
for (slot, logical_id) in SCRAMBLED_LOGICAL_IDS.iter().enumerate() {
match logical_id {
3 | 5 => assert!(files[slot].is_none()),
id => {
let payload = files[slot].as_ref().unwrap().read().await?;
assert_eq!(payload.as_ref(), dedicated_blob_bytes(*id as u8).as_slice());
}
}
}
Ok(())
}