mirror of
https://github.com/lancedb/lancedb.git
synced 2026-06-30 09:30:41 +00:00
Compare commits
2 Commits
jack/sopho
...
yang/appro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ff72022dd | ||
|
|
b2e0aa0588 |
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.31.0-beta.4"
|
||||
current_version = "0.30.1-beta.2"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
@@ -23,8 +23,6 @@ allow_dirty = true
|
||||
commit = true
|
||||
message = "Bump version: {current_version} → {new_version}"
|
||||
commit_args = ""
|
||||
# bump-my-version >=1.4.0 rejects pre_commit_hooks containing shell syntax unless opted in.
|
||||
allow_shell_hooks = true
|
||||
|
||||
# Java maven files
|
||||
pre_commit_hooks = [
|
||||
|
||||
190
Cargo.lock
generated
190
Cargo.lock
generated
@@ -157,9 +157,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "anyhow"
|
||||
version = "1.0.103"
|
||||
version = "1.0.102"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2a4385e2e34eb35d6b3efe798b9eb88096925d87726c0798709bf56d9ed84af3"
|
||||
checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c"
|
||||
|
||||
[[package]]
|
||||
name = "approx"
|
||||
@@ -1297,6 +1297,15 @@ version = "2.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c4512299f36f043ab09a583e57bceb5a5aab7a73db1805848e8fef3c9e8c78b3"
|
||||
|
||||
[[package]]
|
||||
name = "bitpacking"
|
||||
version = "0.9.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "96a7139abd3d9cebf8cd6f920a389cf3dc9576172e32f4563f188cae3c3eb019"
|
||||
dependencies = [
|
||||
"crunchy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bitvec"
|
||||
version = "1.0.1"
|
||||
@@ -1463,9 +1472,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
|
||||
|
||||
[[package]]
|
||||
name = "bytes"
|
||||
version = "1.12.0"
|
||||
version = "1.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ae3f5d315924270530207e2a68396c3cc547f6dca3fbdca317cfb1a51edb593"
|
||||
checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33"
|
||||
|
||||
[[package]]
|
||||
name = "bytes-utils"
|
||||
@@ -1750,7 +1759,7 @@ version = "3.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "faf9468729b8cbcea668e36183cb69d317348c2e08e994829fb56ebfdfbaac34"
|
||||
dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3014,7 +3023,7 @@ dependencies = [
|
||||
"libc",
|
||||
"option-ext",
|
||||
"redox_users",
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3177,9 +3186,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "env_filter"
|
||||
version = "2.0.0"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "900d271a03799a1ee8d1ca9b19893b48ca674a9284fefcfb85f05e74ed314217"
|
||||
checksum = "32e90c2accc4b07a8456ea0debdc2e7587bdd890680d71173a15d4ae604f6eef"
|
||||
dependencies = [
|
||||
"log",
|
||||
"regex",
|
||||
@@ -3187,9 +3196,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.11.11"
|
||||
version = "0.11.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "de671bd27a75a797dc9ae289ba1e77276e75e2026408aab65185384e2d5cd3f6"
|
||||
checksum = "0621c04f2196ac3f488dd583365b9c09be011a4ab8b9f37248ffcc8f6198b56a"
|
||||
dependencies = [
|
||||
"anstream",
|
||||
"anstyle",
|
||||
@@ -3231,7 +3240,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3423,8 +3432,8 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
|
||||
|
||||
[[package]]
|
||||
name = "fsst"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
version = "8.0.0-beta.16"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"rand 0.9.4",
|
||||
@@ -4466,7 +4475,7 @@ checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46"
|
||||
dependencies = [
|
||||
"hermit-abi",
|
||||
"libc",
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4560,7 +4569,7 @@ dependencies = [
|
||||
"portable-atomic-util",
|
||||
"serde_core",
|
||||
"wasm-bindgen",
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4726,8 +4735,8 @@ checksum = "e037a2e1d8d5fdbd49b16a4ea09d5d6401c1f29eca5ff29d03d3824dba16256a"
|
||||
|
||||
[[package]]
|
||||
name = "lance"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
version = "8.0.0-beta.16"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
|
||||
dependencies = [
|
||||
"arc-swap",
|
||||
"arrow",
|
||||
@@ -4745,6 +4754,7 @@ dependencies = [
|
||||
"async_cell",
|
||||
"aws-credential-types",
|
||||
"aws-sdk-dynamodb",
|
||||
"bitpacking",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"chrono",
|
||||
@@ -4761,9 +4771,8 @@ dependencies = [
|
||||
"futures",
|
||||
"half",
|
||||
"humantime",
|
||||
"itertools 0.14.0",
|
||||
"itertools 0.13.0",
|
||||
"lance-arrow",
|
||||
"lance-bitpacking",
|
||||
"lance-core",
|
||||
"lance-datafusion",
|
||||
"lance-encoding",
|
||||
@@ -4801,8 +4810,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-arrow"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
version = "8.0.0-beta.16"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4823,7 +4832,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "lance-arrow-scalar"
|
||||
version = "58.0.0"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4837,7 +4846,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "lance-arrow-stats"
|
||||
version = "58.0.0"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-schema",
|
||||
@@ -4846,19 +4855,18 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-bitpacking"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
version = "8.0.0-beta.16"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
|
||||
dependencies = [
|
||||
"arrayref",
|
||||
"crunchy",
|
||||
"paste",
|
||||
"seq-macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lance-core"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
version = "8.0.0-beta.16"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4870,7 +4878,7 @@ dependencies = [
|
||||
"datafusion-common",
|
||||
"datafusion-sql",
|
||||
"futures",
|
||||
"itertools 0.14.0",
|
||||
"itertools 0.13.0",
|
||||
"lance-arrow",
|
||||
"lance-derive",
|
||||
"libc",
|
||||
@@ -4896,8 +4904,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-datafusion"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
version = "8.0.0-beta.16"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -4927,8 +4935,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-datagen"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
version = "8.0.0-beta.16"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -4945,8 +4953,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-derive"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
version = "8.0.0-beta.16"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -4955,8 +4963,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-encoding"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
version = "8.0.0-beta.16"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
|
||||
dependencies = [
|
||||
"arrow-arith",
|
||||
"arrow-array",
|
||||
@@ -4972,7 +4980,7 @@ dependencies = [
|
||||
"futures",
|
||||
"hex",
|
||||
"hyperloglogplus",
|
||||
"itertools 0.14.0",
|
||||
"itertools 0.13.0",
|
||||
"lance-arrow",
|
||||
"lance-bitpacking",
|
||||
"lance-core",
|
||||
@@ -4991,8 +4999,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-file"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
version = "8.0.0-beta.16"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
|
||||
dependencies = [
|
||||
"arrow-arith",
|
||||
"arrow-array",
|
||||
@@ -5022,8 +5030,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-index"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
version = "8.0.0-beta.16"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
|
||||
dependencies = [
|
||||
"arc-swap",
|
||||
"arrow",
|
||||
@@ -5035,6 +5043,7 @@ dependencies = [
|
||||
"async-channel",
|
||||
"async-recursion",
|
||||
"async-trait",
|
||||
"bitpacking",
|
||||
"bitvec",
|
||||
"bytes",
|
||||
"chrono",
|
||||
@@ -5047,12 +5056,11 @@ dependencies = [
|
||||
"fst",
|
||||
"futures",
|
||||
"half",
|
||||
"itertools 0.14.0",
|
||||
"itertools 0.13.0",
|
||||
"jieba-rs",
|
||||
"jsonb",
|
||||
"lance-arrow",
|
||||
"lance-arrow-stats",
|
||||
"lance-bitpacking",
|
||||
"lance-core",
|
||||
"lance-datafusion",
|
||||
"lance-datagen",
|
||||
@@ -5088,8 +5096,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-io"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
version = "8.0.0-beta.16"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-arith",
|
||||
@@ -5130,8 +5138,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-linalg"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
version = "8.0.0-beta.16"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -5142,13 +5150,12 @@ dependencies = [
|
||||
"lance-core",
|
||||
"num-traits",
|
||||
"rand 0.9.4",
|
||||
"rayon",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lance-namespace"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
version = "8.0.0-beta.16"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"async-trait",
|
||||
@@ -5160,8 +5167,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-namespace-impls"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
version = "8.0.0-beta.16"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-ipc",
|
||||
@@ -5215,15 +5222,15 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-select"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
version = "8.0.0-beta.16"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
"arrow-schema",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"itertools 0.14.0",
|
||||
"itertools 0.13.0",
|
||||
"lance-core",
|
||||
"roaring",
|
||||
"tracing",
|
||||
@@ -5231,8 +5238,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-table"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
version = "8.0.0-beta.16"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -5271,8 +5278,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-testing"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
version = "8.0.0-beta.16"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-schema",
|
||||
@@ -5285,21 +5292,20 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-tokenizer"
|
||||
version = "9.0.0-beta.10"
|
||||
source = "git+https://github.com/jackye1995/lance.git?branch=jack%2Fsophon-pr-6325#1c5b5061c60934b4c18dbe86c5e91b4961105989"
|
||||
version = "8.0.0-beta.16"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.16#6e734df607f2841fe3bba82f05a90f3174933bab"
|
||||
dependencies = [
|
||||
"icu_segmenter",
|
||||
"jieba-rs",
|
||||
"lindera",
|
||||
"rust-stemmers",
|
||||
"serde",
|
||||
"stop-words",
|
||||
"unicode-normalization",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lancedb"
|
||||
version = "0.31.0-beta.4"
|
||||
version = "0.30.1-beta.2"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"anyhow",
|
||||
@@ -5376,14 +5382,13 @@ dependencies = [
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"url",
|
||||
"urlencoding",
|
||||
"uuid",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lancedb-nodejs"
|
||||
version = "0.31.0-beta.4"
|
||||
version = "0.30.1-beta.2"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -5408,7 +5413,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb-python"
|
||||
version = "0.34.0-beta.4"
|
||||
version = "0.33.1-beta.2"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"async-trait",
|
||||
@@ -5641,9 +5646,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "log"
|
||||
version = "0.4.33"
|
||||
version = "0.4.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0ceec5bc11778974d1bcb055b18002eba7f4b3518b6a0081b3af5f21666da9ad"
|
||||
checksum = "953f07c43838f8e6f9758cab68bf5bed85465e7587ebe0b823f1bcd81978ad3a"
|
||||
|
||||
[[package]]
|
||||
name = "loom"
|
||||
@@ -5951,9 +5956,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "napi"
|
||||
version = "3.9.4"
|
||||
version = "3.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b41bda2ac390efb5e8d22025d925ccc3f3807d8c1bea6d19b36127247c4b8f83"
|
||||
checksum = "ad513ff22558f1830b595ea6eb4091da48145d09a222ce157e781896f78be0b9"
|
||||
dependencies = [
|
||||
"bitflags 2.11.1",
|
||||
"chrono",
|
||||
@@ -5976,9 +5981,9 @@ checksum = "c9c366d2c8c60b86fa632df75f745509b52f9128f91a6bad4c796e44abb505e1"
|
||||
|
||||
[[package]]
|
||||
name = "napi-derive"
|
||||
version = "3.5.7"
|
||||
version = "3.5.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "61d66f70256ad5aef58659966064471d0ad90e2897bc36a5a5e0389c85aabc1e"
|
||||
checksum = "89b3f766e04667e6da0e181e2da4f85475d5a6513b7cf6a80bea184e224a5b42"
|
||||
dependencies = [
|
||||
"convert_case",
|
||||
"ctor 1.0.5",
|
||||
@@ -5990,9 +5995,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "napi-derive-backend"
|
||||
version = "5.0.5"
|
||||
version = "5.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "81b4b08f15eed7a2a20c3f4c6314013fc3ac890a3afa9892b594485299ebdb2d"
|
||||
checksum = "0d5af30503edf933ce7377cf6d4c877a62b0f1107ea05585f1b5e430e88d5baf"
|
||||
dependencies = [
|
||||
"convert_case",
|
||||
"proc-macro2",
|
||||
@@ -6085,7 +6090,7 @@ version = "0.50.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
|
||||
dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -7400,8 +7405,8 @@ version = "0.14.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7"
|
||||
dependencies = [
|
||||
"heck 0.4.1",
|
||||
"itertools 0.11.0",
|
||||
"heck 0.5.0",
|
||||
"itertools 0.14.0",
|
||||
"log",
|
||||
"multimap",
|
||||
"petgraph",
|
||||
@@ -7420,7 +7425,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"itertools 0.11.0",
|
||||
"itertools 0.14.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
@@ -7654,7 +7659,7 @@ dependencies = [
|
||||
"once_cell",
|
||||
"socket2 0.6.3",
|
||||
"tracing",
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.60.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -8394,7 +8399,7 @@ dependencies = [
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys",
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -8465,7 +8470,7 @@ dependencies = [
|
||||
"security-framework",
|
||||
"security-framework-sys",
|
||||
"webpki-root-certs",
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -9027,7 +9032,7 @@ version = "0.8.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c1c97747dbf44bb1ca44a561ece23508e99cb592e862f22222dcf42f51d1e451"
|
||||
dependencies = [
|
||||
"heck 0.4.1",
|
||||
"heck 0.5.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
@@ -9039,7 +9044,7 @@ version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "54254b8531cafa275c5e096f62d48c81435d1015405a91198ddb11e967301d40"
|
||||
dependencies = [
|
||||
"heck 0.4.1",
|
||||
"heck 0.5.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
@@ -9200,15 +9205,6 @@ version = "0.2.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e51f1e89f093f99e7432c491c382b88a6860a5adbe6bf02574bf0a08efff1978"
|
||||
|
||||
[[package]]
|
||||
name = "stop-words"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d68df56303396bcfb639455b3c166804aeb7994005010aab5e9e8a1277b8871d"
|
||||
dependencies = [
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "str_stack"
|
||||
version = "0.1.1"
|
||||
@@ -9472,7 +9468,7 @@ dependencies = [
|
||||
"getrandom 0.4.2",
|
||||
"once_cell",
|
||||
"rustix",
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -10120,9 +10116,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
|
||||
|
||||
[[package]]
|
||||
name = "uuid"
|
||||
version = "1.23.4"
|
||||
version = "1.23.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bf80a72845275afea99e7f2b434723d3bc7e38470fcd1c7ed39a599c73319a53"
|
||||
checksum = "144d6b123cef80b301b8f72a9e2ca4370ddec21950d0a103dd22c437006d2db7"
|
||||
dependencies = [
|
||||
"getrandom 0.4.2",
|
||||
"js-sys",
|
||||
@@ -10407,7 +10403,7 @@ version = "0.1.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
|
||||
dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
28
Cargo.toml
28
Cargo.toml
@@ -13,20 +13,20 @@ categories = ["database-implementations"]
|
||||
rust-version = "1.91.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=9.0.0-beta.10", default-features = false, "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
|
||||
lance-core = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
|
||||
lance-datagen = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
|
||||
lance-file = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
|
||||
lance-io = { "version" = "=9.0.0-beta.10", default-features = false, "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
|
||||
lance-index = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
|
||||
lance-linalg = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
|
||||
lance-namespace = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
|
||||
lance-namespace-impls = { "version" = "=9.0.0-beta.10", default-features = false, "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
|
||||
lance-table = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
|
||||
lance-testing = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
|
||||
lance-datafusion = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
|
||||
lance-encoding = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
|
||||
lance-arrow = { "version" = "=9.0.0-beta.10", "branch" = "jack/sophon-pr-6325", "git" = "https://github.com/jackye1995/lance.git" }
|
||||
lance = { "version" = "=8.0.0-beta.16", default-features = false, "tag" = "v8.0.0-beta.16", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-core = { "version" = "=8.0.0-beta.16", "tag" = "v8.0.0-beta.16", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datagen = { "version" = "=8.0.0-beta.16", "tag" = "v8.0.0-beta.16", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-file = { "version" = "=8.0.0-beta.16", "tag" = "v8.0.0-beta.16", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-io = { "version" = "=8.0.0-beta.16", default-features = false, "tag" = "v8.0.0-beta.16", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-index = { "version" = "=8.0.0-beta.16", "tag" = "v8.0.0-beta.16", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-linalg = { "version" = "=8.0.0-beta.16", "tag" = "v8.0.0-beta.16", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace = { "version" = "=8.0.0-beta.16", "tag" = "v8.0.0-beta.16", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace-impls = { "version" = "=8.0.0-beta.16", default-features = false, "tag" = "v8.0.0-beta.16", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-table = { "version" = "=8.0.0-beta.16", "tag" = "v8.0.0-beta.16", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-testing = { "version" = "=8.0.0-beta.16", "tag" = "v8.0.0-beta.16", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datafusion = { "version" = "=8.0.0-beta.16", "tag" = "v8.0.0-beta.16", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-encoding = { "version" = "=8.0.0-beta.16", "tag" = "v8.0.0-beta.16", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-arrow = { "version" = "=8.0.0-beta.16", "tag" = "v8.0.0-beta.16", "git" = "https://github.com/lance-format/lance.git" }
|
||||
ahash = "0.8"
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "58.0.0", optional = false }
|
||||
|
||||
@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
|
||||
<dependency>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-core</artifactId>
|
||||
<version>0.31.0-beta.4</version>
|
||||
<version>0.30.1-beta.2</version>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / OAuthFlowType
|
||||
|
||||
# Enumeration: OAuthFlowType
|
||||
|
||||
OAuth authentication flow types.
|
||||
|
||||
## Enumeration Members
|
||||
|
||||
### AzureManagedIdentity
|
||||
|
||||
```ts
|
||||
AzureManagedIdentity: "azure_managed_identity";
|
||||
```
|
||||
|
||||
Azure Managed Identity via IMDS.
|
||||
|
||||
***
|
||||
|
||||
### ClientCredentials
|
||||
|
||||
```ts
|
||||
ClientCredentials: "client_credentials";
|
||||
```
|
||||
|
||||
Client Credentials grant (service-to-service / M2M).
|
||||
@@ -12,7 +12,6 @@
|
||||
## Enumerations
|
||||
|
||||
- [FullTextQueryType](enumerations/FullTextQueryType.md)
|
||||
- [OAuthFlowType](enumerations/OAuthFlowType.md)
|
||||
- [Occur](enumerations/Occur.md)
|
||||
- [Operator](enumerations/Operator.md)
|
||||
|
||||
@@ -86,8 +85,6 @@
|
||||
- [ListNamespacesResponse](interfaces/ListNamespacesResponse.md)
|
||||
- [LsmWriteSpec](interfaces/LsmWriteSpec.md)
|
||||
- [MergeResult](interfaces/MergeResult.md)
|
||||
- [NativeOAuthConfig](interfaces/NativeOAuthConfig.md)
|
||||
- [OAuthConfig](interfaces/OAuthConfig.md)
|
||||
- [OpenTableOptions](interfaces/OpenTableOptions.md)
|
||||
- [OptimizeOptions](interfaces/OptimizeOptions.md)
|
||||
- [OptimizeStats](interfaces/OptimizeStats.md)
|
||||
|
||||
@@ -64,19 +64,6 @@ client used by manifest-enabled native connections.
|
||||
|
||||
***
|
||||
|
||||
### oauthConfig?
|
||||
|
||||
```ts
|
||||
optional oauthConfig: NativeOAuthConfig;
|
||||
```
|
||||
|
||||
(For LanceDB cloud only): OAuth configuration for IdP-based
|
||||
authentication (e.g., Azure Entra ID). When set, token acquisition
|
||||
and refresh are handled entirely in Rust. TypeScript users should pass
|
||||
the public `OAuthConfig` type exported from `@lancedb/lancedb`.
|
||||
|
||||
***
|
||||
|
||||
### readConsistencyInterval?
|
||||
|
||||
```ts
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / NativeOAuthConfig
|
||||
|
||||
# Interface: NativeOAuthConfig
|
||||
|
||||
OAuth configuration for LanceDB authentication.
|
||||
|
||||
This is the generated napi-rs binding shape. TypeScript users should prefer
|
||||
the public `OAuthConfig` type exported from `@lancedb/lancedb`.
|
||||
|
||||
All token acquisition and refresh is handled in the Rust layer.
|
||||
|
||||
## Properties
|
||||
|
||||
### clientId
|
||||
|
||||
```ts
|
||||
clientId: string;
|
||||
```
|
||||
|
||||
Application / Client ID.
|
||||
|
||||
***
|
||||
|
||||
### clientSecret?
|
||||
|
||||
```ts
|
||||
optional clientSecret: string;
|
||||
```
|
||||
|
||||
Client secret (required for client_credentials).
|
||||
|
||||
***
|
||||
|
||||
### flow?
|
||||
|
||||
```ts
|
||||
optional flow: string;
|
||||
```
|
||||
|
||||
Authentication flow: "client_credentials" or "azure_managed_identity"
|
||||
|
||||
***
|
||||
|
||||
### issuerUrl
|
||||
|
||||
```ts
|
||||
issuerUrl: string;
|
||||
```
|
||||
|
||||
OIDC issuer URL or OAuth authority URL.
|
||||
For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
|
||||
|
||||
***
|
||||
|
||||
### managedIdentityClientId?
|
||||
|
||||
```ts
|
||||
optional managedIdentityClientId: string;
|
||||
```
|
||||
|
||||
Client ID for user-assigned managed identity (azure_managed_identity).
|
||||
|
||||
***
|
||||
|
||||
### refreshBufferSecs?
|
||||
|
||||
```ts
|
||||
optional refreshBufferSecs: number;
|
||||
```
|
||||
|
||||
Seconds before expiry to trigger proactive refresh (default: 300).
|
||||
Keep this well below the token TTL; if it is greater than or equal to
|
||||
the TTL, each request refreshes the token.
|
||||
|
||||
***
|
||||
|
||||
### scopes
|
||||
|
||||
```ts
|
||||
scopes: string[];
|
||||
```
|
||||
|
||||
OAuth scopes to request. For Azure managed identity, exactly one scope
|
||||
or resource is required. For example: `["api://{app_id}/.default"]`
|
||||
@@ -1,111 +0,0 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / OAuthConfig
|
||||
|
||||
# Interface: OAuthConfig
|
||||
|
||||
OAuth configuration for LanceDB authentication.
|
||||
|
||||
This is the public TypeScript OAuth configuration type. The generated
|
||||
`NativeOAuthConfig` type has the same runtime shape but is an implementation
|
||||
detail of the napi-rs binding.
|
||||
|
||||
All token acquisition and refresh is handled in the Rust layer.
|
||||
This config is passed through to Rust via napi-rs.
|
||||
|
||||
## Examples
|
||||
|
||||
```typescript
|
||||
const config: OAuthConfig = {
|
||||
issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
|
||||
clientId: "app-id",
|
||||
clientSecret: "secret",
|
||||
scopes: ["api://lancedb-api/.default"],
|
||||
};
|
||||
```
|
||||
|
||||
```typescript
|
||||
const config: OAuthConfig = {
|
||||
issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
|
||||
clientId: "app-id",
|
||||
scopes: ["api://lancedb-api/.default"],
|
||||
flow: OAuthFlowType.AzureManagedIdentity,
|
||||
};
|
||||
```
|
||||
|
||||
## Properties
|
||||
|
||||
### clientId
|
||||
|
||||
```ts
|
||||
clientId: string;
|
||||
```
|
||||
|
||||
Application / Client ID.
|
||||
|
||||
***
|
||||
|
||||
### clientSecret?
|
||||
|
||||
```ts
|
||||
optional clientSecret: string;
|
||||
```
|
||||
|
||||
Client secret (required for ClientCredentials).
|
||||
|
||||
***
|
||||
|
||||
### flow?
|
||||
|
||||
```ts
|
||||
optional flow: OAuthFlowType;
|
||||
```
|
||||
|
||||
Authentication flow (default: ClientCredentials).
|
||||
|
||||
***
|
||||
|
||||
### issuerUrl
|
||||
|
||||
```ts
|
||||
issuerUrl: string;
|
||||
```
|
||||
|
||||
OIDC issuer URL or OAuth authority URL.
|
||||
For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
|
||||
|
||||
***
|
||||
|
||||
### managedIdentityClientId?
|
||||
|
||||
```ts
|
||||
optional managedIdentityClientId: string;
|
||||
```
|
||||
|
||||
Client ID for user-assigned managed identity (AzureManagedIdentity).
|
||||
|
||||
***
|
||||
|
||||
### refreshBufferSecs?
|
||||
|
||||
```ts
|
||||
optional refreshBufferSecs: number;
|
||||
```
|
||||
|
||||
Seconds before expiry to trigger proactive refresh (default: 300).
|
||||
Keep this well below the token TTL; if it is greater than or equal to
|
||||
the TTL, each request refreshes the token.
|
||||
|
||||
***
|
||||
|
||||
### scopes
|
||||
|
||||
```ts
|
||||
scopes: string[];
|
||||
```
|
||||
|
||||
OAuth scopes to request.
|
||||
For Azure managed identity, exactly one scope or resource is required.
|
||||
For example: `["api://{app_id}/.default"]`
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.31.0-beta.4</version>
|
||||
<version>0.30.1-beta.2</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.31.0-beta.4</version>
|
||||
<version>0.30.1-beta.2</version>
|
||||
<packaging>pom</packaging>
|
||||
<name>${project.artifactId}</name>
|
||||
<description>LanceDB Java SDK Parent POM</description>
|
||||
@@ -28,7 +28,7 @@
|
||||
<properties>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<arrow.version>15.0.0</arrow.version>
|
||||
<lance-core.version>9.0.0-beta.10</lance-core.version>
|
||||
<lance-core.version>8.0.0-beta.16</lance-core.version>
|
||||
<spotless.skip>false</spotless.skip>
|
||||
<spotless.version>2.30.0</spotless.version>
|
||||
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "lancedb-nodejs"
|
||||
edition.workspace = true
|
||||
version = "0.31.0-beta.4"
|
||||
version = "0.30.1-beta.2"
|
||||
publish = false
|
||||
license.workspace = true
|
||||
description.workspace = true
|
||||
|
||||
@@ -52,7 +52,6 @@ export {
|
||||
SplitHashOptions,
|
||||
SplitSequentialOptions,
|
||||
ShuffleOptions,
|
||||
OAuthConfig as NativeOAuthConfig,
|
||||
} from "./native.js";
|
||||
|
||||
export {
|
||||
@@ -131,8 +130,6 @@ export {
|
||||
TokenResponse,
|
||||
} from "./header";
|
||||
|
||||
export { OAuthConfig, OAuthFlowType } from "./oauth";
|
||||
|
||||
export { MergeInsertBuilder, WriteExecutionOptions } from "./merge";
|
||||
|
||||
export * as embedding from "./embedding";
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
/**
|
||||
* OAuth authentication flow types.
|
||||
*/
|
||||
export enum OAuthFlowType {
|
||||
/** Client Credentials grant (service-to-service / M2M). */
|
||||
ClientCredentials = "client_credentials",
|
||||
/** Azure Managed Identity via IMDS. */
|
||||
AzureManagedIdentity = "azure_managed_identity",
|
||||
}
|
||||
|
||||
/**
|
||||
* OAuth configuration for LanceDB authentication.
|
||||
*
|
||||
* This is the public TypeScript OAuth configuration type. The generated
|
||||
* `NativeOAuthConfig` type has the same runtime shape but is an implementation
|
||||
* detail of the napi-rs binding.
|
||||
*
|
||||
* All token acquisition and refresh is handled in the Rust layer.
|
||||
* This config is passed through to Rust via napi-rs.
|
||||
*
|
||||
* @example Client Credentials (service-to-service):
|
||||
* ```typescript
|
||||
* const config: OAuthConfig = {
|
||||
* issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
|
||||
* clientId: "app-id",
|
||||
* clientSecret: "secret",
|
||||
* scopes: ["api://lancedb-api/.default"],
|
||||
* };
|
||||
* ```
|
||||
*
|
||||
* @example Azure Managed Identity:
|
||||
* ```typescript
|
||||
* const config: OAuthConfig = {
|
||||
* issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
|
||||
* clientId: "app-id",
|
||||
* scopes: ["api://lancedb-api/.default"],
|
||||
* flow: OAuthFlowType.AzureManagedIdentity,
|
||||
* };
|
||||
* ```
|
||||
*/
|
||||
export interface OAuthConfig {
|
||||
/**
|
||||
* OIDC issuer URL or OAuth authority URL.
|
||||
* For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
|
||||
*/
|
||||
issuerUrl: string;
|
||||
|
||||
/** Application / Client ID. */
|
||||
clientId: string;
|
||||
|
||||
/**
|
||||
* OAuth scopes to request.
|
||||
* For Azure managed identity, exactly one scope or resource is required.
|
||||
* For example: `["api://{app_id}/.default"]`
|
||||
*/
|
||||
scopes: string[];
|
||||
|
||||
/** Authentication flow (default: ClientCredentials). */
|
||||
flow?: OAuthFlowType;
|
||||
|
||||
/** Client secret (required for ClientCredentials). */
|
||||
clientSecret?: string;
|
||||
|
||||
/** Client ID for user-assigned managed identity (AzureManagedIdentity). */
|
||||
managedIdentityClientId?: string;
|
||||
|
||||
/**
|
||||
* Seconds before expiry to trigger proactive refresh (default: 300).
|
||||
* Keep this well below the token TTL; if it is greater than or equal to
|
||||
* the TTL, each request refreshes the token.
|
||||
*/
|
||||
refreshBufferSecs?: number;
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-arm64",
|
||||
"version": "0.31.0-beta.4",
|
||||
"version": "0.30.1-beta.2",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.darwin-arm64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||
"version": "0.31.0-beta.4",
|
||||
"version": "0.30.1-beta.2",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||
"version": "0.31.0-beta.4",
|
||||
"version": "0.30.1-beta.2",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||
"version": "0.31.0-beta.4",
|
||||
"version": "0.30.1-beta.2",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||
"version": "0.31.0-beta.4",
|
||||
"version": "0.30.1-beta.2",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||
"version": "0.31.0-beta.4",
|
||||
"version": "0.30.1-beta.2",
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||
"version": "0.31.0-beta.4",
|
||||
"version": "0.30.1-beta.2",
|
||||
"os": ["win32"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.win32-x64-msvc.node",
|
||||
|
||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.31.0-beta.4",
|
||||
"version": "0.30.1-beta.2",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.31.0-beta.4",
|
||||
"version": "0.30.1-beta.2",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
"ann"
|
||||
],
|
||||
"private": false,
|
||||
"version": "0.31.0-beta.4",
|
||||
"version": "0.30.1-beta.2",
|
||||
"main": "dist/index.js",
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
|
||||
@@ -112,12 +112,6 @@ impl Connection {
|
||||
|
||||
builder = builder.client_config(rust_config);
|
||||
|
||||
if let Some(oauth_config) = options.oauth_config {
|
||||
let config: lancedb::remote::oauth::OAuthConfig =
|
||||
oauth_config.try_into().default_error()?;
|
||||
builder = builder.oauth_config(config);
|
||||
}
|
||||
|
||||
if let Some(api_key) = options.api_key {
|
||||
builder = builder.api_key(&api_key);
|
||||
}
|
||||
|
||||
@@ -65,11 +65,6 @@ pub struct ConnectionOptions {
|
||||
/// (For LanceDB cloud only): the host to use for LanceDB cloud. Used
|
||||
/// for testing purposes.
|
||||
pub host_override: Option<String>,
|
||||
/// (For LanceDB cloud only): OAuth configuration for IdP-based
|
||||
/// authentication (e.g., Azure Entra ID). When set, token acquisition
|
||||
/// and refresh are handled entirely in Rust. TypeScript users should pass
|
||||
/// the public `OAuthConfig` type exported from `@lancedb/lancedb`.
|
||||
pub oauth_config: Option<remote::OAuthConfig>,
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use lancedb::error::Error;
|
||||
use napi_derive::*;
|
||||
|
||||
/// Timeout configuration for remote HTTP client.
|
||||
@@ -141,84 +140,6 @@ impl From<TlsConfig> for lancedb::remote::TlsConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// OAuth configuration for LanceDB authentication.
|
||||
///
|
||||
/// This is the generated napi-rs binding shape. TypeScript users should prefer
|
||||
/// the public `OAuthConfig` type exported from `@lancedb/lancedb`.
|
||||
///
|
||||
/// All token acquisition and refresh is handled in the Rust layer.
|
||||
#[napi(object)]
|
||||
#[derive(Clone)]
|
||||
pub struct OAuthConfig {
|
||||
/// OIDC issuer URL or OAuth authority URL.
|
||||
/// For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
|
||||
pub issuer_url: String,
|
||||
/// Application / Client ID.
|
||||
pub client_id: String,
|
||||
/// OAuth scopes to request. For Azure managed identity, exactly one scope
|
||||
/// or resource is required. For example: `["api://{app_id}/.default"]`
|
||||
pub scopes: Vec<String>,
|
||||
/// Authentication flow: "client_credentials" or "azure_managed_identity"
|
||||
pub flow: Option<String>,
|
||||
/// Client secret (required for client_credentials).
|
||||
pub client_secret: Option<String>,
|
||||
/// Client ID for user-assigned managed identity (azure_managed_identity).
|
||||
pub managed_identity_client_id: Option<String>,
|
||||
/// Seconds before expiry to trigger proactive refresh (default: 300).
|
||||
/// Keep this well below the token TTL; if it is greater than or equal to
|
||||
/// the TTL, each request refreshes the token.
|
||||
pub refresh_buffer_secs: Option<u32>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for OAuthConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("OAuthConfig")
|
||||
.field("issuer_url", &self.issuer_url)
|
||||
.field("client_id", &self.client_id)
|
||||
.field("scopes", &self.scopes)
|
||||
.field("flow", &self.flow)
|
||||
.field(
|
||||
"client_secret",
|
||||
&self.client_secret.as_deref().map(|_| "<redacted>"),
|
||||
)
|
||||
.field(
|
||||
"managed_identity_client_id",
|
||||
&self.managed_identity_client_id,
|
||||
)
|
||||
.field("refresh_buffer_secs", &self.refresh_buffer_secs)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<OAuthConfig> for lancedb::remote::oauth::OAuthConfig {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(config: OAuthConfig) -> Result<Self, Self::Error> {
|
||||
use lancedb::remote::oauth::OAuthFlow;
|
||||
|
||||
let flow = match config.flow.as_deref().unwrap_or("client_credentials") {
|
||||
"client_credentials" => OAuthFlow::ClientCredentials,
|
||||
"azure_managed_identity" => OAuthFlow::AzureManagedIdentity {
|
||||
client_id: config.managed_identity_client_id,
|
||||
},
|
||||
other => {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!("Unknown OAuth flow type: {other}"),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
issuer_url: config.issuer_url,
|
||||
client_id: config.client_id,
|
||||
client_secret: config.client_secret,
|
||||
scopes: config.scopes,
|
||||
flow,
|
||||
refresh_buffer_secs: config.refresh_buffer_secs.map(|v| v as u64),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ClientConfig> for lancedb::remote::ClientConfig {
|
||||
fn from(config: ClientConfig) -> Self {
|
||||
Self {
|
||||
@@ -235,45 +156,3 @@ impl From<ClientConfig> for lancedb::remote::ClientConfig {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_unknown_oauth_flow_returns_invalid_input() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "https://issuer.example.com".to_string(),
|
||||
client_id: "client-id".to_string(),
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: Some("typo".to_string()),
|
||||
client_secret: None,
|
||||
managed_identity_client_id: None,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
|
||||
let err = lancedb::remote::oauth::OAuthConfig::try_from(config).unwrap_err();
|
||||
assert!(matches!(
|
||||
err,
|
||||
Error::InvalidInput { message }
|
||||
if message == "Unknown OAuth flow type: typo"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_config_debug_redacts_client_secret() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "https://issuer.example.com".to_string(),
|
||||
client_id: "client-id".to_string(),
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: Some("client_credentials".to_string()),
|
||||
client_secret: Some("super-secret".to_string()),
|
||||
managed_identity_client_id: None,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
|
||||
let debug = format!("{config:?}");
|
||||
assert!(!debug.contains("super-secret"));
|
||||
assert!(debug.contains("client_secret: Some(\"<redacted>\")"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.34.0-beta.4"
|
||||
current_version = "0.33.1-beta.2"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
@@ -23,8 +23,6 @@ allow_dirty = true
|
||||
commit = true
|
||||
message = "Bump version: {current_version} → {new_version}"
|
||||
commit_args = ""
|
||||
# bump-my-version >=1.4.0 rejects pre_commit_hooks containing shell syntax unless opted in.
|
||||
allow_shell_hooks = true
|
||||
|
||||
# Update Cargo.lock after version bump
|
||||
pre_commit_hooks = [
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.34.0-beta.4"
|
||||
version = "0.33.1-beta.2"
|
||||
publish = false
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
|
||||
@@ -17,17 +17,6 @@ from .db import AsyncConnection, DBConnection, LanceDBConnection
|
||||
from .remote import ClientConfig
|
||||
from .remote.db import RemoteDBConnection
|
||||
from .expr import Expr, col, lit, func
|
||||
from .udf import (
|
||||
udf,
|
||||
table_udf,
|
||||
Udf,
|
||||
JobHandle,
|
||||
JobFailedError,
|
||||
MaterializedView,
|
||||
AsyncJobHandle,
|
||||
AsyncMaterializedView,
|
||||
)
|
||||
from .lineage import Lineage, Node, Edge, FunctionRef
|
||||
from .schema import vector
|
||||
from .table import AsyncTable, Table
|
||||
from ._lancedb import Session
|
||||
@@ -100,8 +89,6 @@ def connect(
|
||||
If presented, connect to LanceDB cloud.
|
||||
Otherwise, connect to a database on file system or cloud storage.
|
||||
Can be set via environment variable `LANCEDB_API_KEY`.
|
||||
OAuth configuration is currently supported only by ``connect_async``;
|
||||
synchronous LanceDB Cloud connections require an API key.
|
||||
region: str, default "us-east-1"
|
||||
The region to use for LanceDB Cloud.
|
||||
host_override: str, optional
|
||||
@@ -353,7 +340,6 @@ async def connect_async(
|
||||
session: Optional[Session] = None,
|
||||
manifest_enabled: bool = False,
|
||||
namespace_client_properties: Optional[Dict[str, str]] = None,
|
||||
oauth_config=None,
|
||||
) -> AsyncConnection:
|
||||
"""Connect to a LanceDB database.
|
||||
|
||||
@@ -403,10 +389,6 @@ async def connect_async(
|
||||
namespace_client_properties : dict, optional
|
||||
Additional directory namespace client properties to use with
|
||||
``manifest_enabled=True``.
|
||||
oauth_config : OAuthConfig, optional
|
||||
OAuth configuration for LanceDB Cloud/Enterprise. This is supported by
|
||||
``connect_async`` only; synchronous ``connect`` uses API key
|
||||
authentication for ``db://`` URIs.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -453,24 +435,11 @@ async def connect_async(
|
||||
session,
|
||||
manifest_enabled,
|
||||
namespace_client_properties,
|
||||
oauth_config,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"udf",
|
||||
"table_udf",
|
||||
"Udf",
|
||||
"JobHandle",
|
||||
"JobFailedError",
|
||||
"MaterializedView",
|
||||
"AsyncJobHandle",
|
||||
"AsyncMaterializedView",
|
||||
"Lineage",
|
||||
"Node",
|
||||
"Edge",
|
||||
"FunctionRef",
|
||||
"connect",
|
||||
"connect_async",
|
||||
"connect_namespace",
|
||||
|
||||
@@ -280,7 +280,6 @@ async def connect(
|
||||
session: Optional[Session],
|
||||
manifest_enabled: bool = False,
|
||||
namespace_client_properties: Optional[Dict[str, str]] = None,
|
||||
oauth_config: Optional[Any] = None,
|
||||
) -> Connection: ...
|
||||
|
||||
class RecordBatchStream:
|
||||
|
||||
@@ -65,7 +65,6 @@ if TYPE_CHECKING:
|
||||
from .common import DATA, URI
|
||||
from .embeddings import EmbeddingFunctionConfig
|
||||
from ._lancedb import Session
|
||||
from .udf import MaterializedView, AsyncMaterializedView
|
||||
|
||||
from .namespace_utils import (
|
||||
_normalize_create_namespace_mode,
|
||||
@@ -563,259 +562,6 @@ class DBConnection(EnforceOverrides):
|
||||
"""
|
||||
raise NotImplementedError("serialize is not supported for this connection type")
|
||||
|
||||
# -- Derived compute: functions, materialized views, jobs -------------
|
||||
# Server-backed features (LanceDB Enterprise / Cloud); local
|
||||
# connections raise NotImplementedError for now.
|
||||
|
||||
def create_function(
|
||||
self,
|
||||
name,
|
||||
language: str = "python",
|
||||
return_type: Optional[str] = None,
|
||||
body: Optional[str] = None,
|
||||
options: Optional[Dict[str, str]] = None,
|
||||
*,
|
||||
replace: bool = False,
|
||||
):
|
||||
"""Register a UDF (CREATE FUNCTION).
|
||||
|
||||
Pass a ``@udf`` / ``@table_udf``-decorated function (preferred):
|
||||
|
||||
db.create_function(embed)
|
||||
|
||||
or the explicit fields:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str or Udf
|
||||
A decorated UDF object, or the function name.
|
||||
language: str
|
||||
Implementation language (currently "python").
|
||||
return_type: str
|
||||
SQL return type, e.g. "FLOAT", "FLOAT[1536]",
|
||||
"STRUCT(a FLOAT, b VARCHAR)", "TABLE(chunk VARCHAR, idx INT)".
|
||||
body: str
|
||||
Function body: source text, or base64 cloudpickle bytes when
|
||||
options["body_format"] == "cloudpickle".
|
||||
options: dict, optional
|
||||
input_columns, pip, num_gpus, batch_size, timeout,
|
||||
error_policy, docker_image, body_format, ...
|
||||
replace: bool
|
||||
Drop an existing function of the same name first.
|
||||
"""
|
||||
from .udf import Udf
|
||||
|
||||
if isinstance(name, Udf):
|
||||
req = name.create_request()
|
||||
name, language, return_type, body, options = (
|
||||
req["name"],
|
||||
req["language"],
|
||||
req["return_type"],
|
||||
req["body"],
|
||||
req["options"],
|
||||
)
|
||||
if replace:
|
||||
try:
|
||||
self.drop_function(name)
|
||||
except Exception:
|
||||
pass
|
||||
LOOP.run(self._conn.create_function(name, language, return_type, body, options))
|
||||
|
||||
def list_functions(self):
|
||||
"""List registered functions (SHOW FUNCTIONS)."""
|
||||
return LOOP.run(self._conn.list_functions())
|
||||
|
||||
def drop_function(self, name: str):
|
||||
"""Drop a registered function (DROP FUNCTION)."""
|
||||
LOOP.run(self._conn.drop_function(name))
|
||||
|
||||
def create_materialized_view(
|
||||
self,
|
||||
name: str,
|
||||
source=None,
|
||||
select=None,
|
||||
*,
|
||||
query: Optional[str] = None,
|
||||
where: Optional[str] = None,
|
||||
auto_refresh: bool = False,
|
||||
with_no_data: bool = False,
|
||||
replace: bool = False,
|
||||
partition_by: Optional[str] = None,
|
||||
) -> "MaterializedView":
|
||||
"""Create a materialized view (CREATE MATERIALIZED VIEW); returns a
|
||||
`MaterializedView` handle (``.wait()`` blocks until it is populated).
|
||||
|
||||
Two ways to specify the view body:
|
||||
|
||||
- ergonomic: pass ``source`` (a table name or table) and ``select``
|
||||
items -- column names, expression strings ("embed(body)"),
|
||||
(alias, expression) tuples, or ``@udf`` / ``@table_udf`` objects.
|
||||
The SELECT is assembled and parsed server-side (one parser, shared
|
||||
with SQL).
|
||||
- raw: pass ``query=`` with a full SELECT, e.g.
|
||||
"SELECT id, embed(body) AS vec FROM articles WHERE id > 1".
|
||||
|
||||
`partition_by` partitions the view's (single) table function on a source
|
||||
column. If that column has an IVF vector index the server partitions by
|
||||
its index clusters (image-dedup style); otherwise it groups by distinct
|
||||
value. (Geneva's `partition_by` and `partition_by_indexed_column` unify
|
||||
here -- the engine picks the strategy from the column.)
|
||||
"""
|
||||
from .udf import build_view_query, MaterializedView
|
||||
|
||||
if query is None:
|
||||
if source is None or select is None:
|
||||
raise ValueError(
|
||||
"create_materialized_view needs either query= or both "
|
||||
"source and select"
|
||||
)
|
||||
query = build_view_query(source, select)
|
||||
if where:
|
||||
query += f" WHERE {where}"
|
||||
if replace:
|
||||
self._drop_view_if_exists(name)
|
||||
job_id = LOOP.run(
|
||||
self._conn.create_materialized_view(
|
||||
name,
|
||||
query=query,
|
||||
auto_refresh=auto_refresh,
|
||||
with_no_data=with_no_data,
|
||||
partition_by=partition_by,
|
||||
)
|
||||
)
|
||||
return MaterializedView(self, name, job_id=job_id)
|
||||
|
||||
def _drop_view_if_exists(self, name: str) -> None:
|
||||
# `replace=True` is "drop if present"; only a not-found error is
|
||||
# benign here. Anything else (perms, server fault) must surface rather
|
||||
# than be masked by a later create failure.
|
||||
try:
|
||||
self.drop_materialized_view(name)
|
||||
except Exception as e:
|
||||
msg = str(e).lower()
|
||||
if "not found" not in msg and "does not exist" not in msg:
|
||||
raise
|
||||
|
||||
def job(self, job_id: str):
|
||||
"""A `JobHandle` for reconnecting to an inflight job by id -- e.g. an
|
||||
id you stored, or one returned from the SQL / REST surface. Submit
|
||||
methods (`refresh_column`, `MaterializedView.refresh`) already return a
|
||||
handle directly, so you do not need this to wait on a fresh submission."""
|
||||
from .udf import JobHandle
|
||||
|
||||
return JobHandle(self, job_id)
|
||||
|
||||
def lineage(
|
||||
self,
|
||||
table: str,
|
||||
column: Optional[str] = None,
|
||||
*,
|
||||
direction: Optional[str] = None,
|
||||
depth: Optional[int] = None,
|
||||
):
|
||||
"""Derived-compute lineage of a table/view, or one of its columns:
|
||||
upstream sources, downstream dependents, and the function version +
|
||||
location that produced each derived column (with a drift flag). Returns
|
||||
a `Lineage`. `direction` is "upstream" | "downstream" | "both" (server
|
||||
default both); `depth` limits column-hops (transitive when omitted)."""
|
||||
# `self._conn` is the AsyncConnection; drive its async `lineage`
|
||||
# (which parses the JSON) on the loop, mirroring create_materialized_view.
|
||||
return LOOP.run(
|
||||
self._conn.lineage(table, column, direction=direction, depth=depth)
|
||||
)
|
||||
|
||||
def _refresh_materialized_view(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
full: bool = False,
|
||||
src_version: Optional[int] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> str:
|
||||
"""Internal: submit a materialized-view refresh, return the job id.
|
||||
The public surface is ``MaterializedView.refresh()`` (which returns a
|
||||
`JobHandle`); this stays private so refresh is only reached through the
|
||||
handle.
|
||||
|
||||
``full=True`` forces a full rebuild (recompute and replace every row)
|
||||
instead of the default incremental refresh.
|
||||
"""
|
||||
return LOOP.run(
|
||||
self._conn._refresh_materialized_view(
|
||||
name,
|
||||
full=full,
|
||||
src_version=src_version,
|
||||
num_workers=num_workers,
|
||||
max_workers=max_workers,
|
||||
)
|
||||
)
|
||||
|
||||
def explain_refresh_materialized_view(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
full: bool = False,
|
||||
src_version: Optional[int] = None,
|
||||
):
|
||||
"""Plan a refresh without running it (EXPLAIN REFRESH). Returns a
|
||||
plan with .has_work / .source_version / .last_refreshed_version /
|
||||
.full_refresh / .rebuild / .units_total. `full=True` plans a full
|
||||
rebuild (incremental planning needs stable row IDs on the source)."""
|
||||
return LOOP.run(
|
||||
self._conn.explain_refresh_materialized_view(
|
||||
name, full=full, src_version=src_version
|
||||
)
|
||||
)
|
||||
|
||||
def alter_materialized_view(self, name: str, *, auto_refresh: bool):
|
||||
"""Update a materialized view's options (ALTER MATERIALIZED VIEW)."""
|
||||
LOOP.run(self._conn.alter_materialized_view(name, auto_refresh=auto_refresh))
|
||||
|
||||
def drop_materialized_view(self, name: str):
|
||||
"""Drop a materialized view definition (DROP MATERIALIZED VIEW)."""
|
||||
LOOP.run(self._conn.drop_materialized_view(name))
|
||||
|
||||
def list_materialized_views(self):
|
||||
"""List registered materialized view definitions."""
|
||||
return LOOP.run(self._conn.list_materialized_views())
|
||||
|
||||
def list_jobs(self):
|
||||
"""List inflight server-side jobs across the database's tables."""
|
||||
return LOOP.run(self._conn.list_jobs())
|
||||
|
||||
def get_job(self, job_id: str, table: "str | None" = None):
|
||||
"""Look up one server-side job by id (the wait()/status poll path).
|
||||
|
||||
Passing ``table`` (the job's table) lets the server answer with an O(1)
|
||||
single-node read instead of scanning the database's active jobs.
|
||||
Returns the job's status, or None if it's unknown or no longer active.
|
||||
"""
|
||||
return LOOP.run(self._conn.get_job(job_id, table))
|
||||
|
||||
def cancel_job(self, job_id: str) -> bool:
|
||||
"""Cancel an inflight server-side job by id (CANCEL JOB).
|
||||
|
||||
Returns True if a matching inflight job was found and flagged for
|
||||
cancellation, False if none was inflight (already finished or
|
||||
unknown id) -- cancellation is best-effort.
|
||||
"""
|
||||
return LOOP.run(self._conn.cancel_job(job_id))
|
||||
|
||||
def job_history(self, job_id: "str | None" = None):
|
||||
"""Durable history of completed server-side jobs (SHOW JOB HISTORY).
|
||||
|
||||
Pass ``job_id`` to narrow to a single job. Unlike :meth:`list_jobs`
|
||||
(live, inflight) these are the terminal records.
|
||||
"""
|
||||
return LOOP.run(self._conn.job_history(job_id))
|
||||
|
||||
def errors(self, job_id: "str | None" = None, table: "str | None" = None):
|
||||
"""Per-row UDF errors recorded by ``error_policy=skip`` (SHOW ERRORS),
|
||||
optionally filtered by ``job_id`` and/or ``table``.
|
||||
"""
|
||||
return LOOP.run(self._conn.errors(job_id, table))
|
||||
|
||||
|
||||
class LanceDBConnection(DBConnection):
|
||||
"""
|
||||
@@ -2041,200 +1787,6 @@ class AsyncConnection(object):
|
||||
)
|
||||
return AsyncTable(table)
|
||||
|
||||
# -- Derived compute: functions, materialized views, jobs -------------
|
||||
# Server-backed features (LanceDB Enterprise / Cloud); local
|
||||
# connections raise NotImplementedError for now.
|
||||
|
||||
async def create_function(
|
||||
self,
|
||||
name,
|
||||
language: str = "python",
|
||||
return_type: Optional[str] = None,
|
||||
body: Optional[str] = None,
|
||||
options: Optional[Dict[str, str]] = None,
|
||||
*,
|
||||
replace: bool = False,
|
||||
):
|
||||
"""Register a UDF (CREATE FUNCTION). Accepts a ``@udf``/``@table_udf``
|
||||
object (preferred) or the explicit (name, language, return_type, body,
|
||||
options)."""
|
||||
from .udf import Udf
|
||||
|
||||
if isinstance(name, Udf):
|
||||
req = name.create_request()
|
||||
name, language, return_type, body, options = (
|
||||
req["name"],
|
||||
req["language"],
|
||||
req["return_type"],
|
||||
req["body"],
|
||||
req["options"],
|
||||
)
|
||||
if replace:
|
||||
try:
|
||||
await self.drop_function(name)
|
||||
except Exception:
|
||||
pass
|
||||
await self._inner.create_function(name, language, return_type, body, options)
|
||||
|
||||
async def list_functions(self):
|
||||
"""List registered functions (SHOW FUNCTIONS)."""
|
||||
return await self._inner.list_functions()
|
||||
|
||||
async def drop_function(self, name: str):
|
||||
"""Drop a registered function (DROP FUNCTION)."""
|
||||
await self._inner.drop_function(name)
|
||||
|
||||
async def create_materialized_view(
|
||||
self,
|
||||
name: str,
|
||||
source=None,
|
||||
select=None,
|
||||
*,
|
||||
query: Optional[str] = None,
|
||||
where: Optional[str] = None,
|
||||
auto_refresh: bool = False,
|
||||
with_no_data: bool = False,
|
||||
replace: bool = False,
|
||||
partition_by: Optional[str] = None,
|
||||
) -> "AsyncMaterializedView":
|
||||
"""Create a materialized view; returns an `AsyncMaterializedView`
|
||||
handle (``.wait()`` blocks until populated). Pass either ``query=`` (a
|
||||
full SELECT) or ``source`` + ``select`` items; `partition_by`
|
||||
partitions the view's table function on a source column (index-cluster
|
||||
if the column is IVF-indexed, else distinct-value). See the sync
|
||||
method for the select grammar."""
|
||||
from .udf import build_view_query, AsyncMaterializedView
|
||||
|
||||
if query is None:
|
||||
if source is None or select is None:
|
||||
raise ValueError(
|
||||
"create_materialized_view needs either query= or both "
|
||||
"source and select"
|
||||
)
|
||||
query = build_view_query(source, select)
|
||||
if where:
|
||||
query += f" WHERE {where}"
|
||||
if replace:
|
||||
try:
|
||||
await self.drop_materialized_view(name)
|
||||
except Exception as e:
|
||||
msg = str(e).lower()
|
||||
if "not found" not in msg and "does not exist" not in msg:
|
||||
raise
|
||||
job_id = await self._inner.create_materialized_view(
|
||||
name,
|
||||
query,
|
||||
auto_refresh=auto_refresh,
|
||||
with_no_data=with_no_data,
|
||||
partition_by=partition_by,
|
||||
)
|
||||
return AsyncMaterializedView(self, name, job_id=job_id)
|
||||
|
||||
def job(self, job_id: str):
|
||||
"""An `AsyncJobHandle` for reconnecting to an inflight job by id (a
|
||||
stored id, or one from the SQL / REST surface). Submit methods already
|
||||
return a handle, so this is only needed to re-attach to an existing
|
||||
job."""
|
||||
from .udf import AsyncJobHandle
|
||||
|
||||
return AsyncJobHandle(self, job_id)
|
||||
|
||||
async def lineage(
|
||||
self,
|
||||
table: str,
|
||||
column: Optional[str] = None,
|
||||
*,
|
||||
direction: Optional[str] = None,
|
||||
depth: Optional[int] = None,
|
||||
):
|
||||
"""Derived-compute lineage of a table/view (or column). See the sync
|
||||
`Connection.lineage`. Returns a `Lineage`."""
|
||||
from .lineage import Lineage
|
||||
|
||||
raw = await self._inner.table_lineage(table, column, direction, depth)
|
||||
return Lineage.from_json(raw)
|
||||
|
||||
async def _refresh_materialized_view(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
full: bool = False,
|
||||
src_version: Optional[int] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> str:
|
||||
"""Internal: submit a refresh, return the job id. The public surface is
|
||||
``AsyncMaterializedView.refresh()`` (returns an `AsyncJobHandle`).
|
||||
|
||||
``full=True`` forces a full rebuild (recompute and replace every row)
|
||||
instead of the default incremental refresh.
|
||||
"""
|
||||
return await self._inner.refresh_materialized_view(
|
||||
name,
|
||||
full=full,
|
||||
src_version=src_version,
|
||||
num_workers=num_workers,
|
||||
max_workers=max_workers,
|
||||
)
|
||||
|
||||
async def explain_refresh_materialized_view(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
full: bool = False,
|
||||
src_version: Optional[int] = None,
|
||||
):
|
||||
"""Plan a refresh without running it (EXPLAIN REFRESH)."""
|
||||
return await self._inner.explain_refresh_materialized_view(
|
||||
name, full=full, src_version=src_version
|
||||
)
|
||||
|
||||
async def alter_materialized_view(self, name: str, *, auto_refresh: bool):
|
||||
"""Update a materialized view's options."""
|
||||
await self._inner.alter_materialized_view(name, auto_refresh)
|
||||
|
||||
async def drop_materialized_view(self, name: str):
|
||||
"""Drop a materialized view definition."""
|
||||
await self._inner.drop_materialized_view(name)
|
||||
|
||||
async def list_materialized_views(self):
|
||||
"""List registered materialized view definitions."""
|
||||
return await self._inner.list_materialized_views()
|
||||
|
||||
async def list_jobs(self):
|
||||
"""List inflight server-side jobs across the database's tables."""
|
||||
return await self._inner.list_jobs()
|
||||
|
||||
async def get_job(self, job_id: str, table: "str | None" = None):
|
||||
"""Look up one server-side job by id (the wait()/status poll path).
|
||||
``table`` (the job's table) enables an O(1) server-side lookup.
|
||||
Returns the job's status, or None if unknown / no longer active."""
|
||||
return await self._inner.get_job(job_id, table)
|
||||
|
||||
async def cancel_job(self, job_id: str) -> bool:
|
||||
"""Cancel an inflight server-side job by id (CANCEL JOB).
|
||||
|
||||
Returns True if a matching inflight job was found and flagged for
|
||||
cancellation, False otherwise (best-effort).
|
||||
"""
|
||||
return await self._inner.cancel_job(job_id)
|
||||
|
||||
async def job_history(self, job_id: "str | None" = None):
|
||||
"""Durable history of completed server-side jobs (SHOW JOB HISTORY).
|
||||
|
||||
Reads each table's durable job-history store. Pass ``job_id`` to narrow
|
||||
to a single job. Unlike :meth:`list_jobs` (live, inflight) these are the
|
||||
terminal records, with created/updated/completed timestamps.
|
||||
"""
|
||||
return await self._inner.job_history(job_id)
|
||||
|
||||
async def errors(self, job_id: "str | None" = None, table: "str | None" = None):
|
||||
"""Per-row UDF errors recorded by ``error_policy=skip`` (SHOW ERRORS).
|
||||
|
||||
Optionally filtered by ``job_id`` and/or ``table``.
|
||||
"""
|
||||
return await self._inner.errors(job_id, table)
|
||||
|
||||
async def rename_table(
|
||||
self,
|
||||
cur_name: str,
|
||||
|
||||
@@ -81,7 +81,6 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
||||
warnings.warn(
|
||||
"use_token_pooling is deprecated, use pooling_strategy=None instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self.pooling_strategy = None
|
||||
|
||||
|
||||
@@ -1,177 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
"""Client-side model of derived-compute lineage.
|
||||
|
||||
`Connection.lineage()` / `Table.lineage()` / `MaterializedView.lineage()` return
|
||||
a `Lineage`: the graph of what a column or materialized view derives from
|
||||
(upstream), what derives from it (downstream), and -- for each derived column --
|
||||
the function that produced it, the version it was produced with, and whether
|
||||
that is stale relative to the function the registry now holds.
|
||||
|
||||
The server returns this as JSON (the wire contract); these classes deserialize
|
||||
it. Nothing here talks to the server.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Union
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionRef:
|
||||
"""The function that produced a derived column, with version + location."""
|
||||
|
||||
name: str
|
||||
#: Version that produced the data (stamped at compute time), if known.
|
||||
as_computed_version: Optional[str] = None
|
||||
#: Version the registry currently holds for this function name.
|
||||
current_version: Optional[str] = None
|
||||
#: True when the column was produced by an older function than the registry
|
||||
#: now holds -- i.e. silently stale; re-refresh to catch up.
|
||||
stale_vs_current: bool = False
|
||||
language: Optional[str] = None
|
||||
docker_image: Optional[str] = None
|
||||
env_digest: Optional[str] = None
|
||||
code_uri: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def _from(cls, d: dict) -> "FunctionRef":
|
||||
return cls(
|
||||
name=d["name"],
|
||||
as_computed_version=d.get("as_computed_version"),
|
||||
current_version=d.get("current_version"),
|
||||
stale_vs_current=d.get("stale_vs_current", False),
|
||||
language=d.get("language"),
|
||||
docker_image=d.get("docker_image"),
|
||||
env_digest=d.get("env_digest"),
|
||||
code_uri=d.get("code_uri"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Node:
|
||||
"""A lineage node: a table, view, column, or function."""
|
||||
|
||||
kind: str # "table" | "view" | "column" | "function"
|
||||
id: str # "table", "table.column", or "fn:name@version"
|
||||
table: Optional[str] = None
|
||||
function: Optional[FunctionRef] = None
|
||||
|
||||
@classmethod
|
||||
def _from(cls, d: dict) -> "Node":
|
||||
fn = d.get("function")
|
||||
return cls(
|
||||
kind=d["kind"],
|
||||
id=d["id"],
|
||||
table=d.get("table"),
|
||||
function=FunctionRef._from(fn) if fn else None,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Edge:
|
||||
"""`downstream` depends on `upstream`, produced by `via` (a function name,
|
||||
or None for a passthrough)."""
|
||||
|
||||
downstream: str
|
||||
upstream: str
|
||||
via: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def _from(cls, d: dict) -> "Edge":
|
||||
return cls(downstream=d["downstream"], upstream=d["upstream"], via=d.get("via"))
|
||||
|
||||
|
||||
@dataclass
|
||||
class Lineage:
|
||||
"""A derived-compute lineage graph (nodes + labeled edges)."""
|
||||
|
||||
target: str
|
||||
nodes: List[Node] = field(default_factory=list)
|
||||
edges: List[Edge] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, raw: Union[str, bytes, dict]) -> "Lineage":
|
||||
d = json.loads(raw) if isinstance(raw, (str, bytes)) else raw
|
||||
return cls(
|
||||
target=d.get("target", ""),
|
||||
nodes=[Node._from(n) for n in d.get("nodes", [])],
|
||||
edges=[Edge._from(e) for e in d.get("edges", [])],
|
||||
)
|
||||
|
||||
def functions(self) -> List[FunctionRef]:
|
||||
"""The function nodes in the graph."""
|
||||
return [n.function for n in self.nodes if n.function is not None]
|
||||
|
||||
def stale(self) -> List[FunctionRef]:
|
||||
"""Functions whose as-computed version is behind the current registry
|
||||
version -- the columns they produced are silently out of date."""
|
||||
return [f for f in self.functions() if f.stale_vs_current]
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
def prune(d: dict) -> dict:
|
||||
return {k: v for k, v in d.items() if v is not None}
|
||||
|
||||
return {
|
||||
"target": self.target,
|
||||
"nodes": [
|
||||
prune(
|
||||
{
|
||||
"kind": n.kind,
|
||||
"id": n.id,
|
||||
"table": n.table,
|
||||
"function": prune(vars(n.function)) if n.function else None,
|
||||
}
|
||||
)
|
||||
for n in self.nodes
|
||||
],
|
||||
"edges": [prune(vars(e)) for e in self.edges],
|
||||
}
|
||||
|
||||
def to_graphviz(self) -> str:
|
||||
"""Graphviz DOT for the lineage DAG: columns/tables as nodes, function
|
||||
names on edges, drift edges dashed + red."""
|
||||
stale_names = {f.name for f in self.stale()}
|
||||
out = [
|
||||
"digraph lineage {",
|
||||
" rankdir=LR;",
|
||||
' node [fontname="monospace"];',
|
||||
]
|
||||
for n in self.nodes:
|
||||
if n.kind == "function":
|
||||
continue
|
||||
shape = "ellipse" if n.kind in ("table", "view") else "box"
|
||||
out.append(f' "{n.id}" [shape={shape}];')
|
||||
for e in self.edges:
|
||||
attrs = ""
|
||||
if e.via:
|
||||
if e.via in stale_names:
|
||||
attrs = f' [label="{e.via}" color=red style=dashed]'
|
||||
else:
|
||||
attrs = f' [label="{e.via}"]'
|
||||
out.append(f' "{e.upstream}" -> "{e.downstream}"{attrs};')
|
||||
out.append("}")
|
||||
return "\n".join(out)
|
||||
|
||||
def _repr_html_(self) -> str:
|
||||
warn = ""
|
||||
drift = self.stale()
|
||||
if drift:
|
||||
names = ", ".join(sorted({f.name for f in drift}))
|
||||
warn = (
|
||||
f'<p style="color:#b00000"><b>stale vs current:</b> {names} '
|
||||
"(re-refresh to catch up)</p>"
|
||||
)
|
||||
rows = "".join(
|
||||
f"<tr><td><code>{e.downstream}</code></td>"
|
||||
f"<td>← {e.via or ''}</td>"
|
||||
f"<td><code>{e.upstream}</code></td></tr>"
|
||||
for e in self.edges
|
||||
)
|
||||
return (
|
||||
f"<b>lineage: <code>{self.target}</code></b>{warn}"
|
||||
"<table><tr><th>derived</th><th>via</th><th>from</th></tr>"
|
||||
f"{rows}</table>"
|
||||
)
|
||||
@@ -71,9 +71,6 @@ from lancedb.embeddings import EmbeddingFunctionConfig
|
||||
from ._lancedb import Session
|
||||
|
||||
|
||||
_MAX_QUERY_K = 2**31 - 1
|
||||
|
||||
|
||||
def _query_to_namespace_request(
|
||||
table_id: List[str],
|
||||
query: "Query",
|
||||
@@ -151,8 +148,7 @@ def _query_to_namespace_request(
|
||||
if query.limit is not None:
|
||||
k = query.limit
|
||||
elif query.vector is None and query.full_text_query is None:
|
||||
# limit k to max i32 value to avoid client overflows
|
||||
k = _MAX_QUERY_K
|
||||
k = sys.maxsize
|
||||
else:
|
||||
k = 10
|
||||
|
||||
@@ -373,19 +369,6 @@ def _convert_pyarrow_schema_to_json(schema: pa.Schema) -> JsonArrowSchema:
|
||||
return JsonArrowSchema(fields=fields, metadata=meta)
|
||||
|
||||
|
||||
def _builds_namespace_natively(
|
||||
namespace_client_impl: Optional[str],
|
||||
namespace_client_properties: Optional[Dict[str, str]],
|
||||
) -> bool:
|
||||
"""Whether ``connect_namespace_client`` builds the namespace client natively
|
||||
in Rust (installing the read-freshness context provider) rather than wrapping
|
||||
the pre-built Python client.
|
||||
|
||||
Must mirror Rust ``build_namespace_natively`` in ``python/src/connection.rs``.
|
||||
"""
|
||||
return namespace_client_impl == "rest" and bool(namespace_client_properties)
|
||||
|
||||
|
||||
class LanceNamespaceDBConnection(DBConnection):
|
||||
"""
|
||||
A LanceDB connection that uses a namespace for table management.
|
||||
@@ -445,13 +428,6 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
)
|
||||
self._namespace_client_impl = namespace_client_impl
|
||||
self._namespace_client_properties = namespace_client_properties
|
||||
# When the namespace client is built natively (see Rust
|
||||
# ``build_namespace_natively``), the underlying Rust table performs
|
||||
# QueryTable pushdown through the read-freshness context provider, which
|
||||
# the pure-Python ``query_table`` path bypasses.
|
||||
self._route_pushdown_to_rust = _builds_namespace_natively(
|
||||
namespace_client_impl, namespace_client_properties
|
||||
)
|
||||
self._inner = AsyncConnection(
|
||||
_connect_namespace_client(
|
||||
namespace_client,
|
||||
@@ -563,7 +539,6 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
namespace_path=namespace_path,
|
||||
namespace_client=self._namespace_client,
|
||||
pushdown_operations=self._namespace_client_pushdown_operations,
|
||||
route_pushdown_to_rust=self._route_pushdown_to_rust,
|
||||
_async=async_table,
|
||||
)
|
||||
|
||||
@@ -601,7 +576,6 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
namespace_path=namespace_path,
|
||||
namespace_client=self._namespace_client,
|
||||
pushdown_operations=self._namespace_client_pushdown_operations,
|
||||
route_pushdown_to_rust=self._route_pushdown_to_rust,
|
||||
_async=async_table,
|
||||
)
|
||||
if branch is not None:
|
||||
@@ -897,8 +871,6 @@ class AsyncLanceNamespaceDBConnection:
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
session: Optional[Session] = None,
|
||||
namespace_client_pushdown_operations: Optional[List[str]] = None,
|
||||
namespace_client_impl: Optional[str] = None,
|
||||
namespace_client_properties: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize an async namespace-based LanceDB connection.
|
||||
@@ -924,12 +896,6 @@ class AsyncLanceNamespaceDBConnection:
|
||||
namespace.create_table() instead of using declare_table + local write.
|
||||
|
||||
Default is None (no pushdown, all operations run locally).
|
||||
namespace_client_impl : Optional[str]
|
||||
The namespace implementation name used to create this connection.
|
||||
Required (with ``namespace_client_properties``) for the Rust client to
|
||||
be built natively and install the read-freshness provider.
|
||||
namespace_client_properties : Optional[Dict[str, str]]
|
||||
The namespace properties used to create this connection.
|
||||
"""
|
||||
self._namespace_client = namespace_client
|
||||
self.read_consistency_interval = read_consistency_interval
|
||||
@@ -938,14 +904,6 @@ class AsyncLanceNamespaceDBConnection:
|
||||
self._namespace_client_pushdown_operations = set(
|
||||
namespace_client_pushdown_operations or []
|
||||
)
|
||||
self._namespace_client_impl = namespace_client_impl
|
||||
self._namespace_client_properties = namespace_client_properties
|
||||
# See LanceNamespaceDBConnection: when built natively the Rust table runs
|
||||
# QueryTable pushdown through the read-freshness provider, so defer to it
|
||||
# rather than the urllib3 client (which omits x-lancedb-min-timestamp).
|
||||
self._route_pushdown_to_rust = _builds_namespace_natively(
|
||||
namespace_client_impl, namespace_client_properties
|
||||
)
|
||||
self._inner = AsyncConnection(
|
||||
_connect_namespace_client(
|
||||
namespace_client,
|
||||
@@ -959,8 +917,8 @@ class AsyncLanceNamespaceDBConnection:
|
||||
namespace_client_pushdown_operations=(
|
||||
list(self._namespace_client_pushdown_operations)
|
||||
),
|
||||
namespace_client_impl=namespace_client_impl,
|
||||
namespace_client_properties=namespace_client_properties,
|
||||
namespace_client_impl=None,
|
||||
namespace_client_properties=None,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1030,7 +988,6 @@ class AsyncLanceNamespaceDBConnection:
|
||||
namespace_path=namespace_path,
|
||||
namespace_client=self._namespace_client,
|
||||
pushdown_operations=self._namespace_client_pushdown_operations,
|
||||
route_pushdown_to_rust=self._route_pushdown_to_rust,
|
||||
)
|
||||
|
||||
async def open_table(
|
||||
@@ -1068,7 +1025,6 @@ class AsyncLanceNamespaceDBConnection:
|
||||
namespace_path=namespace_path,
|
||||
namespace_client=self._namespace_client,
|
||||
pushdown_operations=self._namespace_client_pushdown_operations,
|
||||
route_pushdown_to_rust=self._route_pushdown_to_rust,
|
||||
)
|
||||
|
||||
async def drop_table(self, name: str, namespace_path: Optional[List[str]] = None):
|
||||
@@ -1427,6 +1383,4 @@ def connect_namespace_async(
|
||||
storage_options=storage_options,
|
||||
session=session,
|
||||
namespace_client_pushdown_operations=namespace_client_pushdown_operations,
|
||||
namespace_client_impl=namespace_client_impl,
|
||||
namespace_client_properties=namespace_client_properties,
|
||||
)
|
||||
|
||||
@@ -48,14 +48,6 @@ class PermutationBuilder:
|
||||
By default, the permutation builder will create a single split that contains all
|
||||
rows in the same order as the base table.
|
||||
"""
|
||||
if not hasattr(table, "_inner"):
|
||||
raise TypeError(
|
||||
f"PermutationBuilder requires a local LanceTable, "
|
||||
f"got {type(table).__name__}. "
|
||||
"The permutation API is not supported on remote tables. "
|
||||
"Remote tables connect to LanceDB Cloud or Enterprise and do not have "
|
||||
"direct access to the underlying Lance dataset needed for permutations."
|
||||
)
|
||||
self._async = async_permutation_builder(table)
|
||||
|
||||
def split_random(
|
||||
|
||||
@@ -275,18 +275,7 @@ def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
|
||||
tz = get_extras(field, "tz")
|
||||
return pa.timestamp("us", tz=tz)
|
||||
elif getattr(py_type, "__origin__", None) in (list, tuple):
|
||||
# A bare, unparameterised ``typing.List`` / ``typing.Tuple`` matches this
|
||||
# branch (its ``__origin__`` is ``list`` / ``tuple``) but has no
|
||||
# ``__args__``, so we cannot infer the element type. Raise a clear
|
||||
# ``TypeError`` instead of crashing with an opaque ``AttributeError``.
|
||||
args = getattr(py_type, "__args__", None)
|
||||
if not args:
|
||||
raise TypeError(
|
||||
"Converting Pydantic type to Arrow Type: unsupported type "
|
||||
f"{py_type}. Specify the element type, e.g. List[int] instead "
|
||||
"of a bare List."
|
||||
)
|
||||
child = args[0]
|
||||
child = py_type.__args__[0]
|
||||
return _pydantic_list_child_to_arrow(child, field)
|
||||
raise TypeError(
|
||||
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}."
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import List, Optional
|
||||
from lancedb import __version__
|
||||
|
||||
from .header import HeaderProvider
|
||||
from .oauth import OAuthConfig, OAuthFlowType
|
||||
|
||||
__all__ = [
|
||||
"TimeoutConfig",
|
||||
@@ -17,8 +16,6 @@ __all__ = [
|
||||
"TlsConfig",
|
||||
"ClientConfig",
|
||||
"HeaderProvider",
|
||||
"OAuthConfig",
|
||||
"OAuthFlowType",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -124,7 +124,6 @@ class RemoteDBConnection(DBConnection):
|
||||
"request_thread_pool is no longer used and will be removed in "
|
||||
"a future release.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if connection_timeout is not None:
|
||||
@@ -133,7 +132,6 @@ class RemoteDBConnection(DBConnection):
|
||||
"release. Please use client_config.timeout_config.connect_timeout "
|
||||
"instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
client_config.timeout_config.connect_timeout = timedelta(
|
||||
seconds=connection_timeout
|
||||
@@ -144,7 +142,6 @@ class RemoteDBConnection(DBConnection):
|
||||
"read_timeout is deprecated and will be removed in a future release. "
|
||||
"Please use client_config.timeout_config.read_timeout instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
client_config.timeout_config.read_timeout = timedelta(seconds=read_timeout)
|
||||
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class OAuthFlowType(str, Enum):
|
||||
"""OAuth authentication flow types."""
|
||||
|
||||
CLIENT_CREDENTIALS = "client_credentials"
|
||||
"""Client Credentials grant (service-to-service / M2M)."""
|
||||
|
||||
AZURE_MANAGED_IDENTITY = "azure_managed_identity"
|
||||
"""Azure Managed Identity via IMDS."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthConfig:
|
||||
"""OAuth configuration for LanceDB authentication.
|
||||
|
||||
All token acquisition and refresh is handled in the Rust layer.
|
||||
This config is passed through to Rust via PyO3.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
issuer_url : str
|
||||
OIDC issuer URL or OAuth authority URL.
|
||||
For Azure: ``https://login.microsoftonline.com/{tenant_id}/v2.0``
|
||||
client_id : str
|
||||
Application / Client ID.
|
||||
scopes : List[str]
|
||||
OAuth scopes to request.
|
||||
For Azure managed identity, exactly one scope or resource is required.
|
||||
For example: ``["api://{app_id}/.default"]``
|
||||
flow : OAuthFlowType
|
||||
Authentication flow to use. Default: CLIENT_CREDENTIALS.
|
||||
client_secret : Optional[str]
|
||||
Client secret (required for CLIENT_CREDENTIALS).
|
||||
managed_identity_client_id : Optional[str]
|
||||
Client ID for user-assigned managed identity (AZURE_MANAGED_IDENTITY).
|
||||
refresh_buffer_secs : Optional[int]
|
||||
Seconds before expiry to trigger proactive refresh (default: 300).
|
||||
Keep this well below the token TTL; if it is greater than or equal to
|
||||
the TTL, each request refreshes the token.
|
||||
|
||||
Examples
|
||||
--------
|
||||
Client Credentials (service-to-service):
|
||||
|
||||
>>> config = OAuthConfig(
|
||||
... issuer_url="https://login.microsoftonline.com/{tenant}/v2.0",
|
||||
... client_id="app-id",
|
||||
... client_secret="secret",
|
||||
... scopes=["api://lancedb-api/.default"],
|
||||
... )
|
||||
|
||||
Azure Managed Identity:
|
||||
|
||||
>>> config = OAuthConfig(
|
||||
... issuer_url="https://login.microsoftonline.com/{tenant}/v2.0",
|
||||
... client_id="app-id",
|
||||
... scopes=["api://lancedb-api/.default"],
|
||||
... flow=OAuthFlowType.AZURE_MANAGED_IDENTITY,
|
||||
... )
|
||||
"""
|
||||
|
||||
issuer_url: str
|
||||
client_id: str
|
||||
scopes: List[str]
|
||||
flow: OAuthFlowType = OAuthFlowType.CLIENT_CREDENTIALS
|
||||
client_secret: Optional[str] = field(default=None, repr=False)
|
||||
managed_identity_client_id: Optional[str] = None
|
||||
refresh_buffer_secs: Optional[int] = None
|
||||
@@ -13,14 +13,10 @@ from typing import (
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
Literal,
|
||||
overload,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..udf import JobHandle
|
||||
import warnings
|
||||
|
||||
from lancedb import __version__
|
||||
@@ -849,8 +845,7 @@ class RemoteTable(Table):
|
||||
"""
|
||||
warnings.warn(
|
||||
"cleanup_old_versions() is a no-op on LanceDB Cloud. "
|
||||
"Tables are automatically cleaned up and optimized.",
|
||||
stacklevel=2,
|
||||
"Tables are automatically cleaned up and optimized."
|
||||
)
|
||||
pass
|
||||
|
||||
@@ -862,8 +857,7 @@ class RemoteTable(Table):
|
||||
"""
|
||||
warnings.warn(
|
||||
"compact_files() is a no-op on LanceDB Cloud. "
|
||||
"Tables are automatically compacted and optimized.",
|
||||
stacklevel=2,
|
||||
"Tables are automatically compacted and optimized."
|
||||
)
|
||||
pass
|
||||
|
||||
@@ -880,150 +874,15 @@ class RemoteTable(Table):
|
||||
"""
|
||||
warnings.warn(
|
||||
"optimize() is a no-op on LanceDB Cloud. "
|
||||
"Indices are optimized automatically.",
|
||||
stacklevel=2,
|
||||
"Indices are optimized automatically."
|
||||
)
|
||||
pass
|
||||
|
||||
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
return LOOP.run(self._table.count_rows(filter))
|
||||
|
||||
def add_columns(
|
||||
self,
|
||||
transforms: Optional[Dict[str, str]] = None,
|
||||
*,
|
||||
computed: Optional[Dict[str, tuple]] = None,
|
||||
) -> Optional[AddColumnsResult]:
|
||||
result = None
|
||||
if transforms is not None:
|
||||
result = LOOP.run(self._table.add_columns(transforms))
|
||||
if computed:
|
||||
LOOP.run(self._table.add_columns(computed=computed))
|
||||
return result
|
||||
|
||||
def refresh_column(
|
||||
self,
|
||||
columns,
|
||||
*,
|
||||
where: Optional[str] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
priority: Optional[str] = None,
|
||||
) -> "JobHandle":
|
||||
"""Trigger recompute of computed columns (REFRESH COLUMN).
|
||||
|
||||
The expression is resolved server-side from each column's stored
|
||||
binding; columns bound to the same struct-returning function
|
||||
refresh together. Returns a `JobHandle` to wait on, poll, or cancel
|
||||
(``tbl.refresh_column("c").wait()``). Server-backed feature
|
||||
(LanceDB Enterprise / Cloud).
|
||||
|
||||
num_workers / max_workers / batch_size / priority are per-refresh
|
||||
scheduling knobs (how to run THIS refresh) and override any default
|
||||
the function carries. `priority` is a Kueue tier
|
||||
(training | interactive | backfill).
|
||||
"""
|
||||
from ..udf import JobHandle
|
||||
|
||||
if isinstance(columns, str):
|
||||
columns = [columns]
|
||||
job_id = LOOP.run(
|
||||
self._table.refresh_column(
|
||||
list(columns),
|
||||
where=where,
|
||||
num_workers=num_workers,
|
||||
max_workers=max_workers,
|
||||
batch_size=batch_size,
|
||||
priority=priority,
|
||||
)
|
||||
)
|
||||
return JobHandle(self._job_conn(), job_id)
|
||||
|
||||
def lineage(self, column=None, *, direction=None, depth=None):
|
||||
"""Derived-compute lineage of this table, or one of its columns:
|
||||
upstream sources, downstream dependents, and the function version +
|
||||
location that produced each derived column (with a drift flag). Returns
|
||||
a `Lineage`. See `Connection.lineage`."""
|
||||
return self._job_conn().lineage(
|
||||
self._name, column, direction=direction, depth=depth
|
||||
)
|
||||
|
||||
def _job_conn(self):
|
||||
"""A client connection for polling jobs this table spawns. Built lazily
|
||||
from the table's serialized connection state and cached (not pickled --
|
||||
a forked/unpickled table rebuilds it on next use)."""
|
||||
from lancedb import deserialize_conn
|
||||
|
||||
conn = getattr(self, "_job_conn_cache", None)
|
||||
if conn is None:
|
||||
conn = deserialize_conn(self._serialized_connection_state())
|
||||
self._job_conn_cache = conn
|
||||
return conn
|
||||
|
||||
def load_columns(
|
||||
self,
|
||||
source: Union[str, Iterable[str]],
|
||||
pk: str,
|
||||
columns: Union[Iterable[str], Dict[str, str]],
|
||||
*,
|
||||
source_format: str = "parquet",
|
||||
source_pk: Optional[str] = None,
|
||||
on_missing: str = "carry",
|
||||
source_storage_options: Optional[Dict[str, str]] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
commit_granularity: Optional[int] = None,
|
||||
priority: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Fill existing columns from an external source by primary-key join.
|
||||
|
||||
The distributed-job equivalent of Geneva's ``Table.load_columns()``:
|
||||
imports precomputed values (e.g. embeddings) from Parquet/Lance/IPC into
|
||||
this table, matching on a primary key. Returns the load job id.
|
||||
Server-backed feature (LanceDB Enterprise / Cloud).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source: str | list[str]
|
||||
One source URI or a list of URIs.
|
||||
pk: str
|
||||
Destination primary-key column. Also the source key unless
|
||||
``source_pk`` is given.
|
||||
columns: list[str] | dict[str, str]
|
||||
Value columns to load. A list loads same-named columns; a dict maps
|
||||
``{target: source}``.
|
||||
source_format: str
|
||||
``"parquet"`` (default), ``"lance"``, or ``"ipc"``.
|
||||
source_pk: str, optional
|
||||
Source primary-key column when it differs from ``pk``.
|
||||
on_missing: str
|
||||
Behavior for destination rows with no source match:
|
||||
``"carry"`` (default, keep existing), ``"null"``, or ``"error"``.
|
||||
"""
|
||||
if isinstance(source, str):
|
||||
source = [source]
|
||||
if isinstance(columns, dict):
|
||||
mappings = [(target, src) for target, src in columns.items()]
|
||||
else:
|
||||
mappings = [(c, None) for c in columns]
|
||||
return LOOP.run(
|
||||
self._table.load_columns(
|
||||
list(source),
|
||||
source_format,
|
||||
pk,
|
||||
mappings,
|
||||
source_key=source_pk,
|
||||
source_storage_options=source_storage_options,
|
||||
on_missing=on_missing,
|
||||
num_workers=num_workers,
|
||||
max_workers=max_workers,
|
||||
batch_size=batch_size,
|
||||
commit_granularity=commit_granularity,
|
||||
priority=priority,
|
||||
)
|
||||
)
|
||||
def add_columns(self, transforms: Dict[str, str]) -> AddColumnsResult:
|
||||
return LOOP.run(self._table.add_columns(transforms))
|
||||
|
||||
def alter_columns(
|
||||
self, *alterations: Iterable[Dict[str, str]]
|
||||
|
||||
@@ -86,10 +86,7 @@ def _from_list(data: list) -> Scannable:
|
||||
|
||||
@to_scannable.register(dict)
|
||||
def _from_dict(data: dict) -> Scannable:
|
||||
raise ValueError(
|
||||
"Cannot create or add rows from a single dictionary. "
|
||||
"Use a list of dictionaries instead."
|
||||
)
|
||||
raise ValueError("Cannot add a single dictionary to a table. Use a list.")
|
||||
|
||||
|
||||
@to_scannable.register(LanceModel)
|
||||
|
||||
@@ -243,10 +243,7 @@ def _into_pyarrow_reader(
|
||||
raise ValueError("Cannot add a single LanceModel to a table. Use a list.")
|
||||
|
||||
if isinstance(data, dict):
|
||||
raise ValueError(
|
||||
"Cannot create or add rows from a single dictionary. "
|
||||
"Use a list of dictionaries instead."
|
||||
)
|
||||
raise ValueError("Cannot add a single dictionary to a table. Use a list.")
|
||||
|
||||
if isinstance(data, list):
|
||||
# Handle empty list case
|
||||
@@ -702,24 +699,6 @@ def _normalize_progress(progress):
|
||||
return progress, False
|
||||
|
||||
|
||||
def _computed_groups(computed):
|
||||
"""Group computed columns by expression, preserving declaration order
|
||||
(struct-returning functions need their columns adjacent so schema order
|
||||
matches field order). Accepts the ergonomic forms -- `fn("col")` values
|
||||
and tuple keys for struct fan-out -- via `_normalize_computed`."""
|
||||
from .udf import _normalize_computed
|
||||
|
||||
groups = []
|
||||
for name, (sql_type, expression) in _normalize_computed(computed).items():
|
||||
for expr, cols in groups:
|
||||
if expr == expression:
|
||||
cols.append((name, sql_type))
|
||||
break
|
||||
else:
|
||||
groups.append((expression, [(name, sql_type)]))
|
||||
return groups
|
||||
|
||||
|
||||
class Table(ABC):
|
||||
"""
|
||||
A Table is a collection of Records in a LanceDB Database.
|
||||
@@ -825,59 +804,6 @@ class Table(ABC):
|
||||
"""The number of rows in this Table"""
|
||||
return self.count_rows(None)
|
||||
|
||||
def add_computed_column(
|
||||
self,
|
||||
columns,
|
||||
fn,
|
||||
args: Optional[List[str]] = None,
|
||||
types=None,
|
||||
) -> None:
|
||||
"""Declare computed column(s) bound to a UDF -- no compute happens
|
||||
here (the agent fills them lazily, or refresh_column() triggers a run).
|
||||
|
||||
.. deprecated::
|
||||
A computed column is an expression over a registered function, so
|
||||
bind it as one: ``add_columns(computed={"vec": embed("data")})``.
|
||||
``embed("data")`` applies the function to the `data` column and
|
||||
infers the type from the function's return signature -- the
|
||||
function never couples to a particular column. Prefer that form.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"add_computed_column is deprecated; use add_columns(computed="
|
||||
'{"vec": embed("data")}).',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
from .udf import Udf, struct_field_types
|
||||
|
||||
multi = isinstance(columns, (tuple, list))
|
||||
if isinstance(fn, Udf):
|
||||
expr = fn.expression(*(args or []))
|
||||
if types is None:
|
||||
if multi:
|
||||
if not fn.returns.upper().startswith("STRUCT"):
|
||||
raise ValueError(
|
||||
"several columns need a STRUCT-returning function"
|
||||
)
|
||||
types = struct_field_types(fn.returns)
|
||||
else:
|
||||
types = fn.returns
|
||||
else:
|
||||
if types is None:
|
||||
raise ValueError("pass types= when fn is a name string")
|
||||
expr = f"{fn}({', '.join(args or [])})"
|
||||
if multi:
|
||||
if len(types) != len(columns):
|
||||
raise ValueError(
|
||||
f"{len(columns)} columns but {len(types)} output types"
|
||||
)
|
||||
computed = {c: (t, expr) for c, t in zip(columns, types)}
|
||||
else:
|
||||
computed = {columns: (types, expr)}
|
||||
self.add_columns(computed=computed)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def embedding_functions(self) -> Dict[str, EmbeddingFunctionConfig]:
|
||||
@@ -2093,7 +2019,6 @@ class LanceTable(Table):
|
||||
namespace_client: Optional[Any] = None,
|
||||
managed_versioning: Optional[bool] = None,
|
||||
pushdown_operations: Optional[set] = None,
|
||||
route_pushdown_to_rust: bool = False,
|
||||
_async: AsyncTable = None,
|
||||
):
|
||||
if namespace_path is None:
|
||||
@@ -2103,14 +2028,6 @@ class LanceTable(Table):
|
||||
self._location = location # Store location for use in _dataset_path
|
||||
self._namespace_client = namespace_client
|
||||
self._pushdown_operations = pushdown_operations or set()
|
||||
# When the connection built the namespace client natively (e.g. an
|
||||
# enterprise "rest" connection), the underlying Rust table already
|
||||
# executes QueryTable pushdown itself -- and, unlike this Python urllib3
|
||||
# path, it routes through the read-freshness context provider that emits
|
||||
# the ``x-lancedb-min-timestamp`` header. So we must defer pushdown to
|
||||
# Rust instead of calling the Python ``namespace_client.query_table``
|
||||
# directly, or reads silently bypass read-freshness (stale results).
|
||||
self._route_pushdown_to_rust = route_pushdown_to_rust
|
||||
if _async is not None:
|
||||
self._table = _async
|
||||
else:
|
||||
@@ -2321,7 +2238,6 @@ class LanceTable(Table):
|
||||
namespace_path=self._namespace_path,
|
||||
namespace_client=self._namespace_client,
|
||||
pushdown_operations=self._pushdown_operations,
|
||||
route_pushdown_to_rust=self._route_pushdown_to_rust,
|
||||
location=self._location,
|
||||
_async=async_table,
|
||||
)
|
||||
@@ -2472,11 +2388,8 @@ class LanceTable(Table):
|
||||
Returns
|
||||
-------
|
||||
pa.Table"""
|
||||
if (
|
||||
_should_push_down_query_table(
|
||||
self._namespace_client, self._pushdown_operations
|
||||
)
|
||||
and not self._route_pushdown_to_rust
|
||||
if _should_push_down_query_table(
|
||||
self._namespace_client, self._pushdown_operations
|
||||
):
|
||||
return self._execute_query(Query()).read_all()
|
||||
|
||||
@@ -3428,7 +3341,6 @@ class LanceTable(Table):
|
||||
location: Optional[str] = None,
|
||||
namespace_client: Optional[Any] = None,
|
||||
pushdown_operations: Optional[set] = None,
|
||||
route_pushdown_to_rust: bool = False,
|
||||
):
|
||||
"""
|
||||
Create a new table.
|
||||
@@ -3491,24 +3403,21 @@ class LanceTable(Table):
|
||||
self._location = location
|
||||
self._namespace_client = namespace_client
|
||||
self._pushdown_operations = pushdown_operations or set()
|
||||
self._route_pushdown_to_rust = route_pushdown_to_rust
|
||||
|
||||
if data_storage_version is not None:
|
||||
warnings.warn(
|
||||
"setting data_storage_version directly on create_table is deprecated. "
|
||||
"setting data_storage_version directly on create_table is deprecated. ",
|
||||
"Use database_options instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if storage_options is None:
|
||||
storage_options = {}
|
||||
storage_options["new_table_data_storage_version"] = data_storage_version
|
||||
if enable_v2_manifest_paths is not None:
|
||||
warnings.warn(
|
||||
"setting enable_v2_manifest_paths directly on create_table is "
|
||||
"setting enable_v2_manifest_paths directly on create_table is ",
|
||||
"deprecated. Use database_options instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if storage_options is None:
|
||||
storage_options = {}
|
||||
@@ -3605,7 +3514,6 @@ class LanceTable(Table):
|
||||
_should_push_down_query_table(
|
||||
self._namespace_client, self._pushdown_operations
|
||||
)
|
||||
and not self._route_pushdown_to_rust
|
||||
and self.current_branch() is None
|
||||
):
|
||||
from lancedb.namespace import _execute_server_side_query
|
||||
@@ -3781,68 +3689,9 @@ class LanceTable(Table):
|
||||
return LOOP.run(self._table.index_stats(index_name))
|
||||
|
||||
def add_columns(
|
||||
self,
|
||||
transforms: Dict[str, str]
|
||||
| pa.field
|
||||
| List[pa.field]
|
||||
| pa.Schema
|
||||
| None = None,
|
||||
*,
|
||||
computed: Optional[Dict] = None,
|
||||
) -> Optional[AddColumnsResult]:
|
||||
result = None
|
||||
if transforms is not None:
|
||||
result = LOOP.run(self._table.add_columns(transforms))
|
||||
if computed:
|
||||
# computed binds an expression over a registered function to a
|
||||
# column: {col: fn("input_col")} -- fn("input_col") yields the
|
||||
# expression and carries the inferred type; a tuple key fans a
|
||||
# STRUCT return out to several columns. Declares the binding only;
|
||||
# the server fills the values (server-backed). The legacy
|
||||
# {col: (sql_type, expression)} tuple form is still accepted.
|
||||
result_unused = LOOP.run(self._table.add_columns(computed=computed))
|
||||
del result_unused
|
||||
return result
|
||||
|
||||
def refresh_column(
|
||||
self,
|
||||
columns,
|
||||
*,
|
||||
where: Optional[str] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
priority: Optional[str] = None,
|
||||
) -> "JobHandle":
|
||||
"""Trigger recompute of computed columns (REFRESH COLUMN).
|
||||
|
||||
The expression is resolved server-side from each column's stored
|
||||
binding; columns bound to the same struct-returning function
|
||||
refresh together. Returns a `JobHandle` to wait on, poll, or cancel
|
||||
(``tbl.refresh_column("col").wait()``) -- mirrors
|
||||
`MaterializedView.refresh()`. Server-backed feature (LanceDB
|
||||
Enterprise / Cloud).
|
||||
|
||||
num_workers / max_workers / batch_size / priority are per-refresh
|
||||
scheduling knobs (how to run THIS refresh) and override any default
|
||||
the function carries. `priority` is a Kueue tier
|
||||
(training | interactive | backfill).
|
||||
"""
|
||||
from .udf import JobHandle
|
||||
|
||||
if isinstance(columns, str):
|
||||
columns = [columns]
|
||||
job_id = LOOP.run(
|
||||
self._table.refresh_column(
|
||||
list(columns),
|
||||
where=where,
|
||||
num_workers=num_workers,
|
||||
max_workers=max_workers,
|
||||
batch_size=batch_size,
|
||||
priority=priority,
|
||||
)
|
||||
)
|
||||
return JobHandle(self._conn, job_id, table=self.name)
|
||||
self, transforms: Dict[str, str] | pa.field | List[pa.field] | pa.Schema
|
||||
) -> AddColumnsResult:
|
||||
return LOOP.run(self._table.add_columns(transforms))
|
||||
|
||||
def alter_columns(
|
||||
self, *alterations: Iterable[Dict[str, str]]
|
||||
@@ -4406,7 +4255,6 @@ class AsyncTable:
|
||||
namespace_path: Optional[List[str]] = None,
|
||||
namespace_client: Optional[Any] = None,
|
||||
pushdown_operations: Optional[set] = None,
|
||||
route_pushdown_to_rust: bool = False,
|
||||
):
|
||||
"""Create a new AsyncTable object.
|
||||
|
||||
@@ -4419,9 +4267,6 @@ class AsyncTable:
|
||||
self._namespace_path = namespace_path or []
|
||||
self._namespace_client = namespace_client
|
||||
self._pushdown_operations = pushdown_operations or set()
|
||||
# See LanceTable.__init__: defer QueryTable pushdown to Rust (which emits
|
||||
# the read-freshness header) for natively-built namespace clients.
|
||||
self._route_pushdown_to_rust = route_pushdown_to_rust
|
||||
|
||||
def _set_namespace_context(
|
||||
self,
|
||||
@@ -4429,12 +4274,10 @@ class AsyncTable:
|
||||
namespace_path: Optional[List[str]] = None,
|
||||
namespace_client: Optional[Any] = None,
|
||||
pushdown_operations: Optional[set] = None,
|
||||
route_pushdown_to_rust: bool = False,
|
||||
) -> "AsyncTable":
|
||||
self._namespace_path = namespace_path or []
|
||||
self._namespace_client = namespace_client
|
||||
self._pushdown_operations = pushdown_operations or set()
|
||||
self._route_pushdown_to_rust = route_pushdown_to_rust
|
||||
return self
|
||||
|
||||
def __repr__(self):
|
||||
@@ -4644,11 +4487,8 @@ class AsyncTable:
|
||||
-------
|
||||
pa.Table
|
||||
"""
|
||||
if (
|
||||
_should_push_down_query_table(
|
||||
self._namespace_client, self._pushdown_operations
|
||||
)
|
||||
and not self._route_pushdown_to_rust
|
||||
if _should_push_down_query_table(
|
||||
self._namespace_client, self._pushdown_operations
|
||||
):
|
||||
return (await self._execute_query(Query())).read_all()
|
||||
|
||||
@@ -5332,11 +5172,8 @@ class AsyncTable:
|
||||
batch_size: Optional[int] = None,
|
||||
timeout: Optional[timedelta] = None,
|
||||
) -> pa.RecordBatchReader:
|
||||
if (
|
||||
_should_push_down_query_table(
|
||||
self._namespace_client, self._pushdown_operations
|
||||
)
|
||||
and not self._route_pushdown_to_rust
|
||||
if _should_push_down_query_table(
|
||||
self._namespace_client, self._pushdown_operations
|
||||
):
|
||||
from lancedb.namespace import _execute_server_side_query
|
||||
|
||||
@@ -5520,44 +5357,9 @@ class AsyncTable:
|
||||
|
||||
return await self._inner.update(updates_sql, where)
|
||||
|
||||
async def refresh_column(
|
||||
self,
|
||||
columns,
|
||||
*,
|
||||
where: Optional[str] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
priority: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Trigger recompute of computed columns (REFRESH COLUMN).
|
||||
Returns the refresh job id. Server-backed feature.
|
||||
|
||||
num_workers / max_workers / batch_size / priority are per-refresh
|
||||
scheduling knobs (how to run THIS refresh); they override any default
|
||||
the function carries. `priority` is a Kueue tier
|
||||
(training | interactive | backfill)."""
|
||||
if isinstance(columns, str):
|
||||
columns = [columns]
|
||||
return await self._inner.refresh_column(
|
||||
list(columns),
|
||||
where_clause=where,
|
||||
num_workers=num_workers,
|
||||
max_workers=max_workers,
|
||||
batch_size=batch_size,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
async def add_columns(
|
||||
self,
|
||||
transforms: dict[str, str]
|
||||
| pa.field
|
||||
| List[pa.field]
|
||||
| pa.Schema
|
||||
| None = None,
|
||||
*,
|
||||
computed: Optional[Dict] = None,
|
||||
) -> Optional[AddColumnsResult]:
|
||||
self, transforms: dict[str, str] | pa.field | List[pa.field] | pa.Schema
|
||||
) -> AddColumnsResult:
|
||||
"""
|
||||
Add new columns with defined values.
|
||||
|
||||
@@ -5576,7 +5378,6 @@ class AsyncTable:
|
||||
version: the new version number of the table after adding columns.
|
||||
|
||||
"""
|
||||
result = None
|
||||
if isinstance(transforms, pa.Field):
|
||||
transforms = [transforms]
|
||||
if isinstance(transforms, list) and all(
|
||||
@@ -5584,69 +5385,9 @@ class AsyncTable:
|
||||
):
|
||||
transforms = pa.schema(transforms)
|
||||
if isinstance(transforms, pa.Schema):
|
||||
result = await self._inner.add_columns_with_schema(transforms)
|
||||
elif transforms is not None:
|
||||
result = await self._inner.add_columns(list(transforms.items()))
|
||||
if computed:
|
||||
# computed binds an expression over a registered function to a
|
||||
# column: {col: fn("input_col")} -- fn("input_col") yields the
|
||||
# expression and carries the inferred type; a tuple key fans a
|
||||
# STRUCT return out to several columns. Declares the binding only;
|
||||
# the server fills the values (server-backed). The legacy
|
||||
# {col: (sql_type, expression)} tuple form is still accepted.
|
||||
for expression, cols in _computed_groups(computed):
|
||||
await self._inner.add_computed_columns(cols, expression)
|
||||
return result
|
||||
|
||||
async def add_computed_column(
|
||||
self,
|
||||
columns,
|
||||
fn,
|
||||
args: Optional[List[str]] = None,
|
||||
types=None,
|
||||
) -> None:
|
||||
"""Declare computed column(s) bound to a UDF (async).
|
||||
|
||||
.. deprecated::
|
||||
Use ``add_columns(computed={"col": fn("input_col")})`` -- a computed
|
||||
column is an expression over a registered function, so bind it that
|
||||
way instead of coupling the UDF to the column here.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"add_computed_column is deprecated; use add_columns(computed="
|
||||
'{"col": fn("input_col")}).',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
from .udf import Udf, struct_field_types
|
||||
|
||||
multi = isinstance(columns, (tuple, list))
|
||||
if isinstance(fn, Udf):
|
||||
expr = fn.expression(*(args or []))
|
||||
if types is None:
|
||||
if multi:
|
||||
if not fn.returns.upper().startswith("STRUCT"):
|
||||
raise ValueError(
|
||||
"several columns need a STRUCT-returning function"
|
||||
)
|
||||
types = struct_field_types(fn.returns)
|
||||
else:
|
||||
types = fn.returns
|
||||
return await self._inner.add_columns_with_schema(transforms)
|
||||
else:
|
||||
if types is None:
|
||||
raise ValueError("pass types= when fn is a name string")
|
||||
expr = f"{fn}({', '.join(args or [])})"
|
||||
if multi:
|
||||
if len(types) != len(columns):
|
||||
raise ValueError(
|
||||
f"{len(columns)} columns but {len(types)} output types"
|
||||
)
|
||||
computed = {c: (t, expr) for c, t in zip(columns, types)}
|
||||
else:
|
||||
computed = {columns: (types, expr)}
|
||||
await self.add_columns(computed=computed)
|
||||
return await self._inner.add_columns(list(transforms.items()))
|
||||
|
||||
async def alter_columns(
|
||||
self, *alterations: Iterable[dict[str, Any]]
|
||||
@@ -5918,7 +5659,6 @@ class AsyncTable:
|
||||
"The 'retrain' parameter is deprecated and will be removed in a "
|
||||
"future version.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return await self._inner.optimize(
|
||||
|
||||
@@ -1,753 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
"""UDF authoring for LanceDB derived compute (server-backed).
|
||||
|
||||
`@udf` / `@table_udf` turn a plain Python function into a registrable
|
||||
server-side UDF: a cloudpickled (or source) body, a SQL signature inferred
|
||||
from type hints, and the runtime options (pip deps, GPUs, batching, ...).
|
||||
Register and use them through the existing connection/table API:
|
||||
|
||||
import lancedb
|
||||
from lancedb import udf, table_udf
|
||||
|
||||
db = lancedb.connect("db://my_db", api_key="...", host_override="...")
|
||||
|
||||
@udf(pip=["torch>=2.0"], num_gpus=1)
|
||||
def embed(text: str) -> list[float]:
|
||||
return model.encode(text).tolist()
|
||||
|
||||
db.create_function(embed) # CREATE FUNCTION (once)
|
||||
tbl = db.open_table("docs")
|
||||
tbl.add_columns(computed={"vec": embed("text")}) # bind embed(text) -> vec
|
||||
tbl.refresh_column("vec").wait() # materialize (returns a JobHandle)
|
||||
view = db.create_materialized_view("chunks", tbl, ["id", chunk_fn])
|
||||
|
||||
`embed("text")` applies the registered function to the `text` column and yields
|
||||
the expression `embed(text)`; the function itself stays decoupled from any
|
||||
column, so the same `embed` works on any column or table.
|
||||
|
||||
These operations are server-backed (LanceDB Enterprise / Cloud); the
|
||||
decorator itself works locally (define + call), only registration needs a
|
||||
remote connection.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import dataclasses
|
||||
import functools
|
||||
import inspect
|
||||
import re
|
||||
import sys
|
||||
import textwrap
|
||||
import time
|
||||
import typing
|
||||
|
||||
# -- type hints -> SQL type strings -------------------------------------
|
||||
|
||||
_SCALARS = {
|
||||
int: "BIGINT",
|
||||
# Pragmatic default for ML workloads: python float maps to FLOAT
|
||||
# (Float32). Use an explicit `returns=` for DOUBLE.
|
||||
float: "FLOAT",
|
||||
str: "VARCHAR",
|
||||
bool: "BOOLEAN",
|
||||
bytes: "BLOB",
|
||||
}
|
||||
|
||||
|
||||
class TypeInferenceError(TypeError):
|
||||
pass
|
||||
|
||||
|
||||
def sql_type(hint) -> str:
|
||||
"""SQL type string for a python type hint."""
|
||||
if hint in _SCALARS:
|
||||
return _SCALARS[hint]
|
||||
origin = typing.get_origin(hint)
|
||||
if origin in (list, typing.List):
|
||||
(item,) = typing.get_args(hint) or (None,)
|
||||
if item in _SCALARS:
|
||||
return f"{_SCALARS[item]}[]"
|
||||
raise TypeInferenceError(
|
||||
f"unsupported list item type {item!r}; use an explicit returns="
|
||||
)
|
||||
fields = _struct_fields(hint)
|
||||
if fields is not None:
|
||||
inner = ", ".join(f"{name} {sql_type(h)}" for name, h in fields)
|
||||
return f"STRUCT({inner})"
|
||||
raise TypeInferenceError(
|
||||
f"cannot infer a SQL type for {hint!r}; pass an explicit type string"
|
||||
)
|
||||
|
||||
|
||||
def _struct_fields(hint):
|
||||
"""(name, hint) pairs for a TypedDict or dataclass, else None."""
|
||||
if dataclasses.is_dataclass(hint):
|
||||
return [(f.name, f.type) for f in dataclasses.fields(hint)]
|
||||
# TypedDict detection: a dict subclass with __annotations__.
|
||||
if (
|
||||
isinstance(hint, type)
|
||||
and issubclass(hint, dict)
|
||||
and typing.get_type_hints(hint)
|
||||
):
|
||||
return list(typing.get_type_hints(hint).items())
|
||||
return None
|
||||
|
||||
|
||||
def return_type(fn, override: "str | None", table: bool) -> str:
|
||||
"""SQL return type for a function: explicit override wins, else the
|
||||
return annotation. Table functions render as TABLE(...) and accept
|
||||
struct-shaped hints (TypedDict/dataclass, optionally list-wrapped)."""
|
||||
if override is not None:
|
||||
s = override.strip()
|
||||
if table and not s.upper().startswith("TABLE"):
|
||||
if s.upper().startswith("STRUCT"):
|
||||
return "TABLE" + s[len("STRUCT") :]
|
||||
raise TypeInferenceError(
|
||||
"a table function's returns= must be TABLE(...) or STRUCT(...)"
|
||||
)
|
||||
return s
|
||||
|
||||
hints = typing.get_type_hints(fn)
|
||||
ret = hints.get("return")
|
||||
if ret is None:
|
||||
raise TypeInferenceError(
|
||||
f"function {fn.__name__!r} needs a return annotation or returns="
|
||||
)
|
||||
if table:
|
||||
# Accept list[Row] / Row where Row is a TypedDict or dataclass.
|
||||
if typing.get_origin(ret) in (list, typing.List):
|
||||
(ret,) = typing.get_args(ret)
|
||||
fields = _struct_fields(ret)
|
||||
if fields is None:
|
||||
raise TypeInferenceError(
|
||||
"a table function must return rows shaped as a TypedDict or "
|
||||
"dataclass (optionally list-wrapped); or pass returns=..."
|
||||
)
|
||||
inner = ", ".join(f"{name} {sql_type(h)}" for name, h in fields)
|
||||
return f"TABLE({inner})"
|
||||
return sql_type(ret)
|
||||
|
||||
|
||||
def param_types(fn) -> "list[tuple[str, str]]":
|
||||
"""(name, sql type) per parameter, from annotations. Each UDF
|
||||
parameter binds to a source column of the same name by default."""
|
||||
hints = typing.get_type_hints(fn)
|
||||
out = []
|
||||
for name, p in inspect.signature(fn).parameters.items():
|
||||
if p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD):
|
||||
raise TypeInferenceError("*args/**kwargs are not supported in UDFs")
|
||||
hint = hints.get(name)
|
||||
if hint is None:
|
||||
raise TypeInferenceError(
|
||||
f"parameter {name!r} of {fn.__name__!r} needs a type annotation"
|
||||
)
|
||||
out.append((name, sql_type(hint)))
|
||||
return out
|
||||
|
||||
|
||||
# -- column expressions -------------------------------------------------
|
||||
|
||||
|
||||
class ColumnExpr(str):
|
||||
"""A computed-column expression produced by applying a registered
|
||||
function to column names, e.g. ``embed("data") -> "embed(data)"``.
|
||||
|
||||
It IS the expression string everywhere a string is expected (views, SQL,
|
||||
logging), and additionally carries the function's declared return type so
|
||||
``add_columns(computed=...)`` can declare the column without a hand-written
|
||||
type. ``field_types`` holds the per-field SQL types of a STRUCT return, for
|
||||
fanning one expression out to several columns.
|
||||
"""
|
||||
|
||||
data_type: "str | None"
|
||||
field_types: "list[str] | None"
|
||||
|
||||
def __new__(cls, expr: str, data_type=None, field_types=None):
|
||||
obj = super().__new__(cls, expr)
|
||||
obj.data_type = data_type
|
||||
obj.field_types = field_types
|
||||
return obj
|
||||
|
||||
|
||||
def _normalize_computed(computed: dict) -> dict:
|
||||
"""Normalize the user-facing ``computed=`` mapping to the canonical
|
||||
``{name: (sql_type, expression)}`` form.
|
||||
|
||||
Accepts, per entry:
|
||||
- value is a `ColumnExpr` (from ``fn("col")``): the column's SQL type
|
||||
comes from the function's return type -- no hand-written type needed. A
|
||||
tuple key (``("chunk", "idx")``) fans a STRUCT return out to one
|
||||
(type, expression) entry per field, in declared order.
|
||||
- value is a legacy ``(sql_type, expression)`` tuple: passed through (the
|
||||
escape hatch, e.g. bare-name function strings).
|
||||
"""
|
||||
out: dict = {}
|
||||
for key, val in computed.items():
|
||||
if isinstance(val, ColumnExpr):
|
||||
expr = str(val)
|
||||
if isinstance(key, (tuple, list)):
|
||||
if not val.field_types:
|
||||
raise ValueError(
|
||||
f"columns {tuple(key)} need a STRUCT-returning function; "
|
||||
f"{expr} returns a single value"
|
||||
)
|
||||
if len(val.field_types) != len(key):
|
||||
raise ValueError(
|
||||
f"{len(key)} columns but {len(val.field_types)} struct fields "
|
||||
f"in {expr}"
|
||||
)
|
||||
for name, t in zip(key, val.field_types):
|
||||
out[name] = (t, expr)
|
||||
else:
|
||||
if val.data_type is None:
|
||||
raise ValueError(f"cannot infer a type for {expr}; pass types=")
|
||||
out[key] = (val.data_type, expr)
|
||||
else:
|
||||
out[key] = val
|
||||
return out
|
||||
|
||||
|
||||
# -- the @udf / @table_udf decorators -----------------------------------
|
||||
|
||||
|
||||
class Udf:
|
||||
def __init__(
|
||||
self,
|
||||
fn,
|
||||
*,
|
||||
returns: "str | None" = None,
|
||||
table: bool = False,
|
||||
name: "str | None" = None,
|
||||
pip: "list[str] | None" = None,
|
||||
pip_index_url: "str | None" = None,
|
||||
pip_extra_index_urls: "list[str] | None" = None,
|
||||
find_links: "list[str] | None" = None,
|
||||
requirements: "str | list[str] | None" = None,
|
||||
conda: "list[str] | None" = None,
|
||||
conda_channels: "list[str] | None" = None,
|
||||
env: "dict[str, str] | list[str] | None" = None,
|
||||
num_cpus: "int | None" = None,
|
||||
num_gpus: "int | None" = None,
|
||||
batch_size: "int | None" = None,
|
||||
timeout: "float | None" = None,
|
||||
error_policy: "str | None" = None,
|
||||
max_skip_ratio: "float | None" = None,
|
||||
retries: "int | None" = None,
|
||||
docker_image: "str | None" = None,
|
||||
description: "str | None" = None,
|
||||
prefer_source: bool = False,
|
||||
):
|
||||
functools.update_wrapper(self, fn)
|
||||
self.fn = fn
|
||||
self.name = name or fn.__name__
|
||||
self.table = table
|
||||
self.params = param_types(fn)
|
||||
self.returns = return_type(fn, returns, table)
|
||||
self.prefer_source = prefer_source
|
||||
self.options: "dict[str, str]" = {}
|
||||
if conda and (pip or requirements):
|
||||
raise ValueError("pass conda or pip/requirements, not both")
|
||||
if conda_channels and not conda:
|
||||
raise ValueError("conda_channels requires conda")
|
||||
if pip:
|
||||
self.options["pip"] = ",".join(pip)
|
||||
if pip_extra_index_urls:
|
||||
self.options["pip_extra_index_urls"] = ",".join(pip_extra_index_urls)
|
||||
if find_links:
|
||||
self.options["find_links"] = ",".join(find_links)
|
||||
if requirements:
|
||||
self.options["requirements"] = _format_requirements(requirements)
|
||||
if conda:
|
||||
self.options["conda"] = ",".join(conda)
|
||||
if conda_channels:
|
||||
self.options["conda_channels"] = ",".join(conda_channels)
|
||||
if env:
|
||||
self.options["env"] = _format_env(env)
|
||||
for key, val in [
|
||||
("pip_index_url", pip_index_url),
|
||||
("num_cpus", num_cpus),
|
||||
("num_gpus", num_gpus),
|
||||
("batch_size", batch_size),
|
||||
("timeout", timeout),
|
||||
("error_policy", error_policy),
|
||||
("max_skip_ratio", max_skip_ratio),
|
||||
("retries", retries),
|
||||
("docker_image", docker_image),
|
||||
]:
|
||||
if val is not None:
|
||||
self.options[key] = str(val)
|
||||
# Keep the source in the description (when available) so the
|
||||
# catalog stays inspectable even for pickled bodies.
|
||||
if description is not None:
|
||||
self.options["description"] = description
|
||||
else:
|
||||
try:
|
||||
self.options["description"] = textwrap.dedent(inspect.getsource(fn))
|
||||
except (OSError, TypeError):
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Call with real values to run locally; call with column-name
|
||||
strings to build an expression for backfills and views, e.g.
|
||||
``embed("data")`` -> the expression ``embed(data)`` (a `ColumnExpr`
|
||||
carrying the function's return type for `add_columns(computed=...)`)."""
|
||||
if args and all(isinstance(a, str) for a in args) and not kwargs:
|
||||
return self.expression(*args)
|
||||
return self.fn(*args, **kwargs)
|
||||
|
||||
def expression(self, *columns: str) -> ColumnExpr:
|
||||
"""The expression applying this function to `columns` (default: the
|
||||
function's own parameter names). Returns a `ColumnExpr` -- a string
|
||||
that also carries the declared return type (and struct field types)."""
|
||||
cols = columns or [p for p, _ in self.params]
|
||||
expr = f"{self.name}({', '.join(cols)})"
|
||||
field_types = None
|
||||
if self.returns.upper().startswith("STRUCT"):
|
||||
field_types = struct_field_types(self.returns)
|
||||
return ColumnExpr(expr, data_type=self.returns, field_types=field_types)
|
||||
|
||||
def _body(self) -> "tuple[str, str]":
|
||||
"""(body literal, body_format). Source when requested and
|
||||
retrievable; cloudpickle otherwise (handles closures)."""
|
||||
if self.prefer_source:
|
||||
try:
|
||||
src = textwrap.dedent(inspect.getsource(self.fn))
|
||||
# Strip the decorator line(s) so the stored body is a
|
||||
# plain function definition.
|
||||
lines = src.splitlines(keepends=True)
|
||||
while lines and lines[0].lstrip().startswith("@"):
|
||||
lines.pop(0)
|
||||
return "".join(lines), "source"
|
||||
except (OSError, TypeError):
|
||||
pass
|
||||
import cloudpickle
|
||||
|
||||
raw = cloudpickle.dumps(self.fn)
|
||||
return base64.b64encode(raw).decode("ascii"), "cloudpickle"
|
||||
|
||||
def _body_and_options(self) -> "tuple[str, dict[str, str]]":
|
||||
"""The body literal plus the finalized options (body_format /
|
||||
python_version / cloudpickle-pip bookkeeping for a non-source
|
||||
body)."""
|
||||
body, body_format = self._body()
|
||||
options = dict(self.options)
|
||||
if body_format != "source":
|
||||
options["body_format"] = body_format
|
||||
# Pickled code objects only load under the same interpreter
|
||||
# minor version; record ours so the worker can fail with a
|
||||
# clear message instead of a bytecode error.
|
||||
options["python_version"] = self.pickle_environment()
|
||||
# The worker deserializes the body with cloudpickle; make sure
|
||||
# the job's pip environment provides it. Conda bakes inject
|
||||
# cloudpickle server-side, so do not create an invalid pip+conda
|
||||
# declaration here.
|
||||
if "conda" not in options:
|
||||
pip = [d for d in options.get("pip", "").split(",") if d]
|
||||
if not any(d.startswith("cloudpickle") for d in pip):
|
||||
pip.append("cloudpickle")
|
||||
options["pip"] = ",".join(pip)
|
||||
return body, options
|
||||
|
||||
def create_request(self) -> dict:
|
||||
"""Keyword arguments for `connection.create_function`."""
|
||||
body, options = self._body_and_options()
|
||||
return {
|
||||
"name": self.name,
|
||||
"language": "python",
|
||||
"return_type": self.returns,
|
||||
"body": body,
|
||||
"options": options,
|
||||
}
|
||||
|
||||
def create_statement(self) -> str:
|
||||
"""The equivalent `CREATE FUNCTION` SQL (for SQL-surface callers)."""
|
||||
params = ", ".join(f"{n} {t}" for n, t in self.params)
|
||||
body, options = self._body_and_options()
|
||||
with_clause = ""
|
||||
if options:
|
||||
rendered = ", ".join(
|
||||
f"{k} = '{_escape(v)}'" for k, v in sorted(options.items())
|
||||
)
|
||||
with_clause = f" WITH ({rendered})"
|
||||
return (
|
||||
f"CREATE FUNCTION {self.name}({params}) RETURNS {self.returns} "
|
||||
f"LANGUAGE python AS '{_escape_body(body)}'{with_clause}"
|
||||
)
|
||||
|
||||
def pickle_environment(self) -> str:
|
||||
"""Python version the body pickles under -- workers should match
|
||||
the minor version for cloudpickle compatibility."""
|
||||
return f"{sys.version_info.major}.{sys.version_info.minor}"
|
||||
|
||||
|
||||
def _escape(s: str) -> str:
|
||||
return str(s).replace("'", "''")
|
||||
|
||||
|
||||
def _format_requirements(requirements: "str | list[str]") -> str:
|
||||
if isinstance(requirements, str):
|
||||
return requirements
|
||||
return "\n".join(str(req) for req in requirements)
|
||||
|
||||
|
||||
def _format_env(env: "dict[str, str] | list[str]") -> str:
|
||||
if isinstance(env, dict):
|
||||
return "; ".join(f"{key}={value}" for key, value in env.items())
|
||||
return "; ".join(str(entry) for entry in env)
|
||||
|
||||
|
||||
def _escape_body(body: str) -> str:
|
||||
# The server unescapes \n / \t in single-quoted bodies; encode real
|
||||
# newlines accordingly and escape quotes.
|
||||
return (
|
||||
body.replace("\\", "\\\\")
|
||||
.replace("'", "''")
|
||||
.replace("\n", "\\n")
|
||||
.replace("\t", "\\t")
|
||||
)
|
||||
|
||||
|
||||
def udf(fn=None, **kwargs):
|
||||
"""Decorate a function as a scalar (or struct-returning) UDF.
|
||||
|
||||
@udf
|
||||
def doubled(val: int) -> float: ...
|
||||
|
||||
@udf(pip=["torch>=2"], num_gpus=1)
|
||||
def embed(body: str) -> list[float]: ...
|
||||
"""
|
||||
if fn is not None:
|
||||
return Udf(fn, **kwargs)
|
||||
return lambda f: Udf(f, **kwargs)
|
||||
|
||||
|
||||
def table_udf(fn=None, **kwargs):
|
||||
"""Decorate a table function (UDTF): each input row may emit zero or
|
||||
more output rows. Only usable in materialized views.
|
||||
|
||||
class Chunk(TypedDict):
|
||||
chunk: str
|
||||
chunk_idx: int
|
||||
|
||||
@table_udf
|
||||
def chunker(body: str) -> list[Chunk]: ...
|
||||
"""
|
||||
kwargs["table"] = True
|
||||
if fn is not None:
|
||||
return Udf(fn, **kwargs)
|
||||
return lambda f: Udf(f, **kwargs)
|
||||
|
||||
|
||||
# -- view / job handles (thin references over a connection) -------------
|
||||
|
||||
|
||||
def struct_field_types(returns: str) -> "list[str]":
|
||||
"""Field type strings of a STRUCT(...) SQL type, in declared order."""
|
||||
inner = returns.strip()[len("STRUCT(") : -1]
|
||||
fields, depth, start = [], 0, 0
|
||||
for i, c in enumerate(inner):
|
||||
if c in "([":
|
||||
depth += 1
|
||||
elif c in ")]":
|
||||
depth -= 1
|
||||
elif c == "," and depth == 0:
|
||||
fields.append(inner[start:i].strip())
|
||||
start = i + 1
|
||||
fields.append(inner[start:].strip())
|
||||
# Each field is "name TYPE"; drop the name.
|
||||
return [f.split(None, 1)[1] for f in fields]
|
||||
|
||||
|
||||
def build_view_query(source, select) -> str:
|
||||
"""Assemble a view SELECT from a source (name or table) and select
|
||||
items: a column name, an expression string, a (alias, expression)
|
||||
tuple, or a @udf/@table_udf object."""
|
||||
src = source.name if hasattr(source, "name") else source
|
||||
items = []
|
||||
for item in select:
|
||||
if isinstance(item, Udf):
|
||||
items.append(item.expression())
|
||||
elif isinstance(item, tuple):
|
||||
alias, expr = item
|
||||
expr = expr.expression() if isinstance(expr, Udf) else expr
|
||||
items.append(f"{expr} AS {alias}")
|
||||
else:
|
||||
items.append(item)
|
||||
return f"SELECT {', '.join(items)} FROM {src}"
|
||||
|
||||
|
||||
def _job_id_matches(handle_id: str, listed_id: str) -> bool:
|
||||
# The refresh/backfill endpoints return the submission id (a uuid), but
|
||||
# the agent names the manifest job "<table>-<type>-<first 8 of the
|
||||
# submission id>" -- which is what list_jobs and cancel report. Match the
|
||||
# canonical id directly, or by that submission prefix.
|
||||
if listed_id == handle_id:
|
||||
return True
|
||||
prefix = handle_id[:8]
|
||||
return len(prefix) >= 4 and prefix in listed_id
|
||||
|
||||
|
||||
class MaterializedView:
|
||||
"""A reference to a materialized view (name + connection). Operations are
|
||||
server-backed connection calls bound to the name.
|
||||
|
||||
``create_materialized_view`` returns one of these; ``job_id`` is the
|
||||
initial-population job (None when the view was created with no data), so
|
||||
``db.create_materialized_view(...).wait()`` blocks until it is populated.
|
||||
"""
|
||||
|
||||
def __init__(self, conn, name: str, job_id: "str | None" = None):
|
||||
self.conn = conn
|
||||
self.name = name
|
||||
#: initial-population job id from create, or None (with_no_data).
|
||||
self.job_id = job_id
|
||||
|
||||
def wait(self, timeout: float = 3600.0, poll: float = 2.0) -> str:
|
||||
"""Block until the initial-population job (from create) finishes.
|
||||
A no-op when the view was created with no data."""
|
||||
if self.job_id is None:
|
||||
return "finished"
|
||||
return JobHandle(self.conn, self.job_id, table=self.name).wait(
|
||||
timeout=timeout, poll=poll
|
||||
)
|
||||
|
||||
def refresh(self, full: bool = False) -> "JobHandle":
|
||||
"""Refresh the materialized view; returns a `JobHandle` to wait on,
|
||||
poll, or cancel (``view.refresh().wait()``).
|
||||
|
||||
``full=True`` forces a full rebuild (recompute and replace every row)
|
||||
instead of the default incremental refresh. A full rebuild preserves
|
||||
the view's indexes -- they are reindexed by the distributed indexer.
|
||||
"""
|
||||
job_id = self.conn._refresh_materialized_view(self.name, full=full)
|
||||
return JobHandle(self.conn, job_id, table=self.name)
|
||||
|
||||
def explain_refresh(self, full: bool = False):
|
||||
"""Plan a refresh without running it (EXPLAIN REFRESH)."""
|
||||
return self.conn.explain_refresh_materialized_view(self.name, full=full)
|
||||
|
||||
def alter(self, auto_refresh: bool) -> None:
|
||||
self.conn.alter_materialized_view(self.name, auto_refresh=auto_refresh)
|
||||
|
||||
def drop(self) -> None:
|
||||
self.conn.drop_materialized_view(self.name)
|
||||
|
||||
# A materialized view is a first-class table: it can be indexed and
|
||||
# searched like any other. These open the materialized dataset by name and
|
||||
# delegate. Indexes declared this way are recorded against the view, so the
|
||||
# engine re-applies them after a full refresh rebuilds the dataset (a full
|
||||
# refresh overwrites the dataset, which would otherwise drop its indices).
|
||||
def _table(self):
|
||||
return self.conn.open_table(self.name)
|
||||
|
||||
def create_index(self, *args, **kwargs):
|
||||
"""Build an index on the materialized view (see Table.create_index)."""
|
||||
return self._table().create_index(*args, **kwargs)
|
||||
|
||||
def create_scalar_index(self, *args, **kwargs):
|
||||
"""Build a scalar index on the materialized view."""
|
||||
return self._table().create_scalar_index(*args, **kwargs)
|
||||
|
||||
def create_fts_index(self, *args, **kwargs):
|
||||
"""Build a full-text-search index on the materialized view."""
|
||||
return self._table().create_fts_index(*args, **kwargs)
|
||||
|
||||
def search(self, *args, **kwargs):
|
||||
"""Search the materialized view (vector / FTS / hybrid)."""
|
||||
return self._table().search(*args, **kwargs)
|
||||
|
||||
def lineage(self, column=None, *, direction=None, depth=None):
|
||||
"""Lineage of the materialized view (or one of its columns). Delegates
|
||||
to the backing table; the server already includes the view's sources
|
||||
and downstream dependents. Returns a `Lineage`."""
|
||||
return self._table().lineage(column, direction=direction, depth=depth)
|
||||
|
||||
|
||||
_PROGRESS = re.compile(r"(\d+)/(\d+)")
|
||||
|
||||
|
||||
class JobFailedError(RuntimeError):
|
||||
"""Raised by ``JobHandle.wait()`` when the server reports the job ``failed``.
|
||||
|
||||
Carries the server-side error so a doomed backfill (e.g. a multi-column
|
||||
``REFRESH COLUMN`` of a scalar UDF) surfaces its real cause promptly,
|
||||
instead of the caller blocking until ``wait()``'s timeout.
|
||||
"""
|
||||
|
||||
def __init__(self, job_id: str, error: "str | None"):
|
||||
self.job_id = job_id
|
||||
self.error = error
|
||||
super().__init__(f"job {job_id} failed: {error or 'unknown error'}")
|
||||
|
||||
|
||||
class JobHandle:
|
||||
"""A reference to an inflight server-side job, with polling helpers."""
|
||||
|
||||
#: How long an unseen job is treated as still materializing (submission
|
||||
#: -> agent cycle -> manifest write is async).
|
||||
GRACE_SECONDS = 20.0
|
||||
|
||||
def __init__(self, conn, job_id: str, table: "str | None" = None):
|
||||
self.conn = conn
|
||||
self.id = job_id
|
||||
#: The job's table, when known (refresh_column / MV refresh). Lets the
|
||||
#: server resolve this job with an O(1) single-node read; without it the
|
||||
#: lookup scans the database's active jobs (still correct).
|
||||
self.table = table
|
||||
self._created = time.monotonic()
|
||||
self._seen = False
|
||||
|
||||
def _job(self):
|
||||
# Poll by id (one job), not list_jobs (every active job): the server
|
||||
# matches the submission/manifest id and reads just this table's node.
|
||||
return self.conn.get_job(self.id, self.table)
|
||||
|
||||
def status(self) -> str:
|
||||
"""pending / running / cancelling / stale, or 'finished' once the
|
||||
job has left the inflight listing."""
|
||||
job = self._job()
|
||||
if job is not None:
|
||||
self._seen = True
|
||||
return job.state
|
||||
if not self._seen and time.monotonic() - self._created < self.GRACE_SECONDS:
|
||||
return "pending"
|
||||
return "finished"
|
||||
|
||||
def progress(self) -> "tuple[int, int] | None":
|
||||
"""(units_done, units_total) while running, else None."""
|
||||
job = self._job()
|
||||
if job is not None and job.units_total is not None:
|
||||
return job.units_done or 0, job.units_total
|
||||
return None
|
||||
|
||||
def wait(self, timeout: float = 3600.0, poll: float = 2.0) -> str:
|
||||
deadline = time.monotonic() + timeout
|
||||
while time.monotonic() < deadline:
|
||||
state = self.status()
|
||||
if state in ("finished", "stale"):
|
||||
return state
|
||||
if state == "failed":
|
||||
# Terminal failure -- surface the server error now, don't block
|
||||
# until `timeout`. `finalize` wrote it to the job's status node.
|
||||
job = self._job()
|
||||
raise JobFailedError(self.id, job.error if job is not None else None)
|
||||
if state == "pending":
|
||||
time.sleep(min(poll, 0.5))
|
||||
continue
|
||||
job = self._job()
|
||||
if job is not None and job.committed:
|
||||
return "finished"
|
||||
time.sleep(poll)
|
||||
raise TimeoutError(f"job {self.id} still {self.status()} after {timeout}s")
|
||||
|
||||
def cancel(self) -> None:
|
||||
# Cancel by the canonical manifest id (what cancel matches), found
|
||||
# via the submission prefix; fall back to the raw id.
|
||||
job = self._job()
|
||||
self.conn.cancel_job(job.job_id if job is not None else self.id)
|
||||
|
||||
|
||||
class AsyncMaterializedView:
|
||||
"""Async reference to a materialized view (name + async connection)."""
|
||||
|
||||
def __init__(self, conn, name: str, job_id: "str | None" = None):
|
||||
self.conn = conn
|
||||
self.name = name
|
||||
#: initial-population job id from create, or None (with_no_data).
|
||||
self.job_id = job_id
|
||||
|
||||
async def wait(self, timeout: float = 3600.0, poll: float = 2.0) -> str:
|
||||
"""Block until the initial-population job (from create) finishes.
|
||||
A no-op when the view was created with no data."""
|
||||
if self.job_id is None:
|
||||
return "finished"
|
||||
return await AsyncJobHandle(self.conn, self.job_id, table=self.name).wait(
|
||||
timeout=timeout, poll=poll
|
||||
)
|
||||
|
||||
async def refresh(self, full: bool = False) -> "AsyncJobHandle":
|
||||
"""Refresh the materialized view; returns an `AsyncJobHandle` to wait
|
||||
on, poll, or cancel.
|
||||
|
||||
``full=True`` forces a full rebuild instead of an incremental refresh
|
||||
(indexes are preserved and reindexed by the distributed indexer).
|
||||
"""
|
||||
job_id = await self.conn._refresh_materialized_view(self.name, full=full)
|
||||
return AsyncJobHandle(self.conn, job_id, table=self.name)
|
||||
|
||||
async def explain_refresh(self, full: bool = False):
|
||||
return await self.conn.explain_refresh_materialized_view(self.name, full=full)
|
||||
|
||||
async def alter(self, auto_refresh: bool) -> None:
|
||||
await self.conn.alter_materialized_view(self.name, auto_refresh=auto_refresh)
|
||||
|
||||
async def drop(self) -> None:
|
||||
await self.conn.drop_materialized_view(self.name)
|
||||
|
||||
async def lineage(self, column=None, *, direction=None, depth=None):
|
||||
"""Lineage of the materialized view (or column). Returns a `Lineage`."""
|
||||
return await self.conn.lineage(
|
||||
self.name, column, direction=direction, depth=depth
|
||||
)
|
||||
|
||||
|
||||
class AsyncJobHandle:
|
||||
"""Async reference to an inflight server-side job, with polling helpers."""
|
||||
|
||||
GRACE_SECONDS = 20.0
|
||||
|
||||
def __init__(self, conn, job_id: str, table: "str | None" = None):
|
||||
self.conn = conn
|
||||
self.id = job_id
|
||||
#: See JobHandle.table -- enables an O(1) by-id lookup when known.
|
||||
self.table = table
|
||||
self._created = time.monotonic()
|
||||
self._seen = False
|
||||
|
||||
async def _job(self):
|
||||
# Poll by id, not list_jobs (see JobHandle._job).
|
||||
return await self.conn.get_job(self.id, self.table)
|
||||
|
||||
async def status(self) -> str:
|
||||
job = await self._job()
|
||||
if job is not None:
|
||||
self._seen = True
|
||||
return job.state
|
||||
if not self._seen and time.monotonic() - self._created < self.GRACE_SECONDS:
|
||||
return "pending"
|
||||
return "finished"
|
||||
|
||||
async def progress(self) -> "tuple[int, int] | None":
|
||||
job = await self._job()
|
||||
if job is not None and job.units_total is not None:
|
||||
return job.units_done or 0, job.units_total
|
||||
return None
|
||||
|
||||
async def wait(self, timeout: float = 3600.0, poll: float = 2.0) -> str:
|
||||
deadline = time.monotonic() + timeout
|
||||
while time.monotonic() < deadline:
|
||||
state = await self.status()
|
||||
if state in ("finished", "stale"):
|
||||
return state
|
||||
if state == "failed":
|
||||
# Terminal failure -- surface the server error now, don't block
|
||||
# until `timeout`. `finalize` wrote it to the job's status node.
|
||||
job = await self._job()
|
||||
raise JobFailedError(self.id, job.error if job is not None else None)
|
||||
if state == "pending":
|
||||
await asyncio.sleep(min(poll, 0.5))
|
||||
continue
|
||||
job = await self._job()
|
||||
if job is not None and job.committed:
|
||||
return "finished"
|
||||
await asyncio.sleep(poll)
|
||||
raise TimeoutError(
|
||||
f"job {self.id} still {await self.status()} after {timeout}s"
|
||||
)
|
||||
|
||||
async def cancel(self) -> None:
|
||||
job = await self._job()
|
||||
await self.conn.cancel_job(job.job_id if job is not None else self.id)
|
||||
@@ -373,15 +373,9 @@ def _(value: list):
|
||||
@value_to_sql.register(dict)
|
||||
def _(value: dict):
|
||||
# https://datafusion.apache.org/user-guide/sql/scalar_functions.html#named-struct
|
||||
# Render the field name through value_to_sql(str(...)) as well so that keys
|
||||
# containing characters meaningful in SQL (e.g. a single quote) are escaped
|
||||
# the same way string values are. A bare f"'{k}'" would emit invalid SQL for
|
||||
# a key like "it's".
|
||||
return (
|
||||
"named_struct("
|
||||
+ ", ".join(
|
||||
f"{value_to_sql(str(k))}, {value_to_sql(v)}" for k, v in value.items()
|
||||
)
|
||||
+ ", ".join(f"'{k}', {value_to_sql(v)}" for k, v in value.items())
|
||||
+ ")"
|
||||
)
|
||||
|
||||
|
||||
@@ -91,9 +91,7 @@ async def test_create_scalar_index(some_table: AsyncTable):
|
||||
# Can recreate if replace=True
|
||||
await some_table.create_index("id", replace=True)
|
||||
indices = await some_table.list_indices()
|
||||
assert str(indices).startswith(
|
||||
'[IndexConfig(name="id_idx", index_type="BTree", columns=["id"]'
|
||||
)
|
||||
assert str(indices) == '[Index(BTree, columns=["id"], name="id_idx")]'
|
||||
assert len(indices) == 1
|
||||
assert indices[0].index_type == "BTree"
|
||||
assert indices[0].columns == ["id"]
|
||||
@@ -108,27 +106,6 @@ async def test_create_scalar_index(some_table: AsyncTable):
|
||||
assert len(indices) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_index_config_repr(db_async):
|
||||
# Use >= 1000 rows so the thousands separator in the repr is exercised.
|
||||
nrows = 1500
|
||||
table = await db_async.create_table(
|
||||
"repr_table", pa.Table.from_pydict({"id": list(range(nrows))})
|
||||
)
|
||||
await table.create_index("id", config=BTree())
|
||||
indices = await table.list_indices()
|
||||
assert len(indices) == 1
|
||||
|
||||
r = repr(indices[0])
|
||||
assert r.startswith('IndexConfig(name="id_idx", index_type="BTree", columns=["id"]')
|
||||
# Integer counts use `_` thousands separators (valid Python int syntax).
|
||||
assert "num_indexed_rows=1_500" in r
|
||||
assert "num_unindexed_rows=0" in r
|
||||
# created_at renders as a datetime so the value round-trips.
|
||||
assert "created_at=datetime.datetime(" in r
|
||||
assert r.endswith(")")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_nested_scalar_index_lists_canonical_paths(db_async):
|
||||
metadata_type = pa.struct(
|
||||
@@ -221,9 +198,7 @@ async def test_create_nested_scalar_index_lists_canonical_paths(db_async):
|
||||
async def test_create_fixed_size_binary_index(some_table: AsyncTable):
|
||||
await some_table.create_index("fsb", config=BTree())
|
||||
indices = await some_table.list_indices()
|
||||
assert str(indices).startswith(
|
||||
'[IndexConfig(name="fsb_idx", index_type="BTree", columns=["fsb"]'
|
||||
)
|
||||
assert str(indices) == '[Index(BTree, columns=["fsb"], name="fsb_idx")]'
|
||||
assert len(indices) == 1
|
||||
assert indices[0].index_type == "BTree"
|
||||
assert indices[0].columns == ["fsb"]
|
||||
@@ -272,9 +247,7 @@ async def test_create_bitmap_index(some_table: AsyncTable):
|
||||
async def test_create_label_list_index(some_table: AsyncTable):
|
||||
await some_table.create_index("tags", config=LabelList())
|
||||
indices = await some_table.list_indices()
|
||||
assert str(indices).startswith(
|
||||
'[IndexConfig(name="tags_idx", index_type="LabelList", columns=["tags"]'
|
||||
)
|
||||
assert str(indices) == '[Index(LabelList, columns=["tags"], name="tags_idx")]'
|
||||
plan = await some_table.query().where("array_has(tags, 'tag0')").explain_plan()
|
||||
assert "ScalarIndexQuery" in plan
|
||||
|
||||
@@ -289,9 +262,7 @@ async def test_create_large_list_label_list_index(db_async):
|
||||
|
||||
await table.create_index("tags", config=LabelList())
|
||||
indices = await table.list_indices()
|
||||
assert str(indices).startswith(
|
||||
'[IndexConfig(name="tags_idx", index_type="LabelList", columns=["tags"]'
|
||||
)
|
||||
assert str(indices) == '[Index(LabelList, columns=["tags"], name="tags_idx")]'
|
||||
plan = await table.query().where("array_has(tags, 'shared')").explain_plan()
|
||||
assert "ScalarIndexQuery" in plan
|
||||
|
||||
@@ -328,9 +299,7 @@ async def test_create_label_list_index_rejects_list_struct(db_async):
|
||||
async def test_full_text_search_index(some_table: AsyncTable):
|
||||
await some_table.create_index("tags", config=FTS(with_position=False))
|
||||
indices = await some_table.list_indices()
|
||||
assert str(indices).startswith(
|
||||
'[IndexConfig(name="tags_idx", index_type="FTS", columns=["tags"]'
|
||||
)
|
||||
assert str(indices) == '[Index(FTS, columns=["tags"], name="tags_idx")]'
|
||||
|
||||
await some_table.prewarm_index("tags_idx")
|
||||
|
||||
|
||||
@@ -1,92 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
"""JobHandle.wait() terminal-state handling.
|
||||
|
||||
Regression coverage for the cluster backfill-failure hang: the server reports a
|
||||
doomed job as ``state="failed"`` within seconds, but ``wait()`` used to ignore
|
||||
``failed`` and block until its (default 3600s) timeout. These tests pin that a
|
||||
``failed`` job raises ``JobFailedError`` promptly, carrying the server error.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from lancedb.udf import JobHandle, AsyncJobHandle, JobFailedError
|
||||
|
||||
|
||||
class FakeJobInfo:
|
||||
"""Mirror of the pyo3 builtins.JobInfo fields wait()/status() read."""
|
||||
|
||||
def __init__(self, state, error=None, committed=False, units_total=None):
|
||||
self.state = state
|
||||
self.error = error
|
||||
self.committed = committed
|
||||
self.units_total = units_total
|
||||
self.units_done = None
|
||||
self.job_id = "job-1"
|
||||
|
||||
|
||||
class FakeConn:
|
||||
"""get_job() walks a scripted list of JobInfo (or None) snapshots, holding
|
||||
the last one once exhausted, so wait() polls a deterministic timeline."""
|
||||
|
||||
def __init__(self, snapshots):
|
||||
self._snaps = list(snapshots)
|
||||
self.calls = 0
|
||||
|
||||
def get_job(self, job_id, table=None):
|
||||
snap = self._snaps[min(self.calls, len(self._snaps) - 1)]
|
||||
self.calls += 1
|
||||
return snap
|
||||
|
||||
|
||||
class AsyncFakeConn(FakeConn):
|
||||
async def get_job(self, job_id, table=None):
|
||||
return FakeConn.get_job(self, job_id, table)
|
||||
|
||||
|
||||
def test_wait_raises_on_failed_promptly():
|
||||
# pending -> failed: wait() must raise the server error, not TimeoutError.
|
||||
conn = FakeConn(
|
||||
[None, FakeJobInfo("failed", error="multi-column backfill needs a STRUCT")]
|
||||
)
|
||||
jh = JobHandle(conn, "job-1", table="t")
|
||||
t0 = time.monotonic()
|
||||
with pytest.raises(JobFailedError) as exc:
|
||||
jh.wait(timeout=30, poll=0.01)
|
||||
assert time.monotonic() - t0 < 5 # prompt, nowhere near the 30s timeout
|
||||
assert "STRUCT" in str(exc.value)
|
||||
assert exc.value.error == "multi-column backfill needs a STRUCT"
|
||||
assert exc.value.job_id == "job-1"
|
||||
|
||||
|
||||
def test_wait_returns_finished_on_success():
|
||||
# running -> finished (job left the inflight listing) returns normally.
|
||||
conn = FakeConn([FakeJobInfo("running", units_total=2), None])
|
||||
jh = JobHandle(conn, "job-1", table="t")
|
||||
jh._seen = True # already observed, so a None now means "finished" not grace
|
||||
assert jh.wait(timeout=30, poll=0.01) == "finished"
|
||||
|
||||
|
||||
def test_wait_returns_finished_on_committed():
|
||||
# A committed job that is still listed resolves to finished.
|
||||
conn = FakeConn([FakeJobInfo("running", committed=True, units_total=2)])
|
||||
jh = JobHandle(conn, "job-1", table="t")
|
||||
jh._seen = True
|
||||
assert jh.wait(timeout=30, poll=0.01) == "finished"
|
||||
|
||||
|
||||
def test_async_wait_raises_on_failed_promptly():
|
||||
conn = AsyncFakeConn([None, FakeJobInfo("failed", error="boom")])
|
||||
jh = AsyncJobHandle(conn, "job-1", table="t")
|
||||
|
||||
async def run():
|
||||
t0 = time.monotonic()
|
||||
with pytest.raises(JobFailedError) as exc:
|
||||
await jh.wait(timeout=30, poll=0.01)
|
||||
assert time.monotonic() - t0 < 5
|
||||
assert exc.value.error == "boom"
|
||||
|
||||
asyncio.run(run())
|
||||
@@ -5,11 +5,11 @@
|
||||
|
||||
import tempfile
|
||||
import shutil
|
||||
import sys
|
||||
import pytest
|
||||
import pyarrow as pa
|
||||
import lancedb
|
||||
from lance_namespace.errors import NamespaceNotEmptyError, TableNotFoundError
|
||||
from lancedb.namespace import _MAX_QUERY_K
|
||||
from lancedb.table import AsyncTable, LanceTable
|
||||
|
||||
|
||||
@@ -65,9 +65,6 @@ def _namespace_lance_table(namespace_client: _NamespaceClient) -> LanceTable:
|
||||
table._namespace_path = ["geneva"]
|
||||
table._namespace_client = namespace_client
|
||||
table._pushdown_operations = {"QueryTable"}
|
||||
# This test exercises the Python-side pushdown path (non-native client), so
|
||||
# pushdown is not routed to Rust.
|
||||
table._route_pushdown_to_rust = False
|
||||
return table
|
||||
|
||||
|
||||
@@ -808,37 +805,6 @@ class TestPushdownOperations:
|
||||
db = lancedb.connect_namespace("dir", {"root": self.temp_dir})
|
||||
assert len(db._namespace_client_pushdown_operations) == 0
|
||||
|
||||
def test_route_pushdown_to_rust_for_native_rest(self):
|
||||
"""A natively-built rest connection must defer QueryTable pushdown to
|
||||
Rust so reads carry the x-lancedb-min-timestamp read-freshness header."""
|
||||
db = lancedb.connect_namespace(
|
||||
"rest",
|
||||
{"uri": "http://localhost:12345"},
|
||||
namespace_client_pushdown_operations=["QueryTable"],
|
||||
)
|
||||
assert db._route_pushdown_to_rust is True
|
||||
|
||||
def test_route_pushdown_to_rust_false_for_dir(self):
|
||||
"""A non-native (dir) connection keeps the Python pushdown path."""
|
||||
db = lancedb.connect_namespace("dir", {"root": self.temp_dir})
|
||||
assert db._route_pushdown_to_rust is False
|
||||
|
||||
def test_async_route_pushdown_to_rust_for_native_rest(self):
|
||||
"""The async connection must not silently bypass the read-freshness fix:
|
||||
a natively-built rest connection defers pushdown to Rust (regression test
|
||||
for the async path omitting the freshness header)."""
|
||||
db = lancedb.connect_namespace_async(
|
||||
"rest",
|
||||
{"uri": "http://localhost:12345"},
|
||||
namespace_client_pushdown_operations=["QueryTable"],
|
||||
)
|
||||
assert db._route_pushdown_to_rust is True
|
||||
|
||||
def test_async_route_pushdown_to_rust_false_for_dir(self):
|
||||
"""The async non-native (dir) connection keeps the Python pushdown path."""
|
||||
db = lancedb.connect_namespace_async("dir", {"root": self.temp_dir})
|
||||
assert db._route_pushdown_to_rust is False
|
||||
|
||||
def test_lance_table_to_arrow_uses_query_pushdown(self):
|
||||
namespace_client = _NamespaceClient()
|
||||
table = _namespace_lance_table(namespace_client)
|
||||
@@ -850,13 +816,10 @@ class TestPushdownOperations:
|
||||
["geneva", "hist"],
|
||||
["geneva", "hist"],
|
||||
]
|
||||
# Unlimited reads cap k at i32::MAX (the namespace query_table `k`
|
||||
# field is i32); sys.maxsize would overflow the Rust binding.
|
||||
assert [request.k for request in namespace_client.requests] == [
|
||||
_MAX_QUERY_K,
|
||||
_MAX_QUERY_K,
|
||||
sys.maxsize,
|
||||
sys.maxsize,
|
||||
]
|
||||
assert all(r.k <= 2**31 - 1 for r in namespace_client.requests)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -911,13 +874,10 @@ class TestAsyncPushdownOperations:
|
||||
["geneva", "hist"],
|
||||
["geneva", "hist"],
|
||||
]
|
||||
# Unlimited reads cap k at i32::MAX (the namespace query_table `k`
|
||||
# field is i32); sys.maxsize would overflow the Rust binding.
|
||||
assert [request.k for request in namespace_client.requests] == [
|
||||
_MAX_QUERY_K,
|
||||
_MAX_QUERY_K,
|
||||
sys.maxsize,
|
||||
sys.maxsize,
|
||||
]
|
||||
assert all(r.k <= 2**31 - 1 for r in namespace_client.requests)
|
||||
|
||||
|
||||
def test_local_table_to_arrow_and_to_pandas_are_unchanged(tmp_path):
|
||||
|
||||
@@ -188,18 +188,6 @@ def test_nested_struct_list():
|
||||
assert schema == expect_schema
|
||||
|
||||
|
||||
def test_bare_generic_raises_type_error():
|
||||
# A bare, unparameterised List/Tuple has no element type to map to Arrow.
|
||||
# It should raise a clear TypeError, not crash with AttributeError: __args__.
|
||||
for bare in (List, Tuple):
|
||||
|
||||
class TestModel(pydantic.BaseModel):
|
||||
items: bare
|
||||
|
||||
with pytest.raises(TypeError, match="unsupported type"):
|
||||
pydantic_to_schema(TestModel)
|
||||
|
||||
|
||||
def test_nested_struct_list_optional():
|
||||
class SplitInfo(pydantic.BaseModel):
|
||||
start_frame: int
|
||||
|
||||
@@ -301,16 +301,6 @@ def test_create_table(mem_db: DBConnection):
|
||||
assert expected == tbl
|
||||
|
||||
|
||||
def test_create_table_rejects_single_dictionary(mem_db: DBConnection):
|
||||
data = {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}
|
||||
with pytest.raises(ValueError) as excep_info:
|
||||
mem_db.create_table("test", data=data)
|
||||
assert (
|
||||
str(excep_info.value) == "Cannot create or add rows from a single dictionary. "
|
||||
"Use a list of dictionaries instead."
|
||||
)
|
||||
|
||||
|
||||
def test_empty_table(mem_db: DBConnection):
|
||||
schema = pa.schema(
|
||||
[
|
||||
@@ -340,8 +330,8 @@ def test_add_dictionary(mem_db: DBConnection):
|
||||
with pytest.raises(ValueError) as excep_info:
|
||||
tbl.add(data=data)
|
||||
assert (
|
||||
str(excep_info.value) == "Cannot create or add rows from a single dictionary. "
|
||||
"Use a list of dictionaries instead."
|
||||
str(excep_info.value)
|
||||
== "Cannot add a single dictionary to a table. Use a list."
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -149,21 +149,6 @@ def test_value_to_sql_dict():
|
||||
assert value_to_sql({}) == "named_struct()"
|
||||
|
||||
|
||||
def test_value_to_sql_dict_key_escaping():
|
||||
# Struct field names that contain a single quote must be escaped (doubled)
|
||||
# the same way string values are, otherwise value_to_sql emits invalid SQL
|
||||
# such as named_struct('it's', 1).
|
||||
assert value_to_sql({"it's": 1}) == "named_struct('it''s', 1)"
|
||||
assert (
|
||||
value_to_sql({"o'brien": "d'angelo"}) == "named_struct('o''brien', 'd''angelo')"
|
||||
)
|
||||
# Escaping also applies to keys of nested structs.
|
||||
assert (
|
||||
value_to_sql({"outer": {"in'r": 1}})
|
||||
== "named_struct('outer', named_struct('in''r', 1))"
|
||||
)
|
||||
|
||||
|
||||
def test_value_to_sql_numpy_scalars():
|
||||
# numpy scalars (e.g. pulled from an ndarray or a pandas column) must
|
||||
# convert the same way as their native Python counterparts. np.float64
|
||||
|
||||
@@ -18,10 +18,7 @@ use lancedb::{
|
||||
connection::Connection as LanceConnection,
|
||||
connection::NamespaceClientPushdownOperation,
|
||||
database::namespace::LanceNamespaceDatabase,
|
||||
database::{
|
||||
CreateFunctionRequest, CreateMaterializedViewRequest, CreateTableMode, Database,
|
||||
ReadConsistency, RefreshMaterializedViewRequest, TableLineageRequest,
|
||||
},
|
||||
database::{CreateTableMode, Database, ReadConsistency},
|
||||
};
|
||||
use pyo3::{
|
||||
Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
|
||||
@@ -30,92 +27,6 @@ use pyo3::{
|
||||
types::{PyDict, PyDictMethods},
|
||||
};
|
||||
|
||||
/// A registered function, as returned by `list_functions`.
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone)]
|
||||
pub struct FunctionInfo {
|
||||
pub name: String,
|
||||
pub language: String,
|
||||
pub return_type: String,
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
/// A registered materialized view definition.
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone)]
|
||||
pub struct MaterializedViewInfo {
|
||||
pub name: String,
|
||||
pub source_table: String,
|
||||
pub projection: Vec<String>,
|
||||
pub udf_columns: Vec<String>,
|
||||
pub filter: Option<String>,
|
||||
pub auto_refresh: bool,
|
||||
}
|
||||
|
||||
/// One inflight server-side job.
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone)]
|
||||
pub struct JobInfo {
|
||||
pub table: String,
|
||||
pub job_id: String,
|
||||
pub job_type: String,
|
||||
pub state: String,
|
||||
pub column: Option<String>,
|
||||
pub age_seconds: Option<i64>,
|
||||
pub command: Option<String>,
|
||||
pub units_done: Option<i64>,
|
||||
pub units_total: Option<i64>,
|
||||
pub committed: bool,
|
||||
pub rows_skipped: u64,
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// One durable, completed/terminal server-side job record (SHOW JOB HISTORY).
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone)]
|
||||
pub struct JobHistoryEntry {
|
||||
pub table: String,
|
||||
pub job_id: String,
|
||||
pub job_type: String,
|
||||
pub state: String,
|
||||
pub column: Option<String>,
|
||||
pub created_ms: i64,
|
||||
pub updated_ms: i64,
|
||||
pub completed_ms: Option<i64>,
|
||||
pub rows_processed: Option<i64>,
|
||||
pub rows_skipped: Option<i64>,
|
||||
pub error: Option<String>,
|
||||
pub events: Option<String>,
|
||||
}
|
||||
|
||||
/// One per-row UDF error recorded by `error_policy=skip` (SHOW ERRORS).
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone)]
|
||||
pub struct JobErrorEntry {
|
||||
pub job_id: String,
|
||||
pub table: String,
|
||||
pub column: String,
|
||||
pub error_type: String,
|
||||
pub error_message: String,
|
||||
pub fragment_id: Option<i64>,
|
||||
pub source_row_id: Option<i64>,
|
||||
pub table_version: Option<i64>,
|
||||
pub age_seconds: Option<i64>,
|
||||
}
|
||||
|
||||
/// The plan a REFRESH MATERIALIZED VIEW would execute (EXPLAIN REFRESH).
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone)]
|
||||
pub struct MvRefreshPlan {
|
||||
pub table_name: String,
|
||||
pub has_work: bool,
|
||||
pub source_version: u64,
|
||||
pub last_refreshed_version: Option<u64>,
|
||||
pub full_refresh: bool,
|
||||
pub rebuild: bool,
|
||||
pub units_total: u64,
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
pub struct Connection {
|
||||
inner: Option<LanceConnection>,
|
||||
@@ -399,308 +310,6 @@ impl Connection {
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (name, language, return_type, body, options=None))]
|
||||
pub fn create_function(
|
||||
self_: PyRef<'_, Self>,
|
||||
name: String,
|
||||
language: String,
|
||||
return_type: String,
|
||||
body: String,
|
||||
options: Option<HashMap<String, String>>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner
|
||||
.create_function(CreateFunctionRequest {
|
||||
name,
|
||||
language,
|
||||
return_type,
|
||||
body,
|
||||
options: options.unwrap_or_default(),
|
||||
})
|
||||
.await
|
||||
.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn list_functions(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let functions = inner.list_functions().await.infer_error()?;
|
||||
Ok(functions
|
||||
.into_iter()
|
||||
.map(|f| FunctionInfo {
|
||||
name: f.name,
|
||||
language: f.language,
|
||||
return_type: f.return_type,
|
||||
description: f.description,
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn drop_function(self_: PyRef<'_, Self>, name: String) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner.drop_function(&name).await.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (name, query, auto_refresh=false, with_no_data=false, partition_by=None))]
|
||||
pub fn create_materialized_view(
|
||||
self_: PyRef<'_, Self>,
|
||||
name: String,
|
||||
query: String,
|
||||
auto_refresh: bool,
|
||||
with_no_data: bool,
|
||||
partition_by: Option<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner
|
||||
.create_materialized_view(CreateMaterializedViewRequest {
|
||||
name,
|
||||
query,
|
||||
auto_refresh,
|
||||
with_no_data,
|
||||
partition_by,
|
||||
})
|
||||
.await
|
||||
.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (name, full=false, src_version=None, num_workers=None, max_workers=None))]
|
||||
pub fn refresh_materialized_view(
|
||||
self_: PyRef<'_, Self>,
|
||||
name: String,
|
||||
full: bool,
|
||||
src_version: Option<u64>,
|
||||
num_workers: Option<u32>,
|
||||
max_workers: Option<u32>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner
|
||||
.refresh_materialized_view(RefreshMaterializedViewRequest {
|
||||
name,
|
||||
full,
|
||||
src_version,
|
||||
num_workers,
|
||||
max_workers,
|
||||
})
|
||||
.await
|
||||
.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
/// Derived-compute lineage of a table/view (or column), returned as the
|
||||
/// server's lineage JSON string (the Python layer parses it).
|
||||
pub fn table_lineage(
|
||||
self_: PyRef<'_, Self>,
|
||||
name: String,
|
||||
column: Option<String>,
|
||||
direction: Option<String>,
|
||||
depth: Option<u32>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner
|
||||
.table_lineage(TableLineageRequest {
|
||||
name,
|
||||
column,
|
||||
direction,
|
||||
depth,
|
||||
})
|
||||
.await
|
||||
.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (name, full=false, src_version=None))]
|
||||
pub fn explain_refresh_materialized_view(
|
||||
self_: PyRef<'_, Self>,
|
||||
name: String,
|
||||
full: bool,
|
||||
src_version: Option<u64>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let p = inner
|
||||
.explain_refresh_materialized_view(&name, full, src_version)
|
||||
.await
|
||||
.infer_error()?;
|
||||
Ok(MvRefreshPlan {
|
||||
table_name: p.table_name,
|
||||
has_work: p.has_work,
|
||||
source_version: p.source_version,
|
||||
last_refreshed_version: p.last_refreshed_version,
|
||||
full_refresh: p.full_refresh,
|
||||
rebuild: p.rebuild,
|
||||
units_total: p.units_total,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn alter_materialized_view(
|
||||
self_: PyRef<'_, Self>,
|
||||
name: String,
|
||||
auto_refresh: bool,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner
|
||||
.alter_materialized_view(&name, auto_refresh)
|
||||
.await
|
||||
.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn drop_materialized_view(
|
||||
self_: PyRef<'_, Self>,
|
||||
name: String,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner.drop_materialized_view(&name).await.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn list_materialized_views(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let views = inner.list_materialized_views().await.infer_error()?;
|
||||
Ok(views
|
||||
.into_iter()
|
||||
.map(|v| MaterializedViewInfo {
|
||||
name: v.name,
|
||||
source_table: v.source_table,
|
||||
projection: v.projection,
|
||||
udf_columns: v.udf_columns,
|
||||
filter: v.filter,
|
||||
auto_refresh: v.auto_refresh,
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn list_jobs(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let jobs = inner.list_jobs().await.infer_error()?;
|
||||
Ok(jobs
|
||||
.into_iter()
|
||||
.map(|j| JobInfo {
|
||||
table: j.table,
|
||||
job_id: j.job_id,
|
||||
job_type: j.job_type,
|
||||
state: j.state,
|
||||
column: j.column,
|
||||
age_seconds: j.age_seconds,
|
||||
command: j.command,
|
||||
units_done: j.units_done,
|
||||
units_total: j.units_total,
|
||||
committed: j.committed,
|
||||
rows_skipped: j.rows_skipped,
|
||||
error: j.error,
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn cancel_job(self_: PyRef<'_, Self>, job_id: String) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner.cancel_job(&job_id).await.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (job_id, table=None))]
|
||||
pub fn get_job(
|
||||
self_: PyRef<'_, Self>,
|
||||
job_id: String,
|
||||
table: Option<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let job = inner
|
||||
.get_job(&job_id, table.as_deref())
|
||||
.await
|
||||
.infer_error()?;
|
||||
Ok(job.map(|j| JobInfo {
|
||||
table: j.table,
|
||||
job_id: j.job_id,
|
||||
job_type: j.job_type,
|
||||
state: j.state,
|
||||
column: j.column,
|
||||
age_seconds: j.age_seconds,
|
||||
command: j.command,
|
||||
units_done: j.units_done,
|
||||
units_total: j.units_total,
|
||||
committed: j.committed,
|
||||
rows_skipped: j.rows_skipped,
|
||||
error: j.error,
|
||||
}))
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (job_id=None))]
|
||||
pub fn job_history(
|
||||
self_: PyRef<'_, Self>,
|
||||
job_id: Option<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let rows = inner.job_history(job_id.as_deref()).await.infer_error()?;
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.map(|r| JobHistoryEntry {
|
||||
table: r.table,
|
||||
job_id: r.job_id,
|
||||
job_type: r.job_type,
|
||||
state: r.state,
|
||||
column: r.column,
|
||||
created_ms: r.created_ms,
|
||||
updated_ms: r.updated_ms,
|
||||
completed_ms: r.completed_ms,
|
||||
rows_processed: r.rows_processed,
|
||||
rows_skipped: r.rows_skipped,
|
||||
error: r.error,
|
||||
events: r.events,
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (job_id=None, table=None))]
|
||||
pub fn errors(
|
||||
self_: PyRef<'_, Self>,
|
||||
job_id: Option<String>,
|
||||
table: Option<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let rows = inner
|
||||
.errors(job_id.as_deref(), table.as_deref())
|
||||
.await
|
||||
.infer_error()?;
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.map(|e| JobErrorEntry {
|
||||
job_id: e.job_id,
|
||||
table: e.table,
|
||||
column: e.column,
|
||||
error_type: e.error_type,
|
||||
error_message: e.error_message,
|
||||
fragment_id: e.fragment_id,
|
||||
source_row_id: e.source_row_id,
|
||||
table_version: e.table_version,
|
||||
age_seconds: e.age_seconds,
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (cur_name, new_name, cur_namespace_path=None, new_namespace_path=None))]
|
||||
pub fn rename_table(
|
||||
self_: PyRef<'_, Self>,
|
||||
@@ -930,7 +539,7 @@ impl Connection {
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (uri, api_key=None, region=None, host_override=None, read_consistency_interval=None, client_config=None, storage_options=None, session=None, manifest_enabled=false, namespace_client_properties=None, oauth_config=None))]
|
||||
#[pyo3(signature = (uri, api_key=None, region=None, host_override=None, read_consistency_interval=None, client_config=None, storage_options=None, session=None, manifest_enabled=false, namespace_client_properties=None))]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn connect(
|
||||
py: Python<'_>,
|
||||
@@ -944,7 +553,6 @@ pub fn connect(
|
||||
session: Option<crate::session::Session>,
|
||||
manifest_enabled: bool,
|
||||
namespace_client_properties: Option<HashMap<String, String>>,
|
||||
oauth_config: Option<crate::oauth::PyOAuthConfig>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
future_into_py(py, async move {
|
||||
let mut builder = lancedb::connect(&uri);
|
||||
@@ -974,11 +582,6 @@ pub fn connect(
|
||||
if let Some(client_config) = client_config {
|
||||
builder = builder.client_config(client_config.into());
|
||||
}
|
||||
if let Some(oauth_config) = oauth_config {
|
||||
let config: lancedb::remote::oauth::OAuthConfig =
|
||||
oauth_config.try_into().infer_error()?;
|
||||
builder = builder.oauth_config(config);
|
||||
}
|
||||
if let Some(session) = session {
|
||||
builder = builder.session(session.inner.clone());
|
||||
}
|
||||
@@ -1007,38 +610,24 @@ pub fn connect_namespace_client(
|
||||
namespace_client_impl: Option<String>,
|
||||
namespace_client_properties: Option<HashMap<String, String>>,
|
||||
) -> PyResult<Connection> {
|
||||
let namespace_client = extract_namespace_arc(py, namespace_client)?;
|
||||
let read_consistency_interval = read_consistency_interval.map(Duration::from_secs_f64);
|
||||
let namespace_client_pushdown_operations =
|
||||
parse_namespace_client_pushdown_operations(namespace_client_pushdown_operations)?;
|
||||
let ns_impl = namespace_client_impl.unwrap_or_else(|| "python".to_string());
|
||||
let ns_properties = namespace_client_properties.unwrap_or_default();
|
||||
let storage_options = storage_options.unwrap_or_default();
|
||||
let session = session.map(|s| s.inner.clone());
|
||||
|
||||
// Prefer building the namespace natively from (impl, properties) so the
|
||||
// read-freshness provider installed
|
||||
let database = if build_namespace_natively(namespace_client_impl.as_deref(), &ns_properties) {
|
||||
let ns_impl = namespace_client_impl.expect("impl present per build_namespace_natively");
|
||||
crate::runtime::block_on(LanceNamespaceDatabase::connect(
|
||||
&ns_impl,
|
||||
ns_properties,
|
||||
storage_options,
|
||||
read_consistency_interval,
|
||||
session,
|
||||
namespace_client_pushdown_operations,
|
||||
))
|
||||
.infer_error()?
|
||||
} else {
|
||||
let namespace_client = extract_namespace_arc(py, namespace_client)?;
|
||||
LanceNamespaceDatabase::from_namespace_client(
|
||||
namespace_client,
|
||||
namespace_client_impl.unwrap_or_else(|| "python".to_string()),
|
||||
ns_properties,
|
||||
storage_options,
|
||||
read_consistency_interval,
|
||||
session,
|
||||
namespace_client_pushdown_operations,
|
||||
)
|
||||
};
|
||||
let database = LanceNamespaceDatabase::from_namespace_client(
|
||||
namespace_client,
|
||||
ns_impl,
|
||||
ns_properties,
|
||||
storage_options,
|
||||
read_consistency_interval,
|
||||
session,
|
||||
namespace_client_pushdown_operations,
|
||||
);
|
||||
|
||||
Ok(Connection::new(LanceConnection::new(
|
||||
Arc::new(database),
|
||||
@@ -1046,16 +635,6 @@ pub fn connect_namespace_client(
|
||||
)))
|
||||
}
|
||||
|
||||
/// Whether to build the namespace natively (from impl + properties) instead of
|
||||
/// wrapping a pre-built client. Native construction is required for the
|
||||
/// read-freshness provider to be installed
|
||||
fn build_namespace_natively(
|
||||
namespace_client_impl: Option<&str>,
|
||||
namespace_client_properties: &HashMap<String, String>,
|
||||
) -> bool {
|
||||
matches!(namespace_client_impl, Some("rest")) && !namespace_client_properties.is_empty()
|
||||
}
|
||||
|
||||
#[derive(FromPyObject)]
|
||||
pub struct PyClientConfig {
|
||||
user_agent: String,
|
||||
@@ -1154,36 +733,3 @@ impl From<PyClientConfig> for lancedb::remote::ClientConfig {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn props(pairs: &[(&str, &str)]) -> HashMap<String, String> {
|
||||
pairs
|
||||
.iter()
|
||||
.map(|(k, v)| (k.to_string(), v.to_string()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn native_build_only_for_rest_with_properties() {
|
||||
let rest = props(&[("uri", "http://localhost:10024")]);
|
||||
|
||||
// rest + non-empty properties -> build natively (installs the
|
||||
// read-freshness provider so checkout_latest() busts the server cache).
|
||||
assert!(build_namespace_natively(Some("rest"), &rest));
|
||||
|
||||
// dir is local (no server cache) -> wrap the pre-built client unchanged.
|
||||
assert!(!build_namespace_natively(
|
||||
Some("dir"),
|
||||
&props(&[("root", "/tmp")])
|
||||
));
|
||||
|
||||
// No impl: only a pre-built client was handed in -> wrap it as-is.
|
||||
assert!(!build_namespace_natively(None, &rest));
|
||||
|
||||
// rest but no properties: nothing to build a connection from -> wrap.
|
||||
assert!(!build_namespace_natively(Some("rest"), &HashMap::new()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -319,53 +319,11 @@ pub struct IndexConfig {
|
||||
|
||||
#[pymethods]
|
||||
impl IndexConfig {
|
||||
pub fn __repr__(&self, py: Python<'_>) -> String {
|
||||
let mut fields = vec![
|
||||
format!("name={:?}", self.name),
|
||||
format!("index_type={:?}", self.index_type),
|
||||
format!("columns={:?}", self.columns),
|
||||
];
|
||||
if let Some(v) = &self.index_uuid {
|
||||
fields.push(format!("index_uuid={:?}", v));
|
||||
}
|
||||
if let Some(v) = &self.type_url {
|
||||
fields.push(format!("type_url={:?}", v));
|
||||
}
|
||||
if let Some(v) = self.created_at {
|
||||
// Render the datetime's own Python repr so the value round-trips,
|
||||
// falling back to RFC 3339 if the conversion ever fails.
|
||||
let rendered = v
|
||||
.into_pyobject(py)
|
||||
.ok()
|
||||
.and_then(|obj| obj.into_any().repr().ok())
|
||||
.map(|r| r.to_string())
|
||||
.unwrap_or_else(|| v.to_rfc3339());
|
||||
fields.push(format!("created_at={}", rendered));
|
||||
}
|
||||
if let Some(v) = self.num_indexed_rows {
|
||||
fields.push(format!("num_indexed_rows={}", fmt_thousands(v)));
|
||||
}
|
||||
if let Some(v) = self.num_unindexed_rows {
|
||||
fields.push(format!("num_unindexed_rows={}", fmt_thousands(v)));
|
||||
}
|
||||
if let Some(v) = self.size_bytes {
|
||||
fields.push(format!("size_bytes={}", fmt_thousands(v)));
|
||||
}
|
||||
if let Some(v) = self.num_segments {
|
||||
fields.push(format!("num_segments={}", v));
|
||||
}
|
||||
if let Some(v) = self.index_version {
|
||||
fields.push(format!("index_version={}", v));
|
||||
}
|
||||
if let Some(v) = &self.index_details {
|
||||
let details = v
|
||||
.bind(py)
|
||||
.repr()
|
||||
.map(|r| r.to_string())
|
||||
.unwrap_or_else(|_| "<unavailable>".to_string());
|
||||
fields.push(format!("index_details={}", details));
|
||||
}
|
||||
format!("IndexConfig({})", fields.join(", "))
|
||||
pub fn __repr__(&self) -> String {
|
||||
format!(
|
||||
"Index({}, columns={:?}, name=\"{}\")",
|
||||
self.index_type, self.columns, self.name
|
||||
)
|
||||
}
|
||||
|
||||
// For backwards-compatibility with the old sync SDK, we also support getting
|
||||
@@ -394,23 +352,6 @@ impl IndexConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// Format an integer with `_` thousands separators, e.g. `24_500_213`.
|
||||
///
|
||||
/// Underscores are valid Python int-literal syntax, so the repr stays
|
||||
/// copy-pasteable and machine-parseable while remaining readable.
|
||||
fn fmt_thousands(n: u64) -> String {
|
||||
let digits = n.to_string();
|
||||
let bytes = digits.as_bytes();
|
||||
let mut out = String::with_capacity(digits.len() + digits.len() / 3);
|
||||
for (i, b) in bytes.iter().enumerate() {
|
||||
if i > 0 && (bytes.len() - i).is_multiple_of(3) {
|
||||
out.push('_');
|
||||
}
|
||||
out.push(*b as char);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn parse_index_details(py: Python<'_>, s: String) -> Py<PyAny> {
|
||||
let json = py.import("json").expect("json module is always available");
|
||||
match json.call_method1("loads", (s.as_str(),)) {
|
||||
|
||||
@@ -26,7 +26,6 @@ pub mod expr;
|
||||
pub mod header;
|
||||
pub mod index;
|
||||
pub mod namespace;
|
||||
pub mod oauth;
|
||||
pub mod permutation;
|
||||
pub mod query;
|
||||
pub mod runtime;
|
||||
@@ -41,11 +40,6 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
.write_style("LANCEDB_LOG_STYLE");
|
||||
env_logger::init_from_env(env);
|
||||
m.add_class::<Connection>()?;
|
||||
m.add_class::<connection::FunctionInfo>()?;
|
||||
m.add_class::<connection::MaterializedViewInfo>()?;
|
||||
m.add_class::<connection::JobInfo>()?;
|
||||
m.add_class::<connection::JobHistoryEntry>()?;
|
||||
m.add_class::<connection::JobErrorEntry>()?;
|
||||
m.add_class::<Session>()?;
|
||||
m.add_class::<Table>()?;
|
||||
m.add_class::<IndexConfig>()?;
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use pyo3::FromPyObject;
|
||||
|
||||
use lancedb::error::Error;
|
||||
use lancedb::remote::oauth::{OAuthConfig, OAuthFlow};
|
||||
|
||||
/// Python-side OAuth configuration, extracted via FromPyObject.
|
||||
/// Maps to `lancedb.remote.oauth.OAuthConfig` Python dataclass.
|
||||
#[derive(FromPyObject)]
|
||||
pub struct PyOAuthConfig {
|
||||
pub issuer_url: String,
|
||||
pub client_id: String,
|
||||
pub scopes: Vec<String>,
|
||||
pub flow: String,
|
||||
pub client_secret: Option<String>,
|
||||
pub managed_identity_client_id: Option<String>,
|
||||
pub refresh_buffer_secs: Option<u64>,
|
||||
}
|
||||
|
||||
impl TryFrom<PyOAuthConfig> for OAuthConfig {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(py: PyOAuthConfig) -> Result<Self, Self::Error> {
|
||||
let flow = match py.flow.as_str() {
|
||||
"client_credentials" => OAuthFlow::ClientCredentials,
|
||||
"azure_managed_identity" => OAuthFlow::AzureManagedIdentity {
|
||||
client_id: py.managed_identity_client_id,
|
||||
},
|
||||
other => {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!("Unknown OAuth flow type: {other}"),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
issuer_url: py.issuer_url,
|
||||
client_id: py.client_id,
|
||||
client_secret: py.client_secret,
|
||||
scopes: py.scopes,
|
||||
flow,
|
||||
refresh_buffer_secs: py.refresh_buffer_secs,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_unknown_oauth_flow_returns_invalid_input() {
|
||||
let config = PyOAuthConfig {
|
||||
issuer_url: "https://issuer.example.com".to_string(),
|
||||
client_id: "client-id".to_string(),
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: "typo".to_string(),
|
||||
client_secret: None,
|
||||
managed_identity_client_id: None,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
|
||||
let err = OAuthConfig::try_from(config).unwrap_err();
|
||||
assert!(matches!(
|
||||
err,
|
||||
Error::InvalidInput { message }
|
||||
if message == "Unknown OAuth flow type: typo"
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -56,15 +56,6 @@ fn get_runtime() -> &'static runtime::Runtime {
|
||||
unsafe { &*new_ptr }
|
||||
}
|
||||
|
||||
/// Block the current thread on a future using the shared runtime.
|
||||
///
|
||||
/// For sync `#[pyfunction]`s that need to drive an async operation (e.g.
|
||||
/// building a namespace client). Must not be called from within the runtime's
|
||||
/// own worker threads.
|
||||
pub fn block_on<F: std::future::Future>(fut: F) -> F::Output {
|
||||
get_runtime().block_on(fut)
|
||||
}
|
||||
|
||||
/// Runs in async-signal context after `fork()` in the child. We can only
|
||||
/// touch atomics here; we deliberately leak the previous runtime because
|
||||
/// dropping a tokio `Runtime` would try to join its (now-dead) worker
|
||||
|
||||
@@ -17,8 +17,8 @@ use arrow::{
|
||||
pyarrow::{FromPyArrow, PyArrowType, ToPyArrow},
|
||||
};
|
||||
use lancedb::table::{
|
||||
AddDataMode, ColumnAlteration, Duration, FieldMetadataUpdate, LoadColumnsRequest,
|
||||
NewColumnTransform, OptimizeAction, OptimizeOptions, Ref, Table as LanceDbTable,
|
||||
AddDataMode, ColumnAlteration, Duration, FieldMetadataUpdate, NewColumnTransform,
|
||||
OptimizeAction, OptimizeOptions, Ref, Table as LanceDbTable,
|
||||
};
|
||||
use pyo3::{
|
||||
Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
|
||||
@@ -1060,83 +1060,6 @@ impl Table {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn add_computed_columns(
|
||||
self_: PyRef<'_, Self>,
|
||||
columns: Vec<(String, String)>,
|
||||
expression: String,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner
|
||||
.add_computed_columns(&columns, &expression)
|
||||
.await
|
||||
.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (columns, where_clause=None, num_workers=None, max_workers=None, batch_size=None, priority=None))]
|
||||
pub fn refresh_column(
|
||||
self_: PyRef<'_, Self>,
|
||||
columns: Vec<String>,
|
||||
where_clause: Option<String>,
|
||||
num_workers: Option<u32>,
|
||||
max_workers: Option<u32>,
|
||||
batch_size: Option<u32>,
|
||||
priority: Option<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner
|
||||
.refresh_column(
|
||||
&columns,
|
||||
where_clause,
|
||||
num_workers,
|
||||
max_workers,
|
||||
batch_size,
|
||||
priority,
|
||||
)
|
||||
.await
|
||||
.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[pyo3(signature = (source_uris, source_format, target_key, columns, source_key=None, source_storage_options=None, on_missing=None, num_workers=None, max_workers=None, batch_size=None, commit_granularity=None, priority=None))]
|
||||
pub fn load_columns(
|
||||
self_: PyRef<'_, Self>,
|
||||
source_uris: Vec<String>,
|
||||
source_format: String,
|
||||
target_key: String,
|
||||
columns: Vec<(String, Option<String>)>,
|
||||
source_key: Option<String>,
|
||||
source_storage_options: Option<std::collections::HashMap<String, String>>,
|
||||
on_missing: Option<String>,
|
||||
num_workers: Option<u32>,
|
||||
max_workers: Option<u32>,
|
||||
batch_size: Option<u32>,
|
||||
commit_granularity: Option<u32>,
|
||||
priority: Option<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
let request = LoadColumnsRequest {
|
||||
source_uris,
|
||||
source_format,
|
||||
source_storage_options,
|
||||
target_key,
|
||||
source_key,
|
||||
columns,
|
||||
on_missing,
|
||||
num_workers,
|
||||
max_workers,
|
||||
batch_size,
|
||||
commit_granularity,
|
||||
priority,
|
||||
};
|
||||
future_into_py(self_.py(), async move {
|
||||
inner.load_columns(request).await.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn add_columns(
|
||||
self_: PyRef<'_, Self>,
|
||||
definitions: Vec<(String, String)>,
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _load_oauth_module():
|
||||
oauth_path = (
|
||||
Path(__file__).parents[1] / "python" / "lancedb" / "remote" / "oauth.py"
|
||||
)
|
||||
spec = importlib.util.spec_from_file_location("lancedb_remote_oauth", oauth_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec.loader is not None
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def test_oauth_config_repr_redacts_client_secret():
|
||||
oauth = _load_oauth_module()
|
||||
|
||||
config = oauth.OAuthConfig(
|
||||
issuer_url="https://issuer.example.com",
|
||||
client_id="client-id",
|
||||
scopes=["scope"],
|
||||
client_secret="super-secret",
|
||||
)
|
||||
|
||||
rendered = repr(config)
|
||||
assert "super-secret" not in rendered
|
||||
assert "client_secret" not in rendered
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb"
|
||||
version = "0.31.0-beta.4"
|
||||
version = "0.30.1-beta.2"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
@@ -50,7 +50,7 @@ lance-namespace = { workspace = true }
|
||||
lance-namespace-impls = { workspace = true }
|
||||
moka = { workspace = true }
|
||||
pin-project = { workspace = true }
|
||||
tokio = { version = "1.23", features = ["rt-multi-thread", "sync"] }
|
||||
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
||||
log.workspace = true
|
||||
async-trait = "0"
|
||||
bytes = "1"
|
||||
@@ -75,7 +75,6 @@ reqwest = { version = "0.12.0", default-features = false, features = [
|
||||
"stream",
|
||||
], optional = true }
|
||||
http = { version = "1", optional = true } # Matching what is in reqwest
|
||||
urlencoding = { version = "2", optional = true }
|
||||
uuid = { version = "1.7.0", features = ["v4", "v5"] }
|
||||
polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
|
||||
polars = { version = ">=0.37,<0.40.0", optional = true }
|
||||
@@ -94,7 +93,6 @@ semver = { workspace = true }
|
||||
anyhow = "1"
|
||||
tempfile = "3.5.0"
|
||||
random_word = { version = "0.4.3", features = ["en"] }
|
||||
tokio = { version = "1.23", features = ["io-util", "macros", "net", "rt-multi-thread", "sync"] }
|
||||
uuid = { version = "1.7.0", features = ["v4"] }
|
||||
walkdir = "2"
|
||||
aws-sdk-dynamodb = { version = "1.55.0" }
|
||||
@@ -131,13 +129,7 @@ huggingface = [
|
||||
"lance-namespace-impls/dir-huggingface",
|
||||
]
|
||||
dynamodb = ["lance/dynamodb", "aws"]
|
||||
remote = [
|
||||
"dep:reqwest",
|
||||
"dep:http",
|
||||
"dep:urlencoding",
|
||||
"lance-namespace-impls/rest",
|
||||
"lance-namespace-impls/rest-adapter",
|
||||
]
|
||||
remote = ["dep:reqwest", "dep:http", "lance-namespace-impls/rest", "lance-namespace-impls/rest-adapter"]
|
||||
fp16kernels = ["lance-linalg/fp16kernels"]
|
||||
s3-test = []
|
||||
bedrock = ["dep:aws-sdk-bedrockruntime"]
|
||||
|
||||
@@ -1,435 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! Lance blob v2 columns store large binary payloads out of line.
|
||||
//!
|
||||
//! Declare a column with [`blob`]. On write, [`crate::table::Table::add`] coerces
|
||||
//! raw `Binary` / `LargeBinary` into the blob struct layout. Queries return
|
||||
//! small descriptors, not bytes.
|
||||
//!
|
||||
//! Blob tables require Lance file format >= 2.2 and stable row ids at create.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::builder::LargeBinaryBuilder;
|
||||
use arrow_array::{Array, LargeBinaryArray, RecordBatch, StructArray, UInt8Array, UInt64Array};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use lance::dataset::{Dataset, WriteParams};
|
||||
use lance_arrow::FieldExt;
|
||||
use lance_core::datatypes::parse_field_path;
|
||||
use lance_encoding::version::LanceFileVersion;
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
|
||||
pub use lance::dataset::BlobFile;
|
||||
|
||||
/// Creates an Arrow field for a Lance blob v2 column.
|
||||
///
|
||||
/// `Struct<data, uri>` with the `lance.blob.v2` marker. Same layout Lance
|
||||
/// expects on write.
|
||||
///
|
||||
/// A blob column may be top-level or nested inside a struct or list. Nested
|
||||
/// blobs are addressed by a dotted path (e.g. `info.blob`) in the read APIs.
|
||||
///
|
||||
/// ```
|
||||
/// use arrow_schema::{DataType, Field, Schema};
|
||||
///
|
||||
/// let schema = Schema::new(vec![
|
||||
/// Field::new("id", DataType::Int64, false),
|
||||
/// lancedb::blob("image", true),
|
||||
/// ]);
|
||||
/// ```
|
||||
pub fn blob(name: impl AsRef<str>, nullable: bool) -> Field {
|
||||
lance::blob::blob_field(name.as_ref(), nullable)
|
||||
}
|
||||
|
||||
/// Returns true if `field` is a blob v2 column.
|
||||
///
|
||||
/// ```
|
||||
/// let field = lancedb::blob("image", true);
|
||||
/// assert!(lancedb::blob::is_blob(&field));
|
||||
/// ```
|
||||
pub fn is_blob(field: &Field) -> bool {
|
||||
field.is_blob_v2()
|
||||
}
|
||||
|
||||
/// Returns true if `field`, or any field nested under it, is a blob v2 column.
|
||||
fn field_tree_has_blob_v2(field: &Field) -> bool {
|
||||
if field.is_blob_v2() {
|
||||
return true;
|
||||
}
|
||||
match field.data_type() {
|
||||
DataType::Struct(children) => children.iter().any(|c| field_tree_has_blob_v2(c)),
|
||||
DataType::List(child) | DataType::LargeList(child) | DataType::FixedSizeList(child, _) => {
|
||||
field_tree_has_blob_v2(child)
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Collects the dotted paths of blob v2 columns under `field`, into `paths`.
|
||||
fn collect_blob_paths(field: &Field, prefix: &str, paths: &mut Vec<String>) {
|
||||
let path = if prefix.is_empty() {
|
||||
field.name().clone()
|
||||
} else {
|
||||
format!("{prefix}.{}", field.name())
|
||||
};
|
||||
if field.is_blob_v2() {
|
||||
paths.push(path);
|
||||
return;
|
||||
}
|
||||
match field.data_type() {
|
||||
DataType::Struct(children) => {
|
||||
for child in children {
|
||||
collect_blob_paths(child, &path, paths);
|
||||
}
|
||||
}
|
||||
DataType::List(child) | DataType::LargeList(child) | DataType::FixedSizeList(child, _) => {
|
||||
collect_blob_paths(child, &path, paths)
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if `schema` declares any blob v2 column, including nested ones.
|
||||
pub(crate) fn has_blob_columns(schema: &Schema) -> bool {
|
||||
schema.fields().iter().any(|f| field_tree_has_blob_v2(f))
|
||||
}
|
||||
|
||||
/// Blob v2 column paths in `schema`, declaration order preserved. Nested blobs
|
||||
/// are dotted paths (e.g. `info.blob`).
|
||||
pub(crate) fn blob_column_names(schema: &Schema) -> Vec<String> {
|
||||
let mut paths = Vec::new();
|
||||
for field in schema.fields() {
|
||||
collect_blob_paths(field, "", &mut paths);
|
||||
}
|
||||
paths
|
||||
}
|
||||
|
||||
/// Bumps storage format to at least [`LanceFileVersion::V2_2`] for blob schemas.
|
||||
pub(crate) fn ensure_blob_storage_version(schema: &Schema, params: &mut WriteParams) {
|
||||
if !has_blob_columns(schema) {
|
||||
return;
|
||||
}
|
||||
|
||||
let resolved = params
|
||||
.data_storage_version
|
||||
.unwrap_or(LanceFileVersion::Stable)
|
||||
.resolve();
|
||||
if resolved < LanceFileVersion::V2_2 {
|
||||
params.data_storage_version = Some(LanceFileVersion::V2_2);
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate that `column` exists and is a blob v2 column.
|
||||
///
|
||||
/// Legacy v1 columns (`lance-encoding:blob`) error with a migration hint.
|
||||
pub(crate) fn ensure_blob_v2_column(
|
||||
schema: &lance_core::datatypes::Schema,
|
||||
column: &str,
|
||||
) -> Result<()> {
|
||||
match schema.field(column) {
|
||||
Some(field) if field.is_blob_v2() => Ok(()),
|
||||
Some(field) if field.is_blob() => Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"column '{column}' is a legacy blob column; blob APIs require blob v2 columns \
|
||||
(ARROW:extension:name = \"lance.blob.v2\")"
|
||||
),
|
||||
}),
|
||||
Some(_) => Err(Error::InvalidInput {
|
||||
message: format!("column '{column}' is not a blob column"),
|
||||
}),
|
||||
None => Err(Error::InvalidInput {
|
||||
message: format!("no column named '{column}' in this table"),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the leaf descriptor `StructArray` for `column` in a descriptor batch.
|
||||
fn leaf_descriptor_struct<'a>(batch: &'a RecordBatch, column: &str) -> Result<&'a StructArray> {
|
||||
let path = parse_field_path(column).map_err(|e| Error::InvalidInput {
|
||||
message: format!("invalid blob column path '{column}': {e}"),
|
||||
})?;
|
||||
let not_struct = || Error::Runtime {
|
||||
message: format!("blob column '{column}' did not read back as a descriptor struct"),
|
||||
};
|
||||
let mut current = batch
|
||||
.column_by_name(&path[0])
|
||||
.and_then(|c| c.as_any().downcast_ref::<StructArray>())
|
||||
.ok_or_else(not_struct)?;
|
||||
for segment in &path[1..] {
|
||||
current = current
|
||||
.column_by_name(segment)
|
||||
.and_then(|c| c.as_any().downcast_ref::<StructArray>())
|
||||
.ok_or_else(not_struct)?;
|
||||
}
|
||||
Ok(current)
|
||||
}
|
||||
|
||||
/// Null rows in `row_ids`, from a descriptor take.
|
||||
///
|
||||
/// Lance `read_blobs` / `take_blobs` skip null rows (`kind == 0 && position == 0 && size == 0`).
|
||||
/// TODO(lance): aligned read API would drop this pass.
|
||||
async fn blob_null_mask(
|
||||
dataset: &Arc<Dataset>,
|
||||
column: &str,
|
||||
row_ids: &[u64],
|
||||
) -> Result<Vec<bool>> {
|
||||
let projection = dataset.schema().project(&[column])?;
|
||||
let descriptors = dataset.take_builder(row_ids, projection)?.execute().await?;
|
||||
if descriptors.num_rows() != row_ids.len() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"blob take for column '{column}' requested {} row ids but only {} exist in the \
|
||||
table; pass row ids collected from this table",
|
||||
row_ids.len(),
|
||||
descriptors.num_rows()
|
||||
),
|
||||
});
|
||||
}
|
||||
let descriptor_struct = leaf_descriptor_struct(&descriptors, column)?;
|
||||
let child = |name: &str| {
|
||||
descriptor_struct
|
||||
.column_by_name(name)
|
||||
.ok_or_else(|| Error::Runtime {
|
||||
message: format!("blob descriptor for '{column}' is missing the '{name}' field"),
|
||||
})
|
||||
};
|
||||
let kinds = child("kind")?
|
||||
.as_any()
|
||||
.downcast_ref::<UInt8Array>()
|
||||
.ok_or_else(|| Error::Runtime {
|
||||
message: format!("blob descriptor 'kind' for '{column}' is not a UInt8 array"),
|
||||
})?;
|
||||
let positions = child("position")?
|
||||
.as_any()
|
||||
.downcast_ref::<UInt64Array>()
|
||||
.ok_or_else(|| Error::Runtime {
|
||||
message: format!("blob descriptor 'position' for '{column}' is not a UInt64 array"),
|
||||
})?;
|
||||
let sizes = child("size")?
|
||||
.as_any()
|
||||
.downcast_ref::<UInt64Array>()
|
||||
.ok_or_else(|| Error::Runtime {
|
||||
message: format!("blob descriptor 'size' for '{column}' is not a UInt64 array"),
|
||||
})?;
|
||||
|
||||
// Match Lance `collect_blob_entries_v2` skip condition (`BlobKind::Inline` == 0).
|
||||
Ok((0..descriptor_struct.len())
|
||||
.map(|i| {
|
||||
descriptor_struct.is_null(i)
|
||||
|| kinds.is_null(i)
|
||||
|| (kinds.value(i) == 0 && positions.value(i) == 0 && sizes.value(i) == 0)
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn non_null_row_ids(row_ids: &[u64], null_mask: &[bool]) -> Vec<u64> {
|
||||
row_ids
|
||||
.iter()
|
||||
.zip(null_mask)
|
||||
.filter_map(|(row_id, is_null)| (!is_null).then_some(*row_id))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Materialize blob bytes for `row_ids` (same length and order, nulls preserved).
|
||||
pub(crate) async fn take_blobs_aligned(
|
||||
dataset: &Arc<Dataset>,
|
||||
column: &str,
|
||||
row_ids: &[u64],
|
||||
) -> Result<LargeBinaryArray> {
|
||||
ensure_blob_v2_column(dataset.schema(), column)?;
|
||||
if row_ids.is_empty() {
|
||||
return Ok(LargeBinaryBuilder::new().finish());
|
||||
}
|
||||
|
||||
let null_mask = blob_null_mask(dataset, column, row_ids).await?;
|
||||
let non_null_row_ids = non_null_row_ids(row_ids, &null_mask);
|
||||
let non_null_count = non_null_row_ids.len();
|
||||
let payloads = if non_null_count == 0 {
|
||||
Vec::new()
|
||||
} else {
|
||||
dataset
|
||||
.read_blobs(column)?
|
||||
.with_row_ids(non_null_row_ids)
|
||||
.preserve_order(true)
|
||||
.execute()
|
||||
.await?
|
||||
};
|
||||
|
||||
if payloads.len() != non_null_count {
|
||||
return Err(Error::Runtime {
|
||||
message: format!(
|
||||
"blob read for column '{column}' returned {} payloads for {} non-null rows",
|
||||
payloads.len(),
|
||||
non_null_count
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
let mut builder = LargeBinaryBuilder::new();
|
||||
let mut payload_idx = 0;
|
||||
for is_null in &null_mask {
|
||||
if *is_null {
|
||||
builder.append_null();
|
||||
} else {
|
||||
builder.append_value(payloads[payload_idx].data.as_ref());
|
||||
payload_idx += 1;
|
||||
}
|
||||
}
|
||||
Ok(builder.finish())
|
||||
}
|
||||
|
||||
/// Open lazy [`BlobFile`] handles for `row_ids` (same length and order, nulls as `None`).
|
||||
pub(crate) async fn take_blob_files_aligned(
|
||||
dataset: &Arc<Dataset>,
|
||||
column: &str,
|
||||
row_ids: &[u64],
|
||||
) -> Result<Vec<Option<BlobFile>>> {
|
||||
ensure_blob_v2_column(dataset.schema(), column)?;
|
||||
if row_ids.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let null_mask = blob_null_mask(dataset, column, row_ids).await?;
|
||||
let non_null_row_ids = non_null_row_ids(row_ids, &null_mask);
|
||||
let handles = if non_null_row_ids.is_empty() {
|
||||
Vec::new()
|
||||
} else {
|
||||
dataset.take_blobs(&non_null_row_ids, column).await?
|
||||
};
|
||||
if handles.len() != non_null_row_ids.len() {
|
||||
return Err(Error::Runtime {
|
||||
message: format!(
|
||||
"blob take for column '{column}' returned {} handles for {} non-null rows",
|
||||
handles.len(),
|
||||
non_null_row_ids.len()
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
let mut handles = handles.into_iter();
|
||||
Ok(null_mask
|
||||
.iter()
|
||||
.map(|is_null| {
|
||||
if *is_null {
|
||||
None
|
||||
} else {
|
||||
Some(handles.next().unwrap())
|
||||
}
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use arrow_schema::DataType;
|
||||
use lance_arrow::ARROW_EXT_NAME_KEY;
|
||||
|
||||
fn blob_schema() -> Schema {
|
||||
Schema::new(vec![
|
||||
Field::new("id", DataType::Int64, false),
|
||||
blob("image", true),
|
||||
])
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn blob_field_carries_v2_extension_marker() {
|
||||
let field = blob("image", true);
|
||||
assert_eq!(
|
||||
field.metadata().get(ARROW_EXT_NAME_KEY).map(String::as_str),
|
||||
Some("lance.blob.v2")
|
||||
);
|
||||
assert!(matches!(field.data_type(), DataType::Struct(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn has_blob_columns_detects_blob_fields() {
|
||||
assert!(has_blob_columns(&blob_schema()));
|
||||
let plain = Schema::new(vec![Field::new("id", DataType::Int64, false)]);
|
||||
assert!(!has_blob_columns(&plain));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn storage_version_bumps_to_v2_2() {
|
||||
let mut params = WriteParams::default();
|
||||
ensure_blob_storage_version(&blob_schema(), &mut params);
|
||||
assert_eq!(
|
||||
params.data_storage_version.unwrap().resolve(),
|
||||
LanceFileVersion::V2_2
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn storage_version_overrides_lower_explicit_version() {
|
||||
let mut params = WriteParams {
|
||||
data_storage_version: Some(LanceFileVersion::V2_0),
|
||||
..Default::default()
|
||||
};
|
||||
ensure_blob_storage_version(&blob_schema(), &mut params);
|
||||
assert_eq!(
|
||||
params.data_storage_version.unwrap().resolve(),
|
||||
LanceFileVersion::V2_2
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn storage_version_keeps_higher_explicit_version() {
|
||||
let mut params = WriteParams {
|
||||
data_storage_version: Some(LanceFileVersion::V2_3),
|
||||
..Default::default()
|
||||
};
|
||||
ensure_blob_storage_version(&blob_schema(), &mut params);
|
||||
assert_eq!(params.data_storage_version.unwrap(), LanceFileVersion::V2_3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn legacy_v1_blob_column_is_rejected_with_migration_hint() {
|
||||
let legacy = Field::new("image", DataType::LargeBinary, true).with_metadata(
|
||||
std::collections::HashMap::from([(
|
||||
"lance-encoding:blob".to_string(),
|
||||
"true".to_string(),
|
||||
)]),
|
||||
);
|
||||
let arrow_schema = Schema::new(vec![legacy]);
|
||||
let lance_schema = lance_core::datatypes::Schema::try_from(&arrow_schema).unwrap();
|
||||
|
||||
let err = ensure_blob_v2_column(&lance_schema, "image").unwrap_err();
|
||||
assert!(matches!(err, Error::InvalidInput { .. }));
|
||||
assert!(err.to_string().contains("legacy blob column"));
|
||||
assert!(err.to_string().contains("lance.blob.v2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_blob_and_unknown_columns_are_rejected_by_name() {
|
||||
let arrow_schema = Schema::new(vec![Field::new("id", DataType::Int64, false)]);
|
||||
let lance_schema = lance_core::datatypes::Schema::try_from(&arrow_schema).unwrap();
|
||||
|
||||
let err = ensure_blob_v2_column(&lance_schema, "id").unwrap_err();
|
||||
assert!(err.to_string().contains("'id' is not a blob column"));
|
||||
|
||||
let err = ensure_blob_v2_column(&lance_schema, "missing").unwrap_err();
|
||||
assert!(err.to_string().contains("no column named 'missing'"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn blob_column_names_includes_nested_path() {
|
||||
let blob_field = blob("blob", true);
|
||||
let info = Field::new(
|
||||
"info",
|
||||
DataType::Struct(vec![Field::new("name", DataType::Utf8, false), blob_field].into()),
|
||||
true,
|
||||
);
|
||||
let schema = Schema::new(vec![Field::new("id", DataType::Int64, false), info]);
|
||||
assert_eq!(blob_column_names(&schema), vec!["info.blob"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn storage_version_noop_without_blob_columns() {
|
||||
let schema = Schema::new(vec![Field::new("id", DataType::Int64, false)]);
|
||||
let mut params = WriteParams::default();
|
||||
ensure_blob_storage_version(&schema, &mut params);
|
||||
assert!(params.data_storage_version.is_none());
|
||||
}
|
||||
}
|
||||
@@ -23,10 +23,8 @@ use crate::connection::create_table::CreateTableBuilder;
|
||||
use crate::data::scannable::Scannable;
|
||||
use crate::database::listing::ListingDatabase;
|
||||
use crate::database::{
|
||||
CloneTableRequest, CreateFunctionRequest, CreateMaterializedViewRequest, Database,
|
||||
DatabaseOptions, FunctionInfo, JobErrorInfo, JobHistoryInfo, JobInfo, MaterializedViewInfo,
|
||||
MvRefreshPlan, OpenTableRequest, ReadConsistency, RefreshMaterializedViewRequest,
|
||||
TableLineageRequest, TableNamesRequest,
|
||||
CloneTableRequest, Database, DatabaseOptions, OpenTableRequest, ReadConsistency,
|
||||
TableNamesRequest,
|
||||
};
|
||||
use crate::embeddings::{EmbeddingRegistry, MemoryRegistry};
|
||||
use crate::error::{Error, Result};
|
||||
@@ -490,113 +488,6 @@ impl Connection {
|
||||
)
|
||||
}
|
||||
|
||||
// -- Derived compute: functions, materialized views, jobs -------------
|
||||
// Server-backed features (LanceDB Enterprise / Cloud); local
|
||||
// databases return NotSupported for now.
|
||||
|
||||
/// Register a UDF (CREATE FUNCTION).
|
||||
pub async fn create_function(&self, request: CreateFunctionRequest) -> Result<()> {
|
||||
self.internal.create_function(request).await
|
||||
}
|
||||
|
||||
/// List registered functions (SHOW FUNCTIONS).
|
||||
pub async fn list_functions(&self) -> Result<Vec<FunctionInfo>> {
|
||||
self.internal.list_functions().await
|
||||
}
|
||||
|
||||
/// Drop a registered function (DROP FUNCTION).
|
||||
pub async fn drop_function(&self, name: &str) -> Result<()> {
|
||||
self.internal.drop_function(name).await
|
||||
}
|
||||
|
||||
/// Create a materialized view (CREATE MATERIALIZED VIEW). Returns
|
||||
/// the initial-population job id, absent when `with_no_data`.
|
||||
pub async fn create_materialized_view(
|
||||
&self,
|
||||
request: CreateMaterializedViewRequest,
|
||||
) -> Result<Option<String>> {
|
||||
self.internal.create_materialized_view(request).await
|
||||
}
|
||||
|
||||
/// Refresh a materialized view; returns the refresh job id.
|
||||
pub async fn refresh_materialized_view(
|
||||
&self,
|
||||
request: RefreshMaterializedViewRequest,
|
||||
) -> Result<String> {
|
||||
self.internal.refresh_materialized_view(request).await
|
||||
}
|
||||
|
||||
/// Derived-compute lineage of a table/view (or column), as server-defined
|
||||
/// JSON. Read-only.
|
||||
pub async fn table_lineage(&self, request: TableLineageRequest) -> Result<String> {
|
||||
self.internal.table_lineage(request).await
|
||||
}
|
||||
|
||||
/// Plan a materialized-view refresh without submitting work
|
||||
/// (EXPLAIN REFRESH).
|
||||
pub async fn explain_refresh_materialized_view(
|
||||
&self,
|
||||
name: &str,
|
||||
full: bool,
|
||||
src_version: Option<u64>,
|
||||
) -> Result<MvRefreshPlan> {
|
||||
self.internal
|
||||
.explain_refresh_materialized_view(name, full, src_version)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Update a materialized view's options (ALTER MATERIALIZED VIEW).
|
||||
pub async fn alter_materialized_view(&self, name: &str, auto_refresh: bool) -> Result<()> {
|
||||
self.internal
|
||||
.alter_materialized_view(name, auto_refresh)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Drop a materialized view definition (DROP MATERIALIZED VIEW).
|
||||
pub async fn drop_materialized_view(&self, name: &str) -> Result<()> {
|
||||
self.internal.drop_materialized_view(name).await
|
||||
}
|
||||
|
||||
/// List registered materialized view definitions.
|
||||
pub async fn list_materialized_views(&self) -> Result<Vec<MaterializedViewInfo>> {
|
||||
self.internal.list_materialized_views().await
|
||||
}
|
||||
|
||||
/// List inflight server-side jobs across the database's tables.
|
||||
pub async fn list_jobs(&self) -> Result<Vec<JobInfo>> {
|
||||
self.internal.list_jobs().await
|
||||
}
|
||||
|
||||
/// Cancel an inflight server-side job by id. Returns true if a
|
||||
/// matching inflight job was flagged for cancellation.
|
||||
pub async fn cancel_job(&self, job_id: &str) -> Result<bool> {
|
||||
self.internal.cancel_job(job_id).await
|
||||
}
|
||||
|
||||
/// Look up a single server-side job by id -- the `wait()`/status poll path.
|
||||
/// `table_hint` (the job's table) enables an O(1) server-side lookup; `None`
|
||||
/// scans the database's active jobs. A `None` result means unknown / not
|
||||
/// active.
|
||||
pub async fn get_job(&self, job_id: &str, table_hint: Option<&str>) -> Result<Option<JobInfo>> {
|
||||
self.internal.get_job(job_id, table_hint).await
|
||||
}
|
||||
|
||||
/// Durable job history (SHOW JOB HISTORY) across the database's tables.
|
||||
/// Pass `job_id` to narrow to a single job.
|
||||
pub async fn job_history(&self, job_id: Option<&str>) -> Result<Vec<JobHistoryInfo>> {
|
||||
self.internal.job_history(job_id).await
|
||||
}
|
||||
|
||||
/// Per-row UDF errors (SHOW ERRORS) across the database's tables, optionally
|
||||
/// filtered by `job_id` and/or `table`.
|
||||
pub async fn errors(
|
||||
&self,
|
||||
job_id: Option<&str>,
|
||||
table: Option<&str>,
|
||||
) -> Result<Vec<JobErrorInfo>> {
|
||||
self.internal.errors(job_id, table).await
|
||||
}
|
||||
|
||||
/// Rename a table in the database.
|
||||
///
|
||||
/// This is only supported in LanceDB Cloud.
|
||||
@@ -685,9 +576,6 @@ impl Connection {
|
||||
/// For LanceNamespaceDatabase, it is the underlying LanceNamespace.
|
||||
/// For ListingDatabase, it is the equivalent DirectoryNamespace.
|
||||
/// For RemoteDatabase, it is the equivalent RestNamespace.
|
||||
///
|
||||
/// Remote connections using dynamic headers forward them through the
|
||||
/// namespace client's per-request context provider.
|
||||
pub async fn namespace_client(&self) -> Result<Arc<dyn lance_namespace::LanceNamespace>> {
|
||||
self.internal.namespace_client().await
|
||||
}
|
||||
@@ -696,9 +584,6 @@ impl Connection {
|
||||
/// Returns (impl_type, properties) where:
|
||||
/// - impl_type: "dir" for DirectoryNamespace, "rest" for RestNamespace
|
||||
/// - properties: configuration properties for the namespace
|
||||
///
|
||||
/// Remote connections using dynamic headers cannot be exported because the
|
||||
/// namespace client config only carries static headers.
|
||||
pub async fn namespace_client_config(
|
||||
&self,
|
||||
) -> Result<(String, std::collections::HashMap<String, String>)> {
|
||||
@@ -776,8 +661,6 @@ pub struct ConnectRequest {
|
||||
pub struct ConnectBuilder {
|
||||
request: ConnectRequest,
|
||||
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
|
||||
#[cfg(feature = "remote")]
|
||||
oauth_config: Option<crate::remote::OAuthConfig>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
@@ -799,8 +682,6 @@ impl ConnectBuilder {
|
||||
session: None,
|
||||
},
|
||||
embedding_registry: None,
|
||||
#[cfg(feature = "remote")]
|
||||
oauth_config: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -889,19 +770,6 @@ impl ConnectBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
/// Configure OAuth authentication for LanceDB Cloud/Enterprise.
|
||||
///
|
||||
/// This creates an [`OAuthHeaderProvider`](crate::remote::OAuthHeaderProvider)
|
||||
/// from the given config and sets it as the header provider. OAuth cannot
|
||||
/// be combined with an API key or another header provider.
|
||||
///
|
||||
/// Token acquisition and refresh are handled in Rust.
|
||||
#[cfg(feature = "remote")]
|
||||
pub fn oauth_config(mut self, config: crate::remote::OAuthConfig) -> Self {
|
||||
self.oauth_config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
/// Provide a custom [`EmbeddingRegistry`] to use for this connection.
|
||||
pub fn embedding_registry(mut self, registry: Arc<dyn EmbeddingRegistry>) -> Self {
|
||||
self.embedding_registry = Some(registry);
|
||||
@@ -1047,40 +915,9 @@ impl ConnectBuilder {
|
||||
let region = options.region.ok_or_else(|| Error::InvalidInput {
|
||||
message: "A region is required when connecting to LanceDb Cloud".to_string(),
|
||||
})?;
|
||||
let api_key = match (&self.oauth_config, &options.api_key) {
|
||||
(Some(_), None) => String::new(),
|
||||
(Some(_), Some(_)) => {
|
||||
return Err(Error::InvalidInput {
|
||||
message:
|
||||
"api_key and oauth_config cannot both be set when connecting to LanceDb Cloud"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
(None, Some(key)) => key.clone(),
|
||||
(None, None) => {
|
||||
return Err(Error::InvalidInput {
|
||||
message:
|
||||
"An api_key or oauth_config is required when connecting to LanceDb Cloud"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
if self.oauth_config.is_some() && self.request.client_config.header_provider.is_some() {
|
||||
return Err(Error::InvalidInput {
|
||||
message:
|
||||
"oauth_config and client_config.header_provider cannot both be set when connecting to LanceDb Cloud"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let mut client_config = self.request.client_config;
|
||||
|
||||
if let Some(oauth_config) = self.oauth_config {
|
||||
let provider = crate::remote::OAuthHeaderProvider::new(oauth_config)?;
|
||||
client_config.header_provider =
|
||||
Some(Arc::new(provider) as Arc<dyn crate::remote::HeaderProvider>);
|
||||
}
|
||||
let api_key = options.api_key.ok_or_else(|| Error::InvalidInput {
|
||||
message: "An api_key is required when connecting to LanceDb Cloud".to_string(),
|
||||
})?;
|
||||
|
||||
let storage_options = StorageOptions(options.storage_options.clone());
|
||||
let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new(
|
||||
@@ -1088,7 +925,7 @@ impl ConnectBuilder {
|
||||
&api_key,
|
||||
®ion,
|
||||
options.host_override,
|
||||
client_config,
|
||||
self.request.client_config,
|
||||
storage_options.into(),
|
||||
self.request.read_consistency_interval,
|
||||
)?);
|
||||
@@ -1397,83 +1234,6 @@ mod tests {
|
||||
assert_eq!(Some(&"EXPLICIT-VALUE".to_string()), options.get(opts_key));
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
#[tokio::test]
|
||||
async fn test_connect_rejects_api_key_with_oauth_config() {
|
||||
let oauth_config = crate::remote::OAuthConfig {
|
||||
issuer_url: "https://issuer.example.com".to_string(),
|
||||
client_id: "client-id".to_string(),
|
||||
client_secret: Some("secret".to_string()),
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: crate::remote::OAuthFlow::ClientCredentials,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
|
||||
let result = ConnectBuilder::new("db://my-container/my-prefix")
|
||||
.region("us-east-1")
|
||||
.api_key("my-api-key")
|
||||
.oauth_config(oauth_config)
|
||||
.execute()
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Err(Error::InvalidInput { message })
|
||||
if message
|
||||
== "api_key and oauth_config cannot both be set when connecting to LanceDb Cloud" =>
|
||||
{}
|
||||
Err(err) => panic!("expected InvalidInput, got {err:?}"),
|
||||
Ok(_) => panic!("expected api_key and oauth_config to be rejected"),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
#[tokio::test]
|
||||
async fn test_connect_rejects_header_provider_with_oauth_config() {
|
||||
#[derive(Debug)]
|
||||
struct TestHeaderProvider;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl crate::remote::HeaderProvider for TestHeaderProvider {
|
||||
async fn get_headers(&self) -> Result<HashMap<String, String>> {
|
||||
Ok(HashMap::from([(
|
||||
"authorization".to_string(),
|
||||
"Bearer token".to_string(),
|
||||
)]))
|
||||
}
|
||||
}
|
||||
|
||||
let oauth_config = crate::remote::OAuthConfig {
|
||||
issuer_url: "https://issuer.example.com".to_string(),
|
||||
client_id: "client-id".to_string(),
|
||||
client_secret: Some("secret".to_string()),
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: crate::remote::OAuthFlow::ClientCredentials,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
let client_config = crate::remote::ClientConfig {
|
||||
header_provider: Some(
|
||||
Arc::new(TestHeaderProvider) as Arc<dyn crate::remote::HeaderProvider>
|
||||
),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = ConnectBuilder::new("db://my-container/my-prefix")
|
||||
.region("us-east-1")
|
||||
.client_config(client_config)
|
||||
.oauth_config(oauth_config)
|
||||
.execute()
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Err(Error::InvalidInput { message })
|
||||
if message
|
||||
== "oauth_config and client_config.header_provider cannot both be set when connecting to LanceDb Cloud" =>
|
||||
{}
|
||||
Err(err) => panic!("expected InvalidInput, got {err:?}"),
|
||||
Ok(_) => panic!("expected header_provider and oauth_config to be rejected"),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
#[tokio::test]
|
||||
async fn test_connect_relative() {
|
||||
|
||||
@@ -27,12 +27,11 @@ use lance_namespace::models::{
|
||||
};
|
||||
|
||||
use crate::data::scannable::Scannable;
|
||||
use crate::error::{Error, Result};
|
||||
use crate::error::Result;
|
||||
use crate::table::{BaseTable, WriteOptions};
|
||||
|
||||
pub mod listing;
|
||||
pub mod namespace;
|
||||
pub(crate) mod read_freshness;
|
||||
|
||||
pub trait DatabaseOptions {
|
||||
fn serialize_into_map(&self, map: &mut HashMap<String, String>);
|
||||
@@ -200,205 +199,6 @@ pub enum ReadConsistency {
|
||||
Strong,
|
||||
}
|
||||
|
||||
/// A request to register a UDF (CREATE FUNCTION).
|
||||
///
|
||||
/// Functions are first-class database objects, decoupled from any
|
||||
/// column; computed columns and materialized views reference them by
|
||||
/// name. Server-backed feature (LanceDB Enterprise / Cloud).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CreateFunctionRequest {
|
||||
/// Function name.
|
||||
pub name: String,
|
||||
/// Implementation language (currently "python").
|
||||
pub language: String,
|
||||
/// SQL return type, e.g. `FLOAT`, `FLOAT[1536]`,
|
||||
/// `STRUCT(a FLOAT, b VARCHAR)`, `TABLE(chunk VARCHAR, idx INT)`.
|
||||
pub return_type: String,
|
||||
/// Function body: source text, or base64 cloudpickle bytes when
|
||||
/// `options["body_format"] = "cloudpickle"`.
|
||||
pub body: String,
|
||||
/// Options: input_columns, pip, num_gpus, batch_size, timeout,
|
||||
/// error_policy, docker_image, body_format, ...
|
||||
pub options: HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// A registered function, as returned by `list_functions`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FunctionInfo {
|
||||
pub name: String,
|
||||
pub language: String,
|
||||
pub return_type: String,
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
/// A request to create a materialized view (CREATE MATERIALIZED VIEW).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CreateMaterializedViewRequest {
|
||||
/// View name.
|
||||
pub name: String,
|
||||
/// The view's SELECT statement, e.g.
|
||||
/// `SELECT id, embed(body) AS vec FROM articles WHERE id > 1`.
|
||||
/// Bare columns project through; function-call columns compute via
|
||||
/// registered UDFs (a RETURNS TABLE function makes a row-expanding
|
||||
/// chunker view).
|
||||
pub query: String,
|
||||
/// Refresh automatically when the source table changes.
|
||||
pub auto_refresh: bool,
|
||||
/// Register the definition only; skip the initial population.
|
||||
pub with_no_data: bool,
|
||||
/// Optional source column to partition the view's table function on. If the
|
||||
/// column has an IVF vector index the server partitions by its clusters
|
||||
/// (image-dedup style); otherwise it groups by distinct value.
|
||||
pub partition_by: Option<String>,
|
||||
}
|
||||
|
||||
impl CreateMaterializedViewRequest {
|
||||
pub fn new(name: impl Into<String>, query: impl Into<String>) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
query: query.into(),
|
||||
auto_refresh: false,
|
||||
with_no_data: false,
|
||||
partition_by: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A request to refresh a materialized view.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RefreshMaterializedViewRequest {
|
||||
/// View name.
|
||||
pub name: String,
|
||||
/// Force a full rebuild (recompute and replace every row) instead of the
|
||||
/// default incremental refresh.
|
||||
pub full: bool,
|
||||
/// Pin the refresh to a source-table version; latest when absent.
|
||||
pub src_version: Option<u64>,
|
||||
/// Initial worker count.
|
||||
pub num_workers: Option<u32>,
|
||||
/// Elastic worker ceiling.
|
||||
pub max_workers: Option<u32>,
|
||||
}
|
||||
|
||||
/// A request for the derived-compute lineage of a table/view (or one of its
|
||||
/// columns). The response is server-defined lineage JSON, returned opaque so
|
||||
/// this client need not model the server's lineage schema.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct TableLineageRequest {
|
||||
/// Table or view name.
|
||||
pub name: String,
|
||||
/// Column for column-level lineage; whole table/view when absent.
|
||||
pub column: Option<String>,
|
||||
/// "upstream" | "downstream" | "both" (server default when absent).
|
||||
pub direction: Option<String>,
|
||||
/// Column-hops to walk; transitive when absent.
|
||||
pub depth: Option<u32>,
|
||||
}
|
||||
|
||||
impl RefreshMaterializedViewRequest {
|
||||
pub fn new(name: impl Into<String>) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
full: false,
|
||||
src_version: None,
|
||||
num_workers: None,
|
||||
max_workers: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A registered materialized view definition, as returned by
|
||||
/// `list_materialized_views`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MaterializedViewInfo {
|
||||
pub name: String,
|
||||
pub source_table: String,
|
||||
/// Source columns projected through.
|
||||
pub projection: Vec<String>,
|
||||
/// `alias=expression` per UDF-computed column.
|
||||
pub udf_columns: Vec<String>,
|
||||
pub filter: Option<String>,
|
||||
pub auto_refresh: bool,
|
||||
}
|
||||
|
||||
/// A row from `list_jobs`: one inflight server-side job (index build,
|
||||
/// compaction, column refresh, view refresh, ...).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct JobInfo {
|
||||
pub table: String,
|
||||
pub job_id: String,
|
||||
pub job_type: String,
|
||||
/// Lifecycle state: "running", "cancelling", or "stale".
|
||||
pub state: String,
|
||||
pub column: Option<String>,
|
||||
pub age_seconds: Option<i64>,
|
||||
pub command: Option<String>,
|
||||
pub units_done: Option<i64>,
|
||||
pub units_total: Option<i64>,
|
||||
/// Whether the job's final commit has completed (output visible).
|
||||
pub committed: bool,
|
||||
pub rows_skipped: u64,
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// A row from `job_history`: one durable, completed/terminal server-side job
|
||||
/// record (SHOW JOB HISTORY), read from a table's `_job_history` store. Unlike
|
||||
/// `JobInfo` (live, inflight jobs) this carries created/updated/completed
|
||||
/// timestamps and the lifecycle event log.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct JobHistoryInfo {
|
||||
pub table: String,
|
||||
pub job_id: String,
|
||||
pub job_type: String,
|
||||
pub state: String,
|
||||
pub column: Option<String>,
|
||||
pub created_ms: i64,
|
||||
pub updated_ms: i64,
|
||||
pub completed_ms: Option<i64>,
|
||||
pub rows_processed: Option<i64>,
|
||||
pub rows_skipped: Option<i64>,
|
||||
pub error: Option<String>,
|
||||
/// Newline-joined lifecycle event log, oldest first.
|
||||
pub events: Option<String>,
|
||||
}
|
||||
|
||||
/// A row from `errors`: one per-row UDF failure recorded by `error_policy=skip`
|
||||
/// (SHOW ERRORS).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct JobErrorInfo {
|
||||
pub job_id: String,
|
||||
pub table: String,
|
||||
pub column: String,
|
||||
pub error_type: String,
|
||||
pub error_message: String,
|
||||
pub fragment_id: Option<i64>,
|
||||
pub source_row_id: Option<i64>,
|
||||
pub table_version: Option<i64>,
|
||||
pub age_seconds: Option<i64>,
|
||||
}
|
||||
|
||||
/// The plan a `REFRESH MATERIALIZED VIEW` would execute, as returned by
|
||||
/// `explain_refresh_materialized_view` (EXPLAIN REFRESH). No work is run.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MvRefreshPlan {
|
||||
pub table_name: String,
|
||||
/// Whether a refresh would do anything (rebuild or non-empty units).
|
||||
pub has_work: bool,
|
||||
pub source_version: u64,
|
||||
pub last_refreshed_version: Option<u64>,
|
||||
pub full_refresh: bool,
|
||||
/// Source changed non-append-only since the last refresh -> rebuild.
|
||||
pub rebuild: bool,
|
||||
/// Number of row-range work units the refresh would process.
|
||||
pub units_total: u64,
|
||||
}
|
||||
|
||||
fn not_supported<T>(what: &str) -> Result<T> {
|
||||
Err(Error::NotSupported {
|
||||
message: format!("{} is not supported by this database", what),
|
||||
})
|
||||
}
|
||||
|
||||
/// The `Database` trait defines the interface for database implementations.
|
||||
///
|
||||
/// A database is responsible for managing tables and their metadata.
|
||||
@@ -444,99 +244,6 @@ pub trait Database:
|
||||
///
|
||||
/// See [`CloneTableRequest`] for detailed documentation and examples.
|
||||
async fn clone_table(&self, request: CloneTableRequest) -> Result<Arc<dyn BaseTable>>;
|
||||
|
||||
// -- Derived compute: functions, materialized views, jobs -------------
|
||||
//
|
||||
// Server-backed features (LanceDB Enterprise / Cloud). The defaults
|
||||
// return NotSupported; the remote database overrides them. Local
|
||||
// single-node implementations are planned.
|
||||
|
||||
/// Register a UDF (CREATE FUNCTION).
|
||||
async fn create_function(&self, _request: CreateFunctionRequest) -> Result<()> {
|
||||
not_supported("create_function")
|
||||
}
|
||||
/// List registered functions (SHOW FUNCTIONS).
|
||||
async fn list_functions(&self) -> Result<Vec<FunctionInfo>> {
|
||||
not_supported("list_functions")
|
||||
}
|
||||
/// Drop a registered function (DROP FUNCTION).
|
||||
async fn drop_function(&self, _name: &str) -> Result<()> {
|
||||
not_supported("drop_function")
|
||||
}
|
||||
/// Create a materialized view (CREATE MATERIALIZED VIEW). Returns
|
||||
/// the initial-population job id, absent when `with_no_data`.
|
||||
async fn create_materialized_view(
|
||||
&self,
|
||||
_request: CreateMaterializedViewRequest,
|
||||
) -> Result<Option<String>> {
|
||||
not_supported("create_materialized_view")
|
||||
}
|
||||
/// Refresh a materialized view; returns the refresh job id.
|
||||
async fn refresh_materialized_view(
|
||||
&self,
|
||||
_request: RefreshMaterializedViewRequest,
|
||||
) -> Result<String> {
|
||||
not_supported("refresh_materialized_view")
|
||||
}
|
||||
/// Derived-compute lineage of a table/view (or column), as server-defined
|
||||
/// JSON. Read-only.
|
||||
async fn table_lineage(&self, _request: TableLineageRequest) -> Result<String> {
|
||||
not_supported("table_lineage")
|
||||
}
|
||||
/// Plan a materialized-view refresh without submitting work
|
||||
/// (EXPLAIN REFRESH). `full` plans a full rebuild (incremental
|
||||
/// planning requires stable row IDs on the source).
|
||||
async fn explain_refresh_materialized_view(
|
||||
&self,
|
||||
_name: &str,
|
||||
_full: bool,
|
||||
_src_version: Option<u64>,
|
||||
) -> Result<MvRefreshPlan> {
|
||||
not_supported("explain_refresh_materialized_view")
|
||||
}
|
||||
/// Update a materialized view's options (ALTER MATERIALIZED VIEW).
|
||||
async fn alter_materialized_view(&self, _name: &str, _auto_refresh: bool) -> Result<()> {
|
||||
not_supported("alter_materialized_view")
|
||||
}
|
||||
/// Drop a materialized view definition (DROP MATERIALIZED VIEW).
|
||||
async fn drop_materialized_view(&self, _name: &str) -> Result<()> {
|
||||
not_supported("drop_materialized_view")
|
||||
}
|
||||
/// List registered materialized view definitions.
|
||||
async fn list_materialized_views(&self) -> Result<Vec<MaterializedViewInfo>> {
|
||||
not_supported("list_materialized_views")
|
||||
}
|
||||
/// List inflight server-side jobs across the database's tables.
|
||||
async fn list_jobs(&self) -> Result<Vec<JobInfo>> {
|
||||
not_supported("list_jobs")
|
||||
}
|
||||
/// Cancel an inflight server-side job by id. Returns true if a
|
||||
/// matching inflight job was found and flagged for cancellation,
|
||||
/// false if none was inflight (best-effort, like SQL `CANCEL JOB`).
|
||||
async fn cancel_job(&self, _job_id: &str) -> Result<bool> {
|
||||
not_supported("cancel_job")
|
||||
}
|
||||
/// Point-access for a single job by id -- the `wait()`/status poll path.
|
||||
/// `table_hint` (the job's table, which `wait()` callers know) enables an
|
||||
/// O(1) server-side lookup. `None` if the job is unknown or not active.
|
||||
async fn get_job(&self, _job_id: &str, _table_hint: Option<&str>) -> Result<Option<JobInfo>> {
|
||||
not_supported("get_job")
|
||||
}
|
||||
/// Durable job history (SHOW JOB HISTORY) across the database's tables,
|
||||
/// optionally narrowed to a single `job_id`.
|
||||
async fn job_history(&self, _job_id: Option<&str>) -> Result<Vec<JobHistoryInfo>> {
|
||||
not_supported("job_history")
|
||||
}
|
||||
/// Per-row UDF errors (SHOW ERRORS) recorded by `error_policy=skip` across
|
||||
/// the database's tables, optionally filtered by `job_id` and/or `table`.
|
||||
async fn errors(
|
||||
&self,
|
||||
_job_id: Option<&str>,
|
||||
_table: Option<&str>,
|
||||
) -> Result<Vec<JobErrorInfo>> {
|
||||
not_supported("errors")
|
||||
}
|
||||
|
||||
/// Open a table in the database
|
||||
async fn open_table(&self, request: OpenTableRequest) -> Result<Arc<dyn BaseTable>>;
|
||||
/// Rename a table in the database
|
||||
|
||||
@@ -18,7 +18,6 @@ use lance_table::io::commit::commit_handler_from_url;
|
||||
use object_store::local::LocalFileSystem;
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::blob::{ensure_blob_storage_version, has_blob_columns};
|
||||
use crate::connection::ConnectRequest;
|
||||
use crate::database::ReadConsistency;
|
||||
use crate::database::namespace::LanceNamespaceDatabase;
|
||||
@@ -839,16 +838,13 @@ impl ListingDatabase {
|
||||
write_params.enable_v2_manifest_paths = enable_v2_manifest_paths;
|
||||
}
|
||||
|
||||
let data_schema = request.data.arrow_schema();
|
||||
if let Some(enable_stable_row_ids) = stable_row_ids_override
|
||||
.or(self.new_table_config.enable_stable_row_ids)
|
||||
.or(has_blob_columns(&data_schema).then_some(true))
|
||||
// Apply enable_stable_row_ids: table-level override takes precedence over connection config
|
||||
if let Some(enable_stable_row_ids) =
|
||||
stable_row_ids_override.or(self.new_table_config.enable_stable_row_ids)
|
||||
{
|
||||
write_params.enable_stable_row_ids = enable_stable_row_ids;
|
||||
}
|
||||
|
||||
ensure_blob_storage_version(&data_schema, &mut write_params);
|
||||
|
||||
if matches!(&request.mode, CreateTableMode::Overwrite) {
|
||||
write_params.mode = WriteMode::Overwrite;
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
//! Namespace-based database implementation that delegates table management to lance-namespace
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use lance::io::commit::namespace_manifest::LanceNamespaceExternalManifestStore;
|
||||
@@ -23,16 +23,12 @@ use lance_namespace_impls::ConnectBuilder;
|
||||
use lance_table::io::commit::CommitHandler;
|
||||
use lance_table::io::commit::external_manifest::ExternalManifestCommitHandler;
|
||||
|
||||
use crate::blob::{ensure_blob_storage_version, has_blob_columns};
|
||||
use crate::connection::NamespaceClientPushdownOperation;
|
||||
use crate::database::ReadConsistency;
|
||||
use crate::database::listing::{
|
||||
NewTableConfig, OPT_NEW_TABLE_ENABLE_STABLE_ROW_IDS, OPT_NEW_TABLE_STORAGE_VERSION,
|
||||
OPT_NEW_TABLE_V2_MANIFEST_PATHS,
|
||||
};
|
||||
use crate::database::read_freshness::{
|
||||
FreshnessBaselines, ReadFreshnessContextProvider, TableFreshness,
|
||||
};
|
||||
use crate::error::{Error, Result};
|
||||
use crate::table::{NativeTable, map_namespace_lance_error};
|
||||
use lance::dataset::WriteMode;
|
||||
@@ -55,10 +51,6 @@ fn is_table_already_exists_namespace_error(err: &lance::Error) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Object-id delimiter default (matches `RestNamespaceBuilder`'s); overridable
|
||||
/// via the `delimiter` property.
|
||||
const DEFAULT_NAMESPACE_DELIMITER: &str = "$";
|
||||
|
||||
/// A database implementation that uses lance-namespace for table management
|
||||
pub struct LanceNamespaceDatabase {
|
||||
namespace: Arc<dyn LanceNamespace>,
|
||||
@@ -78,17 +70,6 @@ pub struct LanceNamespaceDatabase {
|
||||
ns_properties: HashMap<String, String>,
|
||||
// Options for tables created by this connection
|
||||
new_table_config: NewTableConfig,
|
||||
// Per-table read-freshness baselines, shared with the context provider.
|
||||
freshness_baselines: FreshnessBaselines,
|
||||
// Delimiter for building freshness keys; see `table_freshness`.
|
||||
delimiter: String,
|
||||
}
|
||||
|
||||
fn resolve_delimiter(ns_properties: &HashMap<String, String>) -> String {
|
||||
ns_properties
|
||||
.get("delimiter")
|
||||
.cloned()
|
||||
.unwrap_or_else(|| DEFAULT_NAMESPACE_DELIMITER.to_string())
|
||||
}
|
||||
|
||||
impl LanceNamespaceDatabase {
|
||||
@@ -101,9 +82,6 @@ impl LanceNamespaceDatabase {
|
||||
session: Option<Arc<lance::session::Session>>,
|
||||
namespace_client_pushdown_operations: HashSet<NamespaceClientPushdownOperation>,
|
||||
) -> Self {
|
||||
// Client is pre-built, so we can't install the freshness provider here;
|
||||
// baselines are still tracked for a uniform bump path.
|
||||
let delimiter = resolve_delimiter(&namespace_client_properties);
|
||||
Self {
|
||||
namespace: namespace_client,
|
||||
storage_options,
|
||||
@@ -114,8 +92,6 @@ impl LanceNamespaceDatabase {
|
||||
ns_impl: namespace_client_impl,
|
||||
ns_properties: namespace_client_properties,
|
||||
new_table_config: NewTableConfig::default(),
|
||||
freshness_baselines: Arc::new(Mutex::new(HashMap::new())),
|
||||
delimiter,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -160,19 +136,10 @@ impl LanceNamespaceDatabase {
|
||||
if let Some(ref sess) = session {
|
||||
builder = builder.session(sess.clone());
|
||||
}
|
||||
|
||||
// Install the read-freshness provider before building the client.
|
||||
let freshness_baselines: FreshnessBaselines = Arc::new(Mutex::new(HashMap::new()));
|
||||
builder = builder.context_provider(Arc::new(ReadFreshnessContextProvider::new(
|
||||
freshness_baselines.clone(),
|
||||
read_consistency_interval,
|
||||
)));
|
||||
|
||||
let namespace = builder.connect().await.map_err(|e| Error::InvalidInput {
|
||||
message: format!("Failed to connect to namespace: {:?}", e),
|
||||
})?;
|
||||
|
||||
let delimiter = resolve_delimiter(&ns_properties);
|
||||
Ok(Self {
|
||||
namespace,
|
||||
storage_options,
|
||||
@@ -183,20 +150,9 @@ impl LanceNamespaceDatabase {
|
||||
ns_impl: ns_impl.to_string(),
|
||||
ns_properties,
|
||||
new_table_config,
|
||||
freshness_baselines,
|
||||
delimiter,
|
||||
})
|
||||
}
|
||||
|
||||
/// Build a table's freshness handle, keyed to match the `object_id` the
|
||||
/// namespace client sends on reads (table-id parts joined by the delimiter).
|
||||
fn table_freshness(&self, namespace_path: &[String], name: &str) -> TableFreshness {
|
||||
let mut parts = namespace_path.to_vec();
|
||||
parts.push(name.to_string());
|
||||
let key = parts.join(&self.delimiter);
|
||||
TableFreshness::new(self.freshness_baselines.clone(), key)
|
||||
}
|
||||
|
||||
fn extract_storage_overrides(
|
||||
&self,
|
||||
request: &DbCreateTableRequest,
|
||||
@@ -258,16 +214,12 @@ impl LanceNamespaceDatabase {
|
||||
params.enable_v2_manifest_paths = enable_v2_manifest_paths;
|
||||
}
|
||||
|
||||
let data_schema = request.data.schema();
|
||||
if let Some(enable_stable_row_ids) = stable_row_ids_override
|
||||
.or(self.new_table_config.enable_stable_row_ids)
|
||||
.or(has_blob_columns(data_schema.as_ref()).then_some(true))
|
||||
if let Some(enable_stable_row_ids) =
|
||||
stable_row_ids_override.or(self.new_table_config.enable_stable_row_ids)
|
||||
{
|
||||
params.enable_stable_row_ids = enable_stable_row_ids;
|
||||
}
|
||||
|
||||
ensure_blob_storage_version(data_schema.as_ref(), params);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -379,8 +331,7 @@ impl Database for LanceNamespaceDatabase {
|
||||
self.pushdown_operations.clone(),
|
||||
self.session.clone(),
|
||||
)
|
||||
.await?
|
||||
.with_freshness(self.table_freshness(&request.namespace_path, &request.name));
|
||||
.await?;
|
||||
|
||||
return Ok(Arc::new(native_table));
|
||||
}
|
||||
@@ -511,8 +462,7 @@ impl Database for LanceNamespaceDatabase {
|
||||
self.pushdown_operations.clone(),
|
||||
self.session.clone(),
|
||||
)
|
||||
.await?
|
||||
.with_freshness(self.table_freshness(&request.namespace_path, &request.name));
|
||||
.await?;
|
||||
|
||||
Ok(Arc::new(native_table))
|
||||
}
|
||||
@@ -528,8 +478,7 @@ impl Database for LanceNamespaceDatabase {
|
||||
self.pushdown_operations.clone(),
|
||||
self.session.clone(),
|
||||
)
|
||||
.await?
|
||||
.with_freshness(self.table_freshness(&request.namespace_path, &request.name));
|
||||
.await?;
|
||||
|
||||
Ok(Arc::new(native_table))
|
||||
}
|
||||
|
||||
@@ -1,312 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! Read-freshness signaling for the lance-namespace path.
|
||||
//!
|
||||
//! Against a server that serves cached table metadata up to some staleness
|
||||
//! window, a handle that just wrote (or asked for the latest version via
|
||||
//! `checkout_latest`) can still read a stale snapshot. To prevent that, reads
|
||||
//! routed through the namespace client carry an `x-lancedb-min-timestamp`
|
||||
//! header naming the oldest snapshot the caller will accept.
|
||||
//!
|
||||
//! This mirrors `remote::table`: a per-table baseline is bumped to "now" on
|
||||
//! every write and on `checkout_latest()`, and reads send
|
||||
//! `max(baseline, now - read_consistency_interval)`. Since the namespace client
|
||||
//! takes no headers directly, a [`DynamicContextProvider`] injects it per request.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, SystemTime};
|
||||
|
||||
use lance_namespace_impls::{DynamicContextProvider, OperationInfo};
|
||||
|
||||
/// Provider context keys prefixed with `headers.` become HTTP headers (prefix
|
||||
/// stripped), so this emits the `x-lancedb-min-timestamp` header.
|
||||
const MIN_TIMESTAMP_CONTEXT_KEY: &str = "headers.x-lancedb-min-timestamp";
|
||||
|
||||
/// Per-table freshness baselines (keyed by namespace object id), shared between
|
||||
/// the provider that reads them and the table handles that bump them.
|
||||
pub type FreshnessBaselines = Arc<Mutex<HashMap<String, SystemTime>>>;
|
||||
|
||||
/// `max(baseline, now - interval)`, or `None` when neither constraint applies.
|
||||
fn compute_min_timestamp(
|
||||
baseline: Option<SystemTime>,
|
||||
interval: Option<Duration>,
|
||||
now: SystemTime,
|
||||
) -> Option<SystemTime> {
|
||||
let interval_based = match interval {
|
||||
None => None,
|
||||
Some(d) if d.is_zero() => Some(now),
|
||||
Some(d) => Some(now.checked_sub(d).unwrap_or(now)),
|
||||
};
|
||||
match (interval_based, baseline) {
|
||||
(None, None) => None,
|
||||
(Some(t), None) | (None, Some(t)) => Some(t),
|
||||
(Some(a), Some(b)) => Some(a.max(b)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Advance the baseline to `now`, never backwards, so a concurrent handle's
|
||||
/// write can't lower a floor another handle already set.
|
||||
fn next_freshness_baseline(prev: Option<SystemTime>, now: SystemTime) -> SystemTime {
|
||||
match prev {
|
||||
Some(p) => p.max(now),
|
||||
None => now,
|
||||
}
|
||||
}
|
||||
|
||||
/// A handle's view of the shared baseline map for a single table.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct TableFreshness {
|
||||
baselines: FreshnessBaselines,
|
||||
/// Namespace object id for this table (matches the read's `object_id`).
|
||||
key: String,
|
||||
}
|
||||
|
||||
impl TableFreshness {
|
||||
pub fn new(baselines: FreshnessBaselines, key: String) -> Self {
|
||||
Self { baselines, key }
|
||||
}
|
||||
|
||||
pub fn bump(&self) {
|
||||
let now = SystemTime::now();
|
||||
let mut baselines = self.baselines.lock().unwrap();
|
||||
let prev = baselines.get(&self.key).copied();
|
||||
baselines.insert(self.key.clone(), next_freshness_baseline(prev, now));
|
||||
}
|
||||
}
|
||||
|
||||
/// Read ops that can be served stale and so carry the freshness floor.
|
||||
/// `list_table_versions` resolves "latest" for managed-versioning tables, so it
|
||||
/// is what makes `checkout_latest()` observe a prior write.
|
||||
fn is_read_operation(operation: &str) -> bool {
|
||||
matches!(
|
||||
operation,
|
||||
"describe_table" | "list_table_versions" | "query_table" | "list_tables"
|
||||
)
|
||||
}
|
||||
|
||||
/// Injects `x-lancedb-min-timestamp` on namespace reads, per addressed table.
|
||||
#[derive(Debug)]
|
||||
pub struct ReadFreshnessContextProvider {
|
||||
baselines: FreshnessBaselines,
|
||||
read_consistency_interval: Option<Duration>,
|
||||
}
|
||||
|
||||
impl ReadFreshnessContextProvider {
|
||||
pub fn new(baselines: FreshnessBaselines, read_consistency_interval: Option<Duration>) -> Self {
|
||||
Self {
|
||||
baselines,
|
||||
read_consistency_interval,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DynamicContextProvider for ReadFreshnessContextProvider {
|
||||
fn provide_context(&self, info: &OperationInfo) -> HashMap<String, String> {
|
||||
if !is_read_operation(&info.operation) {
|
||||
return HashMap::new();
|
||||
}
|
||||
|
||||
let baseline = self.baselines.lock().unwrap().get(&info.object_id).copied();
|
||||
match compute_min_timestamp(baseline, self.read_consistency_interval, SystemTime::now()) {
|
||||
Some(ts) => {
|
||||
let dt: chrono::DateTime<chrono::Utc> = ts.into();
|
||||
HashMap::from([(MIN_TIMESTAMP_CONTEXT_KEY.to_string(), dt.to_rfc3339())])
|
||||
}
|
||||
None => HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
/// Allowed slop when comparing a header timestamp against a locally
|
||||
/// captured wall-clock bound. Tests run fast enough that 1s is plenty.
|
||||
const TOLERANCE: Duration = Duration::from_secs(1);
|
||||
|
||||
fn parse_header_ts(headers: &HashMap<String, String>) -> SystemTime {
|
||||
let value = headers
|
||||
.get(MIN_TIMESTAMP_CONTEXT_KEY)
|
||||
.expect("expected min-timestamp context key");
|
||||
chrono::DateTime::parse_from_rfc3339(value)
|
||||
.unwrap()
|
||||
.with_timezone(&chrono::Utc)
|
||||
.into()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_min_timestamp_combines_baseline_and_interval() {
|
||||
let now = SystemTime::now();
|
||||
let baseline = now - Duration::from_secs(60);
|
||||
|
||||
// No interval, no baseline -> no header.
|
||||
assert_eq!(compute_min_timestamp(None, None, now), None);
|
||||
|
||||
// Baseline only -> baseline.
|
||||
assert_eq!(
|
||||
compute_min_timestamp(Some(baseline), None, now),
|
||||
Some(baseline)
|
||||
);
|
||||
|
||||
// ZERO interval, no baseline -> now (strong consistency).
|
||||
assert_eq!(
|
||||
compute_min_timestamp(None, Some(Duration::ZERO), now),
|
||||
Some(now)
|
||||
);
|
||||
|
||||
// Positive interval, no baseline -> now - interval.
|
||||
assert_eq!(
|
||||
compute_min_timestamp(None, Some(Duration::from_secs(10)), now),
|
||||
Some(now - Duration::from_secs(10))
|
||||
);
|
||||
|
||||
// Both: pick the more-recent (tighter) constraint.
|
||||
// baseline = now-60, now-interval = now-10. now-10 is newer.
|
||||
assert_eq!(
|
||||
compute_min_timestamp(Some(baseline), Some(Duration::from_secs(10)), now),
|
||||
Some(now - Duration::from_secs(10))
|
||||
);
|
||||
|
||||
// Both, baseline newer: pick baseline.
|
||||
let recent_baseline = now - Duration::from_secs(5);
|
||||
assert_eq!(
|
||||
compute_min_timestamp(Some(recent_baseline), Some(Duration::from_secs(60)), now),
|
||||
Some(recent_baseline)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_next_freshness_baseline_is_monotonic() {
|
||||
let now = SystemTime::now();
|
||||
let earlier = now - Duration::from_secs(30);
|
||||
let later = now + Duration::from_secs(30);
|
||||
|
||||
// No prior baseline -> now.
|
||||
assert_eq!(next_freshness_baseline(None, now), now);
|
||||
// Prior baseline older than now -> now.
|
||||
assert_eq!(next_freshness_baseline(Some(earlier), now), now);
|
||||
// Prior baseline newer than now -> keep the newer baseline.
|
||||
assert_eq!(next_freshness_baseline(Some(later), now), later);
|
||||
}
|
||||
|
||||
fn provider_with(
|
||||
entries: &[(&str, SystemTime)],
|
||||
interval: Option<Duration>,
|
||||
) -> ReadFreshnessContextProvider {
|
||||
let map: HashMap<String, SystemTime> =
|
||||
entries.iter().map(|(k, v)| (k.to_string(), *v)).collect();
|
||||
ReadFreshnessContextProvider::new(Arc::new(Mutex::new(map)), interval)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_emits_header_at_or_after_bumped_baseline() {
|
||||
// A baseline set "now" with no interval: every read op must carry a
|
||||
// floor at or after that baseline. `list_table_versions` is the hook
|
||||
// that makes managed-versioning `checkout_latest()` observe a write.
|
||||
let baseline = SystemTime::now();
|
||||
let provider = provider_with(&[("ns$tbl", baseline)], None);
|
||||
|
||||
// These ops are keyed by the table id, so they pick up the per-table
|
||||
// baseline. (`list_tables` is keyed by the namespace, so it is covered
|
||||
// separately by the interval-floor test.)
|
||||
for op in ["describe_table", "list_table_versions", "query_table"] {
|
||||
let ctx = provider.provide_context(&OperationInfo::new(op, "ns$tbl"));
|
||||
let sent = parse_header_ts(&ctx);
|
||||
assert!(
|
||||
sent >= baseline - TOLERANCE && sent <= baseline + TOLERANCE,
|
||||
"operation {op} should carry a floor at the bumped baseline"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_list_tables_uses_interval_floor_not_table_baseline() {
|
||||
// `list_tables` is addressed by the namespace id, which never matches a
|
||||
// per-table baseline key, so a bumped table baseline must not leak onto
|
||||
// it. With no interval it sends nothing; with one it sends now-interval.
|
||||
let provider = provider_with(&[("ns$tbl", SystemTime::now())], None);
|
||||
let ctx = provider.provide_context(&OperationInfo::new("list_tables", "ns"));
|
||||
assert!(
|
||||
ctx.is_empty(),
|
||||
"list_tables must not inherit a per-table baseline"
|
||||
);
|
||||
|
||||
let interval = Duration::from_secs(30);
|
||||
let provider = provider_with(&[("ns$tbl", SystemTime::now())], Some(interval));
|
||||
let before = SystemTime::now();
|
||||
let ctx = provider.provide_context(&OperationInfo::new("list_tables", "ns"));
|
||||
let after = SystemTime::now();
|
||||
let sent = parse_header_ts(&ctx);
|
||||
assert!(
|
||||
sent >= before - interval - TOLERANCE && sent <= after - interval + TOLERANCE,
|
||||
"list_tables should carry the interval floor"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_no_header_for_empty_baseline_and_no_interval() {
|
||||
// Manual consistency (no interval) on a table that was never bumped:
|
||||
// no floor, so the server may serve from cache.
|
||||
let provider = provider_with(&[], None);
|
||||
let ctx = provider.provide_context(&OperationInfo::new("describe_table", "ns$tbl"));
|
||||
assert!(ctx.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_interval_floor_applies_without_baseline() {
|
||||
// With a consistency interval and no baseline, the floor is now-interval.
|
||||
let interval = Duration::from_secs(30);
|
||||
let provider = provider_with(&[], Some(interval));
|
||||
|
||||
let before = SystemTime::now();
|
||||
let ctx = provider.provide_context(&OperationInfo::new("query_table", "ns$tbl"));
|
||||
let after = SystemTime::now();
|
||||
|
||||
let sent = parse_header_ts(&ctx);
|
||||
assert!(
|
||||
sent >= before - interval - TOLERANCE && sent <= after - interval + TOLERANCE,
|
||||
"expected floor at roughly now - interval"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_non_read_ops_emit_nothing() {
|
||||
// Even with a fresh baseline and a zero interval, a non-read operation
|
||||
// (which establishes rather than consumes a baseline) sends no header.
|
||||
let provider = provider_with(&[("ns$tbl", SystemTime::now())], Some(Duration::ZERO));
|
||||
for op in [
|
||||
"create_table",
|
||||
"register_table",
|
||||
"drop_table",
|
||||
"rename_table",
|
||||
// Pinned to an immutable version, so it cannot be served stale.
|
||||
"describe_table_version",
|
||||
] {
|
||||
let ctx = provider.provide_context(&OperationInfo::new(op, "ns$tbl"));
|
||||
assert!(
|
||||
ctx.is_empty(),
|
||||
"operation {op} must not send a freshness header"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_uses_per_table_baseline() {
|
||||
// The floor is looked up by object id, so an unrelated table's baseline
|
||||
// does not leak onto another table's read.
|
||||
let baseline = SystemTime::now();
|
||||
let provider = provider_with(&[("ns$has_baseline", baseline)], None);
|
||||
|
||||
// The bumped table gets a header.
|
||||
let hit =
|
||||
provider.provide_context(&OperationInfo::new("describe_table", "ns$has_baseline"));
|
||||
assert!(!hit.is_empty());
|
||||
|
||||
// A different table with no baseline (and no interval) gets nothing.
|
||||
let miss = provider.provide_context(&OperationInfo::new("describe_table", "ns$other"));
|
||||
assert!(miss.is_empty());
|
||||
}
|
||||
}
|
||||
@@ -13,7 +13,7 @@ use serde_json::{Value, json};
|
||||
use super::EmbeddingFunction;
|
||||
use crate::{Error, Result};
|
||||
|
||||
use tokio::runtime::{Handle, RuntimeFlavor};
|
||||
use tokio::runtime::Handle;
|
||||
use tokio::task::block_in_place;
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -148,12 +148,6 @@ impl BedrockEmbeddingFunction {
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
// Bedrock's SDK is async but this trait method is synchronous, so we
|
||||
// bridge with `block_in_place` + `block_on`. That requires a
|
||||
// multi-threaded Tokio runtime; return a typed error instead of
|
||||
// panicking when no compatible runtime is available.
|
||||
let handle = current_multi_thread_handle()?;
|
||||
|
||||
for text in texts {
|
||||
let request_body = match self.model {
|
||||
BedrockEmbeddingModel::TitanEmbedding => {
|
||||
@@ -169,28 +163,24 @@ impl BedrockEmbeddingFunction {
|
||||
}
|
||||
};
|
||||
|
||||
// Serialize before entering the blocking section so a serialization
|
||||
// failure surfaces as a typed error rather than an `unwrap` panic.
|
||||
let body = serde_json::to_vec(&request_body).map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to serialize Bedrock request: {e}"),
|
||||
})?;
|
||||
|
||||
let client = self.client.clone();
|
||||
let model_id = self.model.model_id().to_string();
|
||||
let request_body = request_body.clone();
|
||||
|
||||
let response = block_in_place(|| {
|
||||
handle.block_on(async move {
|
||||
let response = block_in_place(move || {
|
||||
Handle::current().block_on(async move {
|
||||
client
|
||||
.invoke_model()
|
||||
.model_id(model_id)
|
||||
.body(aws_sdk_bedrockruntime::primitives::Blob::new(body))
|
||||
.body(aws_sdk_bedrockruntime::primitives::Blob::new(
|
||||
serde_json::to_vec(&request_body).unwrap(),
|
||||
))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Bedrock invoke_model request failed: {e}"),
|
||||
})
|
||||
.map_err(Box::new)
|
||||
})
|
||||
})?;
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let response_json: Value =
|
||||
serde_json::from_slice(response.body.as_ref()).map_err(|e| Error::Runtime {
|
||||
@@ -198,12 +188,22 @@ impl BedrockEmbeddingFunction {
|
||||
})?;
|
||||
|
||||
let embedding = match self.model {
|
||||
BedrockEmbeddingModel::TitanEmbedding => {
|
||||
json_array_to_f32(&response_json["embedding"], "embedding")?
|
||||
}
|
||||
BedrockEmbeddingModel::CohereLarge => {
|
||||
json_array_to_f32(&response_json["embeddings"][0], "embeddings")?
|
||||
}
|
||||
BedrockEmbeddingModel::TitanEmbedding => response_json["embedding"]
|
||||
.as_array()
|
||||
.ok_or_else(|| Error::Runtime {
|
||||
message: "Missing embedding in response".to_string(),
|
||||
})?
|
||||
.iter()
|
||||
.map(|v| v.as_f64().unwrap() as f32)
|
||||
.collect::<Vec<f32>>(),
|
||||
BedrockEmbeddingModel::CohereLarge => response_json["embeddings"][0]
|
||||
.as_array()
|
||||
.ok_or_else(|| Error::Runtime {
|
||||
message: "Missing embeddings in response".to_string(),
|
||||
})?
|
||||
.iter()
|
||||
.map(|v| v.as_f64().unwrap() as f32)
|
||||
.collect::<Vec<f32>>(),
|
||||
};
|
||||
|
||||
builder.append_slice(&embedding);
|
||||
@@ -212,86 +212,3 @@ impl BedrockEmbeddingFunction {
|
||||
Ok(builder.finish())
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a handle to the current multi-threaded Tokio runtime, or a typed
|
||||
/// [`Error::Runtime`] when called outside a runtime or on the current-thread
|
||||
/// runtime. This keeps the synchronous-over-async bridge in
|
||||
/// [`BedrockEmbeddingFunction::compute_inner`] from panicking on runtime
|
||||
/// configurations that cannot support `block_in_place`.
|
||||
fn current_multi_thread_handle() -> Result<Handle> {
|
||||
let handle = Handle::try_current().map_err(|e| Error::Runtime {
|
||||
message: format!("Bedrock embedding must be called from within a Tokio runtime: {e}"),
|
||||
})?;
|
||||
if handle.runtime_flavor() == RuntimeFlavor::CurrentThread {
|
||||
return Err(Error::Runtime {
|
||||
message: "Bedrock embedding requires a multi-threaded Tokio runtime; the \
|
||||
current-thread runtime cannot use `block_in_place`"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
Ok(handle)
|
||||
}
|
||||
|
||||
/// Converts a JSON value expected to be an array of numbers into `Vec<f32>`.
|
||||
///
|
||||
/// Returns a typed [`Error::Runtime`] (rather than panicking) when the value is
|
||||
/// not an array or contains a non-numeric element, so malformed provider
|
||||
/// responses degrade gracefully.
|
||||
fn json_array_to_f32(value: &Value, field: &str) -> Result<Vec<f32>> {
|
||||
let arr = value.as_array().ok_or_else(|| Error::Runtime {
|
||||
message: format!("Missing or non-array '{field}' field in Bedrock response"),
|
||||
})?;
|
||||
arr.iter()
|
||||
.map(|v| {
|
||||
v.as_f64().map(|f| f as f32).ok_or_else(|| Error::Runtime {
|
||||
message: format!("Non-numeric value in Bedrock '{field}' embedding: {v}"),
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn json_array_to_f32_parses_numbers() {
|
||||
let v = json!([1.0, 2, -3.5]);
|
||||
let out = json_array_to_f32(&v, "embedding").unwrap();
|
||||
assert_eq!(out, vec![1.0_f32, 2.0, -3.5]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn json_array_to_f32_rejects_non_array() {
|
||||
// Missing field indexes to `Value::Null`; a malformed payload should be
|
||||
// a typed error, not a panic.
|
||||
let v = json!({"unexpected": "shape"});
|
||||
let err = json_array_to_f32(&v["embedding"], "embedding").unwrap_err();
|
||||
assert!(matches!(err, Error::Runtime { .. }), "got {err:?}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn json_array_to_f32_rejects_non_numeric_element() {
|
||||
let v = json!([1.0, "not-a-number", 3.0]);
|
||||
let err = json_array_to_f32(&v, "embedding").unwrap_err();
|
||||
assert!(matches!(err, Error::Runtime { .. }), "got {err:?}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handle_errors_without_runtime() {
|
||||
// No Tokio runtime in scope -> typed error instead of a panic.
|
||||
let err = current_multi_thread_handle().unwrap_err();
|
||||
assert!(matches!(err, Error::Runtime { .. }), "got {err:?}");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
async fn handle_errors_on_current_thread_runtime() {
|
||||
let err = current_multi_thread_handle().unwrap_err();
|
||||
assert!(matches!(err, Error::Runtime { .. }), "got {err:?}");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn handle_ok_on_multi_thread_runtime() {
|
||||
current_multi_thread_handle().expect("multi-threaded runtime should be accepted");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -163,7 +163,6 @@
|
||||
//! ```
|
||||
|
||||
pub mod arrow;
|
||||
pub mod blob;
|
||||
pub mod connection;
|
||||
pub mod data;
|
||||
pub mod database;
|
||||
@@ -189,7 +188,6 @@ use std::{fmt::Display, str::FromStr};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub use blob::{blob, is_blob};
|
||||
pub use connection::{ConnectNamespaceBuilder, Connection};
|
||||
pub use error::{Error, Result};
|
||||
use lance_index::vector::ApproxMode as LanceApproxMode;
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
|
||||
pub(crate) mod client;
|
||||
pub(crate) mod db;
|
||||
pub mod oauth;
|
||||
mod retry;
|
||||
pub(crate) mod table;
|
||||
pub(crate) mod util;
|
||||
@@ -21,4 +20,3 @@ const JSON_CONTENT_TYPE: &str = "application/json";
|
||||
|
||||
pub use client::{ClientConfig, HeaderProvider, RetryConfig, TimeoutConfig, TlsConfig};
|
||||
pub use db::{RemoteDatabaseOptions, RemoteDatabaseOptionsBuilder};
|
||||
pub use oauth::{OAuthConfig, OAuthFlow, OAuthHeaderProvider};
|
||||
|
||||
@@ -459,14 +459,12 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
config: &ClientConfig,
|
||||
) -> Result<HeaderMap> {
|
||||
let mut headers = HeaderMap::new();
|
||||
if !api_key.is_empty() {
|
||||
headers.insert(
|
||||
HeaderName::from_static("x-api-key"),
|
||||
HeaderValue::from_str(api_key).map_err(|_| Error::InvalidInput {
|
||||
message: "non-ascii api key provided".to_string(),
|
||||
})?,
|
||||
);
|
||||
}
|
||||
headers.insert(
|
||||
HeaderName::from_static("x-api-key"),
|
||||
HeaderValue::from_str(api_key).map_err(|_| Error::InvalidInput {
|
||||
message: "non-ascii api key provided".to_string(),
|
||||
})?,
|
||||
);
|
||||
if region == "local" {
|
||||
let host = format!("{}.local.api.lancedb.com", db_name);
|
||||
headers.insert(
|
||||
@@ -1007,33 +1005,6 @@ mod tests {
|
||||
assert!(!config_tls.assert_hostname);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_headers_skip_empty_api_key() {
|
||||
let headers = RestfulLanceDbClient::<Sender>::default_headers(
|
||||
"",
|
||||
"us-east-1",
|
||||
"db-name",
|
||||
false,
|
||||
&RemoteOptions::default(),
|
||||
None,
|
||||
&ClientConfig::default(),
|
||||
)
|
||||
.unwrap();
|
||||
assert!(!headers.contains_key("x-api-key"));
|
||||
|
||||
let headers = RestfulLanceDbClient::<Sender>::default_headers(
|
||||
"api-key",
|
||||
"us-east-1",
|
||||
"db-name",
|
||||
false,
|
||||
&RemoteOptions::default(),
|
||||
None,
|
||||
&ClientConfig::default(),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(headers.get("x-api-key").unwrap(), "api-key");
|
||||
}
|
||||
|
||||
// Test implementation of HeaderProvider
|
||||
#[derive(Debug, Clone)]
|
||||
struct TestHeaderProvider {
|
||||
|
||||
@@ -7,7 +7,6 @@ use std::sync::Arc;
|
||||
use async_trait::async_trait;
|
||||
use http::StatusCode;
|
||||
use lance_io::object_store::StorageOptions;
|
||||
use lance_namespace_impls::{DynamicContextProvider, OperationInfo};
|
||||
use moka::future::Cache;
|
||||
use reqwest::header::CONTENT_TYPE;
|
||||
|
||||
@@ -19,264 +18,18 @@ use lance_namespace::models::{
|
||||
|
||||
use crate::Error;
|
||||
use crate::database::{
|
||||
CloneTableRequest, CreateFunctionRequest, CreateMaterializedViewRequest, CreateTableMode,
|
||||
CreateTableRequest, Database, DatabaseOptions, FunctionInfo, JobErrorInfo, JobHistoryInfo,
|
||||
JobInfo, MaterializedViewInfo, MvRefreshPlan, OpenTableRequest, ReadConsistency,
|
||||
RefreshMaterializedViewRequest, TableLineageRequest, TableNamesRequest,
|
||||
CloneTableRequest, CreateTableMode, CreateTableRequest, Database, DatabaseOptions,
|
||||
OpenTableRequest, ReadConsistency, TableNamesRequest,
|
||||
};
|
||||
use crate::error::Result;
|
||||
use crate::remote::util::stream_as_body;
|
||||
use crate::table::BaseTable;
|
||||
|
||||
use super::ARROW_STREAM_CONTENT_TYPE;
|
||||
use super::client::{
|
||||
ClientConfig, HeaderProvider, HttpSend, RequestResultExt, RestfulLanceDbClient, Sender,
|
||||
};
|
||||
use super::client::{ClientConfig, HttpSend, RequestResultExt, RestfulLanceDbClient, Sender};
|
||||
use super::table::RemoteTable;
|
||||
use super::util::parse_server_version;
|
||||
|
||||
// Wire types for the derived-compute routes (functions, materialized
|
||||
// views, jobs). Field shapes mirror the server's REST contract.
|
||||
#[derive(serde::Serialize)]
|
||||
struct RemoteCreateFunctionRequest {
|
||||
language: String,
|
||||
return_type: String,
|
||||
body: String,
|
||||
options: std::collections::HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteFunctionEntry {
|
||||
name: String,
|
||||
language: String,
|
||||
return_type: String,
|
||||
#[serde(default)]
|
||||
description: String,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteListFunctionsResponse {
|
||||
functions: Vec<RemoteFunctionEntry>,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct RemoteCreateMaterializedViewRequest {
|
||||
query: String,
|
||||
auto_refresh: bool,
|
||||
with_no_data: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
partition_by: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteCreateMaterializedViewResponse {
|
||||
#[serde(default)]
|
||||
job_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct RemoteRefreshMaterializedViewRequest {
|
||||
#[serde(skip_serializing_if = "std::ops::Not::not")]
|
||||
full: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
src_version: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
num_workers: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
max_workers: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteRefreshMaterializedViewResponse {
|
||||
job_id: String,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct RemoteExplainRefreshRequest {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
full: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
src_version: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteExplainRefreshResponse {
|
||||
table_name: String,
|
||||
has_work: bool,
|
||||
source_version: u64,
|
||||
last_refreshed_version: Option<u64>,
|
||||
full_refresh: bool,
|
||||
rebuild: bool,
|
||||
units_total: u64,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct RemoteAlterMaterializedViewRequest {
|
||||
auto_refresh: bool,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteMaterializedViewEntry {
|
||||
name: String,
|
||||
source_table: String,
|
||||
#[serde(default)]
|
||||
projection: Vec<String>,
|
||||
#[serde(default)]
|
||||
udf_columns: Vec<String>,
|
||||
#[serde(default)]
|
||||
filter: Option<String>,
|
||||
#[serde(default)]
|
||||
auto_refresh: bool,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteListMaterializedViewsResponse {
|
||||
views: Vec<RemoteMaterializedViewEntry>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteJobEntry {
|
||||
table: String,
|
||||
job_id: String,
|
||||
job_type: String,
|
||||
state: String,
|
||||
#[serde(default)]
|
||||
column: Option<String>,
|
||||
#[serde(default)]
|
||||
age_seconds: Option<i64>,
|
||||
#[serde(default)]
|
||||
command: Option<String>,
|
||||
#[serde(default)]
|
||||
units_done: Option<i64>,
|
||||
#[serde(default)]
|
||||
units_total: Option<i64>,
|
||||
#[serde(default)]
|
||||
committed: bool,
|
||||
#[serde(default)]
|
||||
rows_skipped: u64,
|
||||
#[serde(default)]
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteListJobsResponse {
|
||||
jobs: Vec<RemoteJobEntry>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteGetJobResponse {
|
||||
#[serde(default)]
|
||||
job: Option<RemoteJobEntry>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteCancelJobResponse {
|
||||
cancelled: bool,
|
||||
}
|
||||
|
||||
impl From<RemoteJobEntry> for JobInfo {
|
||||
fn from(j: RemoteJobEntry) -> Self {
|
||||
JobInfo {
|
||||
table: j.table,
|
||||
job_id: j.job_id,
|
||||
job_type: j.job_type,
|
||||
state: j.state,
|
||||
column: j.column,
|
||||
age_seconds: j.age_seconds,
|
||||
command: j.command,
|
||||
units_done: j.units_done,
|
||||
units_total: j.units_total,
|
||||
committed: j.committed,
|
||||
rows_skipped: j.rows_skipped,
|
||||
error: j.error,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteJobHistoryEntry {
|
||||
table: String,
|
||||
job_id: String,
|
||||
job_type: String,
|
||||
state: String,
|
||||
#[serde(default)]
|
||||
column: Option<String>,
|
||||
created_ms: i64,
|
||||
updated_ms: i64,
|
||||
#[serde(default)]
|
||||
completed_ms: Option<i64>,
|
||||
#[serde(default)]
|
||||
rows_processed: Option<i64>,
|
||||
#[serde(default)]
|
||||
rows_skipped: Option<i64>,
|
||||
#[serde(default)]
|
||||
error: Option<String>,
|
||||
#[serde(default)]
|
||||
events: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteJobHistoryResponse {
|
||||
jobs: Vec<RemoteJobHistoryEntry>,
|
||||
}
|
||||
|
||||
impl From<RemoteJobHistoryEntry> for JobHistoryInfo {
|
||||
fn from(j: RemoteJobHistoryEntry) -> Self {
|
||||
JobHistoryInfo {
|
||||
table: j.table,
|
||||
job_id: j.job_id,
|
||||
job_type: j.job_type,
|
||||
state: j.state,
|
||||
column: j.column,
|
||||
created_ms: j.created_ms,
|
||||
updated_ms: j.updated_ms,
|
||||
completed_ms: j.completed_ms,
|
||||
rows_processed: j.rows_processed,
|
||||
rows_skipped: j.rows_skipped,
|
||||
error: j.error,
|
||||
events: j.events,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteErrorEntry {
|
||||
job_id: String,
|
||||
table: String,
|
||||
column: String,
|
||||
error_type: String,
|
||||
error_message: String,
|
||||
#[serde(default)]
|
||||
fragment_id: Option<i64>,
|
||||
#[serde(default)]
|
||||
source_row_id: Option<i64>,
|
||||
#[serde(default)]
|
||||
table_version: Option<i64>,
|
||||
#[serde(default)]
|
||||
age_seconds: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RemoteErrorsResponse {
|
||||
errors: Vec<RemoteErrorEntry>,
|
||||
}
|
||||
|
||||
impl From<RemoteErrorEntry> for JobErrorInfo {
|
||||
fn from(e: RemoteErrorEntry) -> Self {
|
||||
JobErrorInfo {
|
||||
job_id: e.job_id,
|
||||
table: e.table,
|
||||
column: e.column,
|
||||
error_type: e.error_type,
|
||||
error_message: e.error_message,
|
||||
fragment_id: e.fragment_id,
|
||||
source_row_id: e.source_row_id,
|
||||
table_version: e.table_version,
|
||||
age_seconds: e.age_seconds,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Request structure for the remote clone table API
|
||||
#[derive(serde::Serialize)]
|
||||
struct RemoteCloneTableRequest {
|
||||
@@ -441,66 +194,10 @@ pub struct RemoteDatabase<S: HttpSend = Sender> {
|
||||
uri: String,
|
||||
/// Headers to pass to the namespace client for authentication
|
||||
namespace_headers: HashMap<String, String>,
|
||||
namespace_context_provider: Option<Arc<dyn DynamicContextProvider>>,
|
||||
/// TLS configuration for mTLS support
|
||||
tls_config: Option<super::client::TlsConfig>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct NamespaceHeaderProviderContext {
|
||||
header_provider: Arc<dyn HeaderProvider>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for NamespaceHeaderProviderContext {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("NamespaceHeaderProviderContext")
|
||||
.field("header_provider", &"Some(...)")
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl DynamicContextProvider for NamespaceHeaderProviderContext {
|
||||
fn provide_context(&self, _info: &OperationInfo) -> HashMap<String, String> {
|
||||
let header_provider = Arc::clone(&self.header_provider);
|
||||
let handle = match std::thread::Builder::new()
|
||||
.name("lancedb-namespace-headers".to_string())
|
||||
.spawn(move || {
|
||||
tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!(
|
||||
"Failed to create runtime for namespace header provider: {e}"
|
||||
),
|
||||
})?
|
||||
.block_on(header_provider.get_headers())
|
||||
}) {
|
||||
Ok(handle) => handle,
|
||||
Err(err) => {
|
||||
log::warn!("Failed to spawn dynamic namespace header provider thread: {err}");
|
||||
return HashMap::new();
|
||||
}
|
||||
};
|
||||
|
||||
let headers = handle.join();
|
||||
|
||||
match headers {
|
||||
Ok(Ok(headers)) => headers
|
||||
.into_iter()
|
||||
.map(|(key, value)| (format!("headers.{key}"), value))
|
||||
.collect(),
|
||||
Ok(Err(err)) => {
|
||||
log::warn!("Failed to get dynamic namespace headers: {err}");
|
||||
HashMap::new()
|
||||
}
|
||||
Err(_) => {
|
||||
log::warn!("Dynamic namespace header provider panicked");
|
||||
HashMap::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RemoteDatabase {
|
||||
pub fn try_new(
|
||||
uri: &str,
|
||||
@@ -531,16 +228,6 @@ impl RemoteDatabase {
|
||||
})
|
||||
.collect();
|
||||
|
||||
let namespace_context_provider =
|
||||
client_config
|
||||
.header_provider
|
||||
.as_ref()
|
||||
.map(|header_provider| {
|
||||
Arc::new(NamespaceHeaderProviderContext {
|
||||
header_provider: Arc::clone(header_provider),
|
||||
}) as Arc<dyn DynamicContextProvider>
|
||||
});
|
||||
|
||||
let client = RestfulLanceDbClient::try_new(
|
||||
&parsed,
|
||||
region,
|
||||
@@ -560,7 +247,6 @@ impl RemoteDatabase {
|
||||
table_cache,
|
||||
uri: uri.to_owned(),
|
||||
namespace_headers,
|
||||
namespace_context_provider,
|
||||
tls_config: client_config.tls_config,
|
||||
})
|
||||
}
|
||||
@@ -585,7 +271,6 @@ mod test_utils {
|
||||
table_cache: Cache::new(0),
|
||||
uri: "http://localhost".to_string(),
|
||||
namespace_headers: HashMap::new(),
|
||||
namespace_context_provider: None,
|
||||
tls_config: None,
|
||||
}
|
||||
}
|
||||
@@ -596,18 +281,11 @@ mod test_utils {
|
||||
T: Into<reqwest::Body>,
|
||||
{
|
||||
let client = client_with_handler_and_config(handler, config.clone());
|
||||
let namespace_context_provider =
|
||||
config.header_provider.as_ref().map(|header_provider| {
|
||||
Arc::new(NamespaceHeaderProviderContext {
|
||||
header_provider: Arc::clone(header_provider),
|
||||
}) as Arc<dyn DynamicContextProvider>
|
||||
});
|
||||
Self {
|
||||
client,
|
||||
table_cache: Cache::new(0),
|
||||
uri: "http://localhost".to_string(),
|
||||
namespace_headers: config.extra_headers.clone(),
|
||||
namespace_context_provider,
|
||||
tls_config: config.tls_config.clone(),
|
||||
}
|
||||
}
|
||||
@@ -885,228 +563,6 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
Ok(table)
|
||||
}
|
||||
|
||||
async fn create_function(&self, request: CreateFunctionRequest) -> Result<()> {
|
||||
let body = RemoteCreateFunctionRequest {
|
||||
language: request.language,
|
||||
return_type: request.return_type,
|
||||
body: request.body,
|
||||
options: request.options,
|
||||
};
|
||||
let req = self
|
||||
.client
|
||||
.post(&format!("/v1/function/{}/create", request.name))
|
||||
.json(&body);
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
self.client.check_response(&request_id, rsp).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn list_functions(&self) -> Result<Vec<FunctionInfo>> {
|
||||
let req = self.client.get("/v1/function/list");
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let body: RemoteListFunctionsResponse = rsp.json().await.err_to_http(request_id)?;
|
||||
Ok(body
|
||||
.functions
|
||||
.into_iter()
|
||||
.map(|f| FunctionInfo {
|
||||
name: f.name,
|
||||
language: f.language,
|
||||
return_type: f.return_type,
|
||||
description: f.description,
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn drop_function(&self, name: &str) -> Result<()> {
|
||||
let req = self.client.post(&format!("/v1/function/{}/drop", name));
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
self.client.check_response(&request_id, rsp).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_materialized_view(
|
||||
&self,
|
||||
request: CreateMaterializedViewRequest,
|
||||
) -> Result<Option<String>> {
|
||||
let body = RemoteCreateMaterializedViewRequest {
|
||||
query: request.query,
|
||||
auto_refresh: request.auto_refresh,
|
||||
with_no_data: request.with_no_data,
|
||||
partition_by: request.partition_by,
|
||||
};
|
||||
let req = self
|
||||
.client
|
||||
.post(&format!("/v1/materialized_view/{}/create", request.name))
|
||||
.json(&body);
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let body: RemoteCreateMaterializedViewResponse =
|
||||
rsp.json().await.err_to_http(request_id)?;
|
||||
Ok(body.job_id)
|
||||
}
|
||||
|
||||
async fn refresh_materialized_view(
|
||||
&self,
|
||||
request: RefreshMaterializedViewRequest,
|
||||
) -> Result<String> {
|
||||
let body = RemoteRefreshMaterializedViewRequest {
|
||||
full: request.full,
|
||||
src_version: request.src_version,
|
||||
num_workers: request.num_workers,
|
||||
max_workers: request.max_workers,
|
||||
};
|
||||
let req = self
|
||||
.client
|
||||
.post(&format!("/v1/materialized_view/{}/refresh", request.name))
|
||||
.json(&body);
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let body: RemoteRefreshMaterializedViewResponse =
|
||||
rsp.json().await.err_to_http(request_id)?;
|
||||
Ok(body.job_id)
|
||||
}
|
||||
|
||||
async fn table_lineage(&self, request: TableLineageRequest) -> Result<String> {
|
||||
let mut req = self
|
||||
.client
|
||||
.get(&format!("/v1/table/{}/lineage", request.name));
|
||||
if let Some(column) = &request.column {
|
||||
req = req.query(&[("column", column)]);
|
||||
}
|
||||
if let Some(direction) = &request.direction {
|
||||
req = req.query(&[("direction", direction)]);
|
||||
}
|
||||
if let Some(depth) = request.depth {
|
||||
req = req.query(&[("depth", depth.to_string())]);
|
||||
}
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
// Server-defined lineage JSON, returned opaque (the client does not
|
||||
// model the lineage schema; the Python layer deserializes it).
|
||||
rsp.text().await.err_to_http(request_id)
|
||||
}
|
||||
|
||||
async fn explain_refresh_materialized_view(
|
||||
&self,
|
||||
name: &str,
|
||||
full: bool,
|
||||
src_version: Option<u64>,
|
||||
) -> Result<MvRefreshPlan> {
|
||||
let body = RemoteExplainRefreshRequest {
|
||||
full: Some(full),
|
||||
src_version,
|
||||
};
|
||||
let req = self
|
||||
.client
|
||||
.post(&format!("/v1/materialized_view/{}/explain_refresh", name))
|
||||
.json(&body);
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let body: RemoteExplainRefreshResponse = rsp.json().await.err_to_http(request_id)?;
|
||||
Ok(MvRefreshPlan {
|
||||
table_name: body.table_name,
|
||||
has_work: body.has_work,
|
||||
source_version: body.source_version,
|
||||
last_refreshed_version: body.last_refreshed_version,
|
||||
full_refresh: body.full_refresh,
|
||||
rebuild: body.rebuild,
|
||||
units_total: body.units_total,
|
||||
})
|
||||
}
|
||||
|
||||
async fn alter_materialized_view(&self, name: &str, auto_refresh: bool) -> Result<()> {
|
||||
let req = self
|
||||
.client
|
||||
.post(&format!("/v1/materialized_view/{}/alter", name))
|
||||
.json(&RemoteAlterMaterializedViewRequest { auto_refresh });
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
self.client.check_response(&request_id, rsp).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn drop_materialized_view(&self, name: &str) -> Result<()> {
|
||||
let req = self
|
||||
.client
|
||||
.post(&format!("/v1/materialized_view/{}/drop", name));
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
self.client.check_response(&request_id, rsp).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn list_materialized_views(&self) -> Result<Vec<MaterializedViewInfo>> {
|
||||
let req = self.client.get("/v1/materialized_view/list");
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let body: RemoteListMaterializedViewsResponse = rsp.json().await.err_to_http(request_id)?;
|
||||
Ok(body
|
||||
.views
|
||||
.into_iter()
|
||||
.map(|v| MaterializedViewInfo {
|
||||
name: v.name,
|
||||
source_table: v.source_table,
|
||||
projection: v.projection,
|
||||
udf_columns: v.udf_columns,
|
||||
filter: v.filter,
|
||||
auto_refresh: v.auto_refresh,
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn list_jobs(&self) -> Result<Vec<JobInfo>> {
|
||||
let req = self.client.get("/v1/job/list");
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let body: RemoteListJobsResponse = rsp.json().await.err_to_http(request_id)?;
|
||||
Ok(body.jobs.into_iter().map(JobInfo::from).collect())
|
||||
}
|
||||
|
||||
async fn get_job(&self, job_id: &str, table: Option<&str>) -> Result<Option<JobInfo>> {
|
||||
// Point-access poll path: GET /v1/job/{id}, with the table as the O(1)
|
||||
// hint when known. `query` handles URL-encoding the table name.
|
||||
let mut req = self.client.get(&format!("/v1/job/{job_id}"));
|
||||
if let Some(t) = table {
|
||||
req = req.query(&[("table", t)]);
|
||||
}
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let body: RemoteGetJobResponse = rsp.json().await.err_to_http(request_id)?;
|
||||
Ok(body.job.map(JobInfo::from))
|
||||
}
|
||||
|
||||
async fn cancel_job(&self, job_id: &str) -> Result<bool> {
|
||||
let req = self.client.post(&format!("/v1/job/{}/cancel", job_id));
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let body: RemoteCancelJobResponse = rsp.json().await.err_to_http(request_id)?;
|
||||
Ok(body.cancelled)
|
||||
}
|
||||
|
||||
async fn job_history(&self, job_id: Option<&str>) -> Result<Vec<JobHistoryInfo>> {
|
||||
let mut req = self.client.get("/v1/job/history");
|
||||
if let Some(j) = job_id {
|
||||
req = req.query(&[("job", j)]);
|
||||
}
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let body: RemoteJobHistoryResponse = rsp.json().await.err_to_http(request_id)?;
|
||||
Ok(body.jobs.into_iter().map(JobHistoryInfo::from).collect())
|
||||
}
|
||||
|
||||
async fn errors(&self, job_id: Option<&str>, table: Option<&str>) -> Result<Vec<JobErrorInfo>> {
|
||||
let mut req = self.client.get("/v1/job/errors");
|
||||
if let Some(j) = job_id {
|
||||
req = req.query(&[("job", j)]);
|
||||
}
|
||||
if let Some(t) = table {
|
||||
req = req.query(&[("table", t)]);
|
||||
}
|
||||
let (request_id, rsp) = self.client.send(req).await?;
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let body: RemoteErrorsResponse = rsp.json().await.err_to_http(request_id)?;
|
||||
Ok(body.errors.into_iter().map(JobErrorInfo::from).collect())
|
||||
}
|
||||
|
||||
async fn open_table(&self, request: OpenTableRequest) -> Result<Arc<dyn BaseTable>> {
|
||||
let identifier = build_table_identifier(
|
||||
&request.name,
|
||||
@@ -1303,12 +759,9 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
// Create a RestNamespace pointing to the same remote host with the same authentication headers
|
||||
let mut builder = lance_namespace_impls::RestNamespaceBuilder::new(self.client.host())
|
||||
.delimiter(&self.client.id_delimiter)
|
||||
// TODO: support header provider
|
||||
.headers(self.namespace_headers.clone());
|
||||
|
||||
if let Some(context_provider) = &self.namespace_context_provider {
|
||||
builder = builder.context_provider(Arc::clone(context_provider));
|
||||
}
|
||||
|
||||
// Apply mTLS configuration if present
|
||||
if let Some(tls_config) = &self.tls_config {
|
||||
if let Some(cert_file) = &tls_config.cert_file {
|
||||
@@ -1328,14 +781,6 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
}
|
||||
|
||||
async fn namespace_client_config(&self) -> Result<(String, HashMap<String, String>)> {
|
||||
if self.namespace_context_provider.is_some() {
|
||||
return Err(Error::NotSupported {
|
||||
message:
|
||||
"Cannot export a namespace client config when dynamic headers are configured; use LanceDB connection namespace methods instead"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let mut properties = HashMap::new();
|
||||
properties.insert("uri".to_string(), self.client.host().to_string());
|
||||
properties.insert("delimiter".to_string(), self.client.id_delimiter.clone());
|
||||
@@ -1387,13 +832,12 @@ impl From<StorageOptions> for RemoteOptions {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{NamespaceHeaderProviderContext, build_cache_key};
|
||||
use super::build_cache_key;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
use arrow_array::{Int32Array, RecordBatch};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use lance_namespace_impls::{DynamicContextProvider, OperationInfo};
|
||||
|
||||
use crate::connection::ConnectBuilder;
|
||||
use crate::{
|
||||
@@ -2046,223 +1490,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_derived_compute_routes() {
|
||||
// create_function
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::POST);
|
||||
assert_eq!(request.url().path(), "/v1/function/embed/create");
|
||||
let body: serde_json::Value =
|
||||
serde_json::from_slice(request.body().unwrap().as_bytes().unwrap()).unwrap();
|
||||
assert_eq!(body["language"], "python");
|
||||
assert_eq!(body["return_type"], "FLOAT[4]");
|
||||
assert_eq!(body["body"], "def embed(x): ...");
|
||||
assert_eq!(body["options"]["pip"], "torch");
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"name":"embed","status":"OK"}"#)
|
||||
.unwrap()
|
||||
});
|
||||
conn.create_function(crate::database::CreateFunctionRequest {
|
||||
name: "embed".into(),
|
||||
language: "python".into(),
|
||||
return_type: "FLOAT[4]".into(),
|
||||
body: "def embed(x): ...".into(),
|
||||
options: [("pip".to_string(), "torch".to_string())].into(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// list_functions
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::GET);
|
||||
assert_eq!(request.url().path(), "/v1/function/list");
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(
|
||||
r#"{"functions":[{"name":"embed","language":"python","return_type":"Float32","description":""}]}"#,
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
let functions = conn.list_functions().await.unwrap();
|
||||
assert_eq!(functions.len(), 1);
|
||||
assert_eq!(functions[0].name, "embed");
|
||||
|
||||
// drop_function
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::POST);
|
||||
assert_eq!(request.url().path(), "/v1/function/embed/drop");
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"name":"embed","status":"OK"}"#)
|
||||
.unwrap()
|
||||
});
|
||||
conn.drop_function("embed").await.unwrap();
|
||||
|
||||
// create_materialized_view
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::POST);
|
||||
assert_eq!(request.url().path(), "/v1/materialized_view/mv1/create");
|
||||
let body: serde_json::Value =
|
||||
serde_json::from_slice(request.body().unwrap().as_bytes().unwrap()).unwrap();
|
||||
assert_eq!(body["query"], "SELECT id, embed(body) AS vec FROM docs");
|
||||
assert_eq!(body["auto_refresh"], true);
|
||||
assert_eq!(body["with_no_data"], false);
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"name":"mv1","job_id":"j-1"}"#)
|
||||
.unwrap()
|
||||
});
|
||||
let mut request = crate::database::CreateMaterializedViewRequest::new(
|
||||
"mv1",
|
||||
"SELECT id, embed(body) AS vec FROM docs",
|
||||
);
|
||||
request.auto_refresh = true;
|
||||
let job_id = conn.create_materialized_view(request).await.unwrap();
|
||||
assert_eq!(job_id.as_deref(), Some("j-1"));
|
||||
|
||||
// refresh_materialized_view
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::POST);
|
||||
assert_eq!(request.url().path(), "/v1/materialized_view/mv1/refresh");
|
||||
let body: serde_json::Value =
|
||||
serde_json::from_slice(request.body().unwrap().as_bytes().unwrap()).unwrap();
|
||||
assert_eq!(body["num_workers"], 2);
|
||||
assert!(body.get("src_version").is_none());
|
||||
http::Response::builder()
|
||||
.status(202)
|
||||
.body(r#"{"job_id":"j-2"}"#)
|
||||
.unwrap()
|
||||
});
|
||||
let mut request = crate::database::RefreshMaterializedViewRequest::new("mv1");
|
||||
request.num_workers = Some(2);
|
||||
let job_id = conn.refresh_materialized_view(request).await.unwrap();
|
||||
assert_eq!(job_id, "j-2");
|
||||
|
||||
// alter_materialized_view
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::POST);
|
||||
assert_eq!(request.url().path(), "/v1/materialized_view/mv1/alter");
|
||||
let body: serde_json::Value =
|
||||
serde_json::from_slice(request.body().unwrap().as_bytes().unwrap()).unwrap();
|
||||
assert_eq!(body["auto_refresh"], false);
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"name":"mv1","status":"OK"}"#)
|
||||
.unwrap()
|
||||
});
|
||||
conn.alter_materialized_view("mv1", false).await.unwrap();
|
||||
|
||||
// drop_materialized_view
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.url().path(), "/v1/materialized_view/mv1/drop");
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"name":"mv1","status":"OK"}"#)
|
||||
.unwrap()
|
||||
});
|
||||
conn.drop_materialized_view("mv1").await.unwrap();
|
||||
|
||||
// list_materialized_views
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::GET);
|
||||
assert_eq!(request.url().path(), "/v1/materialized_view/list");
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(
|
||||
r#"{"views":[{"name":"mv1","source_table":"docs","projection":["id"],"udf_columns":["vec=embed(body)"],"filter":null,"auto_refresh":true}]}"#,
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
let views = conn.list_materialized_views().await.unwrap();
|
||||
assert_eq!(views.len(), 1);
|
||||
assert_eq!(views[0].source_table, "docs");
|
||||
assert!(views[0].auto_refresh);
|
||||
|
||||
// list_jobs
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::GET);
|
||||
assert_eq!(request.url().path(), "/v1/job/list");
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(
|
||||
r#"{"jobs":[{"table":"docs","job_id":"j-3","job_type":"udf_virtual_column_backfill","state":"running","column":"vec","age_seconds":4,"command":null,"units_done":1,"units_total":2,"committed":false,"rows_skipped":0,"error":null}]}"#,
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
let jobs = conn.list_jobs().await.unwrap();
|
||||
assert_eq!(jobs.len(), 1);
|
||||
assert_eq!(jobs[0].state, "running");
|
||||
assert_eq!(jobs[0].units_total, Some(2));
|
||||
|
||||
// cancel_job
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::POST);
|
||||
assert_eq!(request.url().path(), "/v1/job/j-3/cancel");
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"cancelled":true}"#)
|
||||
.unwrap()
|
||||
});
|
||||
assert!(conn.cancel_job("j-3").await.unwrap());
|
||||
|
||||
// cancel_job: no such inflight job -> false, not an error
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.url().path(), "/v1/job/gone/cancel");
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"cancelled":false}"#)
|
||||
.unwrap()
|
||||
});
|
||||
assert!(!conn.cancel_job("gone").await.unwrap());
|
||||
|
||||
// job_history: GET /v1/job/history, no filter
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::GET);
|
||||
assert_eq!(request.url().path(), "/v1/job/history");
|
||||
assert!(request.url().query().is_none());
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(
|
||||
r#"{"jobs":[{"table":"docs","job_id":"j-1","job_type":"udf_virtual_column_backfill","state":"done","column":"vec","created_ms":1000,"updated_ms":2000,"completed_ms":2000,"rows_processed":42,"rows_skipped":3,"error":null,"events":"created\ndone"}]}"#,
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
let hist = conn.job_history(None).await.unwrap();
|
||||
assert_eq!(hist.len(), 1);
|
||||
assert_eq!(hist[0].state, "done");
|
||||
assert_eq!(hist[0].rows_processed, Some(42));
|
||||
assert_eq!(hist[0].events.as_deref(), Some("created\ndone"));
|
||||
|
||||
// job_history: ?job= narrows to one job
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.url().path(), "/v1/job/history");
|
||||
assert_eq!(request.url().query(), Some("job=j-1"));
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"jobs":[]}"#)
|
||||
.unwrap()
|
||||
});
|
||||
assert!(conn.job_history(Some("j-1")).await.unwrap().is_empty());
|
||||
|
||||
// errors: GET /v1/job/errors with job + table filters
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::GET);
|
||||
assert_eq!(request.url().path(), "/v1/job/errors");
|
||||
assert_eq!(request.url().query(), Some("job=j-1&table=docs"));
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(
|
||||
r#"{"errors":[{"job_id":"j-1","table":"docs","column":"vec","error_type":"ValueError","error_message":"boom","fragment_id":0,"source_row_id":42,"table_version":7,"age_seconds":5}]}"#,
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
let errs = conn.errors(Some("j-1"), Some("docs")).await.unwrap();
|
||||
assert_eq!(errs.len(), 1);
|
||||
assert_eq!(errs[0].error_type, "ValueError");
|
||||
assert_eq!(errs[0].source_row_id, Some(42));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_clone_table() {
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
@@ -2475,75 +1702,6 @@ mod tests {
|
||||
assert!(namespace_client.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_namespace_header_provider_context_maps_headers() {
|
||||
#[derive(Debug)]
|
||||
struct TestHeaderProvider;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl HeaderProvider for TestHeaderProvider {
|
||||
async fn get_headers(&self) -> crate::Result<HashMap<String, String>> {
|
||||
Ok(HashMap::from([(
|
||||
"authorization".to_string(),
|
||||
"Bearer token".to_string(),
|
||||
)]))
|
||||
}
|
||||
}
|
||||
|
||||
let context_provider = NamespaceHeaderProviderContext {
|
||||
header_provider: Arc::new(TestHeaderProvider) as Arc<dyn HeaderProvider>,
|
||||
};
|
||||
|
||||
let context =
|
||||
context_provider.provide_context(&OperationInfo::new("list_tables", "namespace"));
|
||||
|
||||
assert_eq!(
|
||||
context.get("headers.authorization"),
|
||||
Some(&"Bearer token".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_namespace_client_supports_dynamic_headers() {
|
||||
#[derive(Debug)]
|
||||
struct TestHeaderProvider;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl HeaderProvider for TestHeaderProvider {
|
||||
async fn get_headers(&self) -> crate::Result<HashMap<String, String>> {
|
||||
Ok(HashMap::from([(
|
||||
"authorization".to_string(),
|
||||
"Bearer token".to_string(),
|
||||
)]))
|
||||
}
|
||||
}
|
||||
|
||||
let client_config = ClientConfig {
|
||||
header_provider: Some(Arc::new(TestHeaderProvider) as Arc<dyn HeaderProvider>),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let conn = Connection::new_with_handler_and_config(
|
||||
|_| {
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"tables": []}"#)
|
||||
.unwrap()
|
||||
},
|
||||
client_config,
|
||||
);
|
||||
|
||||
let namespace_client = conn.namespace_client().await;
|
||||
assert!(namespace_client.is_ok());
|
||||
|
||||
match conn.namespace_client_config().await {
|
||||
Err(Error::NotSupported { message })
|
||||
if message.contains("dynamic headers are configured") => {}
|
||||
Err(err) => panic!("expected NotSupported, got {err:?}"),
|
||||
Ok(_) => panic!("expected namespace_client_config to reject dynamic headers"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Integration tests using RestAdapter to run RemoteDatabase against a real namespace server
|
||||
mod rest_adapter_integration {
|
||||
use super::*;
|
||||
|
||||
@@ -1,907 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use log::debug;
|
||||
use reqwest::Client;
|
||||
use serde::Deserialize;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::remote::client::HeaderProvider;
|
||||
|
||||
const DEFAULT_REFRESH_BUFFER_SECS: u64 = 300;
|
||||
const DEFAULT_TOKEN_TTL_SECS: u64 = 3600;
|
||||
const AZURE_IMDS_ENDPOINT: &str = "http://169.254.169.254/metadata/identity/oauth2/token";
|
||||
const AZURE_IMDS_API_VERSION: &str = "2018-02-01";
|
||||
|
||||
/// OAuth authentication flow configuration.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum OAuthFlow {
|
||||
/// Client Credentials grant (service-to-service / M2M).
|
||||
/// Requires `client_secret` in [`OAuthConfig`].
|
||||
ClientCredentials,
|
||||
|
||||
/// Azure Managed Identity via IMDS.
|
||||
/// Works on Azure VMs, AKS, App Service, and Azure Functions.
|
||||
/// IMDS requests bypass proxy settings because the endpoint is link-local.
|
||||
AzureManagedIdentity {
|
||||
/// Client ID for user-assigned managed identity.
|
||||
/// Omit for system-assigned managed identity.
|
||||
client_id: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
/// OAuth configuration for LanceDB authentication.
|
||||
///
|
||||
/// All token acquisition and refresh is handled in the Rust layer.
|
||||
/// Python and TypeScript bindings expose this as a plain config object.
|
||||
#[derive(Clone)]
|
||||
pub struct OAuthConfig {
|
||||
/// OIDC issuer URL or OAuth authority URL.
|
||||
/// For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
|
||||
pub issuer_url: String,
|
||||
|
||||
/// Application / Client ID.
|
||||
pub client_id: String,
|
||||
|
||||
/// Client secret (required for `ClientCredentials`, optional for others).
|
||||
pub client_secret: Option<String>,
|
||||
|
||||
/// OAuth scopes to request.
|
||||
/// For Azure managed identity, exactly one scope or resource is required.
|
||||
/// For example: `["api://{app_id}/.default"]`
|
||||
pub scopes: Vec<String>,
|
||||
|
||||
/// Authentication flow to use.
|
||||
pub flow: OAuthFlow,
|
||||
|
||||
/// Seconds before token expiry to trigger proactive refresh (default: 300).
|
||||
/// Keep this well below the token TTL; if it is greater than or equal to
|
||||
/// the TTL, each request refreshes the token.
|
||||
pub refresh_buffer_secs: Option<u64>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for OAuthConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("OAuthConfig")
|
||||
.field("issuer_url", &self.issuer_url)
|
||||
.field("client_id", &self.client_id)
|
||||
.field(
|
||||
"client_secret",
|
||||
&self.client_secret.as_deref().map(|_| "<redacted>"),
|
||||
)
|
||||
.field("scopes", &self.scopes)
|
||||
.field("flow", &self.flow)
|
||||
.field("refresh_buffer_secs", &self.refresh_buffer_secs)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
// -- OIDC Discovery --
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
struct OidcDiscovery {
|
||||
token_endpoint: String,
|
||||
}
|
||||
|
||||
// -- Token Response --
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct TokenResponse {
|
||||
access_token: String,
|
||||
/// Token lifetime in seconds.
|
||||
/// Some providers (Azure IMDS) return this as a string, so we accept both.
|
||||
#[serde(default, deserialize_with = "deserialize_optional_u64_or_string")]
|
||||
expires_in: Option<u64>,
|
||||
#[serde(default)]
|
||||
#[allow(dead_code)]
|
||||
token_type: Option<String>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for TokenResponse {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("TokenResponse")
|
||||
.field("access_token", &"<redacted>")
|
||||
.field("expires_in", &self.expires_in)
|
||||
.field("token_type", &self.token_type)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
fn deserialize_optional_u64_or_string<'de, D>(
|
||||
deserializer: D,
|
||||
) -> std::result::Result<Option<u64>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
use serde::de;
|
||||
|
||||
struct U64OrString;
|
||||
impl<'de> de::Visitor<'de> for U64OrString {
|
||||
type Value = Option<u64>;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("an integer, an integer-valued float, a numeric string, or null")
|
||||
}
|
||||
|
||||
fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
|
||||
Ok(Some(v))
|
||||
}
|
||||
|
||||
fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
|
||||
if v < 0 {
|
||||
return Err(E::custom(format!("invalid expires_in value: {v}")));
|
||||
}
|
||||
Ok(Some(v as u64))
|
||||
}
|
||||
|
||||
fn visit_f64<E: de::Error>(self, v: f64) -> std::result::Result<Self::Value, E> {
|
||||
if !v.is_finite() || v < 0.0 || v.fract() != 0.0 || v > u64::MAX as f64 {
|
||||
return Err(E::custom(format!("invalid expires_in value: {v}")));
|
||||
}
|
||||
Ok(Some(v as u64))
|
||||
}
|
||||
|
||||
fn visit_str<E: de::Error>(self, v: &str) -> std::result::Result<Self::Value, E> {
|
||||
v.parse::<u64>().map(Some).map_err(de::Error::custom)
|
||||
}
|
||||
|
||||
fn visit_none<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn visit_unit<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_any(U64OrString)
|
||||
}
|
||||
|
||||
// -- Internal Token State --
|
||||
|
||||
struct TokenState {
|
||||
access_token: Option<String>,
|
||||
expires_at: Option<Instant>,
|
||||
}
|
||||
|
||||
impl TokenState {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
access_token: None,
|
||||
expires_at: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_expired(&self, buffer: Duration) -> bool {
|
||||
match (self.access_token.as_ref(), self.expires_at) {
|
||||
(Some(_), Some(expires_at)) => Instant::now() + buffer >= expires_at,
|
||||
(None, _) => true,
|
||||
(Some(_), None) => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn update(&mut self, resp: &TokenResponse) {
|
||||
self.access_token = Some(resp.access_token.clone());
|
||||
let expires_in = resp.expires_in.unwrap_or(DEFAULT_TOKEN_TTL_SECS);
|
||||
self.expires_at = Some(Instant::now() + Duration::from_secs(expires_in));
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
trait TokenSource: Send + Sync + std::fmt::Debug {
|
||||
async fn fetch_token(&self) -> Result<TokenResponse>;
|
||||
}
|
||||
|
||||
struct ClientCredentialsSource {
|
||||
issuer_url: String,
|
||||
client_id: String,
|
||||
client_secret: String,
|
||||
scopes: Vec<String>,
|
||||
http_client: Client,
|
||||
discovery: RwLock<Option<OidcDiscovery>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ClientCredentialsSource {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ClientCredentialsSource")
|
||||
.field("issuer_url", &self.issuer_url)
|
||||
.field("client_id", &self.client_id)
|
||||
.field("client_secret", &"<redacted>")
|
||||
.field("scopes", &self.scopes)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientCredentialsSource {
|
||||
fn new(
|
||||
issuer_url: String,
|
||||
client_id: String,
|
||||
client_secret: Option<String>,
|
||||
scopes: Vec<String>,
|
||||
) -> Result<Self> {
|
||||
let client_secret = client_secret.ok_or(Error::InvalidInput {
|
||||
message: "client_secret is required for ClientCredentials flow".to_string(),
|
||||
})?;
|
||||
Self::validate_issuer_transport(&issuer_url)?;
|
||||
|
||||
let http_client = Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to create HTTP client for OAuth: {e}"),
|
||||
})?;
|
||||
|
||||
Ok(Self {
|
||||
issuer_url,
|
||||
client_id,
|
||||
client_secret,
|
||||
scopes,
|
||||
http_client,
|
||||
discovery: RwLock::new(None),
|
||||
})
|
||||
}
|
||||
|
||||
fn validate_issuer_transport(issuer_url: &str) -> Result<()> {
|
||||
let issuer = url::Url::parse(issuer_url).map_err(|e| Error::InvalidInput {
|
||||
message: format!("Invalid OAuth issuer_url: {e}"),
|
||||
})?;
|
||||
|
||||
match issuer.scheme() {
|
||||
"https" => Ok(()),
|
||||
"http" if Self::is_loopback_issuer(&issuer) => Ok(()),
|
||||
_ => Err(Error::InvalidInput {
|
||||
message:
|
||||
"ClientCredentials OAuth issuer_url must use https, except for loopback hosts"
|
||||
.to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_loopback_issuer(issuer: &url::Url) -> bool {
|
||||
let Some(host) = issuer.host_str() else {
|
||||
return false;
|
||||
};
|
||||
|
||||
host.eq_ignore_ascii_case("localhost")
|
||||
|| host
|
||||
.parse::<IpAddr>()
|
||||
.map(|addr| addr.is_loopback())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
async fn get_discovery(&self) -> Result<OidcDiscovery> {
|
||||
{
|
||||
let cached = self.discovery.read().await;
|
||||
if let Some(ref disc) = *cached {
|
||||
return Ok(disc.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let mut cache = self.discovery.write().await;
|
||||
// Double-check
|
||||
if let Some(ref disc) = *cache {
|
||||
return Ok(disc.clone());
|
||||
}
|
||||
|
||||
let discovery_url = format!(
|
||||
"{}/.well-known/openid-configuration",
|
||||
self.issuer_url.trim_end_matches('/')
|
||||
);
|
||||
|
||||
debug!("Fetching OIDC discovery from {}", discovery_url);
|
||||
|
||||
let resp = self
|
||||
.http_client
|
||||
.get(&discovery_url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to fetch OIDC discovery document: {e}"),
|
||||
})?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(Error::Runtime {
|
||||
message: format!(
|
||||
"OIDC discovery failed with status {}: {}",
|
||||
resp.status(),
|
||||
resp.text().await.unwrap_or_default()
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
let disc: OidcDiscovery = resp.json().await.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to parse OIDC discovery document: {e}"),
|
||||
})?;
|
||||
|
||||
let result = disc.clone();
|
||||
|
||||
*cache = Some(disc);
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn get_token_endpoint(&self) -> Result<String> {
|
||||
self.get_discovery().await.map(|disc| disc.token_endpoint)
|
||||
}
|
||||
|
||||
fn scopes_string(&self) -> String {
|
||||
self.scopes.join(" ")
|
||||
}
|
||||
|
||||
async fn post_token_request(
|
||||
&self,
|
||||
endpoint: &str,
|
||||
params: &[(&str, &str)],
|
||||
) -> Result<TokenResponse> {
|
||||
let resp = self
|
||||
.http_client
|
||||
.post(endpoint)
|
||||
.form(params)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Token request to {endpoint} failed: {e}"),
|
||||
})?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(Error::Runtime {
|
||||
message: format!(
|
||||
"Token request failed with status {}: {}",
|
||||
resp.status(),
|
||||
resp.text().await.unwrap_or_default()
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
resp.json().await.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to parse token response: {e}"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TokenSource for ClientCredentialsSource {
|
||||
async fn fetch_token(&self) -> Result<TokenResponse> {
|
||||
let token_endpoint = self.get_token_endpoint().await?;
|
||||
let scope = self.scopes_string();
|
||||
let params = [
|
||||
("grant_type", "client_credentials"),
|
||||
("client_id", self.client_id.as_str()),
|
||||
("client_secret", self.client_secret.as_str()),
|
||||
("scope", scope.as_str()),
|
||||
];
|
||||
|
||||
self.post_token_request(&token_endpoint, ¶ms).await
|
||||
}
|
||||
}
|
||||
|
||||
struct AzureImdsSource {
|
||||
client_id: Option<String>,
|
||||
resource: String,
|
||||
http_client: Client,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for AzureImdsSource {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("AzureImdsSource")
|
||||
.field("client_id", &self.client_id)
|
||||
.field("resource", &self.resource)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl AzureImdsSource {
|
||||
fn new(scopes: Vec<String>, client_id: Option<String>) -> Result<Self> {
|
||||
let resource = Self::resource_from_scopes(&scopes)?;
|
||||
let http_client = Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.no_proxy()
|
||||
.build()
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to create HTTP client for Azure IMDS OAuth: {e}"),
|
||||
})?;
|
||||
|
||||
Ok(Self {
|
||||
client_id,
|
||||
resource,
|
||||
http_client,
|
||||
})
|
||||
}
|
||||
|
||||
fn resource_from_scopes(scopes: &[String]) -> Result<String> {
|
||||
let [scope] = scopes else {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "AzureManagedIdentity flow requires exactly one OAuth scope or resource"
|
||||
.to_string(),
|
||||
});
|
||||
};
|
||||
|
||||
Ok(scope.strip_suffix("/.default").unwrap_or(scope).to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TokenSource for AzureImdsSource {
|
||||
async fn fetch_token(&self) -> Result<TokenResponse> {
|
||||
let mut url = format!(
|
||||
"{AZURE_IMDS_ENDPOINT}?api-version={AZURE_IMDS_API_VERSION}&resource={}",
|
||||
urlencoding::encode(&self.resource),
|
||||
);
|
||||
if let Some(cid) = self.client_id.as_deref() {
|
||||
url.push_str(&format!("&client_id={}", urlencoding::encode(cid)));
|
||||
}
|
||||
|
||||
let resp = self
|
||||
.http_client
|
||||
.get(&url)
|
||||
.header("Metadata", "true")
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Azure IMDS request failed: {e}"),
|
||||
})?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(Error::Runtime {
|
||||
message: format!(
|
||||
"Azure IMDS returned status {}: {}",
|
||||
resp.status(),
|
||||
resp.text().await.unwrap_or_default()
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
resp.json().await.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to parse IMDS token response: {e}"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// OAuth header provider that manages the full token lifecycle.
|
||||
///
|
||||
/// Implements [`HeaderProvider`] to inject `Authorization: Bearer <token>`
|
||||
/// headers into every LanceDB request, with automatic token refresh.
|
||||
pub struct OAuthHeaderProvider {
|
||||
token_source: Box<dyn TokenSource>,
|
||||
token_state: Arc<RwLock<TokenState>>,
|
||||
refresh_buffer: Duration,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for OAuthHeaderProvider {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("OAuthHeaderProvider")
|
||||
.field("token_source", &self.token_source)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl OAuthHeaderProvider {
|
||||
/// Create a new OAuth header provider from configuration.
|
||||
pub fn new(config: OAuthConfig) -> Result<Self> {
|
||||
let OAuthConfig {
|
||||
issuer_url,
|
||||
client_id,
|
||||
client_secret,
|
||||
scopes,
|
||||
flow,
|
||||
refresh_buffer_secs,
|
||||
} = config;
|
||||
|
||||
if scopes.is_empty() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "At least one OAuth scope is required".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let refresh_buffer =
|
||||
Duration::from_secs(refresh_buffer_secs.unwrap_or(DEFAULT_REFRESH_BUFFER_SECS));
|
||||
let token_source: Box<dyn TokenSource> = match flow {
|
||||
OAuthFlow::ClientCredentials => Box::new(ClientCredentialsSource::new(
|
||||
issuer_url,
|
||||
client_id,
|
||||
client_secret,
|
||||
scopes,
|
||||
)?),
|
||||
OAuthFlow::AzureManagedIdentity { client_id } => {
|
||||
Box::new(AzureImdsSource::new(scopes, client_id)?)
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
token_source,
|
||||
token_state: Arc::new(RwLock::new(TokenState::new())),
|
||||
refresh_buffer,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get a valid access token, refreshing if necessary.
|
||||
async fn get_valid_token(&self) -> Result<String> {
|
||||
// Fast path: check if current token is still valid
|
||||
{
|
||||
let state = self.token_state.read().await;
|
||||
if !state.is_expired(self.refresh_buffer)
|
||||
&& let Some(ref token) = state.access_token
|
||||
{
|
||||
return Ok(token.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Slow path: acquire or refresh token
|
||||
let mut state = self.token_state.write().await;
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if !state.is_expired(self.refresh_buffer)
|
||||
&& let Some(ref token) = state.access_token
|
||||
{
|
||||
return Ok(token.clone());
|
||||
}
|
||||
|
||||
debug!("Acquiring new OAuth token via {:?}", self.token_source);
|
||||
let resp = self.token_source.fetch_token().await?;
|
||||
|
||||
state.update(&resp);
|
||||
Ok(resp.access_token)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl HeaderProvider for OAuthHeaderProvider {
|
||||
async fn get_headers(&self) -> Result<HashMap<String, String>> {
|
||||
let token = self.get_valid_token().await?;
|
||||
Ok(HashMap::from([(
|
||||
"authorization".to_string(),
|
||||
format!("Bearer {token}"),
|
||||
)]))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
#[test]
|
||||
fn test_token_state_expiry() {
|
||||
let mut state = TokenState::new();
|
||||
assert!(state.is_expired(Duration::from_secs(0)));
|
||||
|
||||
state.access_token = Some("tok".to_string());
|
||||
state.expires_at = Some(Instant::now() + Duration::from_secs(600));
|
||||
assert!(!state.is_expired(Duration::from_secs(300)));
|
||||
assert!(state.is_expired(Duration::from_secs(601)));
|
||||
|
||||
state.expires_at = None;
|
||||
assert!(state.is_expired(Duration::from_secs(0)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_state_uses_default_expiry() {
|
||||
let mut state = TokenState::new();
|
||||
let response = TokenResponse {
|
||||
access_token: "tok".to_string(),
|
||||
expires_in: None,
|
||||
token_type: None,
|
||||
};
|
||||
|
||||
state.update(&response);
|
||||
|
||||
assert!(!state.is_expired(Duration::from_secs(DEFAULT_TOKEN_TTL_SECS - 1)));
|
||||
assert!(state.is_expired(Duration::from_secs(DEFAULT_TOKEN_TTL_SECS + 1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_response_accepts_float_expires_in() {
|
||||
let response: TokenResponse =
|
||||
serde_json::from_str(r#"{"access_token":"tok","expires_in":3600.0}"#).unwrap();
|
||||
|
||||
assert_eq!(response.expires_in, Some(3600));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_response_rejects_negative_expires_in() {
|
||||
let err =
|
||||
serde_json::from_str::<TokenResponse>(r#"{"access_token":"tok","expires_in":-1}"#)
|
||||
.unwrap_err();
|
||||
|
||||
assert!(err.to_string().contains("invalid expires_in value: -1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_response_debug_redacts_access_token() {
|
||||
let response = TokenResponse {
|
||||
access_token: "secret-token".to_string(),
|
||||
expires_in: Some(3600),
|
||||
token_type: Some("Bearer".to_string()),
|
||||
};
|
||||
|
||||
let debug = format!("{response:?}");
|
||||
assert!(!debug.contains("secret-token"));
|
||||
assert!(debug.contains("access_token: \"<redacted>\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scopes_string() {
|
||||
let source = ClientCredentialsSource::new(
|
||||
"https://login.microsoftonline.com/tenant/v2.0".to_string(),
|
||||
"app-id".to_string(),
|
||||
Some("secret".to_string()),
|
||||
vec!["scope1".to_string(), "scope2".to_string()],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(source.scopes_string(), "scope1 scope2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_config_debug_redacts_client_secret() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "https://issuer.example.com".to_string(),
|
||||
client_id: "client-id".to_string(),
|
||||
client_secret: Some("super-secret".to_string()),
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: OAuthFlow::ClientCredentials,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
|
||||
let debug = format!("{config:?}");
|
||||
assert!(!debug.contains("super-secret"));
|
||||
assert!(debug.contains("client_secret: Some(\"<redacted>\")"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth_header_provider_debug_redacts_client_secret() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "https://issuer.example.com".to_string(),
|
||||
client_id: "client-id".to_string(),
|
||||
client_secret: Some("super-secret".to_string()),
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: OAuthFlow::ClientCredentials,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
|
||||
let provider = OAuthHeaderProvider::new(config).unwrap();
|
||||
let debug = format!("{provider:?}");
|
||||
assert!(!debug.contains("super-secret"));
|
||||
assert!(debug.contains("client_secret: \"<redacted>\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_managed_identity_resource_from_default_scope() {
|
||||
assert_eq!(
|
||||
AzureImdsSource::resource_from_scopes(&["api://test/.default".to_string()]).unwrap(),
|
||||
"api://test"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_managed_identity_resource_without_default_suffix() {
|
||||
assert_eq!(
|
||||
AzureImdsSource::resource_from_scopes(&["api://test".to_string()]).unwrap(),
|
||||
"api://test"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_managed_identity_rejects_multiple_scopes() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
|
||||
client_id: "app-id".to_string(),
|
||||
client_secret: None,
|
||||
scopes: vec![
|
||||
"api://test-a/.default".to_string(),
|
||||
"api://test-b/.default".to_string(),
|
||||
],
|
||||
flow: OAuthFlow::AzureManagedIdentity { client_id: None },
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
assert!(OAuthHeaderProvider::new(config).is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_token_endpoint_requires_discovery_success() {
|
||||
let (issuer_url, server) = spawn_discovery_error_server().await;
|
||||
let source = ClientCredentialsSource::new(
|
||||
issuer_url,
|
||||
"client-id".to_string(),
|
||||
Some("secret".to_string()),
|
||||
vec!["scope".to_string()],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let err = source.get_token_endpoint().await.unwrap_err();
|
||||
assert!(matches!(
|
||||
err,
|
||||
Error::Runtime { message }
|
||||
if message.contains("OIDC discovery failed with status 503")
|
||||
));
|
||||
server.await.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_client_credentials_requires_secret() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
|
||||
client_id: "app-id".to_string(),
|
||||
client_secret: None,
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: OAuthFlow::ClientCredentials,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
assert!(OAuthHeaderProvider::new(config).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_client_credentials_rejects_insecure_non_loopback_issuer() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "http://issuer.example.com".to_string(),
|
||||
client_id: "app-id".to_string(),
|
||||
client_secret: Some("secret".to_string()),
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: OAuthFlow::ClientCredentials,
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
|
||||
let err = OAuthHeaderProvider::new(config).unwrap_err();
|
||||
assert!(matches!(
|
||||
err,
|
||||
Error::InvalidInput { message }
|
||||
if message == "ClientCredentials OAuth issuer_url must use https, except for loopback hosts"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_scopes_rejected() {
|
||||
let config = OAuthConfig {
|
||||
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
|
||||
client_id: "app-id".to_string(),
|
||||
client_secret: None,
|
||||
scopes: vec![],
|
||||
flow: OAuthFlow::AzureManagedIdentity { client_id: None },
|
||||
refresh_buffer_secs: None,
|
||||
};
|
||||
assert!(OAuthHeaderProvider::new(config).is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_client_credentials_token_lifecycle() {
|
||||
let (issuer_url, token_requests, server) = spawn_oauth_server().await;
|
||||
let config = OAuthConfig {
|
||||
issuer_url,
|
||||
client_id: "client-id".to_string(),
|
||||
client_secret: Some("secret".to_string()),
|
||||
scopes: vec!["scope".to_string()],
|
||||
flow: OAuthFlow::ClientCredentials,
|
||||
refresh_buffer_secs: Some(0),
|
||||
};
|
||||
let provider = OAuthHeaderProvider::new(config).unwrap();
|
||||
|
||||
let headers = provider.get_headers().await.unwrap();
|
||||
assert_eq!(headers.get("authorization").unwrap(), "Bearer token-1");
|
||||
assert_eq!(token_requests.load(Ordering::SeqCst), 1);
|
||||
|
||||
let headers = provider.get_headers().await.unwrap();
|
||||
assert_eq!(headers.get("authorization").unwrap(), "Bearer token-1");
|
||||
assert_eq!(token_requests.load(Ordering::SeqCst), 1);
|
||||
|
||||
provider.token_state.write().await.expires_at =
|
||||
Some(Instant::now() - Duration::from_secs(1));
|
||||
|
||||
let headers = provider.get_headers().await.unwrap();
|
||||
assert_eq!(headers.get("authorization").unwrap(), "Bearer token-2");
|
||||
assert_eq!(token_requests.load(Ordering::SeqCst), 2);
|
||||
|
||||
server.await.unwrap();
|
||||
}
|
||||
|
||||
async fn spawn_oauth_server() -> (String, Arc<AtomicUsize>, JoinHandle<()>) {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let issuer_url = format!("http://{addr}");
|
||||
let token_requests = Arc::new(AtomicUsize::new(0));
|
||||
let server_token_requests = Arc::clone(&token_requests);
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
for _ in 0..3 {
|
||||
let (mut stream, _) = listener.accept().await.unwrap();
|
||||
let (request_line, body) = read_http_request(&mut stream).await;
|
||||
|
||||
if request_line.starts_with("GET /.well-known/openid-configuration ") {
|
||||
let discovery = format!(r#"{{"token_endpoint":"http://{addr}/token"}}"#);
|
||||
write_json_response(&mut stream, "200 OK", &discovery).await;
|
||||
} else if request_line.starts_with("POST /token ") {
|
||||
assert!(body.contains("grant_type=client_credentials"));
|
||||
assert!(body.contains("client_id=client-id"));
|
||||
assert!(body.contains("client_secret=secret"));
|
||||
assert!(body.contains("scope=scope"));
|
||||
|
||||
let token_num = server_token_requests.fetch_add(1, Ordering::SeqCst) + 1;
|
||||
let token = format!(
|
||||
r#"{{"access_token":"token-{token_num}","expires_in":3600,"token_type":"Bearer"}}"#
|
||||
);
|
||||
write_json_response(&mut stream, "200 OK", &token).await;
|
||||
} else {
|
||||
write_json_response(&mut stream, "404 Not Found", "{}").await;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
(issuer_url, token_requests, server)
|
||||
}
|
||||
|
||||
async fn spawn_discovery_error_server() -> (String, JoinHandle<()>) {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let issuer_url = format!("http://{addr}");
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
let (mut stream, _) = listener.accept().await.unwrap();
|
||||
let (request_line, _) = read_http_request(&mut stream).await;
|
||||
assert!(request_line.starts_with("GET /.well-known/openid-configuration "));
|
||||
write_json_response(&mut stream, "503 Service Unavailable", "{}").await;
|
||||
});
|
||||
|
||||
(issuer_url, server)
|
||||
}
|
||||
|
||||
async fn read_http_request(stream: &mut TcpStream) -> (String, String) {
|
||||
let mut buffer = Vec::new();
|
||||
let mut header_end = None;
|
||||
|
||||
while header_end.is_none() {
|
||||
let mut chunk = [0; 1024];
|
||||
let read = stream.read(&mut chunk).await.unwrap();
|
||||
assert_ne!(read, 0, "connection closed before request headers");
|
||||
buffer.extend_from_slice(&chunk[..read]);
|
||||
header_end = find_subsequence(&buffer, b"\r\n\r\n").map(|pos| pos + 4);
|
||||
}
|
||||
|
||||
let header_end = header_end.unwrap();
|
||||
let headers = String::from_utf8_lossy(&buffer[..header_end]).to_string();
|
||||
let request_line = headers.lines().next().unwrap_or_default().to_string();
|
||||
let content_length = headers
|
||||
.lines()
|
||||
.find_map(|line| {
|
||||
let (name, value) = line.split_once(':')?;
|
||||
name.eq_ignore_ascii_case("content-length")
|
||||
.then(|| value.trim().parse::<usize>().ok())
|
||||
.flatten()
|
||||
})
|
||||
.unwrap_or(0);
|
||||
|
||||
while buffer.len() < header_end + content_length {
|
||||
let mut chunk = [0; 1024];
|
||||
let read = stream.read(&mut chunk).await.unwrap();
|
||||
assert_ne!(read, 0, "connection closed before request body");
|
||||
buffer.extend_from_slice(&chunk[..read]);
|
||||
}
|
||||
|
||||
let body =
|
||||
String::from_utf8_lossy(&buffer[header_end..header_end + content_length]).to_string();
|
||||
|
||||
(request_line, body)
|
||||
}
|
||||
|
||||
fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option<usize> {
|
||||
haystack
|
||||
.windows(needle.len())
|
||||
.position(|window| window == needle)
|
||||
}
|
||||
|
||||
async fn write_json_response(stream: &mut TcpStream, status: &str, body: &str) {
|
||||
let response = format!(
|
||||
"HTTP/1.1 {status}\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
|
||||
body.len()
|
||||
);
|
||||
stream.write_all(response.as_bytes()).await.unwrap();
|
||||
}
|
||||
}
|
||||
@@ -1352,35 +1352,6 @@ impl<S: HttpSend + 'static> RemoteTable<S> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Deserialize an index's `created_at` field.
|
||||
///
|
||||
/// The server returns this as an RFC 3339 string (e.g. `"2026-06-18T21:37:36.637Z"`),
|
||||
/// but older deployments sent a unix timestamp in milliseconds. Accept both so the
|
||||
/// client works against any server version.
|
||||
fn deserialize_created_at<'de, D>(
|
||||
deserializer: D,
|
||||
) -> std::result::Result<Option<DateTime<Utc>>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
use serde::de::Error as _;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum CreatedAt {
|
||||
Rfc3339(String),
|
||||
Millis(i64),
|
||||
}
|
||||
|
||||
match Option::<CreatedAt>::deserialize(deserializer)? {
|
||||
None => Ok(None),
|
||||
Some(CreatedAt::Rfc3339(s)) => DateTime::parse_from_rfc3339(&s)
|
||||
.map(|dt| Some(dt.with_timezone(&Utc)))
|
||||
.map_err(D::Error::custom),
|
||||
Some(CreatedAt::Millis(ms)) => Ok(DateTime::from_timestamp_millis(ms)),
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: HttpSend + 'static> RemoteTable<S> {
|
||||
/// Parse the response from `/index/list/` into `IndexConfig` entries.
|
||||
///
|
||||
@@ -1409,7 +1380,7 @@ impl<S: HttpSend + 'static> RemoteTable<S> {
|
||||
// Used as the sentinel to decide whether to skip the stats call.
|
||||
index_type: Option<IndexType>,
|
||||
index_uuid: Option<String>,
|
||||
#[serde(default, deserialize_with = "deserialize_created_at")]
|
||||
#[serde(default, with = "chrono::serde::ts_milliseconds_option")]
|
||||
created_at: Option<DateTime<Utc>>,
|
||||
num_indexed_rows: Option<u64>,
|
||||
num_unindexed_rows: Option<u64>,
|
||||
@@ -2309,126 +2280,6 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
message: "optimize is not supported on LanceDB cloud.".into(),
|
||||
})
|
||||
}
|
||||
async fn add_computed_columns(
|
||||
&self,
|
||||
columns: &[(String, String)],
|
||||
expression: &str,
|
||||
) -> Result<()> {
|
||||
let new_columns: Vec<serde_json::Value> = columns
|
||||
.iter()
|
||||
.map(|(name, data_type)| {
|
||||
serde_json::json!({
|
||||
"name": name,
|
||||
"computed": { "data_type": data_type, "expression": expression },
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/add_columns/", self.identifier))
|
||||
.json(&serde_json::json!({ "new_columns": new_columns }));
|
||||
let (request_id, response) = self.send(request, true).await?;
|
||||
self.check_table_response(&request_id, response).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn refresh_column(
|
||||
&self,
|
||||
columns: &[String],
|
||||
where_clause: Option<String>,
|
||||
num_workers: Option<u32>,
|
||||
max_workers: Option<u32>,
|
||||
batch_size: Option<u32>,
|
||||
priority: Option<String>,
|
||||
) -> Result<String> {
|
||||
let mut body = serde_json::json!({ "columns": columns });
|
||||
if let Some(w) = where_clause {
|
||||
body["where_clause"] = serde_json::Value::String(w);
|
||||
}
|
||||
if let Some(n) = num_workers {
|
||||
body["num_workers"] = n.into();
|
||||
}
|
||||
if let Some(n) = max_workers {
|
||||
body["max_workers"] = n.into();
|
||||
}
|
||||
if let Some(n) = batch_size {
|
||||
body["batch_size"] = n.into();
|
||||
}
|
||||
if let Some(p) = priority {
|
||||
body["priority"] = serde_json::Value::String(p);
|
||||
}
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/refresh_column", self.identifier))
|
||||
.json(&body);
|
||||
let (request_id, response) = self.send(request, true).await?;
|
||||
let response = self.check_table_response(&request_id, response).await?;
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RefreshColumnResponse {
|
||||
job_id: String,
|
||||
}
|
||||
let body: RefreshColumnResponse = response.json().await.err_to_http(request_id)?;
|
||||
Ok(body.job_id)
|
||||
}
|
||||
|
||||
async fn load_columns(&self, request: crate::table::LoadColumnsRequest) -> Result<String> {
|
||||
let columns: Vec<serde_json::Value> = request
|
||||
.columns
|
||||
.iter()
|
||||
.map(|(target, source)| {
|
||||
serde_json::json!({
|
||||
"target": target,
|
||||
"source": source.clone().unwrap_or_else(|| target.clone()),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let mut source = serde_json::json!({
|
||||
"uris": request.source_uris,
|
||||
"format": request.source_format,
|
||||
});
|
||||
if let Some(opts) = request.source_storage_options {
|
||||
source["storage_options"] = serde_json::to_value(opts).unwrap_or_default();
|
||||
}
|
||||
let mut body = serde_json::json!({
|
||||
"columns": columns,
|
||||
"source": source,
|
||||
"target_key": request.target_key,
|
||||
});
|
||||
if let Some(k) = request.source_key {
|
||||
body["source_key"] = serde_json::Value::String(k);
|
||||
}
|
||||
if let Some(m) = request.on_missing {
|
||||
body["on_missing"] = serde_json::Value::String(m);
|
||||
}
|
||||
if let Some(n) = request.num_workers {
|
||||
body["num_workers"] = n.into();
|
||||
}
|
||||
if let Some(n) = request.max_workers {
|
||||
body["max_workers"] = n.into();
|
||||
}
|
||||
if let Some(n) = request.batch_size {
|
||||
body["batch_size"] = n.into();
|
||||
}
|
||||
if let Some(n) = request.commit_granularity {
|
||||
body["commit_granularity"] = n.into();
|
||||
}
|
||||
if let Some(p) = request.priority {
|
||||
body["priority"] = serde_json::Value::String(p);
|
||||
}
|
||||
let http_request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/load_columns", self.identifier))
|
||||
.json(&body);
|
||||
let (request_id, response) = self.send(http_request, true).await?;
|
||||
let response = self.check_table_response(&request_id, response).await?;
|
||||
#[derive(serde::Deserialize)]
|
||||
struct LoadColumnsResponse {
|
||||
job_id: String,
|
||||
}
|
||||
let body: LoadColumnsResponse = response.json().await.err_to_http(request_id)?;
|
||||
Ok(body.job_id)
|
||||
}
|
||||
|
||||
async fn add_columns(
|
||||
&self,
|
||||
transforms: NewColumnTransform,
|
||||
@@ -2921,75 +2772,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_refresh_column() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/refresh_column");
|
||||
let body: serde_json::Value =
|
||||
serde_json::from_slice(request.body().unwrap().as_bytes().unwrap()).unwrap();
|
||||
assert_eq!(body["columns"], serde_json::json!(["vec"]));
|
||||
assert_eq!(body["num_workers"], 2);
|
||||
assert!(body.get("where_clause").is_none());
|
||||
|
||||
http::Response::builder()
|
||||
.status(202)
|
||||
.body(r#"{"job_id":"j-9"}"#)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let job_id = table
|
||||
.refresh_column(&["vec".to_string()], None, Some(2), None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(job_id, "j-9");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_load_columns() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/load_columns");
|
||||
let body: serde_json::Value =
|
||||
serde_json::from_slice(request.body().unwrap().as_bytes().unwrap()).unwrap();
|
||||
assert_eq!(
|
||||
body["columns"],
|
||||
serde_json::json!([{"target": "embedding", "source": "emb"}])
|
||||
);
|
||||
assert_eq!(body["source"]["format"], "parquet");
|
||||
assert_eq!(
|
||||
body["source"]["uris"],
|
||||
serde_json::json!(["s3://b/x.parquet"])
|
||||
);
|
||||
assert_eq!(body["target_key"], "document_id");
|
||||
assert_eq!(body["source_key"], "doc_id");
|
||||
assert_eq!(body["on_missing"], "null");
|
||||
assert_eq!(body["num_workers"], 4);
|
||||
|
||||
http::Response::builder()
|
||||
.status(202)
|
||||
.body(r#"{"job_id":"lc-7"}"#)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let request = crate::table::LoadColumnsRequest {
|
||||
source_uris: vec!["s3://b/x.parquet".to_string()],
|
||||
source_format: "parquet".to_string(),
|
||||
source_storage_options: None,
|
||||
target_key: "document_id".to_string(),
|
||||
source_key: Some("doc_id".to_string()),
|
||||
columns: vec![("embedding".to_string(), Some("emb".to_string()))],
|
||||
on_missing: Some("null".to_string()),
|
||||
num_workers: Some(4),
|
||||
max_workers: None,
|
||||
batch_size: None,
|
||||
commit_granularity: None,
|
||||
priority: None,
|
||||
};
|
||||
let job_id = table.load_columns(request).await.unwrap();
|
||||
assert_eq!(job_id, "lc-7");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_version() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
@@ -4896,7 +4678,7 @@ mod tests {
|
||||
"num_segments": 2,
|
||||
"index_version": 1,
|
||||
"index_details": "{\"num_partitions\":16}",
|
||||
"created_at": "2026-06-18T21:37:36.637Z",
|
||||
"created_at": 1700000000000i64,
|
||||
"type_url": "type.googleapis.com/lance.index.vector.IvfPq",
|
||||
},
|
||||
{
|
||||
@@ -4946,10 +4728,7 @@ mod tests {
|
||||
vec_idx.type_url,
|
||||
Some("type.googleapis.com/lance.index.vector.IvfPq".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
vec_idx.created_at,
|
||||
Some("2026-06-18T21:37:36.637Z".parse::<DateTime<Utc>>().unwrap())
|
||||
);
|
||||
assert!(vec_idx.created_at.is_some());
|
||||
|
||||
let text_idx = &indices[1];
|
||||
assert_eq!(text_idx.name, "text_idx");
|
||||
@@ -4970,36 +4749,6 @@ mod tests {
|
||||
assert_eq!(text_idx.created_at, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_created_at() {
|
||||
#[derive(Deserialize)]
|
||||
struct Wrapper {
|
||||
#[serde(default, deserialize_with = "deserialize_created_at")]
|
||||
created_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
// RFC 3339 string (current server format).
|
||||
let w: Wrapper =
|
||||
serde_json::from_str(r#"{"created_at": "2026-06-18T21:37:36.637Z"}"#).unwrap();
|
||||
assert_eq!(
|
||||
w.created_at,
|
||||
Some("2026-06-18T21:37:36.637Z".parse::<DateTime<Utc>>().unwrap())
|
||||
);
|
||||
|
||||
// Unix milliseconds (legacy server format).
|
||||
let w: Wrapper = serde_json::from_str(r#"{"created_at": 1700000000000}"#).unwrap();
|
||||
assert_eq!(w.created_at, DateTime::from_timestamp_millis(1700000000000));
|
||||
|
||||
// Null and missing both yield None.
|
||||
let w: Wrapper = serde_json::from_str(r#"{"created_at": null}"#).unwrap();
|
||||
assert_eq!(w.created_at, None);
|
||||
let w: Wrapper = serde_json::from_str(r#"{}"#).unwrap();
|
||||
assert_eq!(w.created_at, None);
|
||||
|
||||
// A malformed string is rejected rather than silently dropped to None.
|
||||
assert!(serde_json::from_str::<Wrapper>(r#"{"created_at": "not-a-date"}"#).is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_versions() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
//! LanceDB Table APIs
|
||||
|
||||
use arrow_array::{LargeBinaryArray, RecordBatch, RecordBatchReader};
|
||||
use arrow_array::{RecordBatch, RecordBatchReader};
|
||||
use arrow_schema::{Schema, SchemaRef};
|
||||
use async_trait::async_trait;
|
||||
use datafusion_execution::TaskContext;
|
||||
@@ -12,7 +12,6 @@ use datafusion_physical_plan::ExecutionPlan;
|
||||
use datafusion_physical_plan::display::DisplayableExecutionPlan;
|
||||
use futures::StreamExt;
|
||||
use futures::stream::FuturesUnordered;
|
||||
use lance::dataset::BlobFile;
|
||||
pub use lance::dataset::ColumnAlteration;
|
||||
pub use lance::dataset::NewColumnTransform;
|
||||
pub use lance::dataset::ReadParams;
|
||||
@@ -44,7 +43,6 @@ use crate::connection::NamespaceClientPushdownOperation;
|
||||
|
||||
use crate::data::scannable::{PeekedScannable, Scannable, estimate_write_partitions};
|
||||
use crate::database::Database;
|
||||
use crate::database::read_freshness::TableFreshness;
|
||||
use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MemoryRegistry};
|
||||
use crate::error::{Error, Result};
|
||||
use crate::index::IndexStatistics;
|
||||
@@ -471,33 +469,6 @@ impl LsmWriteSpec {
|
||||
}
|
||||
}
|
||||
|
||||
/// Request to fill existing table columns from an external source by
|
||||
/// primary-key join (Geneva `Table.load_columns()` parity). Server-backed
|
||||
/// feature (LanceDB Enterprise / Cloud).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LoadColumnsRequest {
|
||||
/// External source URIs.
|
||||
pub source_uris: Vec<String>,
|
||||
/// Source format: "parquet" | "lance" | "ipc".
|
||||
pub source_format: String,
|
||||
/// Source-only storage options (e.g. cloud credentials).
|
||||
pub source_storage_options: Option<HashMap<String, String>>,
|
||||
/// Destination primary-key column.
|
||||
pub target_key: String,
|
||||
/// Source primary-key column. Defaults to `target_key` when None.
|
||||
pub source_key: Option<String>,
|
||||
/// Value column mappings as `(target, source)`; a None source defaults to
|
||||
/// the target name.
|
||||
pub columns: Vec<(String, Option<String>)>,
|
||||
/// Missing-row policy: "carry" (default) | "null" | "error".
|
||||
pub on_missing: Option<String>,
|
||||
pub num_workers: Option<u32>,
|
||||
pub max_workers: Option<u32>,
|
||||
pub batch_size: Option<u32>,
|
||||
pub commit_granularity: Option<u32>,
|
||||
pub priority: Option<String>,
|
||||
}
|
||||
|
||||
/// A trait for anything "table-like". This is used for both native tables (which target
|
||||
/// Lance datasets) and remote tables (which target LanceDB cloud)
|
||||
///
|
||||
@@ -615,28 +586,6 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
|
||||
async fn close_lsm_writers(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
/// Names of the blob v2 columns in this table, in declaration order.
|
||||
async fn blob_columns(&self) -> Result<Vec<String>> {
|
||||
Err(Error::NotSupported {
|
||||
message: "blob_columns is not supported on this table type".into(),
|
||||
})
|
||||
}
|
||||
/// Materialize blob bytes for the given row ids. See [`Table::fetch_blobs`].
|
||||
async fn fetch_blobs(&self, _column: &str, _row_ids: &[u64]) -> Result<LargeBinaryArray> {
|
||||
Err(Error::NotSupported {
|
||||
message: "fetch_blobs is not supported on this table type".into(),
|
||||
})
|
||||
}
|
||||
/// Open lazy blob handles for the given row ids. See [`Table::fetch_blob_files`].
|
||||
async fn fetch_blob_files(
|
||||
&self,
|
||||
_column: &str,
|
||||
_row_ids: &[u64],
|
||||
) -> Result<Vec<Option<BlobFile>>> {
|
||||
Err(Error::NotSupported {
|
||||
message: "fetch_blob_files is not supported on this table type".into(),
|
||||
})
|
||||
}
|
||||
/// Gets the table tag manager.
|
||||
async fn tags(&self) -> Result<Box<dyn Tags + '_>>;
|
||||
/// Optimize the dataset.
|
||||
@@ -647,47 +596,6 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
|
||||
transforms: NewColumnTransform,
|
||||
read_columns: Option<Vec<String>>,
|
||||
) -> Result<AddColumnsResult>;
|
||||
/// Declare computed columns bound to a registered function: each
|
||||
/// `(name, sql_type)` is added all-null with the expression stored
|
||||
/// as its binding; no compute happens here (the server's lazy
|
||||
/// detector or refresh_column fills them). Several columns map a
|
||||
/// struct-returning function's fields positionally. Server-backed
|
||||
/// feature; the default returns NotSupported.
|
||||
async fn add_computed_columns(
|
||||
&self,
|
||||
_columns: &[(String, String)],
|
||||
_expression: &str,
|
||||
) -> Result<()> {
|
||||
Err(Error::NotSupported {
|
||||
message: "computed columns are not supported by this table".into(),
|
||||
})
|
||||
}
|
||||
/// Trigger recompute of computed columns. The expression is
|
||||
/// resolved server-side from each column's stored binding; columns
|
||||
/// bound to the same struct-returning function refresh together.
|
||||
/// Returns the refresh job id. Server-backed feature (LanceDB
|
||||
/// Enterprise / Cloud); the default returns NotSupported.
|
||||
async fn refresh_column(
|
||||
&self,
|
||||
_columns: &[String],
|
||||
_where_clause: Option<String>,
|
||||
_num_workers: Option<u32>,
|
||||
_max_workers: Option<u32>,
|
||||
_batch_size: Option<u32>,
|
||||
_priority: Option<String>,
|
||||
) -> Result<String> {
|
||||
Err(Error::NotSupported {
|
||||
message: "refresh_column is not supported by this table".into(),
|
||||
})
|
||||
}
|
||||
/// Fill existing columns from an external source by primary-key join
|
||||
/// (Geneva `load_columns`). Returns the load job id. Server-backed feature;
|
||||
/// the default returns NotSupported.
|
||||
async fn load_columns(&self, _request: LoadColumnsRequest) -> Result<String> {
|
||||
Err(Error::NotSupported {
|
||||
message: "load_columns is not supported by this table".into(),
|
||||
})
|
||||
}
|
||||
/// Alter columns in the table.
|
||||
async fn alter_columns(&self, alterations: &[ColumnAlteration]) -> Result<AlterColumnsResult>;
|
||||
/// Drop columns from the table.
|
||||
@@ -1018,76 +926,6 @@ impl Table {
|
||||
self.inner.count_rows(filter.map(Filter::Sql)).await
|
||||
}
|
||||
|
||||
/// Names of the blob v2 columns in this table, in declaration order.
|
||||
///
|
||||
/// Nested blobs use dotted paths (e.g. `info.blob`). Returns
|
||||
/// [`Error::NotSupported`] on table types without blob support.
|
||||
pub async fn blob_columns(&self) -> Result<Vec<String>> {
|
||||
self.inner.blob_columns().await
|
||||
}
|
||||
|
||||
/// Materialize blob bytes for the given row ids.
|
||||
///
|
||||
/// Output matches `row_ids` in length and order. Null and zero-length rows
|
||||
/// are null. Prefer [`Self::fetch_blob_files`] for large selections.
|
||||
///
|
||||
/// ```
|
||||
/// use arrow_array::UInt64Array;
|
||||
/// use futures::TryStreamExt;
|
||||
/// use lancedb::query::{ExecutableQuery, QueryBase};
|
||||
///
|
||||
/// # use lancedb::Table;
|
||||
/// # async fn materialize(table: &Table) -> Result<(), Box<dyn std::error::Error>> {
|
||||
/// let mut stream = table.query().with_row_id().limit(10).execute().await?;
|
||||
/// while let Some(batch) = stream.try_next().await? {
|
||||
/// let row_ids = batch
|
||||
/// .column_by_name("_rowid")
|
||||
/// .unwrap()
|
||||
/// .as_any()
|
||||
/// .downcast_ref::<UInt64Array>()
|
||||
/// .unwrap();
|
||||
/// let images = table.fetch_blobs("image", row_ids.values()).await?;
|
||||
/// let _ = images;
|
||||
/// }
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// Returns [`Error::InvalidInput`] when the column does not exist or is
|
||||
/// not a blob v2 column, and [`Error::NotSupported`] on table types
|
||||
/// without blob support.
|
||||
pub async fn fetch_blobs(
|
||||
&self,
|
||||
column: impl AsRef<str>,
|
||||
row_ids: &[u64],
|
||||
) -> Result<LargeBinaryArray> {
|
||||
self.inner.fetch_blobs(column.as_ref(), row_ids).await
|
||||
}
|
||||
|
||||
/// Open lazy [`BlobFile`] handles for the given row ids.
|
||||
///
|
||||
/// Same length and order as `row_ids`. Null rows are `None`. Bytes are not
|
||||
/// read from disk until a call to [`BlobFile::read`].
|
||||
///
|
||||
/// ```
|
||||
/// # use lancedb::Table;
|
||||
/// # async fn lazy_read(table: &Table, row_ids: &[u64]) -> Result<(), Box<dyn std::error::Error>> {
|
||||
/// let handles = table.fetch_blob_files("image", row_ids).await?;
|
||||
/// if let Some(Some(first)) = handles.first() {
|
||||
/// let bytes = first.read().await?;
|
||||
/// println!("first blob is {} bytes", bytes.len());
|
||||
/// }
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub async fn fetch_blob_files(
|
||||
&self,
|
||||
column: impl AsRef<str>,
|
||||
row_ids: &[u64],
|
||||
) -> Result<Vec<Option<BlobFile>>> {
|
||||
self.inner.fetch_blob_files(column.as_ref(), row_ids).await
|
||||
}
|
||||
|
||||
/// Insert new records into this Table
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -1529,48 +1367,6 @@ impl Table {
|
||||
self.inner.add_columns(transforms, read_columns).await
|
||||
}
|
||||
|
||||
/// Declare computed columns bound to a registered function
|
||||
/// (`(name, sql_type)` pairs + a `f(args)` expression). No compute
|
||||
/// happens here. Server-backed feature.
|
||||
pub async fn add_computed_columns(
|
||||
&self,
|
||||
columns: &[(String, String)],
|
||||
expression: &str,
|
||||
) -> Result<()> {
|
||||
self.inner.add_computed_columns(columns, expression).await
|
||||
}
|
||||
|
||||
/// Trigger recompute of computed columns (REFRESH COLUMN). The
|
||||
/// expression comes from each column's stored binding; columns
|
||||
/// bound to the same struct-returning function refresh together.
|
||||
/// Returns the refresh job id. Server-backed feature.
|
||||
pub async fn refresh_column(
|
||||
&self,
|
||||
columns: &[String],
|
||||
where_clause: Option<String>,
|
||||
num_workers: Option<u32>,
|
||||
max_workers: Option<u32>,
|
||||
batch_size: Option<u32>,
|
||||
priority: Option<String>,
|
||||
) -> Result<String> {
|
||||
self.inner
|
||||
.refresh_column(
|
||||
columns,
|
||||
where_clause,
|
||||
num_workers,
|
||||
max_workers,
|
||||
batch_size,
|
||||
priority,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Fill existing columns from an external Parquet/Lance/IPC source by
|
||||
/// primary-key join (Geneva `Table.load_columns()`). Returns the job id.
|
||||
pub async fn load_columns(&self, request: LoadColumnsRequest) -> Result<String> {
|
||||
self.inner.load_columns(request).await
|
||||
}
|
||||
|
||||
/// Change a column's name or nullability.
|
||||
pub async fn alter_columns(
|
||||
&self,
|
||||
@@ -1967,8 +1763,6 @@ pub struct NativeTable {
|
||||
// Operations to push down to the namespace server.
|
||||
// pub(crate) so query.rs can access the field for server-side query execution.
|
||||
pub(crate) pushdown_operations: HashSet<NamespaceClientPushdownOperation>,
|
||||
// Read-freshness baseline; `Some` only for namespace-backed tables.
|
||||
freshness: Option<TableFreshness>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for NativeTable {
|
||||
@@ -2129,7 +1923,6 @@ impl NativeTable {
|
||||
read_consistency_interval,
|
||||
namespace_client,
|
||||
pushdown_operations,
|
||||
freshness: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2141,12 +1934,6 @@ impl NativeTable {
|
||||
self
|
||||
}
|
||||
|
||||
/// Attach the read-freshness baseline handle (namespace connections only).
|
||||
pub(crate) fn with_freshness(mut self, freshness: TableFreshness) -> Self {
|
||||
self.freshness = Some(freshness);
|
||||
self
|
||||
}
|
||||
|
||||
/// Build a sibling `NativeTable` with the same identity but a different
|
||||
/// (independent) dataset wrapper — used to hand out branch-scoped handles.
|
||||
fn with_dataset(&self, dataset: dataset::DatasetConsistencyWrapper) -> Self {
|
||||
@@ -2159,14 +1946,6 @@ impl NativeTable {
|
||||
read_consistency_interval: self.read_consistency_interval,
|
||||
namespace_client: self.namespace_client.clone(),
|
||||
pushdown_operations: self.pushdown_operations.clone(),
|
||||
freshness: self.freshness.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Bump the read-freshness baseline; no-op for non-namespace tables.
|
||||
fn bump_freshness(&self) {
|
||||
if let Some(freshness) = &self.freshness {
|
||||
freshness.bump();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2266,7 +2045,6 @@ impl NativeTable {
|
||||
read_consistency_interval,
|
||||
namespace_client: stored_namespace_client,
|
||||
pushdown_operations,
|
||||
freshness: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2356,7 +2134,6 @@ impl NativeTable {
|
||||
read_consistency_interval,
|
||||
namespace_client,
|
||||
pushdown_operations,
|
||||
freshness: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2488,7 +2265,6 @@ impl NativeTable {
|
||||
read_consistency_interval,
|
||||
namespace_client: stored_namespace_client,
|
||||
pushdown_operations,
|
||||
freshness: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2648,8 +2424,6 @@ impl BaseTable for NativeTable {
|
||||
}
|
||||
|
||||
async fn checkout_latest(&self) -> Result<()> {
|
||||
// Bump before resolving "latest" so that request carries the floor.
|
||||
self.bump_freshness();
|
||||
self.dataset.as_latest().await?;
|
||||
self.dataset.reload().await
|
||||
}
|
||||
@@ -2737,8 +2511,6 @@ impl BaseTable for NativeTable {
|
||||
debug_assert_eq!(dataset.version().version, version);
|
||||
dataset.restore().await?;
|
||||
}
|
||||
// Restore moves "latest", so bump before resolving it (as RemoteTable does).
|
||||
self.bump_freshness();
|
||||
self.dataset.as_latest().await?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -2819,13 +2591,7 @@ impl BaseTable for NativeTable {
|
||||
output.plan
|
||||
};
|
||||
|
||||
let insert_exec = Arc::new(InsertExec::new_with_tracker(
|
||||
ds_wrapper.clone(),
|
||||
ds,
|
||||
plan,
|
||||
lance_params,
|
||||
output.tracker.clone(),
|
||||
));
|
||||
let insert_exec = Arc::new(InsertExec::new(ds_wrapper.clone(), ds, plan, lance_params));
|
||||
|
||||
let tracker_for_tasks = output.tracker.clone();
|
||||
if let Some(ref t) = tracker_for_tasks {
|
||||
@@ -2858,7 +2624,6 @@ impl BaseTable for NativeTable {
|
||||
}
|
||||
|
||||
let version = ds_wrapper.get().await?.manifest().version;
|
||||
self.bump_freshness();
|
||||
Ok(AddResult { version })
|
||||
}
|
||||
|
||||
@@ -2909,9 +2674,7 @@ impl BaseTable for NativeTable {
|
||||
|
||||
async fn update(&self, update: UpdateBuilder) -> Result<UpdateResult> {
|
||||
// Delegate to the submodule implementation
|
||||
let result = update::execute_update(self, update).await?;
|
||||
self.bump_freshness();
|
||||
Ok(result)
|
||||
update::execute_update(self, update).await
|
||||
}
|
||||
|
||||
async fn create_plan(
|
||||
@@ -2943,9 +2706,7 @@ impl BaseTable for NativeTable {
|
||||
params: MergeInsertBuilder,
|
||||
new_data: Box<dyn RecordBatchReader + Send>,
|
||||
) -> Result<MergeResult> {
|
||||
let result = merge::execute_merge_insert(self, params, new_data).await?;
|
||||
self.bump_freshness();
|
||||
Ok(result)
|
||||
merge::execute_merge_insert(self, params, new_data).await
|
||||
}
|
||||
|
||||
async fn set_unenforced_primary_key(&self, columns: &[&str]) -> Result<()> {
|
||||
@@ -2964,30 +2725,9 @@ impl BaseTable for NativeTable {
|
||||
merge::lsm::close_lsm_writers(self).await
|
||||
}
|
||||
|
||||
async fn blob_columns(&self) -> Result<Vec<String>> {
|
||||
let schema = self.schema().await?;
|
||||
Ok(crate::blob::blob_column_names(schema.as_ref()))
|
||||
}
|
||||
|
||||
async fn fetch_blobs(&self, column: &str, row_ids: &[u64]) -> Result<LargeBinaryArray> {
|
||||
let dataset = self.dataset.get().await?;
|
||||
crate::blob::take_blobs_aligned(&dataset, column, row_ids).await
|
||||
}
|
||||
|
||||
async fn fetch_blob_files(
|
||||
&self,
|
||||
column: &str,
|
||||
row_ids: &[u64],
|
||||
) -> Result<Vec<Option<BlobFile>>> {
|
||||
let dataset = self.dataset.get().await?;
|
||||
crate::blob::take_blob_files_aligned(&dataset, column, row_ids).await
|
||||
}
|
||||
|
||||
/// Delete rows from the table
|
||||
async fn delete(&self, predicate: Predicate<'_>) -> Result<DeleteResult> {
|
||||
let result = delete::execute_delete(self, predicate).await?;
|
||||
self.bump_freshness();
|
||||
Ok(result)
|
||||
delete::execute_delete(self, predicate).await
|
||||
}
|
||||
|
||||
async fn tags(&self) -> Result<Box<dyn Tags + '_>> {
|
||||
@@ -3006,30 +2746,22 @@ impl BaseTable for NativeTable {
|
||||
transforms: NewColumnTransform,
|
||||
read_columns: Option<Vec<String>>,
|
||||
) -> Result<AddColumnsResult> {
|
||||
let result = schema_evolution::execute_add_columns(self, transforms, read_columns).await?;
|
||||
self.bump_freshness();
|
||||
Ok(result)
|
||||
schema_evolution::execute_add_columns(self, transforms, read_columns).await
|
||||
}
|
||||
|
||||
async fn alter_columns(&self, alterations: &[ColumnAlteration]) -> Result<AlterColumnsResult> {
|
||||
let result = schema_evolution::execute_alter_columns(self, alterations).await?;
|
||||
self.bump_freshness();
|
||||
Ok(result)
|
||||
schema_evolution::execute_alter_columns(self, alterations).await
|
||||
}
|
||||
|
||||
async fn update_field_metadata(
|
||||
&self,
|
||||
updates: &[FieldMetadataUpdate],
|
||||
) -> Result<UpdateFieldMetadataResult> {
|
||||
let result = schema_evolution::execute_update_field_metadata(self, updates).await?;
|
||||
self.bump_freshness();
|
||||
Ok(result)
|
||||
schema_evolution::execute_update_field_metadata(self, updates).await
|
||||
}
|
||||
|
||||
async fn drop_columns(&self, columns: &[&str]) -> Result<DropColumnsResult> {
|
||||
let result = schema_evolution::execute_drop_columns(self, columns).await?;
|
||||
self.bump_freshness();
|
||||
Ok(result)
|
||||
schema_evolution::execute_drop_columns(self, columns).await
|
||||
}
|
||||
|
||||
async fn list_indices(&self) -> Result<Vec<IndexConfig>> {
|
||||
|
||||
@@ -26,9 +26,6 @@ pub enum AddDataMode {
|
||||
#[default]
|
||||
Append,
|
||||
/// The existing table will be overwritten with the new data
|
||||
///
|
||||
/// On overwrite, raw binary is not coerced into a blob struct. The input
|
||||
/// must declare blob v2 for the column to stay a blob column.
|
||||
Overwrite,
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
//! This module contains adapters to allow LanceDB tables to be used as DataFusion table providers.
|
||||
|
||||
mod blob_coerce;
|
||||
pub mod cast;
|
||||
pub mod insert;
|
||||
pub mod reject_nan;
|
||||
|
||||
@@ -1,495 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! Coerces write-path input into blob v2 struct columns.
|
||||
//!
|
||||
//! [`super::cast::cast_to_table_schema`] calls [`coerce_blob_expr`].
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_schema::{DataType, Field, FieldRef};
|
||||
use datafusion::functions::core::{get_field, named_struct};
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
use datafusion_physical_expr::ScalarFunctionExpr;
|
||||
use datafusion_physical_expr::expressions::{CastExpr, Literal};
|
||||
use datafusion_physical_plan::PhysicalExpr;
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
|
||||
/// Build a projection expression coercing `input_expr` into the blob struct
|
||||
/// declared by `table_field`, composing `named_struct` / `get_field` / `cast`.
|
||||
pub(super) fn coerce_blob_expr(
|
||||
input_expr: Arc<dyn PhysicalExpr>,
|
||||
input_field: &Field,
|
||||
table_field: &FieldRef,
|
||||
config: &Arc<ConfigOptions>,
|
||||
) -> Result<(Arc<dyn PhysicalExpr>, FieldRef)> {
|
||||
let DataType::Struct(declared_fields) = table_field.data_type() else {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"blob v2 column '{}' must be a struct, table declares {}",
|
||||
table_field.name(),
|
||||
table_field.data_type()
|
||||
),
|
||||
});
|
||||
};
|
||||
|
||||
let input_struct_children = match input_field.data_type() {
|
||||
DataType::Binary | DataType::LargeBinary | DataType::BinaryView => None,
|
||||
DataType::Struct(children) => {
|
||||
if !children
|
||||
.iter()
|
||||
.any(|c| c.name() == "data" || c.name() == "uri")
|
||||
{
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"blob struct input for column '{}' must contain a 'data' or 'uri' child",
|
||||
table_field.name()
|
||||
),
|
||||
});
|
||||
}
|
||||
Some(children)
|
||||
}
|
||||
other => {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"cannot coerce column '{}' with type {} into a blob v2 struct. \
|
||||
expected Binary, LargeBinary, BinaryView, or a Struct with a 'data' or 'uri' child",
|
||||
table_field.name(),
|
||||
other,
|
||||
),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let mut ns_args: Vec<Arc<dyn PhysicalExpr>> = Vec::with_capacity(declared_fields.len() * 2);
|
||||
for declared in declared_fields.iter() {
|
||||
ns_args.push(Arc::new(Literal::new(ScalarValue::from(
|
||||
declared.name().as_str(),
|
||||
))));
|
||||
|
||||
let value: Arc<dyn PhysicalExpr> = match input_struct_children {
|
||||
// Raw binary lands in `data` and everything else is a typed null.
|
||||
None => {
|
||||
if declared.name() == "data" {
|
||||
Arc::new(CastExpr::new(
|
||||
input_expr.clone(),
|
||||
declared.data_type().clone(),
|
||||
None,
|
||||
))
|
||||
} else {
|
||||
typed_null(declared.data_type())?
|
||||
}
|
||||
}
|
||||
Some(children) => match children.iter().find(|c| c.name() == declared.name()) {
|
||||
Some(child) => {
|
||||
let field_expr: Arc<dyn PhysicalExpr> = Arc::new(ScalarFunctionExpr::new(
|
||||
&format!("get_field({})", declared.name()),
|
||||
get_field(),
|
||||
vec![
|
||||
input_expr.clone(),
|
||||
Arc::new(Literal::new(ScalarValue::from(declared.name().as_str()))),
|
||||
],
|
||||
Arc::new(child.as_ref().clone()),
|
||||
config.clone(),
|
||||
));
|
||||
if child.data_type() == declared.data_type() {
|
||||
field_expr
|
||||
} else {
|
||||
Arc::new(CastExpr::new(
|
||||
field_expr,
|
||||
declared.data_type().clone(),
|
||||
None,
|
||||
))
|
||||
}
|
||||
}
|
||||
None => typed_null(declared.data_type())?,
|
||||
},
|
||||
};
|
||||
ns_args.push(value);
|
||||
}
|
||||
|
||||
let expr: Arc<dyn PhysicalExpr> = Arc::new(ScalarFunctionExpr::new(
|
||||
&format!("named_struct({})", table_field.name()),
|
||||
named_struct(),
|
||||
ns_args,
|
||||
table_field.clone(),
|
||||
config.clone(),
|
||||
));
|
||||
Ok((expr, table_field.clone()))
|
||||
}
|
||||
|
||||
fn typed_null(data_type: &DataType) -> Result<Arc<dyn PhysicalExpr>> {
|
||||
let scalar = ScalarValue::try_from(data_type).map_err(|e| Error::InvalidInput {
|
||||
message: format!("cannot build null literal for blob child type {data_type}: {e}"),
|
||||
})?;
|
||||
Ok(Arc::new(Literal::new(scalar)))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::super::cast::cast_to_table_schema;
|
||||
use super::*;
|
||||
use crate::blob::blob;
|
||||
use arrow_array::{
|
||||
Array, ArrayRef, BinaryArray, BinaryViewArray, Int32Array, Int64Array, LargeBinaryArray,
|
||||
RecordBatch, StringArray, StructArray, UInt8Array, UInt64Array,
|
||||
};
|
||||
use arrow_schema::Schema;
|
||||
use datafusion::prelude::SessionContext;
|
||||
use datafusion_catalog::MemTable;
|
||||
use datafusion_physical_plan::ExecutionPlan;
|
||||
use futures::TryStreamExt;
|
||||
use lance_arrow::FieldExt;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn wide_blob_field(name: &str) -> Field {
|
||||
Field::new(
|
||||
name,
|
||||
DataType::Struct(
|
||||
vec![
|
||||
Field::new("data", DataType::LargeBinary, true),
|
||||
Field::new("uri", DataType::Utf8, true),
|
||||
Field::new("position", DataType::UInt64, true),
|
||||
Field::new("size", DataType::UInt64, true),
|
||||
]
|
||||
.into(),
|
||||
),
|
||||
true,
|
||||
)
|
||||
.with_metadata(HashMap::from([(
|
||||
"ARROW:extension:name".to_string(),
|
||||
"lance.blob.v2".to_string(),
|
||||
)]))
|
||||
}
|
||||
|
||||
fn blob_table_schema() -> Schema {
|
||||
Schema::new(vec![
|
||||
Field::new("id", DataType::Int64, false),
|
||||
blob("image", true),
|
||||
])
|
||||
}
|
||||
|
||||
fn batch_with_image(image_field: Field, image: ArrayRef) -> RecordBatch {
|
||||
let len = image.len();
|
||||
RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![
|
||||
Field::new("id", DataType::Int64, false),
|
||||
image_field,
|
||||
])),
|
||||
vec![Arc::new(Int64Array::from_iter_values(0..len as i64)), image],
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn image_struct(batch: &RecordBatch) -> &StructArray {
|
||||
batch
|
||||
.column_by_name("image")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<StructArray>()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
async fn plan_from_batch(batch: RecordBatch) -> Arc<dyn ExecutionPlan> {
|
||||
let schema = batch.schema();
|
||||
let table = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
|
||||
let ctx = SessionContext::new();
|
||||
ctx.register_table("t", Arc::new(table)).unwrap();
|
||||
let df = ctx.table("t").await.unwrap();
|
||||
df.create_physical_plan().await.unwrap()
|
||||
}
|
||||
|
||||
async fn coerce(batch: RecordBatch, table_schema: &Schema) -> RecordBatch {
|
||||
let plan = plan_from_batch(batch).await;
|
||||
let plan = cast_to_table_schema(plan, table_schema).unwrap();
|
||||
let ctx = SessionContext::new();
|
||||
let stream = plan.execute(0, ctx.task_ctx()).unwrap();
|
||||
let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
|
||||
arrow_select::concat::concat_batches(&plan.schema(), &batches).unwrap()
|
||||
}
|
||||
|
||||
async fn coerce_err(batch: RecordBatch, table_schema: &Schema) -> Error {
|
||||
let plan = plan_from_batch(batch).await;
|
||||
cast_to_table_schema(plan, table_schema).unwrap_err()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn large_binary_coerces_to_declared_blob_struct() {
|
||||
let batch = batch_with_image(
|
||||
Field::new("image", DataType::LargeBinary, true),
|
||||
Arc::new(LargeBinaryArray::from_iter_values([b"hello".as_slice()])),
|
||||
);
|
||||
let coerced = coerce(batch, &blob_table_schema()).await;
|
||||
let image_field = coerced.schema().field_with_name("image").unwrap().clone();
|
||||
assert!(image_field.is_blob_v2());
|
||||
assert!(matches!(image_field.data_type(), DataType::Struct(_)));
|
||||
let data = image_struct(&coerced).column_by_name("data").unwrap();
|
||||
let data: &LargeBinaryArray = data.as_any().downcast_ref().unwrap();
|
||||
assert_eq!(data.value(0), b"hello");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn binary_coerces_to_declared_blob_struct() {
|
||||
let batch = batch_with_image(
|
||||
Field::new("image", DataType::Binary, true),
|
||||
Arc::new(BinaryArray::from_iter_values([b"hi".as_slice()])),
|
||||
);
|
||||
let coerced = coerce(batch, &blob_table_schema()).await;
|
||||
assert!(
|
||||
coerced
|
||||
.schema()
|
||||
.field_with_name("image")
|
||||
.unwrap()
|
||||
.is_blob_v2()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn binary_view_coerces_to_declared_blob_struct() {
|
||||
let batch = batch_with_image(
|
||||
Field::new("image", DataType::BinaryView, true),
|
||||
Arc::new(BinaryViewArray::from_iter_values([b"view".as_slice()])),
|
||||
);
|
||||
let coerced = coerce(batch, &blob_table_schema()).await;
|
||||
let data = image_struct(&coerced).column_by_name("data").unwrap();
|
||||
let data: &LargeBinaryArray = data.as_any().downcast_ref().unwrap();
|
||||
assert_eq!(data.value(0), b"view");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn binary_nulls_stay_null_after_coercion() {
|
||||
let batch = batch_with_image(
|
||||
Field::new("image", DataType::Binary, true),
|
||||
Arc::new(BinaryArray::from_iter(vec![
|
||||
Some(b"present".as_slice()),
|
||||
None,
|
||||
])),
|
||||
);
|
||||
let coerced = coerce(batch, &blob_table_schema()).await;
|
||||
let image = image_struct(&coerced);
|
||||
let data = image.column_by_name("data").unwrap();
|
||||
assert!(!data.is_null(0));
|
||||
assert!(data.is_null(1));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn binary_coerces_into_four_child_blob_layout() {
|
||||
let table_schema = Schema::new(vec![
|
||||
Field::new("id", DataType::Int64, false),
|
||||
wide_blob_field("image"),
|
||||
]);
|
||||
let batch = batch_with_image(
|
||||
Field::new("image", DataType::LargeBinary, true),
|
||||
Arc::new(LargeBinaryArray::from_iter(vec![
|
||||
Some(b"alpha".as_slice()),
|
||||
None,
|
||||
])),
|
||||
);
|
||||
let coerced = coerce(batch, &table_schema).await;
|
||||
let image = image_struct(&coerced);
|
||||
assert_eq!(
|
||||
image.num_columns(),
|
||||
4,
|
||||
"coerced struct keeps the declared layout"
|
||||
);
|
||||
assert!(image.column_by_name("position").unwrap().is_null(0));
|
||||
assert!(image.column_by_name("size").unwrap().is_null(0));
|
||||
assert!(!image.column_by_name("data").unwrap().is_null(0));
|
||||
assert!(image.column_by_name("data").unwrap().is_null(1));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn prebuilt_struct_gains_blob_field_metadata() {
|
||||
let DataType::Struct(children) = blob("image", true).data_type().clone() else {
|
||||
unreachable!("blob field is a struct")
|
||||
};
|
||||
let prebuilt = StructArray::new(
|
||||
children,
|
||||
vec![
|
||||
Arc::new(LargeBinaryArray::from_iter_values([b"prebuilt".as_slice()])),
|
||||
Arc::new(StringArray::from(vec![None::<&str>])),
|
||||
],
|
||||
None,
|
||||
);
|
||||
let batch = batch_with_image(
|
||||
Field::new("image", prebuilt.data_type().clone(), true),
|
||||
Arc::new(prebuilt),
|
||||
);
|
||||
let coerced = coerce(batch, &blob_table_schema()).await;
|
||||
assert!(
|
||||
coerced
|
||||
.schema()
|
||||
.field_with_name("image")
|
||||
.unwrap()
|
||||
.is_blob_v2()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn prebuilt_narrow_struct_widens_to_declared_layout() {
|
||||
let DataType::Struct(narrow_children) = blob("image", true).data_type().clone() else {
|
||||
unreachable!("blob field is a struct")
|
||||
};
|
||||
let prebuilt = StructArray::new(
|
||||
narrow_children,
|
||||
vec![
|
||||
Arc::new(LargeBinaryArray::from_iter_values([b"prebuilt".as_slice()])),
|
||||
Arc::new(StringArray::from(vec![None::<&str>])),
|
||||
],
|
||||
None,
|
||||
);
|
||||
let table_schema = Schema::new(vec![
|
||||
Field::new("id", DataType::Int64, false),
|
||||
wide_blob_field("image"),
|
||||
]);
|
||||
let batch = batch_with_image(
|
||||
Field::new("image", prebuilt.data_type().clone(), true),
|
||||
Arc::new(prebuilt),
|
||||
);
|
||||
let coerced = coerce(batch, &table_schema).await;
|
||||
let image = image_struct(&coerced);
|
||||
assert_eq!(image.num_columns(), 4);
|
||||
assert!(image.column_by_name("position").unwrap().is_null(0));
|
||||
assert!(image.column_by_name("size").unwrap().is_null(0));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn external_reference_struct_preserves_uri_position_and_size() {
|
||||
let prebuilt = StructArray::new(
|
||||
vec![
|
||||
Field::new("data", DataType::LargeBinary, true),
|
||||
Field::new("uri", DataType::Utf8, true),
|
||||
Field::new("position", DataType::UInt64, true),
|
||||
Field::new("size", DataType::UInt64, true),
|
||||
]
|
||||
.into(),
|
||||
vec![
|
||||
Arc::new(LargeBinaryArray::from(vec![None::<&[u8]>])) as ArrayRef,
|
||||
Arc::new(StringArray::from(vec![Some("s3://bucket/blob.bin")])) as ArrayRef,
|
||||
Arc::new(UInt64Array::from(vec![Some(7)])) as ArrayRef,
|
||||
Arc::new(UInt64Array::from(vec![Some(6)])) as ArrayRef,
|
||||
],
|
||||
None,
|
||||
);
|
||||
let table_schema = Schema::new(vec![
|
||||
Field::new("id", DataType::Int64, false),
|
||||
wide_blob_field("image"),
|
||||
]);
|
||||
let batch = batch_with_image(
|
||||
Field::new("image", prebuilt.data_type().clone(), true),
|
||||
Arc::new(prebuilt),
|
||||
);
|
||||
let coerced = coerce(batch, &table_schema).await;
|
||||
let image = image_struct(&coerced);
|
||||
|
||||
let uri: &StringArray = image
|
||||
.column_by_name("uri")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref()
|
||||
.unwrap();
|
||||
assert_eq!(uri.value(0), "s3://bucket/blob.bin");
|
||||
let position: &UInt64Array = image
|
||||
.column_by_name("position")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref()
|
||||
.unwrap();
|
||||
assert_eq!(position.value(0), 7);
|
||||
let size: &UInt64Array = image
|
||||
.column_by_name("size")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref()
|
||||
.unwrap();
|
||||
assert_eq!(size.value(0), 6);
|
||||
assert!(image.column_by_name("data").unwrap().is_null(0));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn descriptor_struct_without_value_child_is_rejected() {
|
||||
let descriptor = StructArray::new(
|
||||
vec![
|
||||
Field::new("kind", DataType::UInt8, false),
|
||||
Field::new("position", DataType::UInt64, false),
|
||||
Field::new("size", DataType::UInt64, false),
|
||||
]
|
||||
.into(),
|
||||
vec![
|
||||
Arc::new(UInt8Array::from(vec![0])),
|
||||
Arc::new(UInt64Array::from(vec![0])),
|
||||
Arc::new(UInt64Array::from(vec![0])),
|
||||
],
|
||||
None,
|
||||
);
|
||||
let batch = batch_with_image(
|
||||
Field::new("image", descriptor.data_type().clone(), true),
|
||||
Arc::new(descriptor),
|
||||
);
|
||||
let err = coerce_err(batch, &blob_table_schema()).await;
|
||||
assert!(err.to_string().contains("'data' or 'uri'"));
|
||||
assert!(err.to_string().contains("image"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unsupported_input_type_is_rejected_with_column_name() {
|
||||
let batch = batch_with_image(
|
||||
Field::new("image", DataType::Utf8, true),
|
||||
Arc::new(StringArray::from(vec!["not bytes"])),
|
||||
);
|
||||
let err = coerce_err(batch, &blob_table_schema()).await;
|
||||
assert!(matches!(err, Error::InvalidInput { .. }), "got {err:?}");
|
||||
assert!(err.to_string().contains("image"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blob_metadata_survives_cast_of_sibling_column() {
|
||||
let batch = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![
|
||||
Field::new("id", DataType::Int32, false),
|
||||
Field::new("image", DataType::LargeBinary, true),
|
||||
])),
|
||||
vec![
|
||||
Arc::new(Int32Array::from(vec![1])),
|
||||
Arc::new(LargeBinaryArray::from_iter_values([b"x".as_slice()])),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let coerced = coerce(batch, &blob_table_schema()).await;
|
||||
|
||||
let image_field = coerced.schema().field_with_name("image").unwrap().clone();
|
||||
assert!(
|
||||
image_field.is_blob_v2(),
|
||||
"expected blob marker on image field, got {:?}",
|
||||
image_field.metadata()
|
||||
);
|
||||
assert_eq!(
|
||||
coerced.schema().field_with_name("id").unwrap().data_type(),
|
||||
&DataType::Int64
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn exact_blob_input_passes_through_unchanged() {
|
||||
let DataType::Struct(children) = blob("image", true).data_type().clone() else {
|
||||
unreachable!("blob field is a struct")
|
||||
};
|
||||
let image = StructArray::new(
|
||||
children,
|
||||
vec![
|
||||
Arc::new(LargeBinaryArray::from_iter_values([b"exact".as_slice()])),
|
||||
Arc::new(StringArray::from(vec![None::<&str>])),
|
||||
],
|
||||
None,
|
||||
);
|
||||
let batch = batch_with_image(blob("image", true), Arc::new(image));
|
||||
let table_schema = blob_table_schema();
|
||||
|
||||
let input = plan_from_batch(batch).await;
|
||||
let input_ptr = Arc::as_ptr(&input);
|
||||
let plan = cast_to_table_schema(input, &table_schema).unwrap();
|
||||
assert_eq!(Arc::as_ptr(&plan), input_ptr, "no projection inserted");
|
||||
}
|
||||
}
|
||||
@@ -13,10 +13,8 @@ use datafusion_physical_expr::expressions::{CastExpr, Literal};
|
||||
use datafusion_physical_plan::expressions::Column;
|
||||
use datafusion_physical_plan::projection::ProjectionExec;
|
||||
use datafusion_physical_plan::{ExecutionPlan, PhysicalExpr};
|
||||
use lance_arrow::FieldExt;
|
||||
use lance_arrow::json::{is_arrow_json_field, is_json_field};
|
||||
|
||||
use super::blob_coerce::coerce_blob_expr;
|
||||
use crate::{Error, Result};
|
||||
|
||||
pub fn cast_to_table_schema(
|
||||
@@ -79,17 +77,6 @@ fn build_field_exprs(
|
||||
continue;
|
||||
}
|
||||
|
||||
// Blob columns accept raw binary on write; exact matches pass through below.
|
||||
if table_field.is_blob_v2() && input_field.as_ref() != table_field.as_ref() {
|
||||
result.push(coerce_blob_expr(
|
||||
input_expr,
|
||||
input_field,
|
||||
table_field,
|
||||
&config,
|
||||
)?);
|
||||
continue;
|
||||
}
|
||||
|
||||
let expr = match (input_field.data_type(), table_field.data_type()) {
|
||||
// Both are structs: recurse into sub-fields to handle subschemas and casts.
|
||||
(DataType::Struct(in_children), DataType::Struct(tbl_children))
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
//! DataFusion ExecutionPlan for inserting data into LanceDB tables.
|
||||
|
||||
use std::any::Any;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::{Arc, LazyLock, Mutex};
|
||||
|
||||
use arrow_array::{RecordBatch, UInt64Array};
|
||||
@@ -21,12 +20,11 @@ use datafusion_physical_plan::{
|
||||
use futures::TryStreamExt;
|
||||
use lance::Dataset;
|
||||
use lance::dataset::transaction::{Operation, Transaction};
|
||||
use lance::dataset::{CommitBuilder, InsertBuilder, WriteParams, WriteProgressFn};
|
||||
use lance::dataset::{CommitBuilder, InsertBuilder, WriteParams};
|
||||
use lance::io::exec::utils::InstrumentedRecordBatchStreamAdapter;
|
||||
use lance_table::format::Fragment;
|
||||
|
||||
use crate::table::dataset::DatasetConsistencyWrapper;
|
||||
use crate::table::write_progress::WriteProgressTracker;
|
||||
|
||||
pub(crate) static COUNT_SCHEMA: LazyLock<SchemaRef> = LazyLock::new(|| {
|
||||
Arc::new(ArrowSchema::new(vec![Field::new(
|
||||
@@ -83,7 +81,6 @@ pub struct InsertExec {
|
||||
dataset: Arc<Dataset>,
|
||||
input: Arc<dyn ExecutionPlan>,
|
||||
write_params: WriteParams,
|
||||
tracker: Option<Arc<WriteProgressTracker>>,
|
||||
properties: Arc<PlanProperties>,
|
||||
partial_transactions: Arc<Mutex<Vec<Transaction>>>,
|
||||
metrics: ExecutionPlanMetricsSet,
|
||||
@@ -95,16 +92,6 @@ impl InsertExec {
|
||||
dataset: Arc<Dataset>,
|
||||
input: Arc<dyn ExecutionPlan>,
|
||||
write_params: WriteParams,
|
||||
) -> Self {
|
||||
Self::new_with_tracker(ds_wrapper, dataset, input, write_params, None)
|
||||
}
|
||||
|
||||
pub(crate) fn new_with_tracker(
|
||||
ds_wrapper: DatasetConsistencyWrapper,
|
||||
dataset: Arc<Dataset>,
|
||||
input: Arc<dyn ExecutionPlan>,
|
||||
write_params: WriteParams,
|
||||
tracker: Option<Arc<WriteProgressTracker>>,
|
||||
) -> Self {
|
||||
let schema = COUNT_SCHEMA.clone();
|
||||
let num_partitions = input.output_partitioning().partition_count();
|
||||
@@ -120,7 +107,6 @@ impl InsertExec {
|
||||
dataset,
|
||||
input,
|
||||
write_params,
|
||||
tracker,
|
||||
properties: Arc::new(properties),
|
||||
partial_transactions: Arc::new(Mutex::new(Vec::with_capacity(num_partitions))),
|
||||
metrics: ExecutionPlanMetricsSet::new(),
|
||||
@@ -175,12 +161,11 @@ impl ExecutionPlan for InsertExec {
|
||||
"InsertExec requires exactly one child".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(Arc::new(Self::new_with_tracker(
|
||||
Ok(Arc::new(Self::new(
|
||||
self.ds_wrapper.clone(),
|
||||
self.dataset.clone(),
|
||||
children[0].clone(),
|
||||
self.write_params.clone(),
|
||||
self.tracker.clone(),
|
||||
)))
|
||||
}
|
||||
|
||||
@@ -191,11 +176,10 @@ impl ExecutionPlan for InsertExec {
|
||||
) -> DataFusionResult<SendableRecordBatchStream> {
|
||||
let input_stream = self.input.execute(partition, context)?;
|
||||
let dataset = self.dataset.clone();
|
||||
let mut write_params = self.write_params.clone();
|
||||
let write_params = self.write_params.clone();
|
||||
let partial_transactions = self.partial_transactions.clone();
|
||||
let total_partitions = self.input.output_partitioning().partition_count();
|
||||
let ds_wrapper = self.ds_wrapper.clone();
|
||||
let tracker = self.tracker.clone();
|
||||
|
||||
let output_bytes = MetricBuilder::new(&self.metrics).output_bytes(partition);
|
||||
let input_schema = input_stream.schema();
|
||||
@@ -211,20 +195,6 @@ impl ExecutionPlan for InsertExec {
|
||||
));
|
||||
|
||||
let stream = futures::stream::once(async move {
|
||||
if let Some(tracker) = tracker
|
||||
&& write_params.write_progress.is_none()
|
||||
{
|
||||
let last_bytes = Arc::new(AtomicU64::new(0));
|
||||
write_params.write_progress = Some(WriteProgressFn::new(move |stats| {
|
||||
let previous = last_bytes.swap(stats.bytes_written, Ordering::Relaxed);
|
||||
if stats.bytes_written > previous {
|
||||
let delta =
|
||||
usize::try_from(stats.bytes_written - previous).unwrap_or(usize::MAX);
|
||||
tracker.record_bytes(delta);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
let transaction = InsertBuilder::new(dataset.clone())
|
||||
.with_params(&write_params)
|
||||
.execute_uncommitted_stream(input_stream)
|
||||
|
||||
@@ -518,10 +518,6 @@ mod tests {
|
||||
|
||||
let wrapper = DatasetConsistencyWrapper::new_latest(ds, Some(Duration::from_millis(200)));
|
||||
|
||||
// Freeze `cached_at` on the mock clock so a slow external write below can't
|
||||
// expire the TTL before the explicit advance_by() does (flake on loaded CI).
|
||||
clock::pin();
|
||||
|
||||
// Populate the cache
|
||||
let v1 = wrapper.get().await.unwrap().version().version;
|
||||
assert_eq!(v1, 1);
|
||||
|
||||
@@ -579,45 +579,24 @@ fn array_to_f32_vec(arr: &Arc<dyn arrow_array::Array>) -> Result<Vec<f32>> {
|
||||
})
|
||||
}
|
||||
|
||||
/// Magic bytes that prefix (and suffix) the Arrow IPC *file* format.
|
||||
const ARROW_IPC_FILE_MAGIC: &[u8] = b"ARROW1";
|
||||
|
||||
/// Parse Arrow IPC response from the namespace server.
|
||||
///
|
||||
/// The server may return either the Arrow IPC *file* format or the *stream*
|
||||
/// format. REST/phalanx returns the file format (it begins with the `ARROW1`
|
||||
/// magic); reading that with a `StreamReader` fails with "failed to fill whole
|
||||
/// buffer". Detect the magic and pick the matching reader so both are handled.
|
||||
async fn parse_arrow_ipc_response(bytes: bytes::Bytes) -> Result<DatasetRecordBatchStream> {
|
||||
use arrow_ipc::reader::{FileReader, StreamReader};
|
||||
use arrow_ipc::reader::StreamReader;
|
||||
use std::io::Cursor;
|
||||
|
||||
let (schema, batches) = if bytes.starts_with(ARROW_IPC_FILE_MAGIC) {
|
||||
let reader = FileReader::try_new(Cursor::new(bytes), None).map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to parse Arrow IPC file response: {}", e),
|
||||
let cursor = Cursor::new(bytes);
|
||||
let reader = StreamReader::try_new(cursor, None).map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to parse Arrow IPC response: {}", e),
|
||||
})?;
|
||||
|
||||
// Collect all record batches
|
||||
let schema = reader.schema();
|
||||
let batches: Vec<_> = reader
|
||||
.into_iter()
|
||||
.collect::<std::result::Result<Vec<_>, _>>()
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to read Arrow IPC batches: {}", e),
|
||||
})?;
|
||||
let schema = reader.schema();
|
||||
let batches = reader
|
||||
.into_iter()
|
||||
.collect::<std::result::Result<Vec<_>, _>>()
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to read Arrow IPC file batches: {}", e),
|
||||
})?;
|
||||
(schema, batches)
|
||||
} else {
|
||||
let reader =
|
||||
StreamReader::try_new(Cursor::new(bytes), None).map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to parse Arrow IPC response: {}", e),
|
||||
})?;
|
||||
let schema = reader.schema();
|
||||
let batches = reader
|
||||
.into_iter()
|
||||
.collect::<std::result::Result<Vec<_>, _>>()
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to read Arrow IPC batches: {}", e),
|
||||
})?;
|
||||
(schema, batches)
|
||||
};
|
||||
|
||||
// Create a stream from the batches
|
||||
let stream = futures::stream::iter(batches.into_iter().map(Ok));
|
||||
@@ -645,59 +624,6 @@ mod tests {
|
||||
FixedSizeListArray::try_new_from_values(Float32Array::from(values), dimension).unwrap()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_arrow_ipc_response_handles_file_and_stream() {
|
||||
use arrow_array::{Int32Array, RecordBatch};
|
||||
use arrow_ipc::writer::{FileWriter, StreamWriter};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
|
||||
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
|
||||
let batch = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Arrow IPC *file* format -- what REST/phalanx returns. Previously this
|
||||
// failed with "failed to fill whole buffer" because we used a StreamReader.
|
||||
let mut file_buf = Vec::new();
|
||||
{
|
||||
let mut writer = FileWriter::try_new(&mut file_buf, &schema).unwrap();
|
||||
writer.write(&batch).unwrap();
|
||||
writer.finish().unwrap();
|
||||
}
|
||||
assert!(file_buf.starts_with(ARROW_IPC_FILE_MAGIC));
|
||||
let rows: usize = parse_arrow_ipc_response(bytes::Bytes::from(file_buf))
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|b| b.num_rows())
|
||||
.sum();
|
||||
assert_eq!(rows, 3);
|
||||
|
||||
// Arrow IPC *stream* format must still parse.
|
||||
let mut stream_buf = Vec::new();
|
||||
{
|
||||
let mut writer = StreamWriter::try_new(&mut stream_buf, &schema).unwrap();
|
||||
writer.write(&batch).unwrap();
|
||||
writer.finish().unwrap();
|
||||
}
|
||||
assert!(!stream_buf.starts_with(ARROW_IPC_FILE_MAGIC));
|
||||
let rows: usize = parse_arrow_ipc_response(bytes::Bytes::from(stream_buf))
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|b| b.num_rows())
|
||||
.sum();
|
||||
assert_eq!(rows, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_to_namespace_query_vector() {
|
||||
let query_vector = Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0, 4.0]));
|
||||
|
||||
@@ -142,21 +142,11 @@ impl WriteProgressTracker {
|
||||
cb(&progress);
|
||||
}
|
||||
|
||||
/// Record wire bytes from the insert layer.
|
||||
///
|
||||
/// These bytes may be IPC-encoded bytes for remote writes or bytes handed
|
||||
/// to Lance's local writer. When wire bytes are recorded, they take
|
||||
/// precedence over the in-memory Arrow bytes tracked by [`record_batch`].
|
||||
/// Record wire bytes from the insert layer (e.g. IPC-encoded bytes for
|
||||
/// remote writes). When wire bytes are recorded, they take precedence over
|
||||
/// the in-memory Arrow bytes tracked by [`record_batch`].
|
||||
pub fn record_bytes(&self, bytes: usize) {
|
||||
self.wire_bytes.fetch_add(bytes, Ordering::Relaxed);
|
||||
let mut cb = self.callback.lock().unwrap_or_else(|e| e.into_inner());
|
||||
let guard = self
|
||||
.rows_and_bytes
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
let progress = self.snapshot(guard.0, guard.1, false);
|
||||
drop(guard);
|
||||
cb(&progress);
|
||||
}
|
||||
|
||||
/// Emit the final progress callback indicating the write is complete.
|
||||
@@ -179,6 +169,8 @@ impl WriteProgressTracker {
|
||||
let wire = self.wire_bytes.load(Ordering::Relaxed);
|
||||
// Prefer wire bytes (actual I/O size) when the insert layer is
|
||||
// tracking them; fall back to in-memory Arrow size otherwise.
|
||||
// TODO: for local writes, track actual bytes written by Lance
|
||||
// instead of using in-memory Arrow size as a proxy.
|
||||
let output_bytes = if wire > 0 { wire } else { in_memory_bytes };
|
||||
WriteProgress {
|
||||
elapsed: self.start.elapsed(),
|
||||
@@ -391,54 +383,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_progress_uses_lance_write_bytes_for_local_tables() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let db = connect(dir.path().to_str().unwrap())
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap();
|
||||
let table = db
|
||||
.create_table("local_write_bytes", batch)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let new_data = record_batch!(("id", Int32, [4, 5, 6])).unwrap();
|
||||
let in_memory_bytes = new_data.get_array_memory_size();
|
||||
let final_bytes = Arc::new(AtomicUsize::new(0));
|
||||
let seen_non_memory_bytes = Arc::new(std::sync::atomic::AtomicBool::new(false));
|
||||
let final_bytes_cb = final_bytes.clone();
|
||||
let seen_non_memory_bytes_cb = seen_non_memory_bytes.clone();
|
||||
|
||||
table
|
||||
.add(new_data)
|
||||
.write_parallelism(1)
|
||||
.progress(move |p| {
|
||||
if p.output_bytes() > 0 && p.output_bytes() != in_memory_bytes {
|
||||
seen_non_memory_bytes_cb.store(true, Ordering::SeqCst);
|
||||
}
|
||||
if p.done() {
|
||||
final_bytes_cb.store(p.output_bytes(), Ordering::SeqCst);
|
||||
}
|
||||
})
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
seen_non_memory_bytes.load(Ordering::SeqCst),
|
||||
"progress should report Lance writer bytes, not only Arrow memory bytes"
|
||||
);
|
||||
assert_ne!(
|
||||
final_bytes.load(Ordering::SeqCst),
|
||||
in_memory_bytes,
|
||||
"final progress bytes should come from Lance write stats"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_record_batch_recovers_from_poisoned_callback_lock() {
|
||||
use super::{ProgressCallback, WriteProgressTracker};
|
||||
|
||||
@@ -329,15 +329,6 @@ pub mod clock {
|
||||
});
|
||||
}
|
||||
|
||||
/// Start mock time at the current instant if not already pinned.
|
||||
pub fn pin() {
|
||||
MOCK_NOW.with(|mock| {
|
||||
if mock.get().is_none() {
|
||||
mock.set(Some(Instant::now()));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn clear_mock() {
|
||||
MOCK_NOW.with(|mock| mock.set(None));
|
||||
|
||||
@@ -1,949 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::{
|
||||
Array, ArrayRef, BinaryArray, Int64Array, LargeBinaryArray, RecordBatch, StringArray,
|
||||
StructArray, UInt64Array,
|
||||
};
|
||||
use arrow_schema::{DataType, Field, Fields, Schema};
|
||||
use futures::TryStreamExt;
|
||||
use lance_encoding::version::LanceFileVersion;
|
||||
use lancedb::{
|
||||
Connection, Error, Result, Table,
|
||||
blob::blob,
|
||||
connect, connect_namespace,
|
||||
database::listing::OPT_NEW_TABLE_ENABLE_STABLE_ROW_IDS,
|
||||
query::{ExecutableQuery, QueryBase},
|
||||
table::{AddDataMode, CompactionOptions, OptimizeAction},
|
||||
};
|
||||
use tempfile::tempdir;
|
||||
|
||||
fn blob_table_schema() -> Arc<Schema> {
|
||||
Arc::new(Schema::new(vec![
|
||||
Field::new("id", DataType::Int64, false),
|
||||
blob("image", true),
|
||||
]))
|
||||
}
|
||||
|
||||
fn binary_input_batch(ids: &[i64], payloads: &[Option<&[u8]>]) -> RecordBatch {
|
||||
RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![
|
||||
Field::new("id", DataType::Int64, false),
|
||||
Field::new("image", DataType::LargeBinary, true),
|
||||
])),
|
||||
vec![
|
||||
Arc::new(Int64Array::from(ids.to_vec())),
|
||||
Arc::new(LargeBinaryArray::from_iter(payloads.iter().copied())),
|
||||
],
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
async fn create_inline_blob_table(
|
||||
db: &Connection,
|
||||
name: &str,
|
||||
ids: &[i64],
|
||||
payloads: &[Option<&[u8]>],
|
||||
) -> Result<Table> {
|
||||
let table = db
|
||||
.create_empty_table(name, blob_table_schema())
|
||||
.execute()
|
||||
.await?;
|
||||
table
|
||||
.add(binary_input_batch(ids, payloads))
|
||||
.execute()
|
||||
.await?;
|
||||
Ok(table)
|
||||
}
|
||||
|
||||
async fn storage_format_version(table: &Table) -> LanceFileVersion {
|
||||
table
|
||||
.as_native()
|
||||
.unwrap()
|
||||
.manifest()
|
||||
.await
|
||||
.unwrap()
|
||||
.data_storage_format
|
||||
.lance_file_version()
|
||||
.unwrap()
|
||||
.resolve()
|
||||
}
|
||||
|
||||
async fn uses_stable_row_ids(table: &Table) -> bool {
|
||||
table
|
||||
.as_native()
|
||||
.unwrap()
|
||||
.manifest()
|
||||
.await
|
||||
.unwrap()
|
||||
.uses_stable_row_ids()
|
||||
}
|
||||
|
||||
async fn query_image_struct(table: &Table) -> StructArray {
|
||||
let batches = table
|
||||
.query()
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
let batch = arrow_select::concat::concat_batches(&batches[0].schema(), &batches).unwrap();
|
||||
batch
|
||||
.column_by_name("image")
|
||||
.expect("image column present")
|
||||
.as_any()
|
||||
.downcast_ref::<StructArray>()
|
||||
.expect("image column is a descriptor struct")
|
||||
.clone()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn declaring_blob_column_bumps_format_and_enables_stable_row_ids() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
let table = db
|
||||
.create_empty_table("t", blob_table_schema())
|
||||
.execute()
|
||||
.await?;
|
||||
|
||||
assert!(storage_format_version(&table).await >= LanceFileVersion::V2_2);
|
||||
assert!(uses_stable_row_ids(&table).await);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn explicit_stable_row_id_setting_wins_over_blob_default() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
let table = db
|
||||
.create_empty_table("t", blob_table_schema())
|
||||
.storage_option(OPT_NEW_TABLE_ENABLE_STABLE_ROW_IDS, "false")
|
||||
.execute()
|
||||
.await?;
|
||||
|
||||
assert!(storage_format_version(&table).await >= LanceFileVersion::V2_2);
|
||||
assert!(!uses_stable_row_ids(&table).await);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn non_blob_table_keeps_default_format_and_row_id_setting() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
|
||||
let table = db.create_empty_table("t", schema).execute().await?;
|
||||
|
||||
assert!(storage_format_version(&table).await < LanceFileVersion::V2_2);
|
||||
assert!(!uses_stable_row_ids(&table).await);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn creating_with_blob_data_bumps_format() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
|
||||
let blob_field = blob("image", true);
|
||||
let DataType::Struct(children) = blob_field.data_type().clone() else {
|
||||
unreachable!("blob field is a struct")
|
||||
};
|
||||
let image = StructArray::new(
|
||||
children,
|
||||
vec![
|
||||
Arc::new(LargeBinaryArray::from_iter_values([b"payload".as_slice()])),
|
||||
Arc::new(StringArray::from(vec![None::<&str>])),
|
||||
],
|
||||
None,
|
||||
);
|
||||
let batch = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![
|
||||
Field::new("id", DataType::Int64, false),
|
||||
blob_field,
|
||||
])),
|
||||
vec![Arc::new(Int64Array::from(vec![1])), Arc::new(image)],
|
||||
)
|
||||
.unwrap();
|
||||
let table = db.create_table("t", batch).execute().await?;
|
||||
|
||||
assert!(storage_format_version(&table).await >= LanceFileVersion::V2_2);
|
||||
assert!(uses_stable_row_ids(&table).await);
|
||||
assert_eq!(table.count_rows(None).await?, 1);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn add_coerces_large_binary_into_blob_column() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
let table =
|
||||
create_inline_blob_table(&db, "t", &[1, 2], &[Some(b"cat".as_slice()), Some(b"dog")])
|
||||
.await?;
|
||||
|
||||
assert_eq!(table.count_rows(None).await?, 2);
|
||||
let image = query_image_struct(&table).await;
|
||||
assert_eq!(image.len(), 2);
|
||||
let schema = table.schema().await?;
|
||||
let field = schema.field_with_name("image").unwrap();
|
||||
assert_eq!(
|
||||
field
|
||||
.metadata()
|
||||
.get("ARROW:extension:name")
|
||||
.map(String::as_str),
|
||||
Some("lance.blob.v2")
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn add_coerces_binary_into_blob_column() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
let table = db
|
||||
.create_empty_table("t", blob_table_schema())
|
||||
.execute()
|
||||
.await?;
|
||||
|
||||
let batch = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![
|
||||
Field::new("id", DataType::Int64, false),
|
||||
Field::new("image", DataType::Binary, true),
|
||||
])),
|
||||
vec![
|
||||
Arc::new(Int64Array::from(vec![1])),
|
||||
Arc::new(BinaryArray::from_iter_values([b"small".as_slice()])),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
table.add(batch).execute().await?;
|
||||
|
||||
assert_eq!(table.count_rows(None).await?, 1);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn add_accepts_null_blob_rows() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
let table = create_inline_blob_table(
|
||||
&db,
|
||||
"t",
|
||||
&[1, 2, 3],
|
||||
&[Some(b"first".as_slice()), None, Some(b"third")],
|
||||
)
|
||||
.await?;
|
||||
|
||||
assert_eq!(table.count_rows(None).await?, 3);
|
||||
let image = query_image_struct(&table).await;
|
||||
assert_eq!(image.len(), 3);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn add_rejects_uncoercible_blob_input() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
let table = db
|
||||
.create_empty_table("t", blob_table_schema())
|
||||
.execute()
|
||||
.await?;
|
||||
|
||||
let batch = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![
|
||||
Field::new("id", DataType::Int64, false),
|
||||
Field::new("image", DataType::Utf8, true),
|
||||
])),
|
||||
vec![
|
||||
Arc::new(Int64Array::from(vec![1])),
|
||||
Arc::new(StringArray::from(vec!["not bytes"])),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let err = table.add(batch).execute().await.unwrap_err();
|
||||
assert!(err.to_string().contains("image"));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connection_level_stable_row_id_setting_wins_over_blob_default() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap())
|
||||
.storage_option(OPT_NEW_TABLE_ENABLE_STABLE_ROW_IDS, "false")
|
||||
.execute()
|
||||
.await?;
|
||||
let table = db
|
||||
.create_empty_table("t", blob_table_schema())
|
||||
.execute()
|
||||
.await?;
|
||||
|
||||
assert!(storage_format_version(&table).await >= LanceFileVersion::V2_2);
|
||||
assert!(!uses_stable_row_ids(&table).await);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn namespace_create_applies_blob_defaults() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let mut properties = std::collections::HashMap::new();
|
||||
properties.insert("root".to_string(), tmp.path().to_str().unwrap().to_string());
|
||||
let db = connect_namespace("dir", properties).execute().await?;
|
||||
let table = db
|
||||
.create_empty_table("t", blob_table_schema())
|
||||
.execute()
|
||||
.await?;
|
||||
|
||||
assert!(storage_format_version(&table).await >= LanceFileVersion::V2_2);
|
||||
assert!(uses_stable_row_ids(&table).await);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Overwrite takes the input schema as-is. A raw-binary overwrite drops the blob
|
||||
// marker; re-declaring blob v2 in the input restores it.
|
||||
#[tokio::test]
|
||||
async fn overwrite_replaces_blob_schema_with_input_schema() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
let table = create_inline_blob_table(&db, "t", &[1], &[Some(b"blob".as_slice())]).await?;
|
||||
|
||||
let raw_schema = Arc::new(Schema::new(vec![
|
||||
Field::new("id", DataType::Int64, false),
|
||||
Field::new("image", DataType::LargeBinary, true),
|
||||
]));
|
||||
let raw_batch = RecordBatch::try_new(
|
||||
raw_schema.clone(),
|
||||
vec![
|
||||
Arc::new(Int64Array::from(vec![2])),
|
||||
Arc::new(LargeBinaryArray::from_iter_values([b"plain".as_slice()])),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
table
|
||||
.add(raw_batch)
|
||||
.mode(AddDataMode::Overwrite)
|
||||
.execute()
|
||||
.await?;
|
||||
let schema = table.schema().await?;
|
||||
assert_eq!(schema, raw_schema);
|
||||
assert!(
|
||||
!schema
|
||||
.field_with_name("image")
|
||||
.unwrap()
|
||||
.metadata()
|
||||
.contains_key("ARROW:extension:name")
|
||||
);
|
||||
|
||||
let blob_field = blob("image", true);
|
||||
let DataType::Struct(children) = blob_field.data_type().clone() else {
|
||||
unreachable!("blob field is a struct")
|
||||
};
|
||||
let image = StructArray::new(
|
||||
children,
|
||||
vec![
|
||||
Arc::new(LargeBinaryArray::from_iter_values([b"declared".as_slice()])),
|
||||
Arc::new(StringArray::from(vec![None::<&str>])),
|
||||
],
|
||||
None,
|
||||
);
|
||||
let declared_batch = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![
|
||||
Field::new("id", DataType::Int64, false),
|
||||
blob_field,
|
||||
])),
|
||||
vec![Arc::new(Int64Array::from(vec![3])), Arc::new(image)],
|
||||
)
|
||||
.unwrap();
|
||||
table
|
||||
.add(declared_batch)
|
||||
.mode(AddDataMode::Overwrite)
|
||||
.execute()
|
||||
.await?;
|
||||
let schema = table.schema().await?;
|
||||
assert_eq!(
|
||||
schema
|
||||
.field_with_name("image")
|
||||
.unwrap()
|
||||
.metadata()
|
||||
.get("ARROW:extension:name")
|
||||
.map(String::as_str),
|
||||
Some("lance.blob.v2")
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn collect_row_ids(table: &Table) -> Result<Vec<u64>> {
|
||||
let batches = table
|
||||
.query()
|
||||
.with_row_id()
|
||||
.execute()
|
||||
.await?
|
||||
.try_collect::<Vec<_>>()
|
||||
.await?;
|
||||
let batch = arrow_select::concat::concat_batches(&batches[0].schema(), &batches).unwrap();
|
||||
Ok(batch
|
||||
.column_by_name("_rowid")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<UInt64Array>()
|
||||
.unwrap()
|
||||
.values()
|
||||
.to_vec())
|
||||
}
|
||||
|
||||
async fn collect_id_rowid(table: &Table) -> Result<Vec<(i64, u64)>> {
|
||||
let batches = table
|
||||
.query()
|
||||
.with_row_id()
|
||||
.execute()
|
||||
.await?
|
||||
.try_collect::<Vec<_>>()
|
||||
.await?;
|
||||
let batch = arrow_select::concat::concat_batches(&batches[0].schema(), &batches).unwrap();
|
||||
let ids = batch
|
||||
.column_by_name("id")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<Int64Array>()
|
||||
.unwrap();
|
||||
let row_ids = batch
|
||||
.column_by_name("_rowid")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<UInt64Array>()
|
||||
.unwrap();
|
||||
Ok(ids
|
||||
.values()
|
||||
.iter()
|
||||
.copied()
|
||||
.zip(row_ids.values().iter().copied())
|
||||
.collect())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fetch_blobs_round_trips_bytes() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
let payload: &[u8] = b"blob-round-trip-payload";
|
||||
let table = create_inline_blob_table(&db, "t", &[1], &[Some(payload)]).await?;
|
||||
|
||||
let ids = collect_row_ids(&table).await?;
|
||||
let bytes = table.fetch_blobs("image", &ids).await?;
|
||||
assert_eq!(bytes.len(), 1);
|
||||
assert_eq!(bytes.value(0), payload);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fetch_blobs_round_trips_nested_blob_column() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
|
||||
let blob_field = blob("blob", true);
|
||||
let DataType::Struct(blob_children) = blob_field.data_type().clone() else {
|
||||
unreachable!("blob field is a struct")
|
||||
};
|
||||
let blob_array = StructArray::new(
|
||||
blob_children,
|
||||
vec![
|
||||
Arc::new(LargeBinaryArray::from_iter_values([
|
||||
b"hello".as_slice(),
|
||||
b"world".as_slice(),
|
||||
])) as ArrayRef,
|
||||
Arc::new(StringArray::from(vec![None::<&str>, None::<&str>])) as ArrayRef,
|
||||
],
|
||||
None,
|
||||
);
|
||||
let info_fields: Fields = vec![Field::new("name", DataType::Utf8, false), blob_field].into();
|
||||
let info_array = StructArray::new(
|
||||
info_fields.clone(),
|
||||
vec![
|
||||
Arc::new(StringArray::from(vec!["a", "b"])) as ArrayRef,
|
||||
Arc::new(blob_array) as ArrayRef,
|
||||
],
|
||||
None,
|
||||
);
|
||||
let schema = Arc::new(Schema::new(vec![Field::new(
|
||||
"info",
|
||||
DataType::Struct(info_fields),
|
||||
true,
|
||||
)]));
|
||||
let batch = RecordBatch::try_new(schema, vec![Arc::new(info_array) as ArrayRef]).unwrap();
|
||||
let table = db.create_table("t", batch).execute().await?;
|
||||
|
||||
assert!(storage_format_version(&table).await >= LanceFileVersion::V2_2);
|
||||
assert!(uses_stable_row_ids(&table).await);
|
||||
|
||||
let ids = collect_row_ids(&table).await?;
|
||||
let bytes = table.fetch_blobs("info.blob", &ids).await?;
|
||||
assert_eq!(bytes.len(), 2);
|
||||
let values: std::collections::HashSet<&[u8]> =
|
||||
(0..bytes.len()).map(|i| bytes.value(i)).collect();
|
||||
assert!(values.contains(b"hello".as_slice()));
|
||||
assert!(values.contains(b"world".as_slice()));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blob_columns_lists_nested_dotted_paths() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
let blob_field = blob("blob", true);
|
||||
let info = Field::new(
|
||||
"info",
|
||||
DataType::Struct(vec![Field::new("name", DataType::Utf8, false), blob_field].into()),
|
||||
true,
|
||||
);
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
blob("thumbnail", true),
|
||||
Field::new("id", DataType::Int64, false),
|
||||
info,
|
||||
]));
|
||||
let table = db.create_empty_table("t", schema).execute().await?;
|
||||
assert_eq!(table.blob_columns().await?, vec!["thumbnail", "info.blob"]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blob_columns_lists_blob_fields_in_order() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
blob("thumbnail", true),
|
||||
Field::new("id", DataType::Int64, false),
|
||||
blob("image", true),
|
||||
]));
|
||||
let table = db.create_empty_table("t", schema).execute().await?;
|
||||
assert_eq!(table.blob_columns().await?, vec!["thumbnail", "image"]);
|
||||
|
||||
let plain = db
|
||||
.create_empty_table(
|
||||
"plain",
|
||||
Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)])),
|
||||
)
|
||||
.execute()
|
||||
.await?;
|
||||
assert!(plain.blob_columns().await?.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fetch_blobs_preserves_null_alignment() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
let table = create_inline_blob_table(
|
||||
&db,
|
||||
"t",
|
||||
&[1, 2, 3, 4],
|
||||
&[Some(b"a".as_slice()), None, Some(b"c"), None],
|
||||
)
|
||||
.await?;
|
||||
|
||||
let pairs = collect_id_rowid(&table).await?;
|
||||
let ids: Vec<u64> = pairs.iter().map(|(_, rowid)| *rowid).collect();
|
||||
let bytes = table.fetch_blobs("image", &ids).await?;
|
||||
assert_eq!(bytes.len(), ids.len());
|
||||
for (i, (id, _)) in pairs.iter().enumerate() {
|
||||
match id {
|
||||
1 => assert_eq!(bytes.value(i), b"a"),
|
||||
2 | 4 => assert!(bytes.is_null(i)),
|
||||
3 => assert_eq!(bytes.value(i), b"c"),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fetch_blobs_all_null_column_returns_all_nulls() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
let table = create_inline_blob_table(&db, "t", &[1, 2], &[None, None]).await?;
|
||||
|
||||
let ids = collect_row_ids(&table).await?;
|
||||
let bytes = table.fetch_blobs("image", &ids).await?;
|
||||
assert_eq!(bytes.len(), 2);
|
||||
assert_eq!(bytes.null_count(), 2);
|
||||
|
||||
let files = table.fetch_blob_files("image", &ids).await?;
|
||||
assert_eq!(files.len(), 2);
|
||||
assert!(files.iter().all(Option::is_none));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fetch_blobs_aligns_with_reordered_and_duplicate_ids() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
let table = create_inline_blob_table(
|
||||
&db,
|
||||
"t",
|
||||
&[1, 2, 3],
|
||||
&[Some(b"one".as_slice()), Some(b"two"), Some(b"three")],
|
||||
)
|
||||
.await?;
|
||||
|
||||
let pairs = collect_id_rowid(&table).await?;
|
||||
let by_id = |want: i64| pairs.iter().find(|(id, _)| *id == want).unwrap().1;
|
||||
let request = vec![by_id(3), by_id(1), by_id(3), by_id(2)];
|
||||
let bytes = table.fetch_blobs("image", &request).await?;
|
||||
assert_eq!(bytes.len(), 4);
|
||||
assert_eq!(bytes.value(0), b"three");
|
||||
assert_eq!(bytes.value(1), b"one");
|
||||
assert_eq!(bytes.value(2), b"three");
|
||||
assert_eq!(bytes.value(3), b"two");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fetch_blobs_empty_ids_returns_empty() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
let table = create_inline_blob_table(&db, "t", &[1], &[Some(b"x".as_slice())]).await?;
|
||||
|
||||
assert_eq!(table.fetch_blobs("image", &[]).await?.len(), 0);
|
||||
assert!(table.fetch_blob_files("image", &[]).await?.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fetch_blobs_out_of_range_id_errors_without_panic() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
let table = create_inline_blob_table(&db, "t", &[1], &[Some(b"x".as_slice())]).await?;
|
||||
|
||||
let err = table.fetch_blobs("image", &[u64::MAX]).await.unwrap_err();
|
||||
assert!(err.to_string().contains("row ids"));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fetch_blobs_rejects_non_blob_column() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
let table = create_inline_blob_table(&db, "t", &[1], &[Some(b"x".as_slice())]).await?;
|
||||
|
||||
let err = table.fetch_blobs("id", &[0]).await.unwrap_err();
|
||||
assert!(matches!(err, Error::InvalidInput { .. }));
|
||||
assert!(err.to_string().contains("'id' is not a blob column"));
|
||||
|
||||
let err = table.fetch_blob_files("id", &[0]).await.unwrap_err();
|
||||
assert!(err.to_string().contains("'id' is not a blob column"));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fetch_blobs_rejects_unknown_column() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
let table = create_inline_blob_table(&db, "t", &[1], &[Some(b"x".as_slice())]).await?;
|
||||
|
||||
let err = table.fetch_blobs("missing", &[0]).await.unwrap_err();
|
||||
assert!(err.to_string().contains("no column named 'missing'"));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fetch_blobs_rejects_legacy_v1_blob_column() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
let legacy = Field::new("image", DataType::LargeBinary, true).with_metadata(
|
||||
std::collections::HashMap::from([("lance-encoding:blob".to_string(), "true".to_string())]),
|
||||
);
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("id", DataType::Int64, false),
|
||||
legacy,
|
||||
]));
|
||||
let table = db.create_empty_table("t", schema).execute().await?;
|
||||
|
||||
let err = table.fetch_blobs("image", &[0]).await.unwrap_err();
|
||||
assert!(err.to_string().contains("legacy blob column"));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fetch_blob_files_reads_lazily_and_aligns_nulls() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
let table =
|
||||
create_inline_blob_table(&db, "t", &[1, 2], &[Some(b"lazy-bytes".as_slice()), None])
|
||||
.await?;
|
||||
|
||||
let pairs = collect_id_rowid(&table).await?;
|
||||
let ids: Vec<u64> = pairs.iter().map(|(_, rowid)| *rowid).collect();
|
||||
let files = table.fetch_blob_files("image", &ids).await?;
|
||||
assert_eq!(files.len(), 2);
|
||||
for ((id, _), file) in pairs.iter().zip(&files) {
|
||||
match id {
|
||||
1 => {
|
||||
let handle = file.as_ref().unwrap();
|
||||
assert_eq!(handle.read().await.unwrap().as_ref(), b"lazy-bytes");
|
||||
}
|
||||
2 => assert!(file.is_none()),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fetch_blobs_reads_multiple_blob_columns_independently() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("id", DataType::Int64, false),
|
||||
blob("image", true),
|
||||
blob("thumbnail", true),
|
||||
]));
|
||||
let table = db.create_empty_table("t", schema).execute().await?;
|
||||
let batch = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![
|
||||
Field::new("id", DataType::Int64, false),
|
||||
Field::new("image", DataType::LargeBinary, true),
|
||||
Field::new("thumbnail", DataType::LargeBinary, true),
|
||||
])),
|
||||
vec![
|
||||
Arc::new(Int64Array::from(vec![1, 2])),
|
||||
Arc::new(LargeBinaryArray::from_iter(vec![
|
||||
Some(b"image-1".as_slice()),
|
||||
None,
|
||||
])),
|
||||
Arc::new(LargeBinaryArray::from_iter(vec![
|
||||
None,
|
||||
Some(b"thumb-2".as_slice()),
|
||||
])),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
table.add(batch).execute().await?;
|
||||
|
||||
let pairs = collect_id_rowid(&table).await?;
|
||||
let ids: Vec<u64> = pairs.iter().map(|(_, rowid)| *rowid).collect();
|
||||
let images = table.fetch_blobs("image", &ids).await?;
|
||||
let thumbs = table.fetch_blobs("thumbnail", &ids).await?;
|
||||
for (i, (id, _)) in pairs.iter().enumerate() {
|
||||
match id {
|
||||
1 => {
|
||||
assert_eq!(images.value(i), b"image-1");
|
||||
assert!(thumbs.is_null(i));
|
||||
}
|
||||
2 => {
|
||||
assert!(images.is_null(i));
|
||||
assert_eq!(thumbs.value(i), b"thumb-2");
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fetch_blobs_spans_fragments() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
let table = create_inline_blob_table(&db, "t", &[1], &[Some(b"frag-one".as_slice())]).await?;
|
||||
table
|
||||
.add(binary_input_batch(&[2], &[Some(b"frag-two".as_slice())]))
|
||||
.execute()
|
||||
.await?;
|
||||
|
||||
let pairs = collect_id_rowid(&table).await?;
|
||||
let ids: Vec<u64> = pairs.iter().map(|(_, rowid)| *rowid).collect();
|
||||
let bytes = table.fetch_blobs("image", &ids).await?;
|
||||
for (i, (id, _)) in pairs.iter().enumerate() {
|
||||
match id {
|
||||
1 => assert_eq!(bytes.value(i), b"frag-one"),
|
||||
2 => assert_eq!(bytes.value(i), b"frag-two"),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fetch_blobs_packed_payload_round_trip() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
let big = vec![0xAB_u8; 100 * 1024];
|
||||
let small = b"small".to_vec();
|
||||
let table = create_inline_blob_table(
|
||||
&db,
|
||||
"t",
|
||||
&[1, 2],
|
||||
&[Some(big.as_slice()), Some(small.as_slice())],
|
||||
)
|
||||
.await?;
|
||||
|
||||
let pairs = collect_id_rowid(&table).await?;
|
||||
let ids: Vec<u64> = pairs.iter().map(|(_, rowid)| *rowid).collect();
|
||||
let bytes = table.fetch_blobs("image", &ids).await?;
|
||||
for (i, (id, _)) in pairs.iter().enumerate() {
|
||||
match id {
|
||||
1 => assert_eq!(bytes.value(i), big.as_slice()),
|
||||
2 => assert_eq!(bytes.value(i), small.as_slice()),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fetch_blobs_after_delete() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
let table = create_inline_blob_table(
|
||||
&db,
|
||||
"t",
|
||||
&[1, 2, 3],
|
||||
&[Some(b"one".as_slice()), Some(b"two"), Some(b"three")],
|
||||
)
|
||||
.await?;
|
||||
|
||||
table.delete("id = 2").await?;
|
||||
let pairs = collect_id_rowid(&table).await?;
|
||||
assert_eq!(pairs.len(), 2);
|
||||
let ids: Vec<u64> = pairs.iter().map(|(_, rowid)| *rowid).collect();
|
||||
let bytes = table.fetch_blobs("image", &ids).await?;
|
||||
for (i, (id, _)) in pairs.iter().enumerate() {
|
||||
match id {
|
||||
1 => assert_eq!(bytes.value(i), b"one"),
|
||||
3 => assert_eq!(bytes.value(i), b"three"),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fetch_blobs_with_precompaction_row_ids_survives_compaction() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
let table = create_inline_blob_table(&db, "t", &[1], &[Some(b"frag-one".as_slice())]).await?;
|
||||
table
|
||||
.add(binary_input_batch(&[2], &[Some(b"frag-two".as_slice())]))
|
||||
.execute()
|
||||
.await?;
|
||||
|
||||
let pairs_before = collect_id_rowid(&table).await?;
|
||||
let ids_before: Vec<u64> = pairs_before.iter().map(|(_, rowid)| *rowid).collect();
|
||||
|
||||
table
|
||||
.optimize(OptimizeAction::Compact {
|
||||
options: CompactionOptions::default(),
|
||||
remap_options: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let bytes_after = table.fetch_blobs("image", &ids_before).await?;
|
||||
assert_eq!(bytes_after.len(), 2);
|
||||
for (i, (id, _)) in pairs_before.iter().enumerate() {
|
||||
match id {
|
||||
1 => assert_eq!(bytes_after.value(i), b"frag-one"),
|
||||
2 => assert_eq!(bytes_after.value(i), b"frag-two"),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn zero_length_blob_reads_back_as_null() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
let table = create_inline_blob_table(&db, "t", &[1], &[Some(b"".as_slice())]).await?;
|
||||
|
||||
let ids = collect_row_ids(&table).await?;
|
||||
let bytes = table.fetch_blobs("image", &ids).await?;
|
||||
assert_eq!(bytes.len(), 1);
|
||||
assert!(bytes.is_null(0));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
const DEDICATED_BLOB_LEN: usize = 64 * 1024;
|
||||
const SCRAMBLED_LOGICAL_IDS: [i64; 7] = [6, 3, 1, 4, 6, 2, 5];
|
||||
|
||||
fn dedicated_blob_bytes(tag: u8) -> Vec<u8> {
|
||||
vec![tag; DEDICATED_BLOB_LEN]
|
||||
}
|
||||
|
||||
async fn multi_fragment_dedicated_blob_table(db: &Connection) -> Result<Table> {
|
||||
let rows: [(i64, Option<u8>); 6] = [
|
||||
(1, Some(1)),
|
||||
(2, Some(2)),
|
||||
(3, None),
|
||||
(4, Some(4)),
|
||||
(5, None),
|
||||
(6, Some(6)),
|
||||
];
|
||||
let mut table: Option<Table> = None;
|
||||
for (logical_id, blob_tag) in rows {
|
||||
let bytes = blob_tag.map(dedicated_blob_bytes);
|
||||
let image = [bytes.as_deref()];
|
||||
table = Some(match table {
|
||||
None => create_inline_blob_table(db, "t", &[logical_id], &image).await?,
|
||||
Some(t) => {
|
||||
t.add(binary_input_batch(&[logical_id], &image))
|
||||
.execute()
|
||||
.await?;
|
||||
t
|
||||
}
|
||||
});
|
||||
}
|
||||
Ok(table.unwrap())
|
||||
}
|
||||
|
||||
async fn row_ids_for_logical(table: &Table, logical_ids: &[i64]) -> Result<Vec<u64>> {
|
||||
let id_rowid = collect_id_rowid(table).await?;
|
||||
Ok(logical_ids
|
||||
.iter()
|
||||
.map(|logical_id| {
|
||||
id_rowid
|
||||
.iter()
|
||||
.find(|(id, _)| id == logical_id)
|
||||
.map(|(_, row_id)| *row_id)
|
||||
.unwrap()
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fetch_blobs_aligns_across_fragments_with_nulls_and_dups() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
let table = multi_fragment_dedicated_blob_table(&db).await?;
|
||||
let row_ids = row_ids_for_logical(&table, &SCRAMBLED_LOGICAL_IDS).await?;
|
||||
|
||||
let bytes = table.fetch_blobs("image", &row_ids).await?;
|
||||
assert_eq!(bytes.len(), SCRAMBLED_LOGICAL_IDS.len());
|
||||
for (slot, logical_id) in SCRAMBLED_LOGICAL_IDS.iter().enumerate() {
|
||||
match logical_id {
|
||||
3 | 5 => assert!(bytes.is_null(slot)),
|
||||
id => assert_eq!(
|
||||
bytes.value(slot),
|
||||
dedicated_blob_bytes(*id as u8).as_slice()
|
||||
),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fetch_blob_files_aligns_across_fragments_with_nulls_and_dups() -> Result<()> {
|
||||
let tmp = tempdir().unwrap();
|
||||
let db = connect(tmp.path().to_str().unwrap()).execute().await?;
|
||||
let table = multi_fragment_dedicated_blob_table(&db).await?;
|
||||
let row_ids = row_ids_for_logical(&table, &SCRAMBLED_LOGICAL_IDS).await?;
|
||||
|
||||
let files = table.fetch_blob_files("image", &row_ids).await?;
|
||||
assert_eq!(files.len(), SCRAMBLED_LOGICAL_IDS.len());
|
||||
for (slot, logical_id) in SCRAMBLED_LOGICAL_IDS.iter().enumerate() {
|
||||
match logical_id {
|
||||
3 | 5 => assert!(files[slot].is_none()),
|
||||
id => {
|
||||
let payload = files[slot].as_ref().unwrap().read().await?;
|
||||
assert_eq!(payload.as_ref(), dedicated_blob_bytes(*id as u8).as_slice());
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user